diff --git a/.vscode/launch.json b/.vscode/launch.json index 6d2140a..dcccfdf 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -11,9 +11,9 @@ "mode": "auto", "program": "${workspaceFolder}", "args": [ - "gen", - "route", - "/projects/atom_starter", + "swag", + "init", + "/projects/tt", ] } ] diff --git a/cmd/gen_route.go b/cmd/gen_route.go index d0df58a..b54d438 100644 --- a/cmd/gen_route.go +++ b/cmd/gen_route.go @@ -24,7 +24,7 @@ func CommandGenRoute(root *cobra.Command) { root.AddCommand(cmd) } -// https://github.com/swaggo/swag?tab=readme-ov-file#api-operation +// https://git.ipao.vip/rogeecn/atomctl/pkg/swag?tab=readme-ov-file#api-operation func commandGenRouteE(cmd *cobra.Command, args []string) error { var err error var path string diff --git a/cmd/swag_fmt.go b/cmd/swag_fmt.go index 694e5a0..8cdaaac 100644 --- a/cmd/swag_fmt.go +++ b/cmd/swag_fmt.go @@ -1,8 +1,8 @@ package cmd import ( + "git.ipao.vip/rogeecn/atomctl/pkg/swag/format" "github.com/spf13/cobra" - "github.com/swaggo/swag/format" ) func CommandSwagFmt(root *cobra.Command) { diff --git a/cmd/swag_init.go b/cmd/swag_init.go index 0af0d96..278baa4 100644 --- a/cmd/swag_init.go +++ b/cmd/swag_init.go @@ -1,10 +1,13 @@ package cmd import ( + "os" + "path/filepath" + + "git.ipao.vip/rogeecn/atomctl/pkg/swag" + "git.ipao.vip/rogeecn/atomctl/pkg/swag/gen" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "github.com/swaggo/swag" - "github.com/swaggo/swag/gen" ) func CommandSwagInit(root *cobra.Command) { @@ -19,15 +22,23 @@ func CommandSwagInit(root *cobra.Command) { } func commandSwagInitE(cmd *cobra.Command, args []string) error { + pwd, err := os.Getwd() + if err != nil { + return err + } + if len(args) > 0 { + pwd = args[0] + } + leftDelim, rightDelim := "{{", "}}" return gen.New().Build(&gen.Config{ - SearchDir: "./", + SearchDir: pwd, Excludes: "", ParseExtension: "", MainAPIFile: "main.go", PropNamingStrategy: swag.CamelCase, - OutputDir: "./docs", + OutputDir: filepath.Join(pwd, "docs"), OutputTypes: []string{"go", "json", "yaml"}, ParseVendor: false, ParseDependency: 0, diff --git a/go.mod b/go.mod index b1647c5..87425a0 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,6 @@ require ( github.com/spf13/cobra v1.8.1 github.com/spf13/viper v1.19.0 github.com/stretchr/testify v1.9.0 - github.com/swaggo/swag v1.16.4 golang.org/x/mod v0.17.0 golang.org/x/text v0.21.0 golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d diff --git a/go.sum b/go.sum index 14501ba..2946e71 100644 --- a/go.sum +++ b/go.sum @@ -237,8 +237,8 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= -github.com/swaggo/swag v1.16.4 h1:clWJtd9LStiG3VeijiCfOVODP6VpHtKdQy9ELFG3s1A= -github.com/swaggo/swag v1.16.4/go.mod h1:VBsHJRsDvfYvqoiMKnsdwhNV9LEMHgEDZcyVYX0sxPg= +git.ipao.vip/rogeecn/atomctl/pkg/swag v1.16.4 h1:clWJtd9LStiG3VeijiCfOVODP6VpHtKdQy9ELFG3s1A= +git.ipao.vip/rogeecn/atomctl/pkg/swag v1.16.4/go.mod h1:VBsHJRsDvfYvqoiMKnsdwhNV9LEMHgEDZcyVYX0sxPg= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= diff --git a/main_test.go b/main_test.go index cd98cb9..3fc58b6 100644 --- a/main_test.go +++ b/main_test.go @@ -1,64 +1,44 @@ package main import ( - "strings" + "regexp" "testing" + + . "github.com/smartystreets/goconvey/convey" ) -type ParamDefinition struct { - Name string - Type string - Key string - Table string - Model string - Position string -} +func Test_router(t *testing.T) { + routerPattern := regexp.MustCompile(`^(/[\w./\-{}\(\)+:$]*)[[:blank:]]+\[(\w+)]`) -func parseBind(bind string) ParamDefinition { - var param ParamDefinition - parts := strings.FieldsFunc(bind, func(r rune) bool { - return r == ' ' || r == '(' || r == ')' + Convey("Test routerPattern", t, func() { + Convey("Pattern 1", func() { + commentLine := "/api/v1/health [GET] # Check health status" + matches := routerPattern.FindStringSubmatch(commentLine) + t.Logf("matches: %v", matches) + }) + + Convey("Pattern 2", func() { + commentLine := "/api/v1/:health [get] " + matches := routerPattern.FindStringSubmatch(commentLine) + t.Logf("matches: %v", matches) + }) + + Convey("Pattern 3", func() { + commentLine := "/api/v1/get_users-:id [get] " + pattern := regexp.MustCompile(`<.*?>`) + commentLine = pattern.ReplaceAllString(commentLine, "") + + matches := routerPattern.FindStringSubmatch(commentLine) + t.Logf("matches: %v", matches) + }) + + Convey("Pattern 4", func() { + commentLine := "/api/v1/get_users-:id/name/:name [get] " + pattern := regexp.MustCompile(`:(\w+)(<.*?>)?`) + commentLine = pattern.ReplaceAllString(commentLine, "{$1}") + + matches := routerPattern.FindStringSubmatch(commentLine) + t.Logf("matches: %v", matches) + }) }) - - // 过滤掉空的元素 - var newParts []string - for _, part := range parts { - part = strings.TrimSpace(part) - if part != "" { - newParts = append(newParts, part) - } - } - - for i, part := range parts { - switch part { - case "@Bind": - param.Name = parts[i+1] - param.Position = parts[i+2] - case "key": - param.Key = parts[i+1] - case "table": - param.Table = parts[i+1] - case "model": - param.Model = parts[i+1] - } - } - return param -} - -func Test_T(t *testing.T) { - // @Bind [Name] [Type] [Key] [Table] [Model] - suites := []string{ - `@Bind name query key("a") table(b) model("c")`, - `@Bind id uri key(a)`, - `@Bind id uri table(b)`, - `@Bind id uri key(b) model(c)`, - `@Bind id uri key(b) table(c)`, - `@Bind id uri table(b) key(c)`, - `@Bind id uri`, - } - - for _, suite := range suites { - param := parseBind(suite) - t.Logf("Parsed Param: %+v", param) - } } diff --git a/pkg/swag/Dockerfile b/pkg/swag/Dockerfile new file mode 100644 index 0000000..5ea9134 --- /dev/null +++ b/pkg/swag/Dockerfile @@ -0,0 +1,36 @@ +# Dockerfile References: https://docs.docker.com/engine/reference/builder/ + +# Start from the latest golang base image +FROM --platform=$BUILDPLATFORM golang:1.21-alpine as builder + +# Set the Current Working Directory inside the container +WORKDIR /app + +# Copy go mod and sum files +COPY go.mod go.sum ./ + +# Download all dependencies. Dependencies will be cached if the go.mod and go.sum files are not changed +RUN go mod download + +# Copy the source from the current directory to the Working Directory inside the container +COPY . . + +# Configure go compiler target platform +ARG TARGETOS +ARG TARGETARCH +ENV GOARCH=$TARGETARCH \ + GOOS=$TARGETOS + +# Build the Go app +RUN CGO_ENABLED=0 GOOS=linux go build -v -a -installsuffix cgo -o swag cmd/swag/main.go + + +######## Start a new stage from scratch ####### +FROM --platform=$TARGETPLATFORM scratch + +WORKDIR /code/ + +# Copy the Pre-built binary file from the previous stage +COPY --from=builder /app/swag /bin/swag + +ENTRYPOINT ["/bin/swag"] diff --git a/pkg/swag/const.go b/pkg/swag/const.go new file mode 100644 index 0000000..8375510 --- /dev/null +++ b/pkg/swag/const.go @@ -0,0 +1,567 @@ +package swag + +import ( + "go/ast" + "go/token" + "reflect" + "strconv" + "strings" + "unicode/utf8" +) + +// ConstVariable a model to record a const variable +type ConstVariable struct { + Name *ast.Ident + Type ast.Expr + Value interface{} + Comment *ast.CommentGroup + File *ast.File + Pkg *PackageDefinitions +} + +var escapedChars = map[uint8]uint8{ + 'n': '\n', + 'r': '\r', + 't': '\t', + 'v': '\v', + '\\': '\\', + '"': '"', +} + +// EvaluateEscapedChar parse escaped character +func EvaluateEscapedChar(text string) rune { + if len(text) == 1 { + return rune(text[0]) + } + + if len(text) == 2 && text[0] == '\\' { + return rune(escapedChars[text[1]]) + } + + if len(text) == 6 && text[0:2] == "\\u" { + n, err := strconv.ParseInt(text[2:], 16, 32) + if err == nil { + return rune(n) + } + } + + return 0 +} + +// EvaluateEscapedString parse escaped characters in string +func EvaluateEscapedString(text string) string { + if !strings.ContainsRune(text, '\\') { + return text + } + result := make([]byte, 0, len(text)) + for i := 0; i < len(text); i++ { + if text[i] == '\\' { + i++ + if text[i] == 'u' { + i++ + char, err := strconv.ParseInt(text[i:i+4], 16, 32) + if err == nil { + result = utf8.AppendRune(result, rune(char)) + } + i += 3 + } else if c, ok := escapedChars[text[i]]; ok { + result = append(result, c) + } + } else { + result = append(result, text[i]) + } + } + return string(result) +} + +// EvaluateDataConversion evaluate the type a explicit type conversion +func EvaluateDataConversion(x interface{}, typeName string) interface{} { + switch value := x.(type) { + case int: + switch typeName { + case "int": + return int(value) + case "byte": + return byte(value) + case "int8": + return int8(value) + case "int16": + return int16(value) + case "int32": + return int32(value) + case "int64": + return int64(value) + case "uint": + return uint(value) + case "uint8": + return uint8(value) + case "uint16": + return uint16(value) + case "uint32": + return uint32(value) + case "uint64": + return uint64(value) + case "rune": + return rune(value) + } + case uint: + switch typeName { + case "int": + return int(value) + case "byte": + return byte(value) + case "int8": + return int8(value) + case "int16": + return int16(value) + case "int32": + return int32(value) + case "int64": + return int64(value) + case "uint": + return uint(value) + case "uint8": + return uint8(value) + case "uint16": + return uint16(value) + case "uint32": + return uint32(value) + case "uint64": + return uint64(value) + case "rune": + return rune(value) + } + case int8: + switch typeName { + case "int": + return int(value) + case "byte": + return byte(value) + case "int8": + return int8(value) + case "int16": + return int16(value) + case "int32": + return int32(value) + case "int64": + return int64(value) + case "uint": + return uint(value) + case "uint8": + return uint8(value) + case "uint16": + return uint16(value) + case "uint32": + return uint32(value) + case "uint64": + return uint64(value) + case "rune": + return rune(value) + } + case uint8: + switch typeName { + case "int": + return int(value) + case "byte": + return byte(value) + case "int8": + return int8(value) + case "int16": + return int16(value) + case "int32": + return int32(value) + case "int64": + return int64(value) + case "uint": + return uint(value) + case "uint8": + return uint8(value) + case "uint16": + return uint16(value) + case "uint32": + return uint32(value) + case "uint64": + return uint64(value) + case "rune": + return rune(value) + } + case int16: + switch typeName { + case "int": + return int(value) + case "byte": + return byte(value) + case "int8": + return int8(value) + case "int16": + return int16(value) + case "int32": + return int32(value) + case "int64": + return int64(value) + case "uint": + return uint(value) + case "uint8": + return uint8(value) + case "uint16": + return uint16(value) + case "uint32": + return uint32(value) + case "uint64": + return uint64(value) + case "rune": + return rune(value) + } + case uint16: + switch typeName { + case "int": + return int(value) + case "byte": + return byte(value) + case "int8": + return int8(value) + case "int16": + return int16(value) + case "int32": + return int32(value) + case "int64": + return int64(value) + case "uint": + return uint(value) + case "uint8": + return uint8(value) + case "uint16": + return uint16(value) + case "uint32": + return uint32(value) + case "uint64": + return uint64(value) + case "rune": + return rune(value) + } + case int32: + switch typeName { + case "int": + return int(value) + case "byte": + return byte(value) + case "int8": + return int8(value) + case "int16": + return int16(value) + case "int32": + return int32(value) + case "int64": + return int64(value) + case "uint": + return uint(value) + case "uint8": + return uint8(value) + case "uint16": + return uint16(value) + case "uint32": + return uint32(value) + case "uint64": + return uint64(value) + case "rune": + return rune(value) + case "string": + return string(value) + } + case uint32: + switch typeName { + case "int": + return int(value) + case "byte": + return byte(value) + case "int8": + return int8(value) + case "int16": + return int16(value) + case "int32": + return int32(value) + case "int64": + return int64(value) + case "uint": + return uint(value) + case "uint8": + return uint8(value) + case "uint16": + return uint16(value) + case "uint32": + return uint32(value) + case "uint64": + return uint64(value) + case "rune": + return rune(value) + } + case int64: + switch typeName { + case "int": + return int(value) + case "byte": + return byte(value) + case "int8": + return int8(value) + case "int16": + return int16(value) + case "int32": + return int32(value) + case "int64": + return int64(value) + case "uint": + return uint(value) + case "uint8": + return uint8(value) + case "uint16": + return uint16(value) + case "uint32": + return uint32(value) + case "uint64": + return uint64(value) + case "rune": + return rune(value) + } + case uint64: + switch typeName { + case "int": + return int(value) + case "byte": + return byte(value) + case "int8": + return int8(value) + case "int16": + return int16(value) + case "int32": + return int32(value) + case "int64": + return int64(value) + case "uint": + return uint(value) + case "uint8": + return uint8(value) + case "uint16": + return uint16(value) + case "uint32": + return uint32(value) + case "uint64": + return uint64(value) + case "rune": + return rune(value) + } + case string: + switch typeName { + case "string": + return value + } + } + return nil +} + +// EvaluateUnary evaluate the type and value of a unary expression +func EvaluateUnary(x interface{}, operator token.Token, xtype ast.Expr) (interface{}, ast.Expr) { + switch operator { + case token.SUB: + switch value := x.(type) { + case int: + return -value, xtype + case int8: + return -value, xtype + case int16: + return -value, xtype + case int32: + return -value, xtype + case int64: + return -value, xtype + } + case token.XOR: + switch value := x.(type) { + case int: + return ^value, xtype + case int8: + return ^value, xtype + case int16: + return ^value, xtype + case int32: + return ^value, xtype + case int64: + return ^value, xtype + case uint: + return ^value, xtype + case uint8: + return ^value, xtype + case uint16: + return ^value, xtype + case uint32: + return ^value, xtype + case uint64: + return ^value, xtype + } + } + return nil, nil +} + +// EvaluateBinary evaluate the type and value of a binary expression +func EvaluateBinary(x, y interface{}, operator token.Token, xtype, ytype ast.Expr) (interface{}, ast.Expr) { + if operator == token.SHR || operator == token.SHL { + var rightOperand uint64 + yValue := reflect.ValueOf(y) + if yValue.CanUint() { + rightOperand = yValue.Uint() + } else if yValue.CanInt() { + rightOperand = uint64(yValue.Int()) + } + + switch operator { + case token.SHL: + switch xValue := x.(type) { + case int: + return xValue << rightOperand, xtype + case int8: + return xValue << rightOperand, xtype + case int16: + return xValue << rightOperand, xtype + case int32: + return xValue << rightOperand, xtype + case int64: + return xValue << rightOperand, xtype + case uint: + return xValue << rightOperand, xtype + case uint8: + return xValue << rightOperand, xtype + case uint16: + return xValue << rightOperand, xtype + case uint32: + return xValue << rightOperand, xtype + case uint64: + return xValue << rightOperand, xtype + } + case token.SHR: + switch xValue := x.(type) { + case int: + return xValue >> rightOperand, xtype + case int8: + return xValue >> rightOperand, xtype + case int16: + return xValue >> rightOperand, xtype + case int32: + return xValue >> rightOperand, xtype + case int64: + return xValue >> rightOperand, xtype + case uint: + return xValue >> rightOperand, xtype + case uint8: + return xValue >> rightOperand, xtype + case uint16: + return xValue >> rightOperand, xtype + case uint32: + return xValue >> rightOperand, xtype + case uint64: + return xValue >> rightOperand, xtype + } + } + return nil, nil + } + + evalType := xtype + if evalType == nil { + evalType = ytype + } + + xValue := reflect.ValueOf(x) + yValue := reflect.ValueOf(y) + if xValue.Kind() == reflect.String && yValue.Kind() == reflect.String { + return xValue.String() + yValue.String(), evalType + } + + var targetValue reflect.Value + if xValue.Kind() != reflect.Int { + targetValue = reflect.New(xValue.Type()).Elem() + } else { + targetValue = reflect.New(yValue.Type()).Elem() + } + + switch operator { + case token.ADD: + if xValue.CanInt() && yValue.CanInt() { + targetValue.SetInt(xValue.Int() + yValue.Int()) + } else if xValue.CanUint() && yValue.CanUint() { + targetValue.SetUint(xValue.Uint() + yValue.Uint()) + } else if xValue.CanInt() && yValue.CanUint() { + targetValue.SetUint(uint64(xValue.Int()) + yValue.Uint()) + } else if xValue.CanUint() && yValue.CanInt() { + targetValue.SetUint(xValue.Uint() + uint64(yValue.Int())) + } + case token.SUB: + if xValue.CanInt() && yValue.CanInt() { + targetValue.SetInt(xValue.Int() - yValue.Int()) + } else if xValue.CanUint() && yValue.CanUint() { + targetValue.SetUint(xValue.Uint() - yValue.Uint()) + } else if xValue.CanInt() && yValue.CanUint() { + targetValue.SetUint(uint64(xValue.Int()) - yValue.Uint()) + } else if xValue.CanUint() && yValue.CanInt() { + targetValue.SetUint(xValue.Uint() - uint64(yValue.Int())) + } + case token.MUL: + if xValue.CanInt() && yValue.CanInt() { + targetValue.SetInt(xValue.Int() * yValue.Int()) + } else if xValue.CanUint() && yValue.CanUint() { + targetValue.SetUint(xValue.Uint() * yValue.Uint()) + } else if xValue.CanInt() && yValue.CanUint() { + targetValue.SetUint(uint64(xValue.Int()) * yValue.Uint()) + } else if xValue.CanUint() && yValue.CanInt() { + targetValue.SetUint(xValue.Uint() * uint64(yValue.Int())) + } + case token.QUO: + if xValue.CanInt() && yValue.CanInt() { + targetValue.SetInt(xValue.Int() / yValue.Int()) + } else if xValue.CanUint() && yValue.CanUint() { + targetValue.SetUint(xValue.Uint() / yValue.Uint()) + } else if xValue.CanInt() && yValue.CanUint() { + targetValue.SetUint(uint64(xValue.Int()) / yValue.Uint()) + } else if xValue.CanUint() && yValue.CanInt() { + targetValue.SetUint(xValue.Uint() / uint64(yValue.Int())) + } + case token.REM: + if xValue.CanInt() && yValue.CanInt() { + targetValue.SetInt(xValue.Int() % yValue.Int()) + } else if xValue.CanUint() && yValue.CanUint() { + targetValue.SetUint(xValue.Uint() % yValue.Uint()) + } else if xValue.CanInt() && yValue.CanUint() { + targetValue.SetUint(uint64(xValue.Int()) % yValue.Uint()) + } else if xValue.CanUint() && yValue.CanInt() { + targetValue.SetUint(xValue.Uint() % uint64(yValue.Int())) + } + case token.AND: + if xValue.CanInt() && yValue.CanInt() { + targetValue.SetInt(xValue.Int() & yValue.Int()) + } else if xValue.CanUint() && yValue.CanUint() { + targetValue.SetUint(xValue.Uint() & yValue.Uint()) + } else if xValue.CanInt() && yValue.CanUint() { + targetValue.SetUint(uint64(xValue.Int()) & yValue.Uint()) + } else if xValue.CanUint() && yValue.CanInt() { + targetValue.SetUint(xValue.Uint() & uint64(yValue.Int())) + } + case token.OR: + if xValue.CanInt() && yValue.CanInt() { + targetValue.SetInt(xValue.Int() | yValue.Int()) + } else if xValue.CanUint() && yValue.CanUint() { + targetValue.SetUint(xValue.Uint() | yValue.Uint()) + } else if xValue.CanInt() && yValue.CanUint() { + targetValue.SetUint(uint64(xValue.Int()) | yValue.Uint()) + } else if xValue.CanUint() && yValue.CanInt() { + targetValue.SetUint(xValue.Uint() | uint64(yValue.Int())) + } + case token.XOR: + if xValue.CanInt() && yValue.CanInt() { + targetValue.SetInt(xValue.Int() ^ yValue.Int()) + } else if xValue.CanUint() && yValue.CanUint() { + targetValue.SetUint(xValue.Uint() ^ yValue.Uint()) + } else if xValue.CanInt() && yValue.CanUint() { + targetValue.SetUint(uint64(xValue.Int()) ^ yValue.Uint()) + } else if xValue.CanUint() && yValue.CanInt() { + targetValue.SetUint(xValue.Uint() ^ uint64(yValue.Int())) + } + } + return targetValue.Interface(), evalType +} diff --git a/pkg/swag/doc.go b/pkg/swag/doc.go new file mode 100644 index 0000000..aa82f7b --- /dev/null +++ b/pkg/swag/doc.go @@ -0,0 +1,5 @@ +/* +Package swag converts Go annotations to Swagger Documentation 2.0. +See https://git.ipao.vip/rogeecn/atomctl/pkg/swag for more information about swag. +*/ +package swag // import "git.ipao.vip/rogeecn/atomctl/pkg/swag" diff --git a/pkg/swag/enums.go b/pkg/swag/enums.go new file mode 100644 index 0000000..300787b --- /dev/null +++ b/pkg/swag/enums.go @@ -0,0 +1,14 @@ +package swag + +const ( + enumVarNamesExtension = "x-enum-varnames" + enumCommentsExtension = "x-enum-comments" + enumDescriptionsExtension = "x-enum-descriptions" +) + +// EnumValue a model to record an enum consts variable +type EnumValue struct { + key string + Value interface{} + Comment string +} diff --git a/pkg/swag/enums_test.go b/pkg/swag/enums_test.go new file mode 100644 index 0000000..d6a258e --- /dev/null +++ b/pkg/swag/enums_test.go @@ -0,0 +1,34 @@ +package swag + +import ( + "encoding/json" + "math/bits" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseGlobalEnums(t *testing.T) { + searchDir := "testdata/enums" + expected, err := os.ReadFile(filepath.Join(searchDir, "expected.json")) + assert.NoError(t, err) + + p := New() + err = p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + b, err := json.MarshalIndent(p.swagger, "", " ") + assert.NoError(t, err) + assert.Equal(t, string(expected), string(b)) + constsPath := "git.ipao.vip/rogeecn/atomctl/pkg/swag/testdata/enums/consts" + assert.Equal(t, bits.UintSize, p.packages.packages[constsPath].ConstTable["uintSize"].Value) + assert.Equal(t, int32(62), p.packages.packages[constsPath].ConstTable["maxBase"].Value) + assert.Equal(t, 8, p.packages.packages[constsPath].ConstTable["shlByLen"].Value) + assert.Equal(t, 255, p.packages.packages[constsPath].ConstTable["hexnum"].Value) + assert.Equal(t, 15, p.packages.packages[constsPath].ConstTable["octnum"].Value) + assert.Equal(t, `aa\nbb\u8888cc`, p.packages.packages[constsPath].ConstTable["nonescapestr"].Value) + assert.Equal(t, "aa\nbb\u8888cc", p.packages.packages[constsPath].ConstTable["escapestr"].Value) + assert.Equal(t, 1_000_000, p.packages.packages[constsPath].ConstTable["underscored"].Value) + assert.Equal(t, 0b10001000, p.packages.packages[constsPath].ConstTable["binaryInteger"].Value) +} diff --git a/pkg/swag/field_parser.go b/pkg/swag/field_parser.go new file mode 100644 index 0000000..de8945d --- /dev/null +++ b/pkg/swag/field_parser.go @@ -0,0 +1,686 @@ +package swag + +import ( + "fmt" + "go/ast" + "reflect" + "regexp" + "strconv" + "strings" + "sync" + "unicode" + + "github.com/go-openapi/spec" +) + +var _ FieldParser = &tagBaseFieldParser{p: nil, field: nil, tag: ""} + +const ( + requiredLabel = "required" + optionalLabel = "optional" + swaggerTypeTag = "swaggertype" + swaggerIgnoreTag = "swaggerignore" +) + +type tagBaseFieldParser struct { + p *Parser + field *ast.Field + tag reflect.StructTag +} + +func newTagBaseFieldParser(p *Parser, field *ast.Field) FieldParser { + fieldParser := tagBaseFieldParser{ + p: p, + field: field, + tag: "", + } + if fieldParser.field.Tag != nil { + fieldParser.tag = reflect.StructTag(strings.ReplaceAll(field.Tag.Value, "`", "")) + } + + return &fieldParser +} + +func (ps *tagBaseFieldParser) ShouldSkip() bool { + // Skip non-exported fields. + if ps.field.Names != nil && !ast.IsExported(ps.field.Names[0].Name) { + return true + } + + if ps.field.Tag == nil { + return false + } + + ignoreTag := ps.tag.Get(swaggerIgnoreTag) + if strings.EqualFold(ignoreTag, "true") { + return true + } + + // json:"tag,hoge" + name := strings.TrimSpace(strings.Split(ps.tag.Get(jsonTag), ",")[0]) + if name == "-" { + return true + } + + return false +} + +func (ps *tagBaseFieldParser) FieldNames() ([]string, error) { + if len(ps.field.Names) <= 1 { + // if embedded but with a json/form name ?? + if ps.field.Tag != nil { + // json:"tag,hoge" + name := strings.TrimSpace(strings.Split(ps.tag.Get(jsonTag), ",")[0]) + if name != "" { + return []string{name}, nil + } + + // use "form" tag over json tag + name = ps.FormName() + if name != "" { + return []string{name}, nil + } + } + if len(ps.field.Names) == 0 { + return nil, nil + } + } + var names = make([]string, 0, len(ps.field.Names)) + for _, name := range ps.field.Names { + switch ps.p.PropNamingStrategy { + case SnakeCase: + names = append(names, toSnakeCase(name.Name)) + case PascalCase: + names = append(names, name.Name) + default: + names = append(names, toLowerCamelCase(name.Name)) + } + } + return names, nil +} + +func (ps *tagBaseFieldParser) firstTagValue(tag string) string { + if ps.field.Tag != nil { + return strings.TrimRight(strings.TrimSpace(strings.Split(ps.tag.Get(tag), ",")[0]), "[]") + } + return "" +} + +func (ps *tagBaseFieldParser) FormName() string { + return ps.firstTagValue(formTag) +} + +func (ps *tagBaseFieldParser) HeaderName() string { + return ps.firstTagValue(headerTag) +} + +func (ps *tagBaseFieldParser) PathName() string { + return ps.firstTagValue(uriTag) +} + +func toSnakeCase(in string) string { + var ( + runes = []rune(in) + length = len(runes) + out []rune + ) + + for idx := 0; idx < length; idx++ { + if idx > 0 && unicode.IsUpper(runes[idx]) && + ((idx+1 < length && unicode.IsLower(runes[idx+1])) || unicode.IsLower(runes[idx-1])) { + out = append(out, '_') + } + + out = append(out, unicode.ToLower(runes[idx])) + } + + return string(out) +} + +func toLowerCamelCase(in string) string { + var flag bool + + out := make([]rune, len(in)) + + runes := []rune(in) + for i, curr := range runes { + if (i == 0 && unicode.IsUpper(curr)) || (flag && unicode.IsUpper(curr)) { + out[i] = unicode.ToLower(curr) + flag = true + + continue + } + + out[i] = curr + flag = false + } + + return string(out) +} + +func (ps *tagBaseFieldParser) CustomSchema() (*spec.Schema, error) { + if ps.field.Tag == nil { + return nil, nil + } + + typeTag := ps.tag.Get(swaggerTypeTag) + if typeTag != "" { + return BuildCustomSchema(strings.Split(typeTag, ",")) + } + + return nil, nil +} + +type structField struct { + title string + schemaType string + arrayType string + formatType string + maximum *float64 + minimum *float64 + multipleOf *float64 + maxLength *int64 + minLength *int64 + maxItems *int64 + minItems *int64 + exampleValue interface{} + enums []interface{} + enumVarNames []interface{} + unique bool +} + +// splitNotWrapped slices s into all substrings separated by sep if sep is not +// wrapped by brackets and returns a slice of the substrings between those separators. +func splitNotWrapped(s string, sep rune) []string { + openCloseMap := map[rune]rune{ + '(': ')', + '[': ']', + '{': '}', + } + + var ( + result = make([]string, 0) + current = strings.Builder{} + openCount = 0 + openChar rune + ) + + for _, char := range s { + switch { + case openChar == 0 && openCloseMap[char] != 0: + openChar = char + + openCount++ + + current.WriteRune(char) + case char == openChar: + openCount++ + + current.WriteRune(char) + case openCount > 0 && char == openCloseMap[openChar]: + openCount-- + + current.WriteRune(char) + case openCount == 0 && char == sep: + result = append(result, current.String()) + + openChar = 0 + + current = strings.Builder{} + default: + current.WriteRune(char) + } + } + + if current.String() != "" { + result = append(result, current.String()) + } + + return result +} + +// ComplementSchema complement schema with field properties +func (ps *tagBaseFieldParser) ComplementSchema(schema *spec.Schema) error { + types := ps.p.GetSchemaTypePath(schema, 2) + if len(types) == 0 { + return fmt.Errorf("invalid type for field: %s", ps.field.Names[0]) + } + + if IsRefSchema(schema) { + var newSchema = spec.Schema{} + err := ps.complementSchema(&newSchema, types) + if err != nil { + return err + } + if !reflect.ValueOf(newSchema).IsZero() { + *schema = *(newSchema.WithAllOf(*schema)) + } + return nil + } + + return ps.complementSchema(schema, types) +} + +// complementSchema complement schema with field properties +func (ps *tagBaseFieldParser) complementSchema(schema *spec.Schema, types []string) error { + if ps.field.Tag == nil { + if ps.field.Doc != nil { + schema.Description = strings.TrimSpace(ps.field.Doc.Text()) + } + + if schema.Description == "" && ps.field.Comment != nil { + schema.Description = strings.TrimSpace(ps.field.Comment.Text()) + } + + return nil + } + + field := &structField{ + schemaType: types[0], + formatType: ps.tag.Get(formatTag), + title: ps.tag.Get(titleTag), + } + + if len(types) > 1 && (types[0] == ARRAY || types[0] == OBJECT) { + field.arrayType = types[1] + } + + jsonTagValue := ps.tag.Get(jsonTag) + + bindingTagValue := ps.tag.Get(bindingTag) + if bindingTagValue != "" { + parseValidTags(bindingTagValue, field) + } + + validateTagValue := ps.tag.Get(validateTag) + if validateTagValue != "" { + parseValidTags(validateTagValue, field) + } + + enumsTagValue := ps.tag.Get(enumsTag) + if enumsTagValue != "" { + err := parseEnumTags(enumsTagValue, field) + if err != nil { + return err + } + } + + if IsNumericType(field.schemaType) || IsNumericType(field.arrayType) { + maximum, err := getFloatTag(ps.tag, maximumTag) + if err != nil { + return err + } + + if maximum != nil { + field.maximum = maximum + } + + minimum, err := getFloatTag(ps.tag, minimumTag) + if err != nil { + return err + } + + if minimum != nil { + field.minimum = minimum + } + + multipleOf, err := getFloatTag(ps.tag, multipleOfTag) + if err != nil { + return err + } + + if multipleOf != nil { + field.multipleOf = multipleOf + } + } + + if field.schemaType == STRING || field.arrayType == STRING { + maxLength, err := getIntTag(ps.tag, maxLengthTag) + if err != nil { + return err + } + + if maxLength != nil { + field.maxLength = maxLength + } + + minLength, err := getIntTag(ps.tag, minLengthTag) + if err != nil { + return err + } + + if minLength != nil { + field.minLength = minLength + } + } + + // json:"name,string" or json:",string" + exampleTagValue, ok := ps.tag.Lookup(exampleTag) + if ok { + field.exampleValue = exampleTagValue + + if !strings.Contains(jsonTagValue, ",string") { + example, err := defineTypeOfExample(field.schemaType, field.arrayType, exampleTagValue) + if err != nil { + return err + } + + field.exampleValue = example + } + } + + // perform this after setting everything else (min, max, etc...) + if strings.Contains(jsonTagValue, ",string") { + // @encoding/json: "It applies only to fields of string, floating point, integer, or boolean types." + defaultValues := map[string]string{ + // Zero Values as string + STRING: "", + INTEGER: "0", + BOOLEAN: "false", + NUMBER: "0", + } + + defaultValue, ok := defaultValues[field.schemaType] + if ok { + field.schemaType = STRING + *schema = *PrimitiveSchema(field.schemaType) + + if field.exampleValue == nil { + // if exampleValue is not defined by the user, + // we will force an example with a correct value + // (eg: int->"0", bool:"false") + field.exampleValue = defaultValue + } + } + } + + if ps.field.Doc != nil { + schema.Description = strings.TrimSpace(ps.field.Doc.Text()) + } + + if schema.Description == "" && ps.field.Comment != nil { + schema.Description = strings.TrimSpace(ps.field.Comment.Text()) + } + + schema.ReadOnly = ps.tag.Get(readOnlyTag) == "true" + + defaultTagValue := ps.tag.Get(defaultTag) + if defaultTagValue != "" { + value, err := defineType(field.schemaType, defaultTagValue) + if err != nil { + return err + } + + schema.Default = value + } + + schema.Example = field.exampleValue + + if field.schemaType != ARRAY { + schema.Format = field.formatType + } + schema.Title = field.title + + extensionsTagValue := ps.tag.Get(extensionsTag) + if extensionsTagValue != "" { + schema.Extensions = setExtensionParam(extensionsTagValue) + } + + varNamesTag := ps.tag.Get("x-enum-varnames") + if varNamesTag != "" { + varNames := strings.Split(varNamesTag, ",") + if len(varNames) != len(field.enums) { + return fmt.Errorf("invalid count of x-enum-varnames. expected %d, got %d", len(field.enums), len(varNames)) + } + + field.enumVarNames = nil + + for _, v := range varNames { + field.enumVarNames = append(field.enumVarNames, v) + } + + if field.schemaType == ARRAY { + // Add the var names in the items schema + if schema.Items.Schema.Extensions == nil { + schema.Items.Schema.Extensions = map[string]interface{}{} + } + schema.Items.Schema.Extensions[enumVarNamesExtension] = field.enumVarNames + } else { + // Add to top level schema + if schema.Extensions == nil { + schema.Extensions = map[string]interface{}{} + } + schema.Extensions[enumVarNamesExtension] = field.enumVarNames + } + } + + eleSchema := schema + + if field.schemaType == ARRAY { + // For Array only + schema.MaxItems = field.maxItems + schema.MinItems = field.minItems + schema.UniqueItems = field.unique + + eleSchema = schema.Items.Schema + eleSchema.Format = field.formatType + } + + eleSchema.Maximum = field.maximum + eleSchema.Minimum = field.minimum + eleSchema.MultipleOf = field.multipleOf + eleSchema.MaxLength = field.maxLength + eleSchema.MinLength = field.minLength + eleSchema.Enum = field.enums + + return nil +} + +func getFloatTag(structTag reflect.StructTag, tagName string) (*float64, error) { + strValue := structTag.Get(tagName) + if strValue == "" { + return nil, nil + } + + value, err := strconv.ParseFloat(strValue, 64) + if err != nil { + return nil, fmt.Errorf("can't parse numeric value of %q tag: %v", tagName, err) + } + + return &value, nil +} + +func getIntTag(structTag reflect.StructTag, tagName string) (*int64, error) { + strValue := structTag.Get(tagName) + if strValue == "" { + return nil, nil + } + + value, err := strconv.ParseInt(strValue, 10, 64) + if err != nil { + return nil, fmt.Errorf("can't parse numeric value of %q tag: %v", tagName, err) + } + + return &value, nil +} + +func (ps *tagBaseFieldParser) IsRequired() (bool, error) { + if ps.field.Tag == nil { + return false, nil + } + + bindingTag := ps.tag.Get(bindingTag) + if bindingTag != "" { + for _, val := range strings.Split(bindingTag, ",") { + switch val { + case requiredLabel: + return true, nil + case optionalLabel: + return false, nil + } + } + } + + validateTag := ps.tag.Get(validateTag) + if validateTag != "" { + for _, val := range strings.Split(validateTag, ",") { + switch val { + case requiredLabel: + return true, nil + case optionalLabel: + return false, nil + } + } + } + + return ps.p.RequiredByDefault, nil +} + +func parseValidTags(validTag string, sf *structField) { + // `validate:"required,max=10,min=1"` + // ps. required checked by IsRequired(). + for _, val := range strings.Split(validTag, ",") { + var ( + valValue string + keyVal = strings.Split(val, "=") + ) + + switch len(keyVal) { + case 1: + case 2: + valValue = strings.ReplaceAll(strings.ReplaceAll(keyVal[1], utf8HexComma, ","), utf8Pipe, "|") + default: + continue + } + + switch keyVal[0] { + case "max", "lte": + sf.setMax(valValue) + case "min", "gte": + sf.setMin(valValue) + case "oneof": + sf.setOneOf(valValue) + case "unique": + if sf.schemaType == ARRAY { + sf.unique = true + } + case "dive": + // ignore dive + return + default: + continue + } + } +} + +func parseEnumTags(enumTag string, field *structField) error { + enumType := field.schemaType + if field.schemaType == ARRAY { + enumType = field.arrayType + } + + field.enums = nil + + for _, e := range strings.Split(enumTag, ",") { + value, err := defineType(enumType, e) + if err != nil { + return err + } + + field.enums = append(field.enums, value) + } + + return nil +} + +func (sf *structField) setOneOf(valValue string) { + if len(sf.enums) != 0 { + return + } + + enumType := sf.schemaType + if sf.schemaType == ARRAY { + enumType = sf.arrayType + } + + valValues := parseOneOfParam2(valValue) + for i := range valValues { + value, err := defineType(enumType, valValues[i]) + if err != nil { + continue + } + + sf.enums = append(sf.enums, value) + } +} + +func (sf *structField) setMin(valValue string) { + value, err := strconv.ParseFloat(valValue, 64) + if err != nil { + return + } + + switch sf.schemaType { + case INTEGER, NUMBER: + sf.minimum = &value + case STRING: + intValue := int64(value) + sf.minLength = &intValue + case ARRAY: + intValue := int64(value) + sf.minItems = &intValue + } +} + +func (sf *structField) setMax(valValue string) { + value, err := strconv.ParseFloat(valValue, 64) + if err != nil { + return + } + + switch sf.schemaType { + case INTEGER, NUMBER: + sf.maximum = &value + case STRING: + intValue := int64(value) + sf.maxLength = &intValue + case ARRAY: + intValue := int64(value) + sf.maxItems = &intValue + } +} + +const ( + utf8HexComma = "0x2C" + utf8Pipe = "0x7C" +) + +// These code copy from +// https://github.com/go-playground/validator/blob/d4271985b44b735c6f76abc7a06532ee997f9476/baked_in.go#L207 +// ---. +var oneofValsCache = map[string][]string{} +var oneofValsCacheRWLock = sync.RWMutex{} +var splitParamsRegex = regexp.MustCompile(`'[^']*'|\S+`) + +func parseOneOfParam2(param string) []string { + oneofValsCacheRWLock.RLock() + values, ok := oneofValsCache[param] + oneofValsCacheRWLock.RUnlock() + + if !ok { + oneofValsCacheRWLock.Lock() + values = splitParamsRegex.FindAllString(param, -1) + + for i := 0; i < len(values); i++ { + values[i] = strings.ReplaceAll(values[i], "'", "") + } + + oneofValsCache[param] = values + + oneofValsCacheRWLock.Unlock() + } + + return values +} + +// ---. diff --git a/pkg/swag/field_parser_test.go b/pkg/swag/field_parser_test.go new file mode 100644 index 0000000..28f8cdf --- /dev/null +++ b/pkg/swag/field_parser_test.go @@ -0,0 +1,729 @@ +package swag + +import ( + "go/ast" + "testing" + + "github.com/go-openapi/spec" + "github.com/stretchr/testify/assert" +) + +func TestDefaultFieldParser(t *testing.T) { + t.Run("Example tag", func(t *testing.T) { + t.Parallel() + + schema := spec.Schema{} + schema.Type = []string{"string"} + err := newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" example:"one"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Equal(t, "one", schema.Example) + + schema = spec.Schema{} + schema.Type = []string{"string"} + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" example:""`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Equal(t, "", schema.Example) + + schema = spec.Schema{} + schema.Type = []string{"float"} + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" example:"one"`, + }}, + ).ComplementSchema(&schema) + assert.Error(t, err) + }) + + t.Run("Format tag", func(t *testing.T) { + t.Parallel() + + schema := spec.Schema{} + schema.Type = []string{"string"} + err := newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" format:"csv"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Equal(t, "csv", schema.Format) + }) + + t.Run("Title tag", func(t *testing.T) { + t.Parallel() + + schema := spec.Schema{} + schema.Type = []string{"string"} + err := newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" title:"myfield"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Equal(t, "myfield", schema.Title) + }) + + t.Run("Required tag", func(t *testing.T) { + t.Parallel() + + got, err := newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" binding:"required"`, + }}, + ).IsRequired() + assert.NoError(t, err) + assert.Equal(t, true, got) + + got, err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" validate:"required"`, + }}, + ).IsRequired() + assert.NoError(t, err) + assert.Equal(t, true, got) + }) + + t.Run("Default required tag", func(t *testing.T) { + t.Parallel() + + got, err := newTagBaseFieldParser( + &Parser{ + RequiredByDefault: true, + }, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test"`, + }}, + ).IsRequired() + assert.NoError(t, err) + assert.True(t, got) + }) + + t.Run("Optional tag", func(t *testing.T) { + t.Parallel() + + got, err := newTagBaseFieldParser( + &Parser{ + RequiredByDefault: true, + }, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" binding:"optional"`, + }}, + ).IsRequired() + assert.NoError(t, err) + assert.False(t, got) + + got, err = newTagBaseFieldParser( + &Parser{ + RequiredByDefault: true, + }, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" validate:"optional"`, + }}, + ).IsRequired() + assert.NoError(t, err) + assert.False(t, got) + }) + + t.Run("Extensions tag", func(t *testing.T) { + t.Parallel() + + schema := spec.Schema{} + schema.Type = []string{"int"} + schema.Extensions = map[string]interface{}{} + err := newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" extensions:"x-nullable,x-abc=def,!x-omitempty,x-example=[0, 9],x-example2={çãíœ, (bar=(abc, def)), [0,9]}"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Equal(t, true, schema.Extensions["x-nullable"]) + assert.Equal(t, "def", schema.Extensions["x-abc"]) + assert.Equal(t, false, schema.Extensions["x-omitempty"]) + assert.Equal(t, "[0, 9]", schema.Extensions["x-example"]) + assert.Equal(t, "{çãíœ, (bar=(abc, def)), [0,9]}", schema.Extensions["x-example2"]) + }) + + t.Run("Enums tag", func(t *testing.T) { + t.Parallel() + + schema := spec.Schema{} + schema.Type = []string{"string"} + err := newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" enums:"a,b,c"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Equal(t, []interface{}{"a", "b", "c"}, schema.Enum) + + schema = spec.Schema{} + schema.Type = []string{"float"} + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" enums:"a,b,c"`, + }}, + ).ComplementSchema(&schema) + assert.Error(t, err) + }) + + t.Run("EnumVarNames tag", func(t *testing.T) { + t.Parallel() + + schema := spec.Schema{} + schema.Type = []string{"int"} + schema.Extensions = map[string]interface{}{} + schema.Enum = []interface{}{} + err := newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" enums:"0,1,2" x-enum-varnames:"Daily,Weekly,Monthly"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Equal(t, []interface{}{"Daily", "Weekly", "Monthly"}, schema.Extensions["x-enum-varnames"]) + + schema = spec.Schema{} + schema.Type = []string{"int"} + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" enums:"0,1,2,3" x-enum-varnames:"Daily,Weekly,Monthly"`, + }}, + ).ComplementSchema(&schema) + assert.Error(t, err) + + // Test for an array of enums + schema = spec.Schema{} + schema.Type = []string{"array"} + schema.Items = &spec.SchemaOrArray{ + Schema: &spec.Schema{ + SchemaProps: spec.SchemaProps{ + Type: []string{"int"}, + }, + }, + } + schema.Extensions = map[string]interface{}{} + schema.Enum = []interface{}{} + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" enums:"0,1,2" x-enum-varnames:"Daily,Weekly,Monthly"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Equal(t, []interface{}{"Daily", "Weekly", "Monthly"}, schema.Items.Schema.Extensions["x-enum-varnames"]) + assert.Equal(t, spec.Extensions{}, schema.Extensions) + }) + + t.Run("Default tag", func(t *testing.T) { + t.Parallel() + + schema := spec.Schema{} + schema.Type = []string{"string"} + err := newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" default:"pass"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Equal(t, "pass", schema.Default) + + schema = spec.Schema{} + schema.Type = []string{"float"} + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" default:"pass"`, + }}, + ).ComplementSchema(&schema) + assert.Error(t, err) + }) + + t.Run("Numeric value", func(t *testing.T) { + t.Parallel() + + schema := spec.Schema{} + schema.Type = []string{"integer"} + err := newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" maximum:"1"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + max := float64(1) + assert.Equal(t, &max, schema.Maximum) + + schema = spec.Schema{} + schema.Type = []string{"integer"} + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" maximum:"one"`, + }}, + ).ComplementSchema(&schema) + assert.Error(t, err) + + schema = spec.Schema{} + schema.Type = []string{"number"} + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" maximum:"1"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + max = float64(1) + assert.Equal(t, &max, schema.Maximum) + + schema = spec.Schema{} + schema.Type = []string{"number"} + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" maximum:"one"`, + }}, + ).ComplementSchema(&schema) + assert.Error(t, err) + + schema = spec.Schema{} + schema.Type = []string{"number"} + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" multipleOf:"1"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + multipleOf := float64(1) + assert.Equal(t, &multipleOf, schema.MultipleOf) + + schema = spec.Schema{} + schema.Type = []string{"number"} + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" multipleOf:"one"`, + }}, + ).ComplementSchema(&schema) + assert.Error(t, err) + + schema = spec.Schema{} + schema.Type = []string{"integer"} + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" minimum:"1"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + min := float64(1) + assert.Equal(t, &min, schema.Minimum) + + schema = spec.Schema{} + schema.Type = []string{"integer"} + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" minimum:"one"`, + }}, + ).ComplementSchema(&schema) + assert.Error(t, err) + }) + + t.Run("String value", func(t *testing.T) { + t.Parallel() + + schema := spec.Schema{} + schema.Type = []string{"string"} + err := newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" maxLength:"1"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + max := int64(1) + assert.Equal(t, &max, schema.MaxLength) + + schema = spec.Schema{} + schema.Type = []string{"string"} + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" maxLength:"one"`, + }}, + ).ComplementSchema(&schema) + assert.Error(t, err) + + schema = spec.Schema{} + schema.Type = []string{"string"} + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" minLength:"1"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + min := int64(1) + assert.Equal(t, &min, schema.MinLength) + + schema = spec.Schema{} + schema.Type = []string{"string"} + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" minLength:"one"`, + }}, + ).ComplementSchema(&schema) + assert.Error(t, err) + }) + + t.Run("Readonly tag", func(t *testing.T) { + t.Parallel() + + schema := spec.Schema{} + schema.Type = []string{"string"} + err := newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" readonly:"true"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Equal(t, true, schema.ReadOnly) + }) + + t.Run("Invalid tag", func(t *testing.T) { + t.Parallel() + + err := newTagBaseFieldParser( + &Parser{}, + &ast.Field{Names: []*ast.Ident{{Name: "BasicStruct"}}}, + ).ComplementSchema(nil) + assert.Error(t, err) + }) +} + +func TestValidTags(t *testing.T) { + t.Run("Required with max/min tag", func(t *testing.T) { + t.Parallel() + + schema := spec.Schema{} + schema.Type = []string{"string"} + err := newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" validate:"required,max=10,min=1"`, + }}, + ).ComplementSchema(&schema) + max := int64(10) + min := int64(1) + assert.NoError(t, err) + assert.Equal(t, &max, schema.MaxLength) + assert.Equal(t, &min, schema.MinLength) + + schema = spec.Schema{} + schema.Type = []string{"string"} + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" validate:"required,max=10,gte=1"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Equal(t, &max, schema.MaxLength) + assert.Equal(t, &min, schema.MinLength) + + schema = spec.Schema{} + schema.Type = []string{"integer"} + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" validate:"required,max=10,min=1"`, + }}, + ).ComplementSchema(&schema) + maxFloat64 := float64(10) + minFloat64 := float64(1) + assert.NoError(t, err) + assert.Equal(t, &maxFloat64, schema.Maximum) + assert.Equal(t, &minFloat64, schema.Minimum) + + schema = spec.Schema{} + schema.Type = []string{"array"} + schema.Items = &spec.SchemaOrArray{ + Schema: &spec.Schema{ + SchemaProps: spec.SchemaProps{ + Type: []string{"string"}, + }, + }, + } + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" validate:"required,max=10,min=1"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Equal(t, &max, schema.MaxItems) + assert.Equal(t, &min, schema.MinItems) + + // wrong validate tag will be ignored. + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" validate:"required,max=ten,min=1"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Empty(t, schema.MaxItems) + assert.Equal(t, &min, schema.MinItems) + }) + t.Run("Required with oneof tag", func(t *testing.T) { + t.Parallel() + + schema := spec.Schema{} + schema.Type = []string{"string"} + + err := newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" validate:"required,oneof='red book' 'green book'"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Equal(t, []interface{}{"red book", "green book"}, schema.Enum) + + schema = spec.Schema{} + schema.Type = []string{"integer"} + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" validate:"required,oneof=1 2 3"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Equal(t, []interface{}{1, 2, 3}, schema.Enum) + + schema = spec.Schema{} + schema.Type = []string{"array"} + schema.Items = &spec.SchemaOrArray{ + Schema: &spec.Schema{ + SchemaProps: spec.SchemaProps{ + Type: []string{"string"}, + }, + }, + } + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" validate:"required,oneof=red green yellow"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Equal(t, []interface{}{"red", "green", "yellow"}, schema.Items.Schema.Enum) + + schema = spec.Schema{} + schema.Type = []string{"string"} + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" validate:"required,oneof='red green' blue 'c0x2Cc' 'd0x7Cd'"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Equal(t, []interface{}{"red green", "blue", "c,c", "d|d"}, schema.Enum) + + schema = spec.Schema{} + schema.Type = []string{"string"} + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" validate:"required,oneof='c0x9Ab' book"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Equal(t, []interface{}{"c0x9Ab", "book"}, schema.Enum) + + schema = spec.Schema{} + schema.Type = []string{"string"} + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" binding:"oneof=foo bar" validate:"required,oneof=foo bar" enums:"a,b,c"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Equal(t, []interface{}{"a", "b", "c"}, schema.Enum) + + schema = spec.Schema{} + schema.Type = []string{"string"} + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" binding:"oneof=aa bb" validate:"required,oneof=foo bar"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Equal(t, []interface{}{"aa", "bb"}, schema.Enum) + }) + t.Run("Required with unique tag", func(t *testing.T) { + t.Parallel() + + schema := spec.Schema{} + schema.Type = []string{"array"} + schema.Items = &spec.SchemaOrArray{ + Schema: &spec.Schema{ + SchemaProps: spec.SchemaProps{ + Type: []string{"string"}, + }, + }, + } + err := newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" validate:"required,unique"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Equal(t, true, schema.UniqueItems) + }) + + t.Run("All tag", func(t *testing.T) { + t.Parallel() + schema := spec.Schema{} + schema.Type = []string{"array"} + schema.Items = &spec.SchemaOrArray{ + Schema: &spec.Schema{ + SchemaProps: spec.SchemaProps{ + Type: []string{"string"}, + }, + }, + } + err := newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" validate:"required,unique,max=10,min=1,oneof=a0x2Cc 'c0x7Cd book',omitempty,dive,max=1"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Equal(t, true, schema.UniqueItems) + + max := int64(10) + min := int64(1) + assert.Equal(t, &max, schema.MaxItems) + assert.Equal(t, &min, schema.MinItems) + assert.Equal(t, []interface{}{"a,c", "c|d book"}, schema.Items.Schema.Enum) + + schema = spec.Schema{} + schema.Type = []string{"array"} + schema.Items = &spec.SchemaOrArray{ + Schema: &spec.Schema{ + SchemaProps: spec.SchemaProps{ + Type: []string{"string"}, + }, + }, + } + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" validate:"required,oneof=,max=10=90,min=1"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Empty(t, schema.UniqueItems) + assert.Empty(t, schema.MaxItems) + assert.Equal(t, &min, schema.MinItems) + + schema = spec.Schema{} + schema.Type = []string{"array"} + schema.Items = &spec.SchemaOrArray{ + Schema: &spec.Schema{ + SchemaProps: spec.SchemaProps{ + Type: []string{"string"}, + }, + }, + } + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{Tag: &ast.BasicLit{ + Value: `json:"test" validate:"required,max=10,min=one"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Equal(t, &max, schema.MaxItems) + assert.Empty(t, schema.MinItems) + + schema = spec.Schema{} + schema.Type = []string{"integer"} + err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{ + Names: []*ast.Ident{{Name: "Test"}}, + Tag: &ast.BasicLit{ + Value: `json:"test" validate:"required,oneof=one two"`, + }}, + ).ComplementSchema(&schema) + assert.NoError(t, err) + assert.Empty(t, schema.Enum) + }) + + t.Run("Form Filed Name", func(t *testing.T) { + t.Parallel() + + filednames, err := newTagBaseFieldParser( + &Parser{}, + &ast.Field{ + Names: []*ast.Ident{{Name: "Test"}}, + Tag: &ast.BasicLit{ + Value: `form:"test[]"`, + }}, + ).FieldNames() + assert.NoError(t, err) + assert.Equal(t, "test", filednames[0]) + + filednames, err = newTagBaseFieldParser( + &Parser{}, + &ast.Field{ + Names: []*ast.Ident{{Name: "Test"}}, + Tag: &ast.BasicLit{ + Value: `form:"test"`, + }}, + ).FieldNames() + assert.NoError(t, err) + assert.Equal(t, "test", filednames[0]) + }) + + t.Run("Two Names", func(t *testing.T) { + t.Parallel() + + fieldnames, err := newTagBaseFieldParser( + &Parser{}, + &ast.Field{ + Names: []*ast.Ident{{Name: "X"}, {Name: "Y"}}, + }, + ).FieldNames() + assert.NoError(t, err) + assert.Equal(t, 2, len(fieldnames)) + assert.Equal(t, "x", fieldnames[0]) + assert.Equal(t, "y", fieldnames[1]) + }) +} diff --git a/pkg/swag/format/format.go b/pkg/swag/format/format.go new file mode 100644 index 0000000..63cf9c1 --- /dev/null +++ b/pkg/swag/format/format.go @@ -0,0 +1,150 @@ +package format + +import ( + "bytes" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "git.ipao.vip/rogeecn/atomctl/pkg/swag" +) + +// Format implements `fmt` command for formatting swag comments in Go source +// files. +type Format struct { + formatter *swag.Formatter + + // exclude exclude dirs and files in SearchDir + exclude map[string]bool +} + +// New creates a new Format instance +func New() *Format { + return &Format{ + exclude: map[string]bool{}, + formatter: swag.NewFormatter(), + } +} + +// Config specifies configuration for a format run +type Config struct { + // SearchDir the swag would be parse + SearchDir string + + // excludes dirs and files in SearchDir,comma separated + Excludes string + + // MainFile (DEPRECATED) + MainFile string +} + +var defaultExcludes = []string{"docs", "vendor"} + +// Build runs formatter according to configuration in config +func (f *Format) Build(config *Config) error { + searchDirs := strings.Split(config.SearchDir, ",") + for _, searchDir := range searchDirs { + if _, err := os.Stat(searchDir); os.IsNotExist(err) { + return fmt.Errorf("fmt: %w", err) + } + for _, d := range defaultExcludes { + f.exclude[filepath.Join(searchDir, d)] = true + } + } + for _, fi := range strings.Split(config.Excludes, ",") { + if fi = strings.TrimSpace(fi); fi != "" { + f.exclude[filepath.Clean(fi)] = true + } + } + for _, searchDir := range searchDirs { + err := filepath.Walk(searchDir, f.visit) + if err != nil { + return err + } + } + return nil +} + +func (f *Format) visit(path string, fileInfo os.FileInfo, err error) error { + if fileInfo.IsDir() && f.excludeDir(path) { + return filepath.SkipDir + } + if f.excludeFile(path) { + return nil + } + if err := f.format(path); err != nil { + return fmt.Errorf("fmt: %w", err) + } + return nil +} + +func (f *Format) excludeDir(path string) bool { + return f.exclude[path] || + filepath.Base(path)[0] == '.' && + len(filepath.Base(path)) > 1 // exclude hidden folders +} + +func (f *Format) excludeFile(path string) bool { + return f.exclude[path] || + strings.HasSuffix(strings.ToLower(path), "_test.go") || + filepath.Ext(path) != ".go" +} + +func (f *Format) format(path string) error { + original, err := os.ReadFile(path) + if err != nil { + return err + } + contents := make([]byte, len(original)) + copy(contents, original) + formatted, err := f.formatter.Format(path, contents) + if err != nil { + return err + } + if bytes.Equal(original, formatted) { + // Skip write if no change + return nil + } + return write(path, formatted) +} + +func write(path string, contents []byte) error { + originalFileInfo, err := os.Stat(path) + if err != nil { + return err + } + f, err := os.CreateTemp(filepath.Dir(path), filepath.Base(path)) + if err != nil { + return err + } + defer os.Remove(f.Name()) + if _, err := f.Write(contents); err != nil { + return err + } + if err := f.Close(); err != nil { + return err + } + if err := os.Chmod(f.Name(), originalFileInfo.Mode()); err != nil { + return err + } + return os.Rename(f.Name(), path) +} + +// Run the format on src and write the result to dst. +func (f *Format) Run(src io.Reader, dst io.Writer) error { + contents, err := io.ReadAll(src) + if err != nil { + return err + } + result, err := f.formatter.Format("", contents) + if err != nil { + return err + } + r := bytes.NewReader(result) + if _, err := io.Copy(dst, r); err != nil { + return err + } + return nil +} diff --git a/pkg/swag/format/format_test.go b/pkg/swag/format/format_test.go new file mode 100644 index 0000000..94b102f --- /dev/null +++ b/pkg/swag/format/format_test.go @@ -0,0 +1,151 @@ +package format + +import ( + "bytes" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFormat_Format(t *testing.T) { + fx := setup(t) + assert.NoError(t, New().Build(&Config{SearchDir: fx.basedir})) + assert.True(t, fx.isFormatted("main.go")) + assert.True(t, fx.isFormatted("api/api.go")) +} + +func TestFormat_PermissionsPreserved(t *testing.T) { + fx := setup(t) + + originalFileInfo, err := os.Stat(filepath.Join(fx.basedir, "main.go")) + if err != nil { + t.Fatal(err) + } + + assert.NoError(t, New().Build(&Config{SearchDir: fx.basedir})) + assert.True(t, permissionsEqual(t, filepath.Join(fx.basedir, "main.go"), originalFileInfo.Mode())) + assert.True(t, permissionsEqual(t, filepath.Join(fx.basedir, "api/api.go"), originalFileInfo.Mode())) +} + +func TestFormat_ExcludeDir(t *testing.T) { + fx := setup(t) + assert.NoError(t, New().Build(&Config{ + SearchDir: fx.basedir, + Excludes: filepath.Join(fx.basedir, "api"), + })) + assert.False(t, fx.isFormatted("api/api.go")) +} + +func TestFormat_ExcludeFile(t *testing.T) { + fx := setup(t) + assert.NoError(t, New().Build(&Config{ + SearchDir: fx.basedir, + Excludes: filepath.Join(fx.basedir, "main.go"), + })) + assert.False(t, fx.isFormatted("main.go")) +} + +func TestFormat_DefaultExcludes(t *testing.T) { + fx := setup(t) + assert.NoError(t, New().Build(&Config{SearchDir: fx.basedir})) + assert.False(t, fx.isFormatted("api/api_test.go")) + assert.False(t, fx.isFormatted("docs/docs.go")) +} + +func TestFormat_ParseError(t *testing.T) { + fx := setup(t) + os.WriteFile(filepath.Join(fx.basedir, "parse_error.go"), []byte(`package main + func invalid() {`), 0644) + assert.Error(t, New().Build(&Config{SearchDir: fx.basedir})) +} + +func TestFormat_ReadError(t *testing.T) { + fx := setup(t) + os.Chmod(filepath.Join(fx.basedir, "main.go"), 0) + assert.Error(t, New().Build(&Config{SearchDir: fx.basedir})) +} + +func TestFormat_WriteError(t *testing.T) { + fx := setup(t) + os.Chmod(fx.basedir, 0555) + assert.Error(t, New().Build(&Config{SearchDir: fx.basedir})) + os.Chmod(fx.basedir, 0755) +} + +func TestFormat_InvalidSearchDir(t *testing.T) { + formatter := New() + assert.Error(t, formatter.Build(&Config{SearchDir: "no_such_dir"})) +} + +type fixture struct { + t *testing.T + basedir string +} + +func setup(t *testing.T) *fixture { + fx := &fixture{ + t: t, + basedir: t.TempDir(), + } + for filename, contents := range testFiles { + fullpath := filepath.Join(fx.basedir, filepath.Clean(filename)) + if err := os.MkdirAll(filepath.Dir(fullpath), 0755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(fullpath, contents, 0644); err != nil { + t.Fatal(err) + } + } + return fx +} + +func (fx *fixture) isFormatted(file string) bool { + contents, err := os.ReadFile(filepath.Join(fx.basedir, file)) + if err != nil { + fx.t.Fatal(err) + } + return !bytes.Equal(testFiles[file], contents) +} + +func permissionsEqual(t *testing.T, path string, expectedMode os.FileMode) bool { + fileInfo, err := os.Stat(path) + if err != nil { + t.Fatal(err) + } + return expectedMode == fileInfo.Mode() +} + +var testFiles = map[string][]byte{ + "api/api.go": []byte(`package api + + import "net/http" + + // @Summary Add a new pet to the store + // @Description get string by ID + func GetStringByInt(w http.ResponseWriter, r *http.Request) { + //write your code + }`), + "api/api_test.go": []byte(`package api + // @Summary API Test + // @Description Should not be formatted + func TestApi(t *testing.T) {}`), + "docs/docs.go": []byte(`package docs + // @Summary Documentation package + // @Description Should not be formatted`), + "main.go": []byte(`package main + + import ( + "net/http" + + "git.ipao.vip/rogeecn/atomctl/pkg/swag/format/testdata/api" + ) + + // @title Swagger Example API + // @version 1.0 + func main() { + http.HandleFunc("/testapi/get-string-by-int/", api.GetStringByInt) + }`), + "README.md": []byte(`# Format test`), +} diff --git a/pkg/swag/formatter.go b/pkg/swag/formatter.go new file mode 100644 index 0000000..1074a3b --- /dev/null +++ b/pkg/swag/formatter.go @@ -0,0 +1,187 @@ +package swag + +import ( + "bytes" + "fmt" + "go/ast" + goparser "go/parser" + "go/token" + "log" + "os" + "regexp" + "sort" + "strings" + "text/tabwriter" + + "golang.org/x/tools/imports" +) + +// Check of @Param @Success @Failure @Response @Header +var specialTagForSplit = map[string]bool{ + paramAttr: true, + successAttr: true, + failureAttr: true, + responseAttr: true, + headerAttr: true, +} + +var skipChar = map[byte]byte{ + '"': '"', + '(': ')', + '{': '}', + '[': ']', +} + +// Formatter implements a formatter for Go source files. +type Formatter struct { + // debugging output goes here + debug Debugger +} + +// NewFormatter create a new formatter instance. +func NewFormatter() *Formatter { + formatter := &Formatter{ + debug: log.New(os.Stdout, "", log.LstdFlags), + } + return formatter +} + +// Format formats swag comments in contents. It uses fileName to report errors +// that happen during parsing of contents. +func (f *Formatter) Format(fileName string, contents []byte) ([]byte, error) { + fileSet := token.NewFileSet() + ast, err := goparser.ParseFile(fileSet, fileName, contents, goparser.ParseComments) + if err != nil { + return nil, err + } + + // Formatting changes are described as an edit list of byte range + // replacements. We make these content-level edits directly rather than + // changing the AST nodes and writing those out (via [go/printer] or + // [go/format]) so that we only change the formatting of Swag attribute + // comments. This won't touch the formatting of any other comments, or of + // functions, etc. + maxEdits := 0 + for _, comment := range ast.Comments { + maxEdits += len(comment.List) + } + edits := make(edits, 0, maxEdits) + + for _, comment := range ast.Comments { + formatFuncDoc(fileSet, comment.List, &edits) + } + formatted, err := imports.Process(fileName, edits.apply(contents), nil) + if err != nil { + return nil, err + } + return formatted, nil +} + +type edit struct { + begin int + end int + replacement []byte +} + +type edits []edit + +func (edits edits) apply(contents []byte) []byte { + // Apply the edits with the highest offset first, so that earlier edits + // don't affect the offsets of later edits. + sort.Slice(edits, func(i, j int) bool { + return edits[i].begin > edits[j].begin + }) + + for _, edit := range edits { + prefix := contents[:edit.begin] + suffix := contents[edit.end:] + contents = append(prefix, append(edit.replacement, suffix...)...) + } + + return contents +} + +// formatFuncDoc reformats the comment lines in commentList, and appends any +// changes to the edit list. +func formatFuncDoc(fileSet *token.FileSet, commentList []*ast.Comment, edits *edits) { + // Building the edit list to format a comment block is a two-step process. + // First, we iterate over each comment line looking for Swag attributes. In + // each one we find, we replace alignment whitespace with a tab character, + // then write the result into a tab writer. + + linesToComments := make(map[int]int, len(commentList)) + + buffer := &bytes.Buffer{} + w := tabwriter.NewWriter(buffer, 1, 4, 1, '\t', 0) + + for commentIndex, comment := range commentList { + text := comment.Text + if attr, body, found := swagComment(text); found { + formatted := "//\t" + attr + if body != "" { + formatted += "\t" + splitComment2(attr, body) + } + _, _ = fmt.Fprintln(w, formatted) + linesToComments[len(linesToComments)] = commentIndex + } + } + + // Once we've loaded all of the comment lines to be aligned into the tab + // writer, flushing it causes the aligned text to be written out to the + // backing buffer. + _ = w.Flush() + + // Now the second step: we iterate over the aligned comment lines that were + // written into the backing buffer, pair each one up to its original + // comment line, and use the combination to describe the edit that needs to + // be made to the original input. + formattedComments := bytes.Split(buffer.Bytes(), []byte("\n")) + for lineIndex, commentIndex := range linesToComments { + comment := commentList[commentIndex] + *edits = append(*edits, edit{ + begin: fileSet.Position(comment.Pos()).Offset, + end: fileSet.Position(comment.End()).Offset, + replacement: formattedComments[lineIndex], + }) + } +} + +func splitComment2(attr, body string) string { + if specialTagForSplit[strings.ToLower(attr)] { + for i := 0; i < len(body); i++ { + if skipEnd, ok := skipChar[body[i]]; ok { + skipStart, n := body[i], 1 + for i++; i < len(body); i++ { + if skipStart != skipEnd && body[i] == skipStart { + n++ + } else if body[i] == skipEnd { + n-- + if n == 0 { + break + } + } + } + } else if body[i] == ' ' || body[i] == '\t' { + j := i + for ; j < len(body) && (body[j] == ' ' || body[j] == '\t'); j++ { + } + body = replaceRange(body, i, j, "\t") + } + } + } + return body +} + +func replaceRange(s string, start, end int, new string) string { + return s[:start] + new + s[end:] +} + +var swagCommentLineExpression = regexp.MustCompile(`^\/\/\s+(@[\S.]+)\s*(.*)`) + +func swagComment(comment string) (string, string, bool) { + matches := swagCommentLineExpression.FindStringSubmatch(comment) + if matches == nil { + return "", "", false + } + return matches[1], matches[2], true +} diff --git a/pkg/swag/formatter_test.go b/pkg/swag/formatter_test.go new file mode 100644 index 0000000..0e02fab --- /dev/null +++ b/pkg/swag/formatter_test.go @@ -0,0 +1,281 @@ +package swag + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +const ( + SearchDir = "./testdata/format_test" + Excludes = "./testdata/format_test/web" + MainFile = "main.go" +) + +func testFormat(t *testing.T, filename, contents, want string) { + got, err := NewFormatter().Format(filename, []byte(contents)) + assert.NoError(t, err) + assert.Equal(t, want, string(got)) +} + +func Test_FormatMain(t *testing.T) { + contents := `package main + // @title Swagger Example API + // @version 1.0 + // @description This is a sample server Petstore server. + // @termsOfService http://swagger.io/terms/ + + // @contact.name API Support + // @contact.url http://www.swagger.io/support + // @contact.email support@swagger.io + + // @license.name Apache 2.0 + // @license.url http://www.apache.org/licenses/LICENSE-2.0.html + + // @host petstore.swagger.io + // @BasePath /v2 + + // @securityDefinitions.basic BasicAuth + + // @securityDefinitions.apikey ApiKeyAuth + // @in header + // @name Authorization + + // @securitydefinitions.oauth2.application OAuth2Application + // @tokenUrl https://example.com/oauth/token + // @scope.write Grants write access + // @scope.admin Grants read and write access to administrative information + + // @securitydefinitions.oauth2.implicit OAuth2Implicit + // @authorizationurl https://example.com/oauth/authorize + // @scope.write Grants write access + // @scope.admin Grants read and write access to administrative information + + // @securitydefinitions.oauth2.password OAuth2Password + // @tokenUrl https://example.com/oauth/token + // @scope.read Grants read access + // @scope.write Grants write access + // @scope.admin Grants read and write access to administrative information + + // @securitydefinitions.oauth2.accessCode OAuth2AccessCode + // @tokenUrl https://example.com/oauth/token + // @authorizationurl https://example.com/oauth/authorize + // @scope.admin Grants read and write access to administrative information + func main() {}` + + want := `package main + +// @title Swagger Example API +// @version 1.0 +// @description This is a sample server Petstore server. +// @termsOfService http://swagger.io/terms/ + +// @contact.name API Support +// @contact.url http://www.swagger.io/support +// @contact.email support@swagger.io + +// @license.name Apache 2.0 +// @license.url http://www.apache.org/licenses/LICENSE-2.0.html + +// @host petstore.swagger.io +// @BasePath /v2 + +// @securityDefinitions.basic BasicAuth + +// @securityDefinitions.apikey ApiKeyAuth +// @in header +// @name Authorization + +// @securitydefinitions.oauth2.application OAuth2Application +// @tokenUrl https://example.com/oauth/token +// @scope.write Grants write access +// @scope.admin Grants read and write access to administrative information + +// @securitydefinitions.oauth2.implicit OAuth2Implicit +// @authorizationurl https://example.com/oauth/authorize +// @scope.write Grants write access +// @scope.admin Grants read and write access to administrative information + +// @securitydefinitions.oauth2.password OAuth2Password +// @tokenUrl https://example.com/oauth/token +// @scope.read Grants read access +// @scope.write Grants write access +// @scope.admin Grants read and write access to administrative information + +// @securitydefinitions.oauth2.accessCode OAuth2AccessCode +// @tokenUrl https://example.com/oauth/token +// @authorizationurl https://example.com/oauth/authorize +// @scope.admin Grants read and write access to administrative information +func main() {} +` + testFormat(t, "main.go", contents, want) +} + +func Test_FormatMultipleFunctions(t *testing.T) { + contents := `package main + +// @Produce json +// @Success 200 {object} string +// @Failure 400 {object} string + func A() {} + +// @Description Description of B. +// @Produce json +// @Success 200 {array} string +// @Failure 400 {object} string + func B() {}` + + want := `package main + +// @Produce json +// @Success 200 {object} string +// @Failure 400 {object} string +func A() {} + +// @Description Description of B. +// @Produce json +// @Success 200 {array} string +// @Failure 400 {object} string +func B() {} +` + + testFormat(t, "main.go", contents, want) +} + +func Test_FormatApi(t *testing.T) { + contents := `package api + +import "net/http" + +// @Summary Add a new pet to the store +// @Description get string by ID +// @ID get-string-by-int +// @Accept json +// @Produce json +// @Param some_id path int true "Some ID" Format(int64) +// @Param some_id body web.Pet true "Some ID" +// @Success 200 {string} string "ok" +// @Failure 400 {object} web.APIError "We need ID!!" +// @Failure 404 {object} web.APIError "Can not find ID" +// @Router /testapi/get-string-by-int/{some_id} [get] + func GetStringByInt(w http.ResponseWriter, r *http.Request) {}` + + want := `package api + +import "net/http" + +// @Summary Add a new pet to the store +// @Description get string by ID +// @ID get-string-by-int +// @Accept json +// @Produce json +// @Param some_id path int true "Some ID" Format(int64) +// @Param some_id body web.Pet true "Some ID" +// @Success 200 {string} string "ok" +// @Failure 400 {object} web.APIError "We need ID!!" +// @Failure 404 {object} web.APIError "Can not find ID" +// @Router /testapi/get-string-by-int/{some_id} [get] +func GetStringByInt(w http.ResponseWriter, r *http.Request) {} +` + + testFormat(t, "api.go", contents, want) +} + +func Test_NonSwagComment(t *testing.T) { + contents := `package api + +// @Summary Add a new pet to the store +// @Description get string by ID +// @ID get-string-by-int +// @ Accept json +// This is not a @swag comment` + want := `package api + +// @Summary Add a new pet to the store +// @Description get string by ID +// @ID get-string-by-int +// @ Accept json +// This is not a @swag comment +` + + testFormat(t, "non_swag.go", contents, want) +} + +func Test_EmptyComment(t *testing.T) { + contents := `package empty + +// @Summary Add a new pet to the store +// @Description ` + want := `package empty + +// @Summary Add a new pet to the store +// @Description +` + + testFormat(t, "empty.go", contents, want) +} + +func Test_AlignAttribute(t *testing.T) { + contents := `package align + +// @Summary Add a new pet to the store +// @Description Description` + want := `package align + +// @Summary Add a new pet to the store +// @Description Description +` + + testFormat(t, "align.go", contents, want) + +} + +func Test_SyntaxError(t *testing.T) { + contents := []byte(`package invalid + func invalid() {`) + + _, err := NewFormatter().Format("invalid.go", contents) + assert.Error(t, err) +} + +func Test_splitComment2(t *testing.T) { + type args struct { + attr string + body string + } + tests := []struct { + name string + args args + want string + }{ + { + "test_splitComment2_1", + args{ + attr: "@param", + body: " data body web.GenericBodyMulti[[]types.Post, [][]types.Post]", + }, + "\tdata\tbody\tweb.GenericBodyMulti[[]types.Post, [][]types.Post]", + }, + { + "test_splitComment2_2", + args{ + attr: "@param", + body: ` some_id path int true "Some ID" Format(int64)`, + }, + "\tsome_id\tpath\tint\ttrue\t\"Some ID\"\tFormat(int64)", + }, + { + "test_splitComment2_3", + args{ + attr: "@param", + body: ` @Param some_id body web.Pet true "Some ID"`, + }, + "\t@Param\tsome_id\tbody\tweb.Pet\ttrue\t\"Some ID\"", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, splitComment2(tt.args.attr, tt.args.body), "splitComment2(%v, %v)", tt.args.attr, tt.args.body) + }) + } +} diff --git a/pkg/swag/gen/gen.go b/pkg/swag/gen/gen.go new file mode 100644 index 0000000..ae86354 --- /dev/null +++ b/pkg/swag/gen/gen.go @@ -0,0 +1,540 @@ +package gen + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "go/format" + "io" + "log" + "os" + "path" + "path/filepath" + "strings" + "text/template" + "time" + + "git.ipao.vip/rogeecn/atomctl/pkg/swag" + "github.com/go-openapi/spec" + "golang.org/x/text/cases" + "golang.org/x/text/language" + "sigs.k8s.io/yaml" +) + +var open = os.Open + +// DefaultOverridesFile is the location swagger will look for type overrides. +const DefaultOverridesFile = ".swaggo" + +type genTypeWriter func(*Config, *spec.Swagger) error + +// Gen presents a generate tool for swag. +type Gen struct { + json func(data interface{}) ([]byte, error) + jsonIndent func(data interface{}) ([]byte, error) + jsonToYAML func(data []byte) ([]byte, error) + outputTypeMap map[string]genTypeWriter + debug Debugger +} + +// Debugger is the interface that wraps the basic Printf method. +type Debugger interface { + Printf(format string, v ...interface{}) +} + +// New creates a new Gen. +func New() *Gen { + gen := Gen{ + json: json.Marshal, + jsonIndent: func(data interface{}) ([]byte, error) { + return json.MarshalIndent(data, "", " ") + }, + jsonToYAML: yaml.JSONToYAML, + debug: log.New(os.Stdout, "", log.LstdFlags), + } + + gen.outputTypeMap = map[string]genTypeWriter{ + "go": gen.writeDocSwagger, + "json": gen.writeJSONSwagger, + "yaml": gen.writeYAMLSwagger, + "yml": gen.writeYAMLSwagger, + } + + return &gen +} + +// Config presents Gen configurations. +type Config struct { + Debugger swag.Debugger + + // SearchDir the swag would parse,comma separated if multiple + SearchDir string + + // excludes dirs and files in SearchDir,comma separated + Excludes string + + // outputs only specific extension + ParseExtension string + + // OutputDir represents the output directory for all the generated files + OutputDir string + + // OutputTypes define types of files which should be generated + OutputTypes []string + + // MainAPIFile the Go file path in which 'swagger general API Info' is written + MainAPIFile string + + // PropNamingStrategy represents property naming strategy like snake case,camel case,pascal case + PropNamingStrategy string + + // MarkdownFilesDir used to find markdown files, which can be used for tag descriptions + MarkdownFilesDir string + + // CodeExampleFilesDir used to find code example files, which can be used for x-codeSamples + CodeExampleFilesDir string + + // InstanceName is used to get distinct names for different swagger documents in the + // same project. The default value is "swagger". + InstanceName string + + // ParseDepth dependency parse depth + ParseDepth int + + // ParseVendor whether swag should be parse vendor folder + ParseVendor bool + + // ParseDependencies whether swag should be parse outside dependency folder: 0 none, 1 models, 2 operations, 3 all + ParseDependency int + + // ParseInternal whether swag should parse internal packages + ParseInternal bool + + // Strict whether swag should error or warn when it detects cases which are most likely user errors + Strict bool + + // GeneratedTime whether swag should generate the timestamp at the top of docs.go + GeneratedTime bool + + // RequiredByDefault set validation required for all fields by default + RequiredByDefault bool + + // OverridesFile defines global type overrides. + OverridesFile string + + // ParseGoList whether swag use go list to parse dependency + ParseGoList bool + + // include only tags mentioned when searching, comma separated + Tags string + + // LeftTemplateDelim defines the left delimiter for the template generation + LeftTemplateDelim string + + // RightTemplateDelim defines the right delimiter for the template generation + RightTemplateDelim string + + // PackageName defines package name of generated `docs.go` + PackageName string + + // CollectionFormat set default collection format + CollectionFormat string + + // Parse only packages whose import path match the given prefix, comma separated + PackagePrefix string + + // State set host state + State string + + // ParseFuncBody whether swag should parse api info inside of funcs + ParseFuncBody bool +} + +// Build builds swagger json file for given searchDir and mainAPIFile. Returns json. +func (g *Gen) Build(config *Config) error { + if config.Debugger != nil { + g.debug = config.Debugger + } + if config.InstanceName == "" { + config.InstanceName = swag.Name + } + + searchDirs := strings.Split(config.SearchDir, ",") + for _, searchDir := range searchDirs { + if _, err := os.Stat(searchDir); os.IsNotExist(err) { + return fmt.Errorf("dir: %s does not exist", searchDir) + } + } + + if config.LeftTemplateDelim == "" { + config.LeftTemplateDelim = "{{" + } + + if config.RightTemplateDelim == "" { + config.RightTemplateDelim = "}}" + } + + var overrides map[string]string + + if config.OverridesFile != "" { + overridesFile, err := open(config.OverridesFile) + if err != nil { + // Don't bother reporting if the default file is missing; assume there are no overrides + if !(config.OverridesFile == DefaultOverridesFile && os.IsNotExist(err)) { + return fmt.Errorf("could not open overrides file: %w", err) + } + } else { + g.debug.Printf("Using overrides from %s", config.OverridesFile) + + overrides, err = parseOverrides(overridesFile) + if err != nil { + return err + } + } + } + + g.debug.Printf("Generate swagger docs....") + + p := swag.New( + swag.SetParseDependency(config.ParseDependency), + swag.SetMarkdownFileDirectory(config.MarkdownFilesDir), + swag.SetDebugger(config.Debugger), + swag.SetExcludedDirsAndFiles(config.Excludes), + swag.SetParseExtension(config.ParseExtension), + swag.SetCodeExamplesDirectory(config.CodeExampleFilesDir), + swag.SetStrict(config.Strict), + swag.SetOverrides(overrides), + swag.ParseUsingGoList(config.ParseGoList), + swag.SetTags(config.Tags), + swag.SetCollectionFormat(config.CollectionFormat), + swag.SetPackagePrefix(config.PackagePrefix), + ) + + p.PropNamingStrategy = config.PropNamingStrategy + p.ParseVendor = config.ParseVendor + p.ParseInternal = config.ParseInternal + p.RequiredByDefault = config.RequiredByDefault + p.HostState = config.State + p.ParseFuncBody = config.ParseFuncBody + + if err := p.ParseAPIMultiSearchDir(searchDirs, config.MainAPIFile, config.ParseDepth); err != nil { + return err + } + + swagger := p.GetSwagger() + + if err := os.MkdirAll(config.OutputDir, os.ModePerm); err != nil { + return err + } + + for _, outputType := range config.OutputTypes { + outputType = strings.ToLower(strings.TrimSpace(outputType)) + if typeWriter, ok := g.outputTypeMap[outputType]; ok { + if err := typeWriter(config, swagger); err != nil { + return err + } + } else { + log.Printf("output type '%s' not supported", outputType) + } + } + + return nil +} + +func (g *Gen) writeDocSwagger(config *Config, swagger *spec.Swagger) error { + filename := "docs.go" + + if config.State != "" { + filename = config.State + "_" + filename + } + + if config.InstanceName != swag.Name { + filename = config.InstanceName + "_" + filename + } + + docFileName := path.Join(config.OutputDir, filename) + + absOutputDir, err := filepath.Abs(config.OutputDir) + if err != nil { + return err + } + + var packageName string + if len(config.PackageName) > 0 { + packageName = config.PackageName + } else { + packageName = filepath.Base(absOutputDir) + packageName = strings.ReplaceAll(packageName, "-", "_") + } + + docs, err := os.Create(docFileName) + if err != nil { + return err + } + defer docs.Close() + + // Write doc + err = g.writeGoDoc(packageName, docs, swagger, config) + if err != nil { + return err + } + + g.debug.Printf("create docs.go at %+v", docFileName) + + return nil +} + +func (g *Gen) writeJSONSwagger(config *Config, swagger *spec.Swagger) error { + filename := "swagger.json" + + if config.State != "" { + filename = config.State + "_" + filename + } + + if config.InstanceName != swag.Name { + filename = config.InstanceName + "_" + filename + } + + jsonFileName := path.Join(config.OutputDir, filename) + + b, err := g.jsonIndent(swagger) + if err != nil { + return err + } + + err = g.writeFile(b, jsonFileName) + if err != nil { + return err + } + + g.debug.Printf("create swagger.json at %+v", jsonFileName) + + return nil +} + +func (g *Gen) writeYAMLSwagger(config *Config, swagger *spec.Swagger) error { + filename := "swagger.yaml" + + if config.State != "" { + filename = config.State + "_" + filename + } + + if config.InstanceName != swag.Name { + filename = config.InstanceName + "_" + filename + } + + yamlFileName := path.Join(config.OutputDir, filename) + + b, err := g.json(swagger) + if err != nil { + return err + } + + y, err := g.jsonToYAML(b) + if err != nil { + return fmt.Errorf("cannot covert json to yaml error: %s", err) + } + + err = g.writeFile(y, yamlFileName) + if err != nil { + return err + } + + g.debug.Printf("create swagger.yaml at %+v", yamlFileName) + + return nil +} + +func (g *Gen) writeFile(b []byte, file string) error { + f, err := os.Create(file) + if err != nil { + return err + } + + defer f.Close() + + _, err = f.Write(b) + + return err +} + +func (g *Gen) formatSource(src []byte) []byte { + code, err := format.Source(src) + if err != nil { + code = src // Formatter failed, return original code. + } + + return code +} + +// Read and parse the overrides file. +func parseOverrides(r io.Reader) (map[string]string, error) { + overrides := make(map[string]string) + scanner := bufio.NewScanner(r) + + for scanner.Scan() { + line := scanner.Text() + + // Skip comments + if len(line) > 1 && line[0:2] == "//" { + continue + } + + parts := strings.Fields(line) + + switch len(parts) { + case 0: + // only whitespace + continue + case 2: + // either a skip or malformed + if parts[0] != "skip" { + return nil, fmt.Errorf("could not parse override: '%s'", line) + } + + overrides[parts[1]] = "" + case 3: + // either a replace or malformed + if parts[0] != "replace" { + return nil, fmt.Errorf("could not parse override: '%s'", line) + } + + overrides[parts[1]] = parts[2] + default: + return nil, fmt.Errorf("could not parse override: '%s'", line) + } + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("error reading overrides file: %w", err) + } + + return overrides, nil +} + +func (g *Gen) writeGoDoc(packageName string, output io.Writer, swagger *spec.Swagger, config *Config) error { + generator, err := template.New("swagger_info").Funcs(template.FuncMap{ + "printDoc": func(v string) string { + // Add schemes + v = "{\n \"schemes\": " + config.LeftTemplateDelim + " marshal .Schemes " + config.RightTemplateDelim + "," + v[1:] + // Sanitize backticks + return strings.Replace(v, "`", "`+\"`\"+`", -1) + }, + }).Parse(packageTemplate) + if err != nil { + return err + } + + swaggerSpec := &spec.Swagger{ + VendorExtensible: swagger.VendorExtensible, + SwaggerProps: spec.SwaggerProps{ + ID: swagger.ID, + Consumes: swagger.Consumes, + Produces: swagger.Produces, + Swagger: swagger.Swagger, + Info: &spec.Info{ + VendorExtensible: swagger.Info.VendorExtensible, + InfoProps: spec.InfoProps{ + Description: config.LeftTemplateDelim + "escape .Description" + config.RightTemplateDelim, + Title: config.LeftTemplateDelim + ".Title" + config.RightTemplateDelim, + TermsOfService: swagger.Info.TermsOfService, + Contact: swagger.Info.Contact, + License: swagger.Info.License, + Version: config.LeftTemplateDelim + ".Version" + config.RightTemplateDelim, + }, + }, + Host: config.LeftTemplateDelim + ".Host" + config.RightTemplateDelim, + BasePath: config.LeftTemplateDelim + ".BasePath" + config.RightTemplateDelim, + Paths: swagger.Paths, + Definitions: swagger.Definitions, + Parameters: swagger.Parameters, + Responses: swagger.Responses, + SecurityDefinitions: swagger.SecurityDefinitions, + Security: swagger.Security, + Tags: swagger.Tags, + ExternalDocs: swagger.ExternalDocs, + }, + } + + // crafted docs.json + buf, err := g.jsonIndent(swaggerSpec) + if err != nil { + return err + } + + state := "" + if len(config.State) > 0 { + state = cases.Title(language.English).String(strings.ToLower(config.State)) + } + + buffer := &bytes.Buffer{} + + err = generator.Execute(buffer, struct { + Timestamp time.Time + Doc string + Host string + PackageName string + BasePath string + Title string + Description string + Version string + State string + InstanceName string + Schemes []string + GeneratedTime bool + LeftTemplateDelim string + RightTemplateDelim string + }{ + Timestamp: time.Now(), + GeneratedTime: config.GeneratedTime, + Doc: string(buf), + Host: swagger.Host, + PackageName: packageName, + BasePath: swagger.BasePath, + Schemes: swagger.Schemes, + Title: swagger.Info.Title, + Description: swagger.Info.Description, + Version: swagger.Info.Version, + State: state, + InstanceName: config.InstanceName, + LeftTemplateDelim: config.LeftTemplateDelim, + RightTemplateDelim: config.RightTemplateDelim, + }) + if err != nil { + return err + } + + code := g.formatSource(buffer.Bytes()) + + // write + _, err = output.Write(code) + + return err +} + +var packageTemplate = `// Package {{.PackageName}} Code generated by swaggo/swag{{ if .GeneratedTime }} at {{ .Timestamp }}{{ end }}. DO NOT EDIT +package {{.PackageName}} + +import "github.com/swaggo/swag" + +const docTemplate{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }}{{ .State }} = ` + "`{{ printDoc .Doc}}`" + ` + +// Swagger{{ .State }}Info{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }} holds exported Swagger Info so clients can modify it +var Swagger{{ .State }}Info{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }} = &swag.Spec{ + Version: {{ printf "%q" .Version}}, + Host: {{ printf "%q" .Host}}, + BasePath: {{ printf "%q" .BasePath}}, + Schemes: []string{ {{ range $index, $schema := .Schemes}}{{if gt $index 0}},{{end}}{{printf "%q" $schema}}{{end}} }, + Title: {{ printf "%q" .Title}}, + Description: {{ printf "%q" .Description}}, + InfoInstanceName: {{ printf "%q" .InstanceName }}, + SwaggerTemplate: docTemplate{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }}{{ .State }}, + LeftDelim: {{ printf "%q" .LeftTemplateDelim}}, + RightDelim: {{ printf "%q" .RightTemplateDelim}}, +} + +func init() { + swag.Register(Swagger{{ .State }}Info{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }}.InstanceName(), Swagger{{ .State }}Info{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }}) +} +` diff --git a/pkg/swag/gen/gen_test.go b/pkg/swag/gen/gen_test.go new file mode 100644 index 0000000..02a914b --- /dev/null +++ b/pkg/swag/gen/gen_test.go @@ -0,0 +1,976 @@ +package gen + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "log" + "os" + "os/exec" + "path" + "path/filepath" + "plugin" + "strings" + "testing" + + "git.ipao.vip/rogeecn/atomctl/pkg/swag" + "github.com/go-openapi/spec" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const searchDir = "../testdata/simple" + +var outputTypes = []string{"go", "json", "yaml"} + +func TestGen_Build(t *testing.T) { + config := &Config{ + SearchDir: searchDir, + MainAPIFile: "./main.go", + OutputDir: "../testdata/simple/docs", + OutputTypes: outputTypes, + PropNamingStrategy: "", + } + assert.NoError(t, New().Build(config)) + + expectedFiles := []string{ + filepath.Join(config.OutputDir, "docs.go"), + filepath.Join(config.OutputDir, "swagger.json"), + filepath.Join(config.OutputDir, "swagger.yaml"), + } + for _, expectedFile := range expectedFiles { + if _, err := os.Stat(expectedFile); os.IsNotExist(err) { + require.NoError(t, err) + } + + _ = os.Remove(expectedFile) + } +} + +func TestGen_SpecificOutputTypes(t *testing.T) { + config := &Config{ + SearchDir: searchDir, + MainAPIFile: "./main.go", + OutputDir: "../testdata/simple/docs", + OutputTypes: []string{"go", "unknownType"}, + PropNamingStrategy: "", + } + assert.NoError(t, New().Build(config)) + + tt := []struct { + expectedFile string + shouldExist bool + }{ + {filepath.Join(config.OutputDir, "docs.go"), true}, + {filepath.Join(config.OutputDir, "swagger.json"), false}, + {filepath.Join(config.OutputDir, "swagger.yaml"), false}, + } + for _, tc := range tt { + _, err := os.Stat(tc.expectedFile) + if tc.shouldExist { + if os.IsNotExist(err) { + require.NoError(t, err) + } + } else { + require.Error(t, err) + require.True(t, errors.Is(err, os.ErrNotExist)) + } + + _ = os.Remove(tc.expectedFile) + } +} + +func TestGen_BuildInstanceName(t *testing.T) { + config := &Config{ + SearchDir: searchDir, + MainAPIFile: "./main.go", + OutputDir: "../testdata/simple/docs", + OutputTypes: outputTypes, + PropNamingStrategy: "", + } + assert.NoError(t, New().Build(config)) + + goSourceFile := filepath.Join(config.OutputDir, "docs.go") + + // Validate default registration name + expectedCode, err := os.ReadFile(goSourceFile) + if err != nil { + require.NoError(t, err) + } + + if !strings.Contains( + string(expectedCode), + "swag.Register(SwaggerInfo.InstanceName(), SwaggerInfo)", + ) { + t.Fatal(errors.New("generated go code does not contain the correct default registration sequence")) + } + + if !strings.Contains( + string(expectedCode), + "var SwaggerInfo =", + ) { + t.Fatal(errors.New("generated go code does not contain the correct default variable declaration")) + } + + // Custom name + config.InstanceName = "Custom" + goSourceFile = filepath.Join(config.OutputDir, config.InstanceName+"_"+"docs.go") + assert.NoError(t, New().Build(config)) + + expectedCode, err = os.ReadFile(goSourceFile) + if err != nil { + require.NoError(t, err) + } + + if !strings.Contains( + string(expectedCode), + "swag.Register(SwaggerInfoCustom.InstanceName(), SwaggerInfoCustom)", + ) { + t.Fatal(errors.New("generated go code does not contain the correct registration sequence")) + } + + if !strings.Contains( + string(expectedCode), + "var SwaggerInfoCustom =", + ) { + t.Fatal(errors.New("generated go code does not contain the correct variable declaration")) + } + + // cleanup + expectedFiles := []string{ + filepath.Join(config.OutputDir, config.InstanceName+"_"+"docs.go"), + filepath.Join(config.OutputDir, config.InstanceName+"_"+"swagger.json"), + filepath.Join(config.OutputDir, config.InstanceName+"_"+"swagger.yaml"), + } + + for _, expectedFile := range expectedFiles { + if _, err := os.Stat(expectedFile); os.IsNotExist(err) { + require.NoError(t, err) + } + + _ = os.Remove(expectedFile) + } +} + +func TestGen_BuildSnakeCase(t *testing.T) { + config := &Config{ + SearchDir: "../testdata/simple2", + MainAPIFile: "./main.go", + OutputDir: "../testdata/simple2/docs", + OutputTypes: outputTypes, + PropNamingStrategy: swag.SnakeCase, + } + + assert.NoError(t, New().Build(config)) + + expectedFiles := []string{ + filepath.Join(config.OutputDir, "docs.go"), + filepath.Join(config.OutputDir, "swagger.json"), + filepath.Join(config.OutputDir, "swagger.yaml"), + } + for _, expectedFile := range expectedFiles { + if _, err := os.Stat(expectedFile); os.IsNotExist(err) { + require.NoError(t, err) + } + + _ = os.Remove(expectedFile) + } +} + +func TestGen_BuildLowerCamelcase(t *testing.T) { + config := &Config{ + SearchDir: "../testdata/simple3", + MainAPIFile: "./main.go", + OutputDir: "../testdata/simple3/docs", + OutputTypes: outputTypes, + PropNamingStrategy: "", + } + + assert.NoError(t, New().Build(config)) + + expectedFiles := []string{ + filepath.Join(config.OutputDir, "docs.go"), + filepath.Join(config.OutputDir, "swagger.json"), + filepath.Join(config.OutputDir, "swagger.yaml"), + } + for _, expectedFile := range expectedFiles { + if _, err := os.Stat(expectedFile); os.IsNotExist(err) { + require.NoError(t, err) + } + + _ = os.Remove(expectedFile) + } +} + +func TestGen_BuildDescriptionWithQuotes(t *testing.T) { + config := &Config{ + SearchDir: "../testdata/quotes", + MainAPIFile: "./main.go", + OutputDir: "../testdata/quotes/docs", + OutputTypes: outputTypes, + MarkdownFilesDir: "../testdata/quotes", + } + + require.NoError(t, New().Build(config)) + + expectedFiles := []string{ + filepath.Join(config.OutputDir, "docs.go"), + filepath.Join(config.OutputDir, "swagger.json"), + filepath.Join(config.OutputDir, "swagger.yaml"), + } + for _, expectedFile := range expectedFiles { + if _, err := os.Stat(expectedFile); os.IsNotExist(err) { + require.NoError(t, err) + } + } + + cmd := exec.Command("go", "build", "-buildmode=plugin", "git.ipao.vip/rogeecn/atomctl/pkg/swag/testdata/quotes") + + cmd.Dir = config.SearchDir + + output, err := cmd.CombinedOutput() + if err != nil { + require.NoError(t, err, string(output)) + } + + p, err := plugin.Open(filepath.Join(config.SearchDir, "quotes.so")) + if err != nil { + require.NoError(t, err) + } + + defer os.Remove("quotes.so") + + readDoc, err := p.Lookup("ReadDoc") + if err != nil { + require.NoError(t, err) + } + + jsonOutput := readDoc.(func() string)() + + var jsonDoc interface{} + if err := json.Unmarshal([]byte(jsonOutput), &jsonDoc); err != nil { + require.NoError(t, err) + } + + expectedJSON, err := os.ReadFile(filepath.Join(config.SearchDir, "expected.json")) + if err != nil { + require.NoError(t, err) + } + + assert.JSONEq(t, string(expectedJSON), jsonOutput) +} + +func TestGen_BuildDocCustomDelims(t *testing.T) { + config := &Config{ + SearchDir: "../testdata/delims", + MainAPIFile: "./main.go", + OutputDir: "../testdata/delims/docs", + OutputTypes: outputTypes, + MarkdownFilesDir: "../testdata/delims", + InstanceName: "CustomDelims", + LeftTemplateDelim: "{%", + RightTemplateDelim: "%}", + } + + require.NoError(t, New().Build(config)) + + expectedFiles := []string{ + filepath.Join(config.OutputDir, "CustomDelims_docs.go"), + filepath.Join(config.OutputDir, "CustomDelims_swagger.json"), + filepath.Join(config.OutputDir, "CustomDelims_swagger.yaml"), + } + for _, expectedFile := range expectedFiles { + if _, err := os.Stat(expectedFile); os.IsNotExist(err) { + require.NoError(t, err) + } + } + + cmd := exec.Command("go", "build", "-buildmode=plugin", "git.ipao.vip/rogeecn/atomctl/pkg/swag/testdata/delims") + + cmd.Dir = config.SearchDir + + output, err := cmd.CombinedOutput() + if err != nil { + require.NoError(t, err, string(output)) + } + + p, err := plugin.Open(filepath.Join(config.SearchDir, "delims.so")) + if err != nil { + require.NoError(t, err) + } + + defer os.Remove("delims.so") + + readDoc, err := p.Lookup("ReadDoc") + if err != nil { + require.NoError(t, err) + } + + jsonOutput := readDoc.(func() string)() + + var jsonDoc interface{} + if err := json.Unmarshal([]byte(jsonOutput), &jsonDoc); err != nil { + require.NoError(t, err) + } + + expectedJSON, err := os.ReadFile(filepath.Join(config.SearchDir, "expected.json")) + if err != nil { + require.NoError(t, err) + } + + assert.JSONEq(t, string(expectedJSON), jsonOutput) +} + +func TestGen_jsonIndent(t *testing.T) { + config := &Config{ + SearchDir: searchDir, + MainAPIFile: "./main.go", + OutputDir: "../testdata/simple/docs", + OutputTypes: outputTypes, + PropNamingStrategy: "", + } + + gen := New() + gen.jsonIndent = func(data interface{}) ([]byte, error) { + return nil, errors.New("fail") + } + + assert.Error(t, gen.Build(config)) +} + +func TestGen_jsonToYAML(t *testing.T) { + config := &Config{ + SearchDir: searchDir, + MainAPIFile: "./main.go", + OutputDir: "../testdata/simple/docs", + OutputTypes: outputTypes, + PropNamingStrategy: "", + } + + gen := New() + gen.jsonToYAML = func(data []byte) ([]byte, error) { + return nil, errors.New("fail") + } + assert.Error(t, gen.Build(config)) + + expectedFiles := []string{ + filepath.Join(config.OutputDir, "docs.go"), + filepath.Join(config.OutputDir, "swagger.json"), + } + + for _, expectedFile := range expectedFiles { + if _, err := os.Stat(expectedFile); os.IsNotExist(err) { + require.NoError(t, err) + } + + _ = os.Remove(expectedFile) + } +} + +func TestGen_SearchDirIsNotExist(t *testing.T) { + var swaggerConfDir, propNamingStrategy string + + config := &Config{ + SearchDir: "../isNotExistDir", + MainAPIFile: "./main.go", + OutputDir: swaggerConfDir, + OutputTypes: outputTypes, + PropNamingStrategy: propNamingStrategy, + } + + assert.EqualError(t, New().Build(config), "dir: ../isNotExistDir does not exist") +} + +func TestGen_MainAPiNotExist(t *testing.T) { + var swaggerConfDir, propNamingStrategy string + + config := &Config{ + SearchDir: searchDir, + MainAPIFile: "./notExists.go", + OutputDir: swaggerConfDir, + OutputTypes: outputTypes, + PropNamingStrategy: propNamingStrategy, + } + + assert.Error(t, New().Build(config)) +} + +func TestGen_OutputIsNotExist(t *testing.T) { + var propNamingStrategy string + config := &Config{ + SearchDir: searchDir, + MainAPIFile: "./main.go", + OutputDir: "/dev/null", + OutputTypes: outputTypes, + PropNamingStrategy: propNamingStrategy, + } + assert.Error(t, New().Build(config)) +} + +func TestGen_FailToWrite(t *testing.T) { + outputDir := filepath.Join(os.TempDir(), "swagg", "test") + outputTypes := []string{"go", "json", "yaml"} + + var propNamingStrategy string + config := &Config{ + SearchDir: searchDir, + MainAPIFile: "./main.go", + OutputDir: outputDir, + OutputTypes: outputTypes, + PropNamingStrategy: propNamingStrategy, + } + + err := os.MkdirAll(outputDir, 0o755) + if err != nil { + require.NoError(t, err) + } + + _ = os.RemoveAll(filepath.Join(outputDir, "swagger.yaml")) + + err = os.Mkdir(filepath.Join(outputDir, "swagger.yaml"), 0o755) + if err != nil { + require.NoError(t, err) + } + assert.Error(t, New().Build(config)) + + _ = os.RemoveAll(filepath.Join(outputDir, "swagger.json")) + + err = os.Mkdir(filepath.Join(outputDir, "swagger.json"), 0o755) + if err != nil { + require.NoError(t, err) + } + assert.Error(t, New().Build(config)) + + _ = os.RemoveAll(filepath.Join(outputDir, "docs.go")) + + err = os.Mkdir(filepath.Join(outputDir, "docs.go"), 0o755) + if err != nil { + require.NoError(t, err) + } + assert.Error(t, New().Build(config)) + + err = os.RemoveAll(outputDir) + if err != nil { + require.NoError(t, err) + } +} + +func TestGen_configWithOutputDir(t *testing.T) { + config := &Config{ + SearchDir: searchDir, + MainAPIFile: "./main.go", + OutputDir: "../testdata/simple/docs", + OutputTypes: outputTypes, + PropNamingStrategy: "", + } + + assert.NoError(t, New().Build(config)) + + expectedFiles := []string{ + filepath.Join(config.OutputDir, "docs.go"), + filepath.Join(config.OutputDir, "swagger.json"), + filepath.Join(config.OutputDir, "swagger.yaml"), + } + for _, expectedFile := range expectedFiles { + if _, err := os.Stat(expectedFile); os.IsNotExist(err) { + require.NoError(t, err) + } + + _ = os.Remove(expectedFile) + } +} + +func TestGen_configWithOutputTypesAll(t *testing.T) { + searchDir := "../testdata/simple" + outputTypes := []string{"go", "json", "yaml"} + + config := &Config{ + SearchDir: searchDir, + MainAPIFile: "./main.go", + OutputDir: "../testdata/simple/docs", + OutputTypes: outputTypes, + PropNamingStrategy: "", + } + + assert.NoError(t, New().Build(config)) + + expectedFiles := []string{ + path.Join(config.OutputDir, "docs.go"), + path.Join(config.OutputDir, "swagger.json"), + path.Join(config.OutputDir, "swagger.yaml"), + } + for _, expectedFile := range expectedFiles { + if _, err := os.Stat(expectedFile); os.IsNotExist(err) { + t.Fatal(err) + } + + _ = os.Remove(expectedFile) + } +} + +func TestGen_configWithOutputTypesSingle(t *testing.T) { + searchDir := "../testdata/simple" + outputTypes := []string{"go", "json", "yaml"} + + for _, outputType := range outputTypes { + config := &Config{ + SearchDir: searchDir, + MainAPIFile: "./main.go", + OutputDir: "../testdata/simple/docs", + OutputTypes: []string{outputType}, + PropNamingStrategy: "", + } + + assert.NoError(t, New().Build(config)) + + outFileName := "swagger" + if outputType == "go" { + outFileName = "docs" + } + + expectedFiles := []string{ + path.Join(config.OutputDir, outFileName+"."+outputType), + } + for _, expectedFile := range expectedFiles { + if _, err := os.Stat(expectedFile); os.IsNotExist(err) { + t.Fatal(err) + } + + _ = os.Remove(expectedFile) + } + } +} + +func TestGen_formatSource(t *testing.T) { + src := `package main + +import "net + +func main() {} +` + g := New() + + res := g.formatSource([]byte(src)) + assert.Equal(t, []byte(src), res, "Should return same content due to fmt fail") + + src2 := `package main + +import "fmt" + +func main() { +fmt.Print("Hello world") +} +` + res = g.formatSource([]byte(src2)) + assert.NotEqual(t, []byte(src2), res, "Should return fmt code") +} + +type mockWriter struct { + hook func([]byte) +} + +func (w *mockWriter) Write(data []byte) (int, error) { + if w.hook != nil { + w.hook(data) + } + + return len(data), nil +} + +func TestGen_writeGoDoc(t *testing.T) { + gen := New() + + swapTemplate := packageTemplate + + packageTemplate = `{{{` + err := gen.writeGoDoc("docs", nil, nil, &Config{}) + assert.Error(t, err) + + packageTemplate = `{{.Data}}` + swagger := &spec.Swagger{ + VendorExtensible: spec.VendorExtensible{}, + SwaggerProps: spec.SwaggerProps{ + Info: &spec.Info{}, + }, + } + + err = gen.writeGoDoc("docs", &mockWriter{}, swagger, &Config{}) + assert.Error(t, err) + + packageTemplate = `{{ if .GeneratedTime }}Fake Time{{ end }}` + err = gen.writeGoDoc("docs", + &mockWriter{ + hook: func(data []byte) { + assert.Equal(t, "Fake Time", string(data)) + }, + }, swagger, &Config{GeneratedTime: true}) + assert.NoError(t, err) + + err = gen.writeGoDoc("docs", + &mockWriter{ + hook: func(data []byte) { + assert.Equal(t, "", string(data)) + }, + }, swagger, &Config{GeneratedTime: false}) + assert.NoError(t, err) + + packageTemplate = swapTemplate +} + +func TestGen_GeneratedDoc(t *testing.T) { + config := &Config{ + SearchDir: searchDir, + MainAPIFile: "./main.go", + OutputDir: "../testdata/simple/docs", + OutputTypes: outputTypes, + PropNamingStrategy: "", + } + + assert.NoError(t, New().Build(config)) + + goCMD, err := exec.LookPath("go") + assert.NoError(t, err) + + cmd := exec.Command(goCMD, "build", filepath.Join(config.OutputDir, "docs.go")) + + cmd.Stdout = os.Stdout + + cmd.Stderr = os.Stderr + + assert.NoError(t, cmd.Run()) + + expectedFiles := []string{ + filepath.Join(config.OutputDir, "docs.go"), + filepath.Join(config.OutputDir, "swagger.json"), + filepath.Join(config.OutputDir, "swagger.yaml"), + } + for _, expectedFile := range expectedFiles { + if _, err := os.Stat(expectedFile); os.IsNotExist(err) { + require.NoError(t, err) + } + + _ = os.Remove(expectedFile) + } +} + +func TestGen_cgoImports(t *testing.T) { + config := &Config{ + SearchDir: "../testdata/simple_cgo", + MainAPIFile: "./main.go", + OutputDir: "../testdata/simple_cgo/docs", + OutputTypes: outputTypes, + PropNamingStrategy: "", + ParseDependency: 1, + } + + assert.NoError(t, New().Build(config)) + + expectedFiles := []string{ + filepath.Join(config.OutputDir, "docs.go"), + filepath.Join(config.OutputDir, "swagger.json"), + filepath.Join(config.OutputDir, "swagger.yaml"), + } + for _, expectedFile := range expectedFiles { + if _, err := os.Stat(expectedFile); os.IsNotExist(err) { + require.NoError(t, err) + } + + _ = os.Remove(expectedFile) + } +} + +func TestGen_parseOverrides(t *testing.T) { + testCases := []struct { + Name string + Data string + Expected map[string]string + ExpectedError error + }{ + { + Name: "replace", + Data: `replace github.com/foo/bar baz`, + Expected: map[string]string{ + "github.com/foo/bar": "baz", + }, + }, + { + Name: "skip", + Data: `skip github.com/foo/bar`, + Expected: map[string]string{ + "github.com/foo/bar": "", + }, + }, + { + Name: "generic-simple", + Data: `replace types.Field[string] string`, + Expected: map[string]string{ + "types.Field[string]": "string", + }, + }, + { + Name: "generic-double", + Data: `replace types.Field[string,string] string`, + Expected: map[string]string{ + "types.Field[string,string]": "string", + }, + }, + { + Name: "comment", + Data: `// this is a comment + replace foo bar`, + Expected: map[string]string{ + "foo": "bar", + }, + }, + { + Name: "ignore whitespace", + Data: ` + + replace foo bar`, + Expected: map[string]string{ + "foo": "bar", + }, + }, + { + Name: "unknown directive", + Data: `foo`, + ExpectedError: fmt.Errorf("could not parse override: 'foo'"), + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.Name, func(t *testing.T) { + t.Parallel() + + overrides, err := parseOverrides(strings.NewReader(tc.Data)) + assert.Equal(t, tc.Expected, overrides) + assert.Equal(t, tc.ExpectedError, err) + }) + } +} + +func TestGen_TypeOverridesFile(t *testing.T) { + customPath := "/foo/bar/baz" + + tmp, err := os.CreateTemp("", "") + require.NoError(t, err) + + defer os.Remove(tmp.Name()) + + config := &Config{ + SearchDir: searchDir, + MainAPIFile: "./main.go", + OutputDir: "../testdata/simple/docs", + PropNamingStrategy: "", + } + + t.Run("Default file is missing", func(t *testing.T) { + open = func(path string) (*os.File, error) { + assert.Equal(t, DefaultOverridesFile, path) + + return nil, os.ErrNotExist + } + defer func() { + open = os.Open + }() + + config.OverridesFile = DefaultOverridesFile + err := New().Build(config) + assert.NoError(t, err) + }) + + t.Run("Default file is present", func(t *testing.T) { + open = func(path string) (*os.File, error) { + assert.Equal(t, DefaultOverridesFile, path) + + return tmp, nil + } + defer func() { + open = os.Open + }() + + config.OverridesFile = DefaultOverridesFile + err := New().Build(config) + assert.NoError(t, err) + }) + + t.Run("Different file is missing", func(t *testing.T) { + open = func(path string) (*os.File, error) { + assert.Equal(t, customPath, path) + + return nil, os.ErrNotExist + } + defer func() { + open = os.Open + }() + + config.OverridesFile = customPath + err := New().Build(config) + assert.EqualError(t, err, "could not open overrides file: file does not exist") + }) + + t.Run("Different file is present", func(t *testing.T) { + open = func(path string) (*os.File, error) { + assert.Equal(t, customPath, path) + + return tmp, nil + } + defer func() { + open = os.Open + }() + + config.OverridesFile = customPath + err := New().Build(config) + assert.NoError(t, err) + }) +} + +func TestGen_Debugger(t *testing.T) { + var buf bytes.Buffer + config := &Config{ + SearchDir: searchDir, + MainAPIFile: "./main.go", + OutputDir: "../testdata/simple/docs", + OutputTypes: outputTypes, + PropNamingStrategy: "", + Debugger: log.New(&buf, "", log.LstdFlags), + } + assert.True(t, buf.Len() == 0) + assert.NoError(t, New().Build(config)) + assert.True(t, buf.Len() > 0) + + expectedFiles := []string{ + filepath.Join(config.OutputDir, "docs.go"), + filepath.Join(config.OutputDir, "swagger.json"), + filepath.Join(config.OutputDir, "swagger.yaml"), + } + for _, expectedFile := range expectedFiles { + if _, err := os.Stat(expectedFile); os.IsNotExist(err) { + require.NoError(t, err) + } + + _ = os.Remove(expectedFile) + } +} + +func TestGen_ErrorAndInterface(t *testing.T) { + config := &Config{ + SearchDir: "../testdata/error", + MainAPIFile: "./main.go", + OutputDir: "../testdata/error/docs", + OutputTypes: outputTypes, + PropNamingStrategy: "", + } + + assert.NoError(t, New().Build(config)) + + expectedFiles := []string{ + filepath.Join(config.OutputDir, "docs.go"), + filepath.Join(config.OutputDir, "swagger.json"), + filepath.Join(config.OutputDir, "swagger.yaml"), + } + t.Cleanup(func() { + for _, expectedFile := range expectedFiles { + _ = os.Remove(expectedFile) + } + }) + + // check files + for _, expectedFile := range expectedFiles { + if _, err := os.Stat(expectedFile); os.IsNotExist(err) { + require.NoError(t, err) + } + } + + // check content + jsonOutput, err := os.ReadFile(filepath.Join(config.OutputDir, "swagger.json")) + if err != nil { + require.NoError(t, err) + } + expectedJSON, err := os.ReadFile(filepath.Join(config.SearchDir, "expected.json")) + if err != nil { + require.NoError(t, err) + } + + assert.JSONEq(t, string(expectedJSON), string(jsonOutput)) +} + +func TestGen_StateAdmin(t *testing.T) { + config := &Config{ + SearchDir: "../testdata/state", + MainAPIFile: "./main.go", + OutputDir: "../testdata/state/docs", + OutputTypes: outputTypes, + PropNamingStrategy: "", + State: "admin", + } + + assert.NoError(t, New().Build(config)) + + expectedFiles := []string{ + filepath.Join(config.OutputDir, "admin_docs.go"), + filepath.Join(config.OutputDir, "admin_swagger.json"), + filepath.Join(config.OutputDir, "admin_swagger.yaml"), + } + t.Cleanup(func() { + for _, expectedFile := range expectedFiles { + _ = os.Remove(expectedFile) + } + }) + + // check files + for _, expectedFile := range expectedFiles { + if _, err := os.Stat(expectedFile); os.IsNotExist(err) { + require.NoError(t, err) + } + } + + // check content + jsonOutput, err := os.ReadFile(filepath.Join(config.OutputDir, "admin_swagger.json")) + require.NoError(t, err) + expectedJSON, err := os.ReadFile(filepath.Join(config.SearchDir, "admin_expected.json")) + require.NoError(t, err) + + assert.JSONEq(t, string(expectedJSON), string(jsonOutput)) +} + +func TestGen_StateUser(t *testing.T) { + config := &Config{ + SearchDir: "../testdata/state", + MainAPIFile: "./main.go", + OutputDir: "../testdata/state/docs", + OutputTypes: outputTypes, + PropNamingStrategy: "", + State: "user", + } + + assert.NoError(t, New().Build(config)) + + expectedFiles := []string{ + filepath.Join(config.OutputDir, "user_docs.go"), + filepath.Join(config.OutputDir, "user_swagger.json"), + filepath.Join(config.OutputDir, "user_swagger.yaml"), + } + t.Cleanup(func() { + for _, expectedFile := range expectedFiles { + _ = os.Remove(expectedFile) + } + }) + + // check files + for _, expectedFile := range expectedFiles { + if _, err := os.Stat(expectedFile); os.IsNotExist(err) { + require.NoError(t, err) + } + } + + // check content + jsonOutput, err := os.ReadFile(filepath.Join(config.OutputDir, "user_swagger.json")) + require.NoError(t, err) + expectedJSON, err := os.ReadFile(filepath.Join(config.SearchDir, "user_expected.json")) + require.NoError(t, err) + + assert.JSONEq(t, string(expectedJSON), string(jsonOutput)) +} diff --git a/pkg/swag/generics.go b/pkg/swag/generics.go new file mode 100644 index 0000000..80e93a9 --- /dev/null +++ b/pkg/swag/generics.go @@ -0,0 +1,448 @@ +//go:build go1.18 +// +build go1.18 + +package swag + +import ( + "errors" + "fmt" + "go/ast" + "strings" + "unicode" + + "github.com/go-openapi/spec" +) + +type genericTypeSpec struct { + TypeSpec *TypeSpecDef + Name string +} + +type formalParamType struct { + Name string + Type string +} + +func (t *genericTypeSpec) TypeName() string { + if t.TypeSpec != nil { + return t.TypeSpec.TypeName() + } + return t.Name +} + +func normalizeGenericTypeName(name string) string { + return strings.Replace(name, ".", "_", -1) +} + +func (pkgDefs *PackagesDefinitions) getTypeFromGenericParam(genericParam string, file *ast.File) (typeSpecDef *TypeSpecDef) { + if strings.HasPrefix(genericParam, "[]") { + typeSpecDef = pkgDefs.getTypeFromGenericParam(genericParam[2:], file) + if typeSpecDef == nil { + return nil + } + var expr ast.Expr + switch typeSpecDef.TypeSpec.Type.(type) { + case *ast.ArrayType, *ast.MapType: + expr = typeSpecDef.TypeSpec.Type + default: + name := typeSpecDef.TypeName() + expr = ast.NewIdent(name) + if _, ok := pkgDefs.uniqueDefinitions[name]; !ok { + pkgDefs.uniqueDefinitions[name] = typeSpecDef + } + } + return &TypeSpecDef{ + TypeSpec: &ast.TypeSpec{ + Name: ast.NewIdent(string(IgnoreNameOverridePrefix) + "array_" + typeSpecDef.TypeName()), + Type: &ast.ArrayType{ + Elt: expr, + }, + }, + Enums: typeSpecDef.Enums, + PkgPath: typeSpecDef.PkgPath, + ParentSpec: typeSpecDef.ParentSpec, + SchemaName: "array_" + typeSpecDef.SchemaName, + NotUnique: false, + } + } + + if strings.HasPrefix(genericParam, "map[") { + parts := strings.SplitN(genericParam[4:], "]", 2) + if len(parts) != 2 { + return nil + } + typeSpecDef = pkgDefs.getTypeFromGenericParam(parts[1], file) + if typeSpecDef == nil { + return nil + } + var expr ast.Expr + switch typeSpecDef.TypeSpec.Type.(type) { + case *ast.ArrayType, *ast.MapType: + expr = typeSpecDef.TypeSpec.Type + default: + name := typeSpecDef.TypeName() + expr = ast.NewIdent(name) + if _, ok := pkgDefs.uniqueDefinitions[name]; !ok { + pkgDefs.uniqueDefinitions[name] = typeSpecDef + } + } + return &TypeSpecDef{ + TypeSpec: &ast.TypeSpec{ + Name: ast.NewIdent(string(IgnoreNameOverridePrefix) + "map_" + parts[0] + "_" + typeSpecDef.TypeName()), + Type: &ast.MapType{ + Key: ast.NewIdent(parts[0]), //assume key is string or integer + Value: expr, + }, + }, + Enums: typeSpecDef.Enums, + PkgPath: typeSpecDef.PkgPath, + ParentSpec: typeSpecDef.ParentSpec, + SchemaName: "map_" + parts[0] + "_" + typeSpecDef.SchemaName, + NotUnique: false, + } + } + if IsGolangPrimitiveType(genericParam) { + return &TypeSpecDef{ + TypeSpec: &ast.TypeSpec{ + Name: ast.NewIdent(genericParam), + Type: ast.NewIdent(genericParam), + }, + SchemaName: genericParam, + } + } + return pkgDefs.FindTypeSpec(genericParam, file) +} + +func (pkgDefs *PackagesDefinitions) parametrizeGenericType(file *ast.File, original *TypeSpecDef, fullGenericForm string) *TypeSpecDef { + if original == nil || original.TypeSpec.TypeParams == nil || len(original.TypeSpec.TypeParams.List) == 0 { + return original + } + + name, genericParams := splitGenericsTypeName(fullGenericForm) + if genericParams == nil { + return nil + } + + //generic[x,y any,z any] considered, TODO what if the type is not `any`, but a concrete one, such as `int32|int64` or an certain interface{} + var formals []formalParamType + for _, field := range original.TypeSpec.TypeParams.List { + for _, ident := range field.Names { + formal := formalParamType{Name: ident.Name} + if ident, ok := field.Type.(*ast.Ident); ok { + formal.Type = ident.Name + } + formals = append(formals, formal) + } + } + if len(genericParams) != len(formals) { + return nil + } + genericParamTypeDefs := map[string]*genericTypeSpec{} + + for i, genericParam := range genericParams { + var typeDef *TypeSpecDef + if !IsGolangPrimitiveType(genericParam) { + typeDef = pkgDefs.getTypeFromGenericParam(genericParam, file) + if typeDef != nil { + genericParam = typeDef.TypeName() + if _, ok := pkgDefs.uniqueDefinitions[genericParam]; !ok { + pkgDefs.uniqueDefinitions[genericParam] = typeDef + } + } + } + genericParamTypeDefs[formals[i].Name] = &genericTypeSpec{ + TypeSpec: typeDef, + Name: genericParam, + } + } + + name = fmt.Sprintf("%s%s-", string(IgnoreNameOverridePrefix), original.TypeName()) + schemaName := fmt.Sprintf("%s-", original.SchemaName) + + var nameParts []string + var schemaNameParts []string + + for _, def := range formals { + if specDef, ok := genericParamTypeDefs[def.Name]; ok { + nameParts = append(nameParts, specDef.Name) + + schemaNamePart := specDef.Name + + if specDef.TypeSpec != nil { + schemaNamePart = specDef.TypeSpec.SchemaName + } + + schemaNameParts = append(schemaNameParts, schemaNamePart) + } + } + + name += normalizeGenericTypeName(strings.Join(nameParts, "-")) + schemaName += normalizeGenericTypeName(strings.Join(schemaNameParts, "-")) + + if typeSpec, ok := pkgDefs.uniqueDefinitions[name]; ok { + return typeSpec + } + + parametrizedTypeSpec := &TypeSpecDef{ + File: original.File, + PkgPath: original.PkgPath, + TypeSpec: &ast.TypeSpec{ + Name: &ast.Ident{ + Name: name, + NamePos: original.TypeSpec.Name.NamePos, + Obj: original.TypeSpec.Name.Obj, + }, + Doc: original.TypeSpec.Doc, + Assign: original.TypeSpec.Assign, + }, + SchemaName: schemaName, + } + pkgDefs.uniqueDefinitions[name] = parametrizedTypeSpec + + parametrizedTypeSpec.TypeSpec.Type = pkgDefs.resolveGenericType(original.File, original.TypeSpec.Type, genericParamTypeDefs) + + return parametrizedTypeSpec +} + +// splitGenericsTypeName splits a generic struct name in his parts +func splitGenericsTypeName(fullGenericForm string) (string, []string) { + //remove all spaces character + fullGenericForm = strings.Map(func(r rune) rune { + if unicode.IsSpace(r) { + return -1 + } + return r + }, fullGenericForm) + + // split only at the first '[' and remove the last ']' + if fullGenericForm[len(fullGenericForm)-1] != ']' { + return "", nil + } + + genericParams := strings.SplitN(fullGenericForm[:len(fullGenericForm)-1], "[", 2) + if len(genericParams) == 1 { + return "", nil + } + + // generic type name + genericTypeName := genericParams[0] + + depth := 0 + genericParams = strings.FieldsFunc(genericParams[1], func(r rune) bool { + if r == '[' { + depth++ + } else if r == ']' { + depth-- + } else if r == ',' && depth == 0 { + return true + } + return false + }) + if depth != 0 { + return "", nil + } + + return genericTypeName, genericParams +} + +func (pkgDefs *PackagesDefinitions) getParametrizedType(genTypeSpec *genericTypeSpec) ast.Expr { + if genTypeSpec.TypeSpec != nil && strings.Contains(genTypeSpec.Name, ".") { + parts := strings.SplitN(genTypeSpec.Name, ".", 2) + return &ast.SelectorExpr{ + X: &ast.Ident{Name: parts[0]}, + Sel: &ast.Ident{Name: parts[1]}, + } + } + + //a primitive type name or a type name in current package + return &ast.Ident{Name: genTypeSpec.Name} +} + +func (pkgDefs *PackagesDefinitions) resolveGenericType(file *ast.File, expr ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec) ast.Expr { + switch astExpr := expr.(type) { + case *ast.Ident: + if genTypeSpec, ok := genericParamTypeDefs[astExpr.Name]; ok { + return pkgDefs.getParametrizedType(genTypeSpec) + } + case *ast.ArrayType: + return &ast.ArrayType{ + Elt: pkgDefs.resolveGenericType(file, astExpr.Elt, genericParamTypeDefs), + Len: astExpr.Len, + Lbrack: astExpr.Lbrack, + } + case *ast.MapType: + return &ast.MapType{ + Map: astExpr.Map, + Key: pkgDefs.resolveGenericType(file, astExpr.Key, genericParamTypeDefs), + Value: pkgDefs.resolveGenericType(file, astExpr.Value, genericParamTypeDefs), + } + case *ast.StarExpr: + return &ast.StarExpr{ + Star: astExpr.Star, + X: pkgDefs.resolveGenericType(file, astExpr.X, genericParamTypeDefs), + } + case *ast.IndexExpr, *ast.IndexListExpr: + fullGenericName, _ := getGenericFieldType(file, expr, genericParamTypeDefs) + typeDef := pkgDefs.FindTypeSpec(fullGenericName, file) + if typeDef != nil { + return typeDef.TypeSpec.Name + } + case *ast.StructType: + newStructTypeDef := &ast.StructType{ + Struct: astExpr.Struct, + Incomplete: astExpr.Incomplete, + Fields: &ast.FieldList{ + Opening: astExpr.Fields.Opening, + Closing: astExpr.Fields.Closing, + }, + } + + for _, field := range astExpr.Fields.List { + newField := &ast.Field{ + Type: field.Type, + Doc: field.Doc, + Names: field.Names, + Tag: field.Tag, + Comment: field.Comment, + } + + newField.Type = pkgDefs.resolveGenericType(file, field.Type, genericParamTypeDefs) + + newStructTypeDef.Fields.List = append(newStructTypeDef.Fields.List, newField) + } + return newStructTypeDef + } + return expr +} + +func getExtendedGenericFieldType(file *ast.File, field ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec) (string, error) { + switch fieldType := field.(type) { + case *ast.ArrayType: + fieldName, err := getExtendedGenericFieldType(file, fieldType.Elt, genericParamTypeDefs) + return "[]" + fieldName, err + case *ast.StarExpr: + return getExtendedGenericFieldType(file, fieldType.X, genericParamTypeDefs) + case *ast.Ident: + if genericParamTypeDefs != nil { + if typeSpec, ok := genericParamTypeDefs[fieldType.Name]; ok { + return typeSpec.Name, nil + } + } + if fieldType.Obj == nil { + return fieldType.Name, nil + } + + tSpec := &TypeSpecDef{ + File: file, + TypeSpec: fieldType.Obj.Decl.(*ast.TypeSpec), + PkgPath: file.Name.Name, + } + return tSpec.TypeName(), nil + default: + return getFieldType(file, field, genericParamTypeDefs) + } +} + +func getGenericFieldType(file *ast.File, field ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec) (string, error) { + var fullName string + var baseName string + var err error + switch fieldType := field.(type) { + case *ast.IndexListExpr: + baseName, err = getGenericTypeName(file, fieldType.X) + if err != nil { + return "", err + } + fullName = baseName + "[" + + for _, index := range fieldType.Indices { + fieldName, err := getExtendedGenericFieldType(file, index, genericParamTypeDefs) + if err != nil { + return "", err + } + + fullName += fieldName + "," + } + + fullName = strings.TrimRight(fullName, ",") + "]" + case *ast.IndexExpr: + baseName, err = getGenericTypeName(file, fieldType.X) + if err != nil { + return "", err + } + + indexName, err := getExtendedGenericFieldType(file, fieldType.Index, genericParamTypeDefs) + if err != nil { + return "", err + } + + fullName = fmt.Sprintf("%s[%s]", baseName, indexName) + } + + if fullName == "" { + return "", fmt.Errorf("unknown field type %#v", field) + } + + var packageName string + if !strings.Contains(baseName, ".") { + if file.Name == nil { + return "", errors.New("file name is nil") + } + packageName, _ = getFieldType(file, file.Name, genericParamTypeDefs) + } + + return strings.TrimLeft(fmt.Sprintf("%s.%s", packageName, fullName), "."), nil +} + +func getGenericTypeName(file *ast.File, field ast.Expr) (string, error) { + switch fieldType := field.(type) { + case *ast.Ident: + if fieldType.Obj == nil { + return fieldType.Name, nil + } + + tSpec := &TypeSpecDef{ + File: file, + TypeSpec: fieldType.Obj.Decl.(*ast.TypeSpec), + PkgPath: file.Name.Name, + } + return tSpec.TypeName(), nil + case *ast.ArrayType: + tSpec := &TypeSpecDef{ + File: file, + TypeSpec: fieldType.Elt.(*ast.Ident).Obj.Decl.(*ast.TypeSpec), + PkgPath: file.Name.Name, + } + return tSpec.TypeName(), nil + case *ast.SelectorExpr: + return fmt.Sprintf("%s.%s", fieldType.X.(*ast.Ident).Name, fieldType.Sel.Name), nil + } + return "", fmt.Errorf("unknown type %#v", field) +} + +func (parser *Parser) parseGenericTypeExpr(file *ast.File, typeExpr ast.Expr) (*spec.Schema, error) { + switch expr := typeExpr.(type) { + // suppress debug messages for these types + case *ast.InterfaceType: + case *ast.StructType: + case *ast.Ident: + case *ast.StarExpr: + case *ast.SelectorExpr: + case *ast.ArrayType: + case *ast.MapType: + case *ast.FuncType: + case *ast.IndexExpr, *ast.IndexListExpr: + name, err := getExtendedGenericFieldType(file, expr, nil) + if err == nil { + if schema, err := parser.getTypeSchema(name, file, false); err == nil { + return schema, nil + } + } + + parser.debug.Printf("Type definition of type '%T' is not supported yet. Using 'object' instead. (%s)\n", typeExpr, err) + default: + parser.debug.Printf("Type definition of type '%T' is not supported yet. Using 'object' instead.\n", typeExpr) + } + + return PrimitiveSchema(OBJECT), nil +} diff --git a/pkg/swag/generics_test.go b/pkg/swag/generics_test.go new file mode 100644 index 0000000..fecaa4d --- /dev/null +++ b/pkg/swag/generics_test.go @@ -0,0 +1,450 @@ +//go:build go1.18 +// +build go1.18 + +package swag + +import ( + "encoding/json" + "fmt" + "go/ast" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +type testLogger struct { + Messages []string +} + +func (t *testLogger) Printf(format string, v ...interface{}) { + t.Messages = append(t.Messages, fmt.Sprintf(format, v...)) +} + +func TestParseGenericsBasic(t *testing.T) { + t.Parallel() + + searchDir := "testdata/generics_basic" + expected, err := os.ReadFile(filepath.Join(searchDir, "expected.json")) + assert.NoError(t, err) + + p := New() + p.Overrides = map[string]string{ + "types.Field[string]": "string", + "types.DoubleField[string,string]": "[]string", + "types.TrippleField[string,string]": "[][]string", + } + + err = p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + b, err := json.MarshalIndent(p.swagger, "", " ") + assert.NoError(t, err) + assert.Equal(t, string(expected), string(b)) +} + +func TestParseGenericsArrays(t *testing.T) { + t.Parallel() + + searchDir := "testdata/generics_arrays" + expected, err := os.ReadFile(filepath.Join(searchDir, "expected.json")) + assert.NoError(t, err) + + p := New() + err = p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + b, err := json.MarshalIndent(p.swagger, "", " ") + + assert.NoError(t, err) + assert.Equal(t, string(expected), string(b)) +} + +func TestParseGenericsNested(t *testing.T) { + t.Parallel() + + searchDir := "testdata/generics_nested" + expected, err := os.ReadFile(filepath.Join(searchDir, "expected.json")) + assert.NoError(t, err) + + p := New() + err = p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + b, err := json.MarshalIndent(p.swagger, "", " ") + assert.NoError(t, err) + assert.Equal(t, string(expected), string(b)) +} + +func TestParseGenericsMultiLevelNesting(t *testing.T) { + t.Parallel() + + searchDir := "testdata/generics_multi_level_nesting" + expected, err := os.ReadFile(filepath.Join(searchDir, "expected.json")) + assert.NoError(t, err) + + p := New() + err = p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + b, err := json.MarshalIndent(p.swagger, "", " ") + assert.NoError(t, err) + assert.Equal(t, string(expected), string(b)) +} + +func TestParseGenericsProperty(t *testing.T) { + t.Parallel() + + searchDir := "testdata/generics_property" + expected, err := os.ReadFile(filepath.Join(searchDir, "expected.json")) + assert.NoError(t, err) + + p := New() + err = p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + b, err := json.MarshalIndent(p.swagger, "", " ") + assert.NoError(t, err) + assert.Equal(t, string(expected), string(b)) +} + +func TestParseGenericsNames(t *testing.T) { + t.Parallel() + + searchDir := "testdata/generics_names" + expected, err := os.ReadFile(filepath.Join(searchDir, "expected.json")) + assert.NoError(t, err) + + p := New() + err = p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + b, err := json.MarshalIndent(p.swagger, "", " ") + assert.NoError(t, err) + assert.Equal(t, string(expected), string(b)) +} + +func TestParseGenericsPackageAlias(t *testing.T) { + t.Parallel() + + searchDir := "testdata/generics_package_alias/internal" + expected, err := os.ReadFile(filepath.Join(searchDir, "expected.json")) + assert.NoError(t, err) + + p := New(SetParseDependency(1)) + err = p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + b, err := json.MarshalIndent(p.swagger, "", " ") + assert.NoError(t, err) + assert.Equal(t, string(expected), string(b)) +} + +func TestParseGenericsFunctionScoped(t *testing.T) { + t.Parallel() + + searchDir := "testdata/generics_function_scoped" + expected, err := os.ReadFile(filepath.Join(searchDir, "expected.json")) + assert.NoError(t, err) + + p := New() + err = p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + b, err := json.MarshalIndent(p.swagger, "", " ") + + assert.NoError(t, err) + assert.Equal(t, string(expected), string(b)) +} + +func TestParametrizeStruct(t *testing.T) { + pd := PackagesDefinitions{ + packages: make(map[string]*PackageDefinitions), + uniqueDefinitions: make(map[string]*TypeSpecDef), + } + // valid + typeSpec := pd.parametrizeGenericType( + &ast.File{Name: &ast.Ident{Name: "test2"}}, + &TypeSpecDef{ + File: &ast.File{Name: &ast.Ident{Name: "test"}}, + TypeSpec: &ast.TypeSpec{ + Name: &ast.Ident{Name: "Field"}, + TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}}, + Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}}, + }}, "test.Field[string, []string]") + assert.NotNil(t, typeSpec) + assert.Equal(t, "$test.Field-string-array_string", typeSpec.Name()) + assert.Equal(t, "test.Field-string-array_string", typeSpec.TypeName()) + + // definition contains one type params, but two type params are provided + typeSpec = pd.parametrizeGenericType( + &ast.File{Name: &ast.Ident{Name: "test2"}}, + &TypeSpecDef{ + TypeSpec: &ast.TypeSpec{ + Name: &ast.Ident{Name: "Field"}, + TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}}}, + Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}}, + }}, "test.Field[string, string]") + assert.Nil(t, typeSpec) + + // definition contains two type params, but only one is used + typeSpec = pd.parametrizeGenericType( + &ast.File{Name: &ast.Ident{Name: "test2"}}, + &TypeSpecDef{ + TypeSpec: &ast.TypeSpec{ + Name: &ast.Ident{Name: "Field"}, + TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}}, + Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}}, + }}, "test.Field[string]") + assert.Nil(t, typeSpec) + + // name is not a valid type name + typeSpec = pd.parametrizeGenericType( + &ast.File{Name: &ast.Ident{Name: "test2"}}, + &TypeSpecDef{ + TypeSpec: &ast.TypeSpec{ + Name: &ast.Ident{Name: "Field"}, + TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}}, + Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}}, + }}, "test.Field[string") + assert.Nil(t, typeSpec) + + typeSpec = pd.parametrizeGenericType( + &ast.File{Name: &ast.Ident{Name: "test2"}}, + &TypeSpecDef{ + TypeSpec: &ast.TypeSpec{ + Name: &ast.Ident{Name: "Field"}, + TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}}, + Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}}, + }}, "test.Field[string, [string]") + assert.Nil(t, typeSpec) + + typeSpec = pd.parametrizeGenericType( + &ast.File{Name: &ast.Ident{Name: "test2"}}, + &TypeSpecDef{ + TypeSpec: &ast.TypeSpec{ + Name: &ast.Ident{Name: "Field"}, + TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}}, + Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}}, + }}, "test.Field[string, ]string]") + assert.Nil(t, typeSpec) +} + +func TestSplitGenericsTypeNames(t *testing.T) { + t.Parallel() + + field, params := splitGenericsTypeName("test.Field") + assert.Empty(t, field) + assert.Nil(t, params) + + field, params = splitGenericsTypeName("test.Field]") + assert.Empty(t, field) + assert.Nil(t, params) + + field, params = splitGenericsTypeName("test.Field[string") + assert.Empty(t, field) + assert.Nil(t, params) + + field, params = splitGenericsTypeName("test.Field[string] ") + assert.Equal(t, "test.Field", field) + assert.Equal(t, []string{"string"}, params) + + field, params = splitGenericsTypeName("test.Field[string, []string]") + assert.Equal(t, "test.Field", field) + assert.Equal(t, []string{"string", "[]string"}, params) + + field, params = splitGenericsTypeName("test.Field[test.Field[ string, []string] ]") + assert.Equal(t, "test.Field", field) + assert.Equal(t, []string{"test.Field[string,[]string]"}, params) +} + +func TestGetGenericFieldType(t *testing.T) { + field, err := getFieldType( + &ast.File{Name: &ast.Ident{Name: "test"}}, + &ast.IndexListExpr{ + X: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}}, + Indices: []ast.Expr{&ast.Ident{Name: "string"}}, + }, + nil, + ) + assert.NoError(t, err) + assert.Equal(t, "test.Field[string]", field) + + field, err = getFieldType( + &ast.File{Name: &ast.Ident{}}, + &ast.IndexListExpr{ + X: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}}, + Indices: []ast.Expr{&ast.Ident{Name: "string"}}, + }, + nil, + ) + assert.NoError(t, err) + assert.Equal(t, "Field[string]", field) + + field, err = getFieldType( + &ast.File{Name: &ast.Ident{Name: "test"}}, + &ast.IndexListExpr{ + X: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}}, + Indices: []ast.Expr{&ast.Ident{Name: "string"}, &ast.Ident{Name: "int"}}, + }, + nil, + ) + assert.NoError(t, err) + assert.Equal(t, "test.Field[string,int]", field) + + field, err = getFieldType( + &ast.File{Name: &ast.Ident{Name: "test"}}, + &ast.IndexListExpr{ + X: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}}, + Indices: []ast.Expr{&ast.Ident{Name: "string"}, &ast.ArrayType{Elt: &ast.Ident{Name: "int"}}}, + }, + nil, + ) + assert.NoError(t, err) + assert.Equal(t, "test.Field[string,[]int]", field) + + field, err = getFieldType( + &ast.File{Name: &ast.Ident{Name: "test"}}, + &ast.IndexListExpr{ + X: &ast.BadExpr{}, + Indices: []ast.Expr{&ast.Ident{Name: "string"}, &ast.Ident{Name: "int"}}, + }, + nil, + ) + assert.Error(t, err) + + field, err = getFieldType( + &ast.File{Name: &ast.Ident{Name: "test"}}, + &ast.IndexListExpr{ + X: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}}, + Indices: []ast.Expr{&ast.Ident{Name: "string"}, &ast.ArrayType{Elt: &ast.BadExpr{}}}, + }, + nil, + ) + assert.Error(t, err) + + field, err = getFieldType( + &ast.File{Name: &ast.Ident{Name: "test"}}, + &ast.IndexExpr{X: &ast.Ident{Name: "Field"}, Index: &ast.Ident{Name: "string"}}, + nil, + ) + assert.NoError(t, err) + assert.Equal(t, "test.Field[string]", field) + + field, err = getFieldType( + &ast.File{Name: nil}, + &ast.IndexExpr{X: &ast.Ident{Name: "Field"}, Index: &ast.Ident{Name: "string"}}, + nil, + ) + assert.Error(t, err) + + field, err = getFieldType( + &ast.File{Name: &ast.Ident{Name: "test"}}, + &ast.IndexExpr{X: &ast.BadExpr{}, Index: &ast.Ident{Name: "string"}}, + nil, + ) + assert.Error(t, err) + + field, err = getFieldType( + &ast.File{Name: &ast.Ident{Name: "test"}}, + &ast.IndexExpr{X: &ast.Ident{Name: "Field"}, Index: &ast.BadExpr{}}, + nil, + ) + assert.Error(t, err) + + field, err = getFieldType( + &ast.File{Name: &ast.Ident{Name: "test"}}, + &ast.IndexExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "field"}, Sel: &ast.Ident{Name: "Name"}}, Index: &ast.Ident{Name: "string"}}, + nil, + ) + assert.NoError(t, err) + assert.Equal(t, "field.Name[string]", field) +} + +func TestGetGenericTypeName(t *testing.T) { + field, err := getGenericTypeName( + &ast.File{Name: &ast.Ident{Name: "test"}}, + &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}}, + ) + assert.NoError(t, err) + assert.Equal(t, "test.Field", field) + + field, err = getGenericTypeName( + &ast.File{Name: &ast.Ident{Name: "test"}}, + &ast.ArrayType{Elt: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}}}, + ) + assert.NoError(t, err) + assert.Equal(t, "test.Field", field) + + field, err = getGenericTypeName( + &ast.File{Name: &ast.Ident{Name: "test"}}, + &ast.SelectorExpr{X: &ast.Ident{Name: "field"}, Sel: &ast.Ident{Name: "Name"}}, + ) + assert.NoError(t, err) + assert.Equal(t, "field.Name", field) + + _, err = getGenericTypeName( + &ast.File{Name: &ast.Ident{Name: "test"}}, + &ast.BadExpr{}, + ) + assert.Error(t, err) +} + +func TestParseGenericTypeExpr(t *testing.T) { + t.Parallel() + + parser := New() + logger := &testLogger{} + SetDebugger(logger)(parser) + + _, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.InterfaceType{}) + assert.Empty(t, logger.Messages) + _, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.StructType{}) + assert.Empty(t, logger.Messages) + _, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.Ident{}) + assert.Empty(t, logger.Messages) + _, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.StarExpr{}) + assert.Empty(t, logger.Messages) + _, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.SelectorExpr{}) + assert.Empty(t, logger.Messages) + _, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.ArrayType{}) + assert.Empty(t, logger.Messages) + _, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.MapType{}) + assert.Empty(t, logger.Messages) + _, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.FuncType{}) + assert.Empty(t, logger.Messages) + _, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.BadExpr{}) + assert.NotEmpty(t, logger.Messages) + assert.Len(t, logger.Messages, 1) + + parser.packages.uniqueDefinitions["field.Name[string]"] = &TypeSpecDef{ + File: &ast.File{Name: &ast.Ident{Name: "test"}}, + TypeSpec: &ast.TypeSpec{ + Name: &ast.Ident{Name: "Field"}, + TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}}}, + Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}}, + }, + } + spec, err := parser.parseTypeExpr( + &ast.File{Name: &ast.Ident{Name: "test"}}, + &ast.IndexExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "field"}, Sel: &ast.Ident{Name: "Name"}}, Index: &ast.Ident{Name: "string"}}, + false, + ) + assert.NotNil(t, spec) + assert.NoError(t, err) + + logger.Messages = []string{} + spec, err = parser.parseTypeExpr( + &ast.File{Name: &ast.Ident{Name: "test"}}, + &ast.IndexExpr{X: &ast.BadExpr{}, Index: &ast.Ident{Name: "string"}}, + false, + ) + assert.NotNil(t, spec) + assert.Equal(t, "object", spec.SchemaProps.Type[0]) + assert.NotEmpty(t, logger.Messages) + assert.Len(t, logger.Messages, 1) + + logger.Messages = []string{} + spec, err = parser.parseTypeExpr( + &ast.File{Name: &ast.Ident{Name: "test"}}, + &ast.BadExpr{}, + false, + ) + assert.NotNil(t, spec) + assert.Equal(t, "object", spec.SchemaProps.Type[0]) + assert.NotEmpty(t, logger.Messages) + assert.Len(t, logger.Messages, 1) +} diff --git a/pkg/swag/golist.go b/pkg/swag/golist.go new file mode 100644 index 0000000..fa0b2cd --- /dev/null +++ b/pkg/swag/golist.go @@ -0,0 +1,78 @@ +package swag + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "go/build" + "os/exec" + "path/filepath" +) + +func listPackages(ctx context.Context, dir string, env []string, args ...string) (pkgs []*build.Package, finalErr error) { + cmd := exec.CommandContext(ctx, "go", append([]string{"list", "-json", "-e"}, args...)...) + cmd.Env = env + cmd.Dir = dir + + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, err + } + var stderrBuf bytes.Buffer + cmd.Stderr = &stderrBuf + defer func() { + if stderrBuf.Len() > 0 { + finalErr = fmt.Errorf("%v\n%s", finalErr, stderrBuf.Bytes()) + } + }() + + err = cmd.Start() + if err != nil { + return nil, err + } + dec := json.NewDecoder(stdout) + for dec.More() { + var pkg build.Package + err = dec.Decode(&pkg) + if err != nil { + return nil, err + } + pkgs = append(pkgs, &pkg) + } + err = cmd.Wait() + if err != nil { + return nil, err + } + return pkgs, nil +} + +func (parser *Parser) getAllGoFileInfoFromDepsByList(pkg *build.Package, parseFlag ParseFlag) error { + ignoreInternal := pkg.Goroot && !parser.ParseInternal + if ignoreInternal { // ignored internal + return nil + } + + if parser.skipPackageByPrefix(pkg.ImportPath) { + return nil // ignored by user-defined package path prefixes + } + + srcDir := pkg.Dir + var err error + for i := range pkg.GoFiles { + err = parser.parseFile(pkg.ImportPath, filepath.Join(srcDir, pkg.GoFiles[i]), nil, parseFlag) + if err != nil { + return err + } + } + + // parse .go source files that import "C" + for i := range pkg.CgoFiles { + err = parser.parseFile(pkg.ImportPath, filepath.Join(srcDir, pkg.CgoFiles[i]), nil, parseFlag) + if err != nil { + return err + } + } + + return nil +} diff --git a/pkg/swag/golist_test.go b/pkg/swag/golist_test.go new file mode 100644 index 0000000..4a12473 --- /dev/null +++ b/pkg/swag/golist_test.go @@ -0,0 +1,115 @@ +package swag + +import ( + "context" + "errors" + "fmt" + "go/build" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestListPackages(t *testing.T) { + cases := []struct { + name string + args []string + searchDir string + except error + }{ + { + name: "errorArgs", + args: []string{"-abc"}, + searchDir: "testdata/golist", + except: fmt.Errorf("exit status 2"), + }, + { + name: "normal", + args: []string{"-deps"}, + searchDir: "testdata/golist", + except: nil, + }, + { + name: "list error", + args: []string{"-deps"}, + searchDir: "testdata/golist_not_exist", + except: errors.New("searchDir not exist"), + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + _, err := listPackages(context.TODO(), c.searchDir, nil, c.args...) + if c.except != nil { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + } + }) + } +} + +func TestGetAllGoFileInfoFromDepsByList(t *testing.T) { + p := New(ParseUsingGoList(true)) + pwd, err := os.Getwd() + assert.NoError(t, err) + cases := []struct { + name string + buildPackage *build.Package + ignoreInternal bool + except error + }{ + { + name: "normal", + buildPackage: &build.Package{ + Name: "main", + ImportPath: "git.ipao.vip/rogeecn/atomctl/pkg/swag/testdata/golist", + Dir: "testdata/golist", + GoFiles: []string{"main.go"}, + CgoFiles: []string{"api/api.go"}, + }, + except: nil, + }, + { + name: "ignore internal", + buildPackage: &build.Package{ + Goroot: true, + }, + ignoreInternal: true, + except: nil, + }, + { + name: "gofiles error", + buildPackage: &build.Package{ + Dir: "testdata/golist_not_exist", + GoFiles: []string{"main.go"}, + }, + except: errors.New("file not exist"), + }, + { + name: "cgofiles error", + buildPackage: &build.Package{ + Dir: "testdata/golist_not_exist", + CgoFiles: []string{"main.go"}, + }, + except: errors.New("file not exist"), + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if c.ignoreInternal { + p.ParseInternal = false + } + c.buildPackage.Dir = filepath.Join(pwd, c.buildPackage.Dir) + err := p.getAllGoFileInfoFromDepsByList(c.buildPackage, ParseModels) + if c.except != nil { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + } + }) + } +} diff --git a/pkg/swag/operation.go b/pkg/swag/operation.go new file mode 100644 index 0000000..4458170 --- /dev/null +++ b/pkg/swag/operation.go @@ -0,0 +1,1256 @@ +package swag + +import ( + "encoding/json" + "fmt" + "go/ast" + goparser "go/parser" + "go/token" + "net/http" + "os" + "path/filepath" + "regexp" + "strconv" + "strings" + + "github.com/go-openapi/spec" + "golang.org/x/tools/go/loader" +) + +// RouteProperties describes HTTP properties of a single router comment. +type RouteProperties struct { + HTTPMethod string + Path string + Deprecated bool +} + +// Operation describes a single API operation on a path. +// For more information: https://git.ipao.vip/rogeecn/atomctl/pkg/swag#api-operation +type Operation struct { + parser *Parser + codeExampleFilesDir string + spec.Operation + RouterProperties []RouteProperties + State string +} + +var mimeTypeAliases = map[string]string{ + "json": "application/json", + "xml": "text/xml", + "plain": "text/plain", + "html": "text/html", + "mpfd": "multipart/form-data", + "x-www-form-urlencoded": "application/x-www-form-urlencoded", + "json-api": "application/vnd.api+json", + "json-stream": "application/x-json-stream", + "octet-stream": "application/octet-stream", + "png": "image/png", + "jpeg": "image/jpeg", + "gif": "image/gif", +} + +var ( + mimeTypePattern = regexp.MustCompile("^[^/]+/[^/]+$") + securityPairSepPattern = regexp.MustCompile(`\|\||&&`) // || for compatibility with old version, && for clarity +) + +// NewOperation creates a new Operation with default properties. +// map[int]Response. +func NewOperation(parser *Parser, options ...func(*Operation)) *Operation { + if parser == nil { + parser = New() + } + + result := &Operation{ + parser: parser, + RouterProperties: []RouteProperties{}, + Operation: spec.Operation{ + OperationProps: spec.OperationProps{ + ID: "", + Description: "", + Summary: "", + Security: nil, + ExternalDocs: nil, + Deprecated: false, + Tags: []string{}, + Consumes: []string{}, + Produces: []string{}, + Schemes: []string{}, + Parameters: []spec.Parameter{}, + Responses: &spec.Responses{ + VendorExtensible: spec.VendorExtensible{ + Extensions: spec.Extensions{}, + }, + ResponsesProps: spec.ResponsesProps{ + Default: nil, + StatusCodeResponses: make(map[int]spec.Response), + }, + }, + }, + VendorExtensible: spec.VendorExtensible{ + Extensions: spec.Extensions{}, + }, + }, + codeExampleFilesDir: "", + } + + for _, option := range options { + option(result) + } + + return result +} + +// SetCodeExampleFilesDirectory sets the directory to search for codeExamples. +func SetCodeExampleFilesDirectory(directoryPath string) func(*Operation) { + return func(o *Operation) { + o.codeExampleFilesDir = directoryPath + } +} + +// ParseComment parses comment for given comment string and returns error if error occurs. +func (operation *Operation) ParseComment(comment string, astFile *ast.File) error { + commentLine := strings.TrimSpace(strings.TrimLeft(comment, "/")) + if len(commentLine) == 0 { + return nil + } + + fields := FieldsByAnySpace(commentLine, 2) + attribute := fields[0] + lowerAttribute := strings.ToLower(attribute) + var lineRemainder string + if len(fields) > 1 { + lineRemainder = fields[1] + } + switch lowerAttribute { + case stateAttr: + operation.ParseStateComment(lineRemainder) + case descriptionAttr: + operation.ParseDescriptionComment(lineRemainder) + case descriptionMarkdownAttr: + commentInfo, err := getMarkdownForTag(lineRemainder, operation.parser.markdownFileDir) + if err != nil { + return err + } + + operation.ParseDescriptionComment(string(commentInfo)) + case summaryAttr: + operation.Summary = lineRemainder + case idAttr: + operation.ID = lineRemainder + case tagsAttr: + operation.ParseTagsComment(lineRemainder) + case acceptAttr: + return operation.ParseAcceptComment(lineRemainder) + case produceAttr: + return operation.ParseProduceComment(lineRemainder) + case paramAttr: + return operation.ParseParamComment(lineRemainder, astFile) + case successAttr, failureAttr, responseAttr: + return operation.ParseResponseComment(lineRemainder, astFile) + case headerAttr: + return operation.ParseResponseHeaderComment(lineRemainder, astFile) + case routerAttr: + return operation.ParseRouterComment(lineRemainder, false) + case deprecatedRouterAttr: + return operation.ParseRouterComment(lineRemainder, true) + case securityAttr: + return operation.ParseSecurityComment(lineRemainder) + case deprecatedAttr: + operation.Deprecate() + case xCodeSamplesAttr: + return operation.ParseCodeSample(attribute, commentLine, lineRemainder) + default: + return operation.ParseMetadata(attribute, lowerAttribute, lineRemainder) + } + + return nil +} + +// ParseCodeSample parse code sample. +func (operation *Operation) ParseCodeSample(attribute, _, lineRemainder string) error { + if lineRemainder == "file" { + data, err := getCodeExampleForSummary(operation.Summary, operation.codeExampleFilesDir) + if err != nil { + return err + } + + var valueJSON interface{} + + err = json.Unmarshal(data, &valueJSON) + if err != nil { + return fmt.Errorf("annotation %s need a valid json value", attribute) + } + + // don't use the method provided by spec lib, because it will call toLower() on attribute names, which is wrongly + operation.Extensions[attribute[1:]] = valueJSON + + return nil + } + + // Fallback into existing logic + return operation.ParseMetadata(attribute, strings.ToLower(attribute), lineRemainder) +} + +// ParseStateComment parse state comment. +func (operation *Operation) ParseStateComment(lineRemainder string) { + operation.State = lineRemainder +} + +// ParseDescriptionComment parse description comment. +func (operation *Operation) ParseDescriptionComment(lineRemainder string) { + if operation.Description == "" { + operation.Description = lineRemainder + + return + } + + operation.Description += "\n" + lineRemainder +} + +// ParseMetadata parse metadata. +func (operation *Operation) ParseMetadata(attribute, lowerAttribute, lineRemainder string) error { + // parsing specific meta data extensions + if strings.HasPrefix(lowerAttribute, "@x-") { + if len(lineRemainder) == 0 { + return fmt.Errorf("annotation %s need a value", attribute) + } + + var valueJSON interface{} + + err := json.Unmarshal([]byte(lineRemainder), &valueJSON) + if err != nil { + return fmt.Errorf("annotation %s need a valid json value", attribute) + } + + // don't use the method provided by spec lib, because it will call toLower() on attribute names, which is wrongly + operation.Extensions[attribute[1:]] = valueJSON + } + + return nil +} + +var paramPattern = regexp.MustCompile(`(\S+)\s+(\w+)\s+([\S. ]+?)\s+(\w+)\s+"([^"]+)"`) + +func findInSlice(arr []string, target string) bool { + for _, str := range arr { + if str == target { + return true + } + } + + return false +} + +// ParseParamComment parses params return []string of param properties +// E.g. @Param queryText formData string true "The email for login" +// +// [param name] [paramType] [data type] [is mandatory?] [Comment] +// +// E.g. @Param some_id path int true "Some ID". +func (operation *Operation) ParseParamComment(commentLine string, astFile *ast.File) error { + matches := paramPattern.FindStringSubmatch(commentLine) + if len(matches) != 6 { + return fmt.Errorf("missing required param comment parameters \"%s\"", commentLine) + } + + name := matches[1] + paramType := matches[2] + refType, format := TransToValidSchemeTypeWithFormat(matches[3]) + + // Detect refType + objectType := OBJECT + + if strings.HasPrefix(refType, "[]") { + objectType = ARRAY + refType = strings.TrimPrefix(refType, "[]") + refType, format = TransToValidSchemeTypeWithFormat(refType) + } else if IsPrimitiveType(refType) || + paramType == "formData" && refType == "file" { + objectType = PRIMITIVE + } + + var enums []interface{} + if !IsPrimitiveType(refType) { + schema, _ := operation.parser.getTypeSchema(refType, astFile, false) + if schema != nil && len(schema.Type) == 1 && schema.Enum != nil { + if objectType == OBJECT { + objectType = PRIMITIVE + } + refType, format = TransToValidSchemeTypeWithFormat(schema.Type[0]) + enums = schema.Enum + } + } + + requiredText := strings.ToLower(matches[4]) + required := requiredText == "true" || requiredText == requiredLabel + description := strings.Join(strings.Split(matches[5], "\\n"), "\n") + + param := createParameter(paramType, description, name, objectType, refType, format, required, enums, operation.parser.collectionFormatInQuery) + + switch paramType { + case "path", "header", "query", "formData": + switch objectType { + case ARRAY: + if !IsPrimitiveType(refType) && !(refType == "file" && paramType == "formData") { + return fmt.Errorf("%s is not supported array type for %s", refType, paramType) + } + case PRIMITIVE: + break + case OBJECT: + schema, err := operation.parser.getTypeSchema(refType, astFile, false) + if err != nil { + return err + } + + if len(schema.Properties) == 0 { + return nil + } + + items := schema.Properties.ToOrderedSchemaItems() + + for _, item := range items { + name, prop := item.Name, &item.Schema + if len(prop.Type) == 0 { + prop = operation.parser.getUnderlyingSchema(prop) + if len(prop.Type) == 0 { + continue + } + } + + nameOverrideType := paramType + // query also uses formData tags + if paramType == "query" { + nameOverrideType = "formData" + } + // load overridden type specific name from extensions if exists + if nameVal, ok := item.Schema.Extensions[nameOverrideType]; ok { + name = nameVal.(string) + } + + switch { + case prop.Type[0] == ARRAY: + if prop.Items.Schema == nil { + continue + } + itemSchema := prop.Items.Schema + if len(itemSchema.Type) == 0 { + itemSchema = operation.parser.getUnderlyingSchema(prop.Items.Schema) + } + if itemSchema == nil { + continue + } + if len(itemSchema.Type) == 0 { + continue + } + if !IsSimplePrimitiveType(itemSchema.Type[0]) { + continue + } + param = createParameter(paramType, prop.Description, name, prop.Type[0], itemSchema.Type[0], format, findInSlice(schema.Required, item.Name), itemSchema.Enum, operation.parser.collectionFormatInQuery) + + case IsSimplePrimitiveType(prop.Type[0]): + param = createParameter(paramType, prop.Description, name, PRIMITIVE, prop.Type[0], format, findInSlice(schema.Required, item.Name), nil, operation.parser.collectionFormatInQuery) + default: + operation.parser.debug.Printf("skip field [%s] in %s is not supported type for %s", name, refType, paramType) + continue + } + + param.Nullable = prop.Nullable + param.Format = prop.Format + param.Default = prop.Default + param.Example = prop.Example + param.Extensions = prop.Extensions + param.CommonValidations.Maximum = prop.Maximum + param.CommonValidations.Minimum = prop.Minimum + param.CommonValidations.ExclusiveMaximum = prop.ExclusiveMaximum + param.CommonValidations.ExclusiveMinimum = prop.ExclusiveMinimum + param.CommonValidations.MaxLength = prop.MaxLength + param.CommonValidations.MinLength = prop.MinLength + param.CommonValidations.Pattern = prop.Pattern + param.CommonValidations.MaxItems = prop.MaxItems + param.CommonValidations.MinItems = prop.MinItems + param.CommonValidations.UniqueItems = prop.UniqueItems + param.CommonValidations.MultipleOf = prop.MultipleOf + param.CommonValidations.Enum = prop.Enum + operation.Operation.Parameters = append(operation.Operation.Parameters, param) + } + + return nil + } + case "body": + if objectType == PRIMITIVE { + param.Schema = PrimitiveSchema(refType) + } else { + schema, err := operation.parseAPIObjectSchema(commentLine, objectType, refType, astFile) + if err != nil { + return err + } + + param.Schema = schema + } + default: + return fmt.Errorf("not supported paramType: %s", paramType) + } + + err := operation.parseParamAttribute(commentLine, objectType, refType, paramType, ¶m) + if err != nil { + return err + } + + operation.Operation.Parameters = append(operation.Operation.Parameters, param) + + return nil +} + +const ( + formTag = "form" + jsonTag = "json" + uriTag = "uri" + headerTag = "header" + bindingTag = "binding" + defaultTag = "default" + enumsTag = "enums" + exampleTag = "example" + schemaExampleTag = "schemaExample" + formatTag = "format" + titleTag = "title" + validateTag = "validate" + minimumTag = "minimum" + maximumTag = "maximum" + minLengthTag = "minLength" + maxLengthTag = "maxLength" + multipleOfTag = "multipleOf" + readOnlyTag = "readonly" + extensionsTag = "extensions" + collectionFormatTag = "collectionFormat" +) + +var regexAttributes = map[string]*regexp.Regexp{ + // for Enums(A, B) + enumsTag: regexp.MustCompile(`(?i)\s+enums\(.*\)`), + // for maximum(0) + maximumTag: regexp.MustCompile(`(?i)\s+maxinum|maximum\(.*\)`), + // for minimum(0) + minimumTag: regexp.MustCompile(`(?i)\s+mininum|minimum\(.*\)`), + // for default(0) + defaultTag: regexp.MustCompile(`(?i)\s+default\(.*\)`), + // for minlength(0) + minLengthTag: regexp.MustCompile(`(?i)\s+minlength\(.*\)`), + // for maxlength(0) + maxLengthTag: regexp.MustCompile(`(?i)\s+maxlength\(.*\)`), + // for format(email) + formatTag: regexp.MustCompile(`(?i)\s+format\(.*\)`), + // for extensions(x-example=test) + extensionsTag: regexp.MustCompile(`(?i)\s+extensions\(.*\)`), + // for collectionFormat(csv) + collectionFormatTag: regexp.MustCompile(`(?i)\s+collectionFormat\(.*\)`), + // example(0) + exampleTag: regexp.MustCompile(`(?i)\s+example\(.*\)`), + // schemaExample(0) + schemaExampleTag: regexp.MustCompile(`(?i)\s+schemaExample\(.*\)`), +} + +func (operation *Operation) parseParamAttribute(comment, objectType, schemaType, paramType string, param *spec.Parameter) error { + schemaType = TransToValidSchemeType(schemaType) + + for attrKey, re := range regexAttributes { + attr, err := findAttr(re, comment) + if err != nil { + continue + } + + switch attrKey { + case enumsTag: + err = setEnumParam(param, attr, objectType, schemaType, paramType) + case minimumTag, maximumTag: + err = setNumberParam(param, attrKey, schemaType, attr, comment) + case defaultTag: + err = setDefault(param, schemaType, attr) + case minLengthTag, maxLengthTag: + err = setStringParam(param, attrKey, schemaType, attr, comment) + case formatTag: + param.Format = attr + case exampleTag: + err = setExample(param, schemaType, attr) + case schemaExampleTag: + err = setSchemaExample(param, schemaType, attr) + case extensionsTag: + param.Extensions = setExtensionParam(attr) + case collectionFormatTag: + err = setCollectionFormatParam(param, attrKey, objectType, attr, comment) + } + + if err != nil { + return err + } + } + + return nil +} + +func findAttr(re *regexp.Regexp, commentLine string) (string, error) { + attr := re.FindString(commentLine) + + l, r := strings.Index(attr, "("), strings.Index(attr, ")") + if l == -1 || r == -1 { + return "", fmt.Errorf("can not find regex=%s, comment=%s", re.String(), commentLine) + } + + return strings.TrimSpace(attr[l+1 : r]), nil +} + +func setStringParam(param *spec.Parameter, name, schemaType, attr, commentLine string) error { + if schemaType != STRING { + return fmt.Errorf("%s is attribute to set to a number. comment=%s got=%s", name, commentLine, schemaType) + } + + n, err := strconv.ParseInt(attr, 10, 64) + if err != nil { + return fmt.Errorf("%s is allow only a number got=%s", name, attr) + } + + switch name { + case minLengthTag: + param.MinLength = &n + case maxLengthTag: + param.MaxLength = &n + } + + return nil +} + +func setNumberParam(param *spec.Parameter, name, schemaType, attr, commentLine string) error { + switch schemaType { + case INTEGER, NUMBER: + n, err := strconv.ParseFloat(attr, 64) + if err != nil { + return fmt.Errorf("maximum is allow only a number. comment=%s got=%s", commentLine, attr) + } + + switch name { + case minimumTag: + param.Minimum = &n + case maximumTag: + param.Maximum = &n + } + + return nil + default: + return fmt.Errorf("%s is attribute to set to a number. comment=%s got=%s", name, commentLine, schemaType) + } +} + +func setEnumParam(param *spec.Parameter, attr, objectType, schemaType, paramType string) error { + for _, e := range strings.Split(attr, ",") { + e = strings.TrimSpace(e) + + value, err := defineType(schemaType, e) + if err != nil { + return err + } + + switch objectType { + case ARRAY: + param.Items.Enum = append(param.Items.Enum, value) + default: + switch paramType { + case "body": + param.Schema.Enum = append(param.Schema.Enum, value) + default: + param.Enum = append(param.Enum, value) + } + } + } + + return nil +} + +func setExtensionParam(attr string) spec.Extensions { + extensions := spec.Extensions{} + + for _, val := range splitNotWrapped(attr, ',') { + parts := strings.SplitN(val, "=", 2) + if len(parts) == 2 { + extensions.Add(parts[0], parts[1]) + + continue + } + + if len(parts[0]) > 0 && string(parts[0][0]) == "!" { + extensions.Add(parts[0][1:], false) + + continue + } + + extensions.Add(parts[0], true) + } + + return extensions +} + +func setCollectionFormatParam(param *spec.Parameter, name, schemaType, attr, commentLine string) error { + if schemaType == ARRAY { + param.CollectionFormat = TransToValidCollectionFormat(attr) + + return nil + } + + return fmt.Errorf("%s is attribute to set to an array. comment=%s got=%s", name, commentLine, schemaType) +} + +func setDefault(param *spec.Parameter, schemaType string, value string) error { + val, err := defineType(schemaType, value) + if err != nil { + return nil // Don't set a default value if it's not valid + } + + param.Default = val + + return nil +} + +func setSchemaExample(param *spec.Parameter, schemaType string, value string) error { + val, err := defineType(schemaType, value) + if err != nil { + return nil // Don't set a example value if it's not valid + } + // skip schema + if param.Schema == nil { + return nil + } + + switch v := val.(type) { + case string: + // replaces \r \n \t in example string values. + param.Schema.Example = strings.NewReplacer(`\r`, "\r", `\n`, "\n", `\t`, "\t").Replace(v) + default: + param.Schema.Example = val + } + + return nil +} + +func setExample(param *spec.Parameter, schemaType string, value string) error { + val, err := defineType(schemaType, value) + if err != nil { + return nil // Don't set a example value if it's not valid + } + + param.Example = val + + return nil +} + +// defineType enum value define the type (object and array unsupported). +func defineType(schemaType string, value string) (v interface{}, err error) { + schemaType = TransToValidSchemeType(schemaType) + + switch schemaType { + case STRING: + return value, nil + case NUMBER: + v, err = strconv.ParseFloat(value, 64) + if err != nil { + return nil, fmt.Errorf("enum value %s can't convert to %s err: %s", value, schemaType, err) + } + case INTEGER: + v, err = strconv.Atoi(value) + if err != nil { + return nil, fmt.Errorf("enum value %s can't convert to %s err: %s", value, schemaType, err) + } + case BOOLEAN: + v, err = strconv.ParseBool(value) + if err != nil { + return nil, fmt.Errorf("enum value %s can't convert to %s err: %s", value, schemaType, err) + } + default: + return nil, fmt.Errorf("%s is unsupported type in enum value %s", schemaType, value) + } + + return v, nil +} + +// ParseTagsComment parses comment for given `tag` comment string. +func (operation *Operation) ParseTagsComment(commentLine string) { + for _, tag := range strings.Split(commentLine, ",") { + operation.Tags = append(operation.Tags, strings.TrimSpace(tag)) + } +} + +// ParseAcceptComment parses comment for given `accept` comment string. +func (operation *Operation) ParseAcceptComment(commentLine string) error { + return parseMimeTypeList(commentLine, &operation.Consumes, "%v accept type can't be accepted") +} + +// ParseProduceComment parses comment for given `produce` comment string. +func (operation *Operation) ParseProduceComment(commentLine string) error { + return parseMimeTypeList(commentLine, &operation.Produces, "%v produce type can't be accepted") +} + +// parseMimeTypeList parses a list of MIME Types for a comment like +// `produce` (`Content-Type:` response header) or +// `accept` (`Accept:` request header). +func parseMimeTypeList(mimeTypeList string, typeList *[]string, format string) error { + for _, typeName := range strings.Split(mimeTypeList, ",") { + if mimeTypePattern.MatchString(typeName) { + *typeList = append(*typeList, typeName) + + continue + } + + aliasMimeType, ok := mimeTypeAliases[typeName] + if !ok { + return fmt.Errorf(format, typeName) + } + + *typeList = append(*typeList, aliasMimeType) + } + + return nil +} + +var ( + routerPattern = regexp.MustCompile(`^(/[\w./\-{}\(\)+:$]*)[[:blank:]]+\[(\w+)]`) + routerParamTypePattern = regexp.MustCompile(`:(\w+)(<.*?>)?`) +) + +// ParseRouterComment parses comment for given `router` comment string. +func (operation *Operation) ParseRouterComment(commentLine string, deprecated bool) error { + commentLine = routerParamTypePattern.ReplaceAllString(commentLine, "{$1}") + + matches := routerPattern.FindStringSubmatch(commentLine) + if len(matches) != 3 { + return fmt.Errorf("can not parse router comment \"%s\"", commentLine) + } + + signature := RouteProperties{ + Path: matches[1], + HTTPMethod: strings.ToUpper(matches[2]), + Deprecated: deprecated, + } + + if _, ok := allMethod[signature.HTTPMethod]; !ok { + return fmt.Errorf("invalid method: %s", signature.HTTPMethod) + } + + operation.RouterProperties = append(operation.RouterProperties, signature) + + return nil +} + +// ParseSecurityComment parses comment for given `security` comment string. +func (operation *Operation) ParseSecurityComment(commentLine string) error { + if len(commentLine) == 0 { + operation.Security = []map[string][]string{} + return nil + } + + var ( + securityMap = make(map[string][]string) + securitySource = commentLine[strings.Index(commentLine, "@Security")+1:] + ) + + for _, securityOption := range securityPairSepPattern.Split(securitySource, -1) { + securityOption = strings.TrimSpace(securityOption) + + left, right := strings.Index(securityOption, "["), strings.Index(securityOption, "]") + + if !(left == -1 && right == -1) { + scopes := securityOption[left+1 : right] + + var options []string + + for _, scope := range strings.Split(scopes, ",") { + options = append(options, strings.TrimSpace(scope)) + } + + securityKey := securityOption[0:left] + securityMap[securityKey] = append(securityMap[securityKey], options...) + } else { + securityKey := strings.TrimSpace(securityOption) + securityMap[securityKey] = []string{} + } + } + + operation.Security = append(operation.Security, securityMap) + + return nil +} + +// findTypeDef attempts to find the *ast.TypeSpec for a specific type given the +// type's name and the package's import path. +// TODO: improve finding external pkg. +func findTypeDef(importPath, typeName string) (*ast.TypeSpec, error) { + cwd, err := os.Getwd() + if err != nil { + return nil, err + } + + conf := loader.Config{ + ParserMode: goparser.SpuriousErrors, + Cwd: cwd, + } + + conf.Import(importPath) + + lprog, err := conf.Load() + if err != nil { + return nil, err + } + + // If the pkg is vendored, the actual pkg path is going to resemble + // something like "{importPath}/vendor/{importPath}" + for k := range lprog.AllPackages { + realPkgPath := k.Path() + + if strings.Contains(realPkgPath, "vendor/"+importPath) { + importPath = realPkgPath + } + } + + pkgInfo := lprog.Package(importPath) + + if pkgInfo == nil { + return nil, fmt.Errorf("package was nil") + } + + // TODO: possibly cache pkgInfo since it's an expensive operation + for i := range pkgInfo.Files { + for _, astDeclaration := range pkgInfo.Files[i].Decls { + generalDeclaration, ok := astDeclaration.(*ast.GenDecl) + if ok && generalDeclaration.Tok == token.TYPE { + for _, astSpec := range generalDeclaration.Specs { + typeSpec, ok := astSpec.(*ast.TypeSpec) + if ok { + if typeSpec.Name.String() == typeName { + return typeSpec, nil + } + } + } + } + } + } + + return nil, fmt.Errorf("type spec not found") +} + +var responsePattern = regexp.MustCompile(`^([\w,]+)\s+([\w{}]+)\s+([\w\-.\\{}=,\[\s\]]+)\s*(".*)?`) + +// ResponseType{data1=Type1,data2=Type2}. +var combinedPattern = regexp.MustCompile(`^([\w\-./\[\]]+){(.*)}$`) + +func (operation *Operation) parseObjectSchema(refType string, astFile *ast.File) (*spec.Schema, error) { + return parseObjectSchema(operation.parser, refType, astFile) +} + +func parseObjectSchema(parser *Parser, refType string, astFile *ast.File) (*spec.Schema, error) { + switch { + case refType == NIL: + return nil, nil + case refType == INTERFACE: + return &spec.Schema{}, nil + case refType == ANY: + return &spec.Schema{}, nil + case IsGolangPrimitiveType(refType): + return TransToValidPrimitiveSchema(refType), nil + case IsPrimitiveType(refType): + return PrimitiveSchema(refType), nil + case strings.HasPrefix(refType, "[]"): + schema, err := parseObjectSchema(parser, refType[2:], astFile) + if err != nil { + return nil, err + } + + return spec.ArrayProperty(schema), nil + case strings.HasPrefix(refType, "map["): + // ignore key type + idx := strings.Index(refType, "]") + if idx < 0 { + return nil, fmt.Errorf("invalid type: %s", refType) + } + + refType = refType[idx+1:] + if refType == INTERFACE || refType == ANY { + return spec.MapProperty(nil), nil + } + + schema, err := parseObjectSchema(parser, refType, astFile) + if err != nil { + return nil, err + } + + return spec.MapProperty(schema), nil + case strings.Contains(refType, "{"): + return parseCombinedObjectSchema(parser, refType, astFile) + default: + if parser != nil { // checking refType has existing in 'TypeDefinitions' + schema, err := parser.getTypeSchema(refType, astFile, true) + if err != nil { + return nil, err + } + + return schema, nil + } + + return RefSchema(refType), nil + } +} + +func parseFields(s string) []string { + nestLevel := 0 + + return strings.FieldsFunc(s, func(char rune) bool { + if char == '{' { + nestLevel++ + + return false + } else if char == '}' { + nestLevel-- + + return false + } + + return char == ',' && nestLevel == 0 + }) +} + +func parseCombinedObjectSchema(parser *Parser, refType string, astFile *ast.File) (*spec.Schema, error) { + matches := combinedPattern.FindStringSubmatch(refType) + if len(matches) != 3 { + return nil, fmt.Errorf("invalid type: %s", refType) + } + + schema, err := parseObjectSchema(parser, matches[1], astFile) + if err != nil { + return nil, err + } + + fields, props := parseFields(matches[2]), map[string]spec.Schema{} + + for _, field := range fields { + keyVal := strings.SplitN(field, "=", 2) + if len(keyVal) == 2 { + schema, err := parseObjectSchema(parser, keyVal[1], astFile) + if err != nil { + return nil, err + } + + if schema == nil { + schema = PrimitiveSchema(OBJECT) + } + + props[keyVal[0]] = *schema + } + } + + if len(props) == 0 { + return schema, nil + } + + if schema.Ref.GetURL() == nil && len(schema.Type) > 0 && schema.Type[0] == OBJECT && len(schema.Properties) == 0 && schema.AdditionalProperties == nil { + schema.Properties = props + return schema, nil + } + + return spec.ComposedSchema(*schema, spec.Schema{ + SchemaProps: spec.SchemaProps{ + Type: []string{OBJECT}, + Properties: props, + }, + }), nil +} + +func (operation *Operation) parseAPIObjectSchema(commentLine, schemaType, refType string, astFile *ast.File) (*spec.Schema, error) { + if strings.HasSuffix(refType, ",") && strings.Contains(refType, "[") { + // regexp may have broken generic syntax. find closing bracket and add it back + allMatchesLenOffset := strings.Index(commentLine, refType) + len(refType) + lostPartEndIdx := strings.Index(commentLine[allMatchesLenOffset:], "]") + if lostPartEndIdx >= 0 { + refType += commentLine[allMatchesLenOffset : allMatchesLenOffset+lostPartEndIdx+1] + } + } + + switch schemaType { + case OBJECT: + if !strings.HasPrefix(refType, "[]") { + return operation.parseObjectSchema(refType, astFile) + } + + refType = refType[2:] + + fallthrough + case ARRAY: + schema, err := operation.parseObjectSchema(refType, astFile) + if err != nil { + return nil, err + } + + return spec.ArrayProperty(schema), nil + default: + return PrimitiveSchema(schemaType), nil + } +} + +// ParseResponseComment parses comment for given `response` comment string. +func (operation *Operation) ParseResponseComment(commentLine string, astFile *ast.File) error { + matches := responsePattern.FindStringSubmatch(commentLine) + if len(matches) != 5 { + err := operation.ParseEmptyResponseComment(commentLine) + if err != nil { + return operation.ParseEmptyResponseOnly(commentLine) + } + + return err + } + + description := strings.Trim(matches[4], "\"") + + schema, err := operation.parseAPIObjectSchema(commentLine, strings.Trim(matches[2], "{}"), strings.TrimSpace(matches[3]), astFile) + if err != nil { + return err + } + + for _, codeStr := range strings.Split(matches[1], ",") { + if strings.EqualFold(codeStr, defaultTag) { + operation.DefaultResponse().WithSchema(schema).WithDescription(description) + + continue + } + + code, err := strconv.Atoi(codeStr) + if err != nil { + return fmt.Errorf("can not parse response comment \"%s\"", commentLine) + } + + resp := spec.NewResponse().WithSchema(schema).WithDescription(description) + if description == "" { + resp.WithDescription(http.StatusText(code)) + } + + operation.AddResponse(code, resp) + } + + return nil +} + +func newHeaderSpec(schemaType, description string) spec.Header { + return spec.Header{ + SimpleSchema: spec.SimpleSchema{ + Type: schemaType, + }, + HeaderProps: spec.HeaderProps{ + Description: description, + }, + VendorExtensible: spec.VendorExtensible{ + Extensions: nil, + }, + CommonValidations: spec.CommonValidations{ + Maximum: nil, + ExclusiveMaximum: false, + Minimum: nil, + ExclusiveMinimum: false, + MaxLength: nil, + MinLength: nil, + Pattern: "", + MaxItems: nil, + MinItems: nil, + UniqueItems: false, + MultipleOf: nil, + Enum: nil, + }, + } +} + +// ParseResponseHeaderComment parses comment for given `response header` comment string. +func (operation *Operation) ParseResponseHeaderComment(commentLine string, _ *ast.File) error { + matches := responsePattern.FindStringSubmatch(commentLine) + if len(matches) != 5 { + return fmt.Errorf("can not parse response comment \"%s\"", commentLine) + } + + header := newHeaderSpec(strings.Trim(matches[2], "{}"), strings.Trim(matches[4], "\"")) + + headerKey := strings.TrimSpace(matches[3]) + + if strings.EqualFold(matches[1], "all") { + if operation.Responses.Default != nil { + operation.Responses.Default.Headers[headerKey] = header + } + + if operation.Responses.StatusCodeResponses != nil { + for code, response := range operation.Responses.StatusCodeResponses { + response.Headers[headerKey] = header + operation.Responses.StatusCodeResponses[code] = response + } + } + + return nil + } + + for _, codeStr := range strings.Split(matches[1], ",") { + if strings.EqualFold(codeStr, defaultTag) { + if operation.Responses.Default != nil { + operation.Responses.Default.Headers[headerKey] = header + } + + continue + } + + code, err := strconv.Atoi(codeStr) + if err != nil { + return fmt.Errorf("can not parse response comment \"%s\"", commentLine) + } + + if operation.Responses.StatusCodeResponses != nil { + response, responseExist := operation.Responses.StatusCodeResponses[code] + if responseExist { + response.Headers[headerKey] = header + + operation.Responses.StatusCodeResponses[code] = response + } + } + } + + return nil +} + +var emptyResponsePattern = regexp.MustCompile(`([\w,]+)\s+"(.*)"`) + +// ParseEmptyResponseComment parse only comment out status code and description,eg: @Success 200 "it's ok". +func (operation *Operation) ParseEmptyResponseComment(commentLine string) error { + matches := emptyResponsePattern.FindStringSubmatch(commentLine) + if len(matches) != 3 { + return fmt.Errorf("can not parse response comment \"%s\"", commentLine) + } + + description := strings.Trim(matches[2], "\"") + + for _, codeStr := range strings.Split(matches[1], ",") { + if strings.EqualFold(codeStr, defaultTag) { + operation.DefaultResponse().WithDescription(description) + + continue + } + + code, err := strconv.Atoi(codeStr) + if err != nil { + return fmt.Errorf("can not parse response comment \"%s\"", commentLine) + } + + operation.AddResponse(code, spec.NewResponse().WithDescription(description)) + } + + return nil +} + +// ParseEmptyResponseOnly parse only comment out status code ,eg: @Success 200. +func (operation *Operation) ParseEmptyResponseOnly(commentLine string) error { + for _, codeStr := range strings.Split(commentLine, ",") { + if strings.EqualFold(codeStr, defaultTag) { + _ = operation.DefaultResponse() + + continue + } + + code, err := strconv.Atoi(codeStr) + if err != nil { + return fmt.Errorf("can not parse response comment \"%s\"", commentLine) + } + + operation.AddResponse(code, spec.NewResponse().WithDescription(http.StatusText(code))) + } + + return nil +} + +// DefaultResponse return the default response member pointer. +func (operation *Operation) DefaultResponse() *spec.Response { + if operation.Responses.Default == nil { + operation.Responses.Default = &spec.Response{ + ResponseProps: spec.ResponseProps{ + Description: "", + Headers: make(map[string]spec.Header), + }, + } + } + + return operation.Responses.Default +} + +// AddResponse add a response for a code. +func (operation *Operation) AddResponse(code int, response *spec.Response) { + if response.Headers == nil { + response.Headers = make(map[string]spec.Header) + } + + operation.Responses.StatusCodeResponses[code] = *response +} + +// createParameter returns swagger spec.Parameter for given paramType, description, paramName, schemaType, required. +func createParameter(paramType, description, paramName, objectType, schemaType string, format string, required bool, enums []interface{}, collectionFormat string) spec.Parameter { + // //five possible parameter types. query, path, body, header, form + result := spec.Parameter{ + ParamProps: spec.ParamProps{ + Name: paramName, + Description: description, + Required: required, + In: paramType, + }, + } + + if paramType == "body" { + return result + } + + switch objectType { + case ARRAY: + result.Type = objectType + result.CollectionFormat = collectionFormat + result.Items = &spec.Items{ + CommonValidations: spec.CommonValidations{ + Enum: enums, + }, + SimpleSchema: spec.SimpleSchema{ + Type: schemaType, + Format: format, + }, + } + case PRIMITIVE, OBJECT: + result.Type = schemaType + result.Enum = enums + result.Format = format + } + return result +} + +func getCodeExampleForSummary(summaryName string, dirPath string) ([]byte, error) { + dirEntries, err := os.ReadDir(dirPath) + if err != nil { + return nil, err + } + + for _, entry := range dirEntries { + if entry.IsDir() { + continue + } + + fileName := entry.Name() + + if !strings.Contains(fileName, ".json") { + continue + } + + if strings.Contains(fileName, summaryName) { + fullPath := filepath.Join(dirPath, fileName) + + commentInfo, err := os.ReadFile(fullPath) + if err != nil { + return nil, fmt.Errorf("Failed to read code example file %s error: %s ", fullPath, err) + } + + return commentInfo, nil + } + } + + return nil, fmt.Errorf("unable to find code example file for tag %s in the given directory", summaryName) +} diff --git a/pkg/swag/operation_test.go b/pkg/swag/operation_test.go new file mode 100644 index 0000000..e3b5828 --- /dev/null +++ b/pkg/swag/operation_test.go @@ -0,0 +1,2623 @@ +package swag + +import ( + "encoding/json" + "fmt" + "go/ast" + goparser "go/parser" + "go/token" + "os" + "path/filepath" + "testing" + + "github.com/go-openapi/spec" + "github.com/stretchr/testify/assert" +) + +func TestParseEmptyComment(t *testing.T) { + t.Parallel() + + operation := NewOperation(nil) + err := operation.ParseComment("//", nil) + + assert.NoError(t, err) +} + +func TestParseTagsComment(t *testing.T) { + t.Parallel() + + operation := NewOperation(nil) + err := operation.ParseComment(`/@Tags pet, store,user`, nil) + assert.NoError(t, err) + assert.Equal(t, operation.Tags, []string{"pet", "store", "user"}) +} + +func TestParseAcceptComment(t *testing.T) { + t.Parallel() + + comment := `/@Accept json,xml,plain,html,mpfd,x-www-form-urlencoded,json-api,json-stream,octet-stream,png,jpeg,gif,application/xhtml+xml,application/health+json` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + assert.Equal(t, + operation.Consumes, + []string{ + "application/json", + "text/xml", + "text/plain", + "text/html", + "multipart/form-data", + "application/x-www-form-urlencoded", + "application/vnd.api+json", + "application/x-json-stream", + "application/octet-stream", + "image/png", + "image/jpeg", + "image/gif", + "application/xhtml+xml", + "application/health+json", + }) +} + +func TestParseAcceptCommentErr(t *testing.T) { + t.Parallel() + + comment := `/@Accept unknown` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + assert.Error(t, err) +} + +func TestParseProduceComment(t *testing.T) { + t.Parallel() + + expected := `{ + "produces": [ + "application/json", + "text/xml", + "text/plain", + "text/html", + "multipart/form-data", + "application/x-www-form-urlencoded", + "application/vnd.api+json", + "application/x-json-stream", + "application/octet-stream", + "image/png", + "image/jpeg", + "image/gif", + "application/health+json" + ] +}` + comment := `/@Produce json,xml,plain,html,mpfd,x-www-form-urlencoded,json-api,json-stream,octet-stream,png,jpeg,gif,application/health+json` + operation := new(Operation) + err := operation.ParseComment(comment, nil) + assert.NoError(t, err, "ParseComment should not fail") + b, _ := json.MarshalIndent(operation, "", " ") + assert.JSONEq(t, expected, string(b)) +} + +func TestParseProduceCommentErr(t *testing.T) { + t.Parallel() + + operation := new(Operation) + err := operation.ParseComment("/@Produce foo", nil) + assert.Error(t, err) +} + +func TestParseRouterComment(t *testing.T) { + t.Parallel() + + comment := `/@Router /customer/get-wishlist/{wishlist_id} [get]` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + assert.Len(t, operation.RouterProperties, 1) + assert.Equal(t, "/customer/get-wishlist/{wishlist_id}", operation.RouterProperties[0].Path) + assert.Equal(t, "GET", operation.RouterProperties[0].HTTPMethod) + + comment = `/@Router /customer/get-wishlist/{wishlist_id} [unknown]` + operation = NewOperation(nil) + err = operation.ParseComment(comment, nil) + assert.Error(t, err) +} + +func TestParseRouterMultipleComments(t *testing.T) { + t.Parallel() + + comment := `/@Router /customer/get-wishlist/{wishlist_id} [get]` + anotherComment := `/@Router /customer/get-the-wishlist/{wishlist_id} [post]` + operation := NewOperation(nil) + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + + err = operation.ParseComment(anotherComment, nil) + assert.NoError(t, err) + + assert.Len(t, operation.RouterProperties, 2) + assert.Equal(t, "/customer/get-wishlist/{wishlist_id}", operation.RouterProperties[0].Path) + assert.Equal(t, "GET", operation.RouterProperties[0].HTTPMethod) + assert.Equal(t, "/customer/get-the-wishlist/{wishlist_id}", operation.RouterProperties[1].Path) + assert.Equal(t, "POST", operation.RouterProperties[1].HTTPMethod) +} + +func TestParseRouterOnlySlash(t *testing.T) { + t.Parallel() + + comment := `// @Router / [get]` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + assert.Len(t, operation.RouterProperties, 1) + assert.Equal(t, "/", operation.RouterProperties[0].Path) + assert.Equal(t, "GET", operation.RouterProperties[0].HTTPMethod) +} + +func TestParseRouterCommentWithPlusSign(t *testing.T) { + t.Parallel() + + comment := `/@Router /customer/get-wishlist/{proxy+} [post]` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + assert.Len(t, operation.RouterProperties, 1) + assert.Equal(t, "/customer/get-wishlist/{proxy+}", operation.RouterProperties[0].Path) + assert.Equal(t, "POST", operation.RouterProperties[0].HTTPMethod) +} + +func TestParseRouterCommentWithDollarSign(t *testing.T) { + t.Parallel() + + comment := `/@Router /customer/get-wishlist/{wishlist_id}$move [post]` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + assert.Len(t, operation.RouterProperties, 1) + assert.Equal(t, "/customer/get-wishlist/{wishlist_id}$move", operation.RouterProperties[0].Path) + assert.Equal(t, "POST", operation.RouterProperties[0].HTTPMethod) +} + +func TestParseRouterCommentWithParens(t *testing.T) { + t.Parallel() + + comment := `/@Router /customer({id}) [get]` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + assert.Len(t, operation.RouterProperties, 1) + assert.Equal(t, "/customer({id})", operation.RouterProperties[0].Path) + assert.Equal(t, "GET", operation.RouterProperties[0].HTTPMethod) +} + +func TestParseRouterCommentNoDollarSignAtPathStartErr(t *testing.T) { + t.Parallel() + + comment := `/@Router $customer/get-wishlist/{wishlist_id}$move [post]` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + assert.Error(t, err) +} + +func TestParseRouterCommentWithColonSign(t *testing.T) { + t.Parallel() + + comment := `/@Router /customer/get-wishlist/{wishlist_id}:move [post]` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + assert.Len(t, operation.RouterProperties, 1) + assert.Equal(t, "/customer/get-wishlist/{wishlist_id}:move", operation.RouterProperties[0].Path) + assert.Equal(t, "POST", operation.RouterProperties[0].HTTPMethod) +} + +func TestParseRouterCommentNoColonSignAtPathStartErr(t *testing.T) { + t.Parallel() + + comment := `/@Router :customer/get-wishlist/{wishlist_id}:move [post]` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + assert.Error(t, err) +} + +func TestParseRouterCommentMethodSeparationErr(t *testing.T) { + t.Parallel() + + comment := `/@Router /api/{id}|,*[get` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + assert.Error(t, err) +} + +func TestParseRouterCommentMethodMissingErr(t *testing.T) { + t.Parallel() + + comment := `/@Router /customer/get-wishlist/{wishlist_id}` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + assert.Error(t, err) +} + +func TestOperation_ParseResponseWithDefault(t *testing.T) { + t.Parallel() + + comment := `@Success default {object} nil "An empty response"` + operation := NewOperation(nil) + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + + assert.Equal(t, "An empty response", operation.Responses.Default.Description) + + comment = `@Success 200,default {string} Response "A response"` + operation = NewOperation(nil) + + err = operation.ParseComment(comment, nil) + assert.NoError(t, err) + + assert.Equal(t, "A response", operation.Responses.Default.Description) + assert.Equal(t, "A response", operation.Responses.StatusCodeResponses[200].Description) +} + +func TestParseResponseSuccessCommentWithEmptyResponse(t *testing.T) { + t.Parallel() + + comment := `@Success 200 {object} nil "An empty response"` + operation := NewOperation(nil) + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + + response := operation.Responses.StatusCodeResponses[200] + assert.Equal(t, `An empty response`, response.Description) + + b, _ := json.MarshalIndent(operation, "", " ") + expected := `{ + "responses": { + "200": { + "description": "An empty response" + } + } +}` + assert.Equal(t, expected, string(b)) +} + +func TestParseResponseFailureCommentWithEmptyResponse(t *testing.T) { + t.Parallel() + + comment := `@Failure 500 {object} nil` + operation := NewOperation(nil) + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + + b, _ := json.MarshalIndent(operation, "", " ") + expected := `{ + "responses": { + "500": { + "description": "Internal Server Error" + } + } +}` + assert.Equal(t, expected, string(b)) +} + +func TestParseResponseCommentWithObjectType(t *testing.T) { + t.Parallel() + + comment := `@Success 200 {object} model.OrderRow "Error message, if code != 200` + operation := NewOperation(nil) + operation.parser.addTestType("model.OrderRow") + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + + response := operation.Responses.StatusCodeResponses[200] + assert.Equal(t, `Error message, if code != 200`, response.Description) + + b, _ := json.MarshalIndent(operation, "", " ") + + expected := `{ + "responses": { + "200": { + "description": "Error message, if code != 200", + "schema": { + "$ref": "#/definitions/model.OrderRow" + } + } + } +}` + assert.Equal(t, expected, string(b)) +} + +func TestParseResponseCommentWithNestedPrimitiveType(t *testing.T) { + t.Parallel() + + comment := `@Success 200 {object} model.CommonHeader{data=string,data2=int} "Error message, if code != 200` + operation := NewOperation(nil) + + operation.parser.addTestType("model.CommonHeader") + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + + response := operation.Responses.StatusCodeResponses[200] + assert.Equal(t, `Error message, if code != 200`, response.Description) + + b, _ := json.MarshalIndent(operation, "", " ") + + expected := `{ + "responses": { + "200": { + "description": "Error message, if code != 200", + "schema": { + "allOf": [ + { + "$ref": "#/definitions/model.CommonHeader" + }, + { + "type": "object", + "properties": { + "data": { + "type": "string" + }, + "data2": { + "type": "integer" + } + } + } + ] + } + } + } +}` + assert.Equal(t, expected, string(b)) +} + +func TestParseResponseCommentWithNestedPrimitiveArrayType(t *testing.T) { + t.Parallel() + + comment := `@Success 200 {object} model.CommonHeader{data=[]string,data2=[]int} "Error message, if code != 200` + operation := NewOperation(nil) + + operation.parser.addTestType("model.CommonHeader") + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + + response := operation.Responses.StatusCodeResponses[200] + assert.Equal(t, `Error message, if code != 200`, response.Description) + + b, _ := json.MarshalIndent(operation, "", " ") + + expected := `{ + "responses": { + "200": { + "description": "Error message, if code != 200", + "schema": { + "allOf": [ + { + "$ref": "#/definitions/model.CommonHeader" + }, + { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "type": "string" + } + }, + "data2": { + "type": "array", + "items": { + "type": "integer" + } + } + } + } + ] + } + } + } +}` + assert.Equal(t, expected, string(b)) +} + +func TestParseResponseCommentWithNestedObjectType(t *testing.T) { + t.Parallel() + + comment := `@Success 200 {object} model.CommonHeader{data=model.Payload,data2=model.Payload2} "Error message, if code != 200` + operation := NewOperation(nil) + operation.parser.addTestType("model.CommonHeader") + operation.parser.addTestType("model.Payload") + operation.parser.addTestType("model.Payload2") + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + + response := operation.Responses.StatusCodeResponses[200] + assert.Equal(t, `Error message, if code != 200`, response.Description) + + b, _ := json.MarshalIndent(operation, "", " ") + + expected := `{ + "responses": { + "200": { + "description": "Error message, if code != 200", + "schema": { + "allOf": [ + { + "$ref": "#/definitions/model.CommonHeader" + }, + { + "type": "object", + "properties": { + "data": { + "$ref": "#/definitions/model.Payload" + }, + "data2": { + "$ref": "#/definitions/model.Payload2" + } + } + } + ] + } + } + } +}` + assert.Equal(t, expected, string(b)) +} + +func TestParseResponseCommentWithNestedArrayObjectType(t *testing.T) { + t.Parallel() + + comment := `@Success 200 {object} model.CommonHeader{data=[]model.Payload,data2=[]model.Payload2} "Error message, if code != 200` + operation := NewOperation(nil) + + operation.parser.addTestType("model.CommonHeader") + operation.parser.addTestType("model.Payload") + operation.parser.addTestType("model.Payload2") + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + + response := operation.Responses.StatusCodeResponses[200] + assert.Equal(t, `Error message, if code != 200`, response.Description) + + b, _ := json.MarshalIndent(operation, "", " ") + + expected := `{ + "responses": { + "200": { + "description": "Error message, if code != 200", + "schema": { + "allOf": [ + { + "$ref": "#/definitions/model.CommonHeader" + }, + { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/definitions/model.Payload" + } + }, + "data2": { + "type": "array", + "items": { + "$ref": "#/definitions/model.Payload2" + } + } + } + } + ] + } + } + } +}` + assert.Equal(t, expected, string(b)) +} + +func TestParseResponseCommentWithNestedFields(t *testing.T) { + t.Parallel() + + comment := `@Success 200 {object} model.CommonHeader{data1=int,data2=[]int,data3=model.Payload,data4=[]model.Payload} "Error message, if code != 200` + operation := NewOperation(nil) + + operation.parser.addTestType("model.CommonHeader") + operation.parser.addTestType("model.Payload") + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + + response := operation.Responses.StatusCodeResponses[200] + assert.Equal(t, `Error message, if code != 200`, response.Description) + + b, _ := json.MarshalIndent(operation, "", " ") + + expected := `{ + "responses": { + "200": { + "description": "Error message, if code != 200", + "schema": { + "allOf": [ + { + "$ref": "#/definitions/model.CommonHeader" + }, + { + "type": "object", + "properties": { + "data1": { + "type": "integer" + }, + "data2": { + "type": "array", + "items": { + "type": "integer" + } + }, + "data3": { + "$ref": "#/definitions/model.Payload" + }, + "data4": { + "type": "array", + "items": { + "$ref": "#/definitions/model.Payload" + } + } + } + } + ] + } + } + } +}` + assert.Equal(t, expected, string(b)) +} + +func TestParseResponseCommentWithDeepNestedFields(t *testing.T) { + t.Parallel() + + comment := `@Success 200 {object} model.CommonHeader{data1=int,data2=[]int,data3=model.Payload{data1=int,data2=model.DeepPayload},data4=[]model.Payload{data1=[]int,data2=[]model.DeepPayload}} "Error message, if code != 200` + operation := NewOperation(nil) + + operation.parser.addTestType("model.CommonHeader") + operation.parser.addTestType("model.Payload") + operation.parser.addTestType("model.DeepPayload") + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + + response := operation.Responses.StatusCodeResponses[200] + assert.Equal(t, `Error message, if code != 200`, response.Description) + + b, _ := json.MarshalIndent(operation, "", " ") + expected := `{ + "responses": { + "200": { + "description": "Error message, if code != 200", + "schema": { + "allOf": [ + { + "$ref": "#/definitions/model.CommonHeader" + }, + { + "type": "object", + "properties": { + "data1": { + "type": "integer" + }, + "data2": { + "type": "array", + "items": { + "type": "integer" + } + }, + "data3": { + "allOf": [ + { + "$ref": "#/definitions/model.Payload" + }, + { + "type": "object", + "properties": { + "data1": { + "type": "integer" + }, + "data2": { + "$ref": "#/definitions/model.DeepPayload" + } + } + } + ] + }, + "data4": { + "type": "array", + "items": { + "allOf": [ + { + "$ref": "#/definitions/model.Payload" + }, + { + "type": "object", + "properties": { + "data1": { + "type": "array", + "items": { + "type": "integer" + } + }, + "data2": { + "type": "array", + "items": { + "$ref": "#/definitions/model.DeepPayload" + } + } + } + } + ] + } + } + } + } + ] + } + } + } +}` + assert.Equal(t, expected, string(b)) +} + +func TestParseResponseCommentWithNestedArrayMapFields(t *testing.T) { + t.Parallel() + + comment := `@Success 200 {object} []map[string]model.CommonHeader{data1=[]map[string]model.Payload,data2=map[string][]int} "Error message, if code != 200` + operation := NewOperation(nil) + + operation.parser.addTestType("model.CommonHeader") + operation.parser.addTestType("model.Payload") + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + + response := operation.Responses.StatusCodeResponses[200] + assert.Equal(t, `Error message, if code != 200`, response.Description) + + b, _ := json.MarshalIndent(operation, "", " ") + expected := `{ + "responses": { + "200": { + "description": "Error message, if code != 200", + "schema": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": { + "allOf": [ + { + "$ref": "#/definitions/model.CommonHeader" + }, + { + "type": "object", + "properties": { + "data1": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": { + "$ref": "#/definitions/model.Payload" + } + } + }, + "data2": { + "type": "object", + "additionalProperties": { + "type": "array", + "items": { + "type": "integer" + } + } + } + } + } + ] + } + } + } + } + } +}` + assert.Equal(t, expected, string(b)) +} + +func TestParseResponseCommentWithObjectTypeInSameFile(t *testing.T) { + t.Parallel() + + comment := `@Success 200 {object} testOwner "Error message, if code != 200"` + operation := NewOperation(nil) + + operation.parser.addTestType("swag.testOwner") + + fset := token.NewFileSet() + astFile, err := goparser.ParseFile(fset, "operation_test.go", `package swag + type testOwner struct { + + } + `, goparser.ParseComments) + assert.NoError(t, err) + + err = operation.ParseComment(comment, astFile) + assert.NoError(t, err) + + response := operation.Responses.StatusCodeResponses[200] + assert.Equal(t, `Error message, if code != 200`, response.Description) + + b, _ := json.MarshalIndent(operation, "", " ") + + expected := `{ + "responses": { + "200": { + "description": "Error message, if code != 200", + "schema": { + "$ref": "#/definitions/swag.testOwner" + } + } + } +}` + assert.Equal(t, expected, string(b)) +} + +func TestParseResponseCommentWithObjectTypeAnonymousField(t *testing.T) { + // TODO: test Anonymous +} + +func TestParseResponseCommentWithObjectTypeErr(t *testing.T) { + t.Parallel() + + comment := `@Success 200 {object} model.OrderRow "Error message, if code != 200"` + operation := NewOperation(nil) + + operation.parser.addTestType("model.notexist") + + err := operation.ParseComment(comment, nil) + assert.Error(t, err) +} + +func TestParseResponseCommentWithArrayType(t *testing.T) { + t.Parallel() + + comment := `@Success 200 {array} model.OrderRow "Error message, if code != 200` + operation := NewOperation(nil) + operation.parser.addTestType("model.OrderRow") + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + response := operation.Responses.StatusCodeResponses[200] + assert.Equal(t, `Error message, if code != 200`, response.Description) + assert.Equal(t, spec.StringOrArray{"array"}, response.Schema.Type) + + b, _ := json.MarshalIndent(operation, "", " ") + + expected := `{ + "responses": { + "200": { + "description": "Error message, if code != 200", + "schema": { + "type": "array", + "items": { + "$ref": "#/definitions/model.OrderRow" + } + } + } + } +}` + assert.Equal(t, expected, string(b)) +} + +func TestParseResponseCommentWithBasicType(t *testing.T) { + t.Parallel() + + comment := `@Success 200 {string} string "it's ok'"` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + assert.NoError(t, err, "ParseComment should not fail") + b, _ := json.MarshalIndent(operation, "", " ") + + expected := `{ + "responses": { + "200": { + "description": "it's ok'", + "schema": { + "type": "string" + } + } + } +}` + assert.Equal(t, expected, string(b)) +} + +func TestParseResponseCommentWithBasicTypeAndCodes(t *testing.T) { + t.Parallel() + + comment := `@Success 200,201,default {string} string "it's ok"` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + assert.NoError(t, err, "ParseComment should not fail") + b, _ := json.MarshalIndent(operation, "", " ") + + expected := `{ + "responses": { + "200": { + "description": "it's ok", + "schema": { + "type": "string" + } + }, + "201": { + "description": "it's ok", + "schema": { + "type": "string" + } + }, + "default": { + "description": "it's ok", + "schema": { + "type": "string" + } + } + } +}` + assert.Equal(t, expected, string(b)) +} + +func TestParseEmptyResponseComment(t *testing.T) { + t.Parallel() + + comment := `@Success 200 "it is ok"` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + assert.NoError(t, err, "ParseComment should not fail") + + b, _ := json.MarshalIndent(operation, "", " ") + + expected := `{ + "responses": { + "200": { + "description": "it is ok" + } + } +}` + assert.Equal(t, expected, string(b)) +} + +func TestParseEmptyResponseCommentWithCodes(t *testing.T) { + t.Parallel() + + comment := `@Success 200,201,default "it is ok"` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + assert.NoError(t, err, "ParseComment should not fail") + + b, _ := json.MarshalIndent(operation, "", " ") + + expected := `{ + "responses": { + "200": { + "description": "it is ok" + }, + "201": { + "description": "it is ok" + }, + "default": { + "description": "it is ok" + } + } +}` + assert.Equal(t, expected, string(b)) +} + +func TestParseResponseCommentWithHeader(t *testing.T) { + t.Parallel() + + operation := NewOperation(nil) + err := operation.ParseComment(`@Success 200 "it's ok"`, nil) + assert.NoError(t, err, "ParseComment should not fail") + + err = operation.ParseComment(`@Header 200 {string} Token "qwerty"`, nil) + assert.NoError(t, err, "ParseComment should not fail") + + b, err := json.MarshalIndent(operation, "", " ") + assert.NoError(t, err) + + expected := `{ + "responses": { + "200": { + "description": "it's ok", + "headers": { + "Token": { + "type": "string", + "description": "qwerty" + } + } + } + } +}` + assert.Equal(t, expected, string(b)) + + err = operation.ParseComment(`@Header 200 "Mallformed"`, nil) + assert.Error(t, err, "ParseComment should not fail") + + err = operation.ParseComment(`@Header 200,asdsd {string} Token "qwerty"`, nil) + assert.Error(t, err, "ParseComment should not fail") +} + +func TestParseResponseCommentWithHeaderForCodes(t *testing.T) { + t.Parallel() + + operation := NewOperation(nil) + + comment := `@Success 200,201,default "it's ok"` + err := operation.ParseComment(comment, nil) + assert.NoError(t, err, "ParseComment should not fail") + + comment = `@Header 200,201,default {string} Token "qwerty"` + err = operation.ParseComment(comment, nil) + assert.NoError(t, err, "ParseComment should not fail") + + comment = `@Header all {string} Token2 "qwerty"` + err = operation.ParseComment(comment, nil) + assert.NoError(t, err, "ParseComment should not fail") + + b, err := json.MarshalIndent(operation, "", " ") + assert.NoError(t, err) + + expected := `{ + "responses": { + "200": { + "description": "it's ok", + "headers": { + "Token": { + "type": "string", + "description": "qwerty" + }, + "Token2": { + "type": "string", + "description": "qwerty" + } + } + }, + "201": { + "description": "it's ok", + "headers": { + "Token": { + "type": "string", + "description": "qwerty" + }, + "Token2": { + "type": "string", + "description": "qwerty" + } + } + }, + "default": { + "description": "it's ok", + "headers": { + "Token": { + "type": "string", + "description": "qwerty" + }, + "Token2": { + "type": "string", + "description": "qwerty" + } + } + } + } +}` + assert.Equal(t, expected, string(b)) + + comment = `@Header 200 "Mallformed"` + err = operation.ParseComment(comment, nil) + assert.Error(t, err, "ParseComment should not fail") +} + +func TestParseResponseCommentWithHeaderOnlyAll(t *testing.T) { + t.Parallel() + + operation := NewOperation(nil) + + comment := `@Success 200,201,default "it's ok"` + err := operation.ParseComment(comment, nil) + assert.NoError(t, err, "ParseComment should not fail") + + comment = `@Header all {string} Token "qwerty"` + err = operation.ParseComment(comment, nil) + assert.NoError(t, err, "ParseComment should not fail") + + b, err := json.MarshalIndent(operation, "", " ") + assert.NoError(t, err) + + expected := `{ + "responses": { + "200": { + "description": "it's ok", + "headers": { + "Token": { + "type": "string", + "description": "qwerty" + } + } + }, + "201": { + "description": "it's ok", + "headers": { + "Token": { + "type": "string", + "description": "qwerty" + } + } + }, + "default": { + "description": "it's ok", + "headers": { + "Token": { + "type": "string", + "description": "qwerty" + } + } + } + } +}` + assert.Equal(t, expected, string(b)) + + comment = `@Header 200 "Mallformed"` + err = operation.ParseComment(comment, nil) + assert.Error(t, err, "ParseComment should not fail") +} + +func TestParseEmptyResponseOnlyCode(t *testing.T) { + t.Parallel() + + operation := NewOperation(nil) + err := operation.ParseComment(`@Success 200`, nil) + assert.NoError(t, err, "ParseComment should not fail") + + b, _ := json.MarshalIndent(operation, "", " ") + + expected := `{ + "responses": { + "200": { + "description": "OK" + } + } +}` + assert.Equal(t, expected, string(b)) +} + +func TestParseEmptyResponseOnlyCodes(t *testing.T) { + t.Parallel() + + comment := `@Success 200,201,default` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + assert.NoError(t, err, "ParseComment should not fail") + + b, _ := json.MarshalIndent(operation, "", " ") + + expected := `{ + "responses": { + "200": { + "description": "OK" + }, + "201": { + "description": "Created" + }, + "default": { + "description": "" + } + } +}` + assert.Equal(t, expected, string(b)) +} + +func TestParseResponseCommentParamMissing(t *testing.T) { + t.Parallel() + + operation := NewOperation(nil) + + paramLenErrComment := `@Success notIntCode` + paramLenErr := operation.ParseComment(paramLenErrComment, nil) + assert.EqualError(t, paramLenErr, `can not parse response comment "notIntCode"`) + + paramLenErrComment = `@Success notIntCode {string} string "it ok"` + paramLenErr = operation.ParseComment(paramLenErrComment, nil) + assert.EqualError(t, paramLenErr, `can not parse response comment "notIntCode {string} string "it ok""`) + + paramLenErrComment = `@Success notIntCode "it ok"` + paramLenErr = operation.ParseComment(paramLenErrComment, nil) + assert.EqualError(t, paramLenErr, `can not parse response comment "notIntCode "it ok""`) +} + +func TestOperation_ParseParamComment(t *testing.T) { + t.Parallel() + + t.Run("integer", func(t *testing.T) { + t.Parallel() + for _, paramType := range []string{"header", "path", "query", "formData"} { + t.Run(paramType, func(t *testing.T) { + o := NewOperation(nil) + err := o.ParseComment(`@Param some_id `+paramType+` int true "Some ID"`, nil) + + assert.NoError(t, err) + assert.Equal(t, o.Parameters, []spec.Parameter{{ + SimpleSchema: spec.SimpleSchema{ + Type: "integer", + }, + ParamProps: spec.ParamProps{ + Name: "some_id", + Description: "Some ID", + In: paramType, + Required: true, + }, + }}) + }) + } + }) + + t.Run("string", func(t *testing.T) { + t.Parallel() + for _, paramType := range []string{"header", "path", "query", "formData"} { + t.Run(paramType, func(t *testing.T) { + o := NewOperation(nil) + err := o.ParseComment(`@Param some_string `+paramType+` string true "Some String"`, nil) + + assert.NoError(t, err) + assert.Equal(t, o.Parameters, []spec.Parameter{{ + SimpleSchema: spec.SimpleSchema{ + Type: "string", + }, + ParamProps: spec.ParamProps{ + Name: "some_string", + Description: "Some String", + In: paramType, + Required: true, + }, + }}) + }) + } + }) + + t.Run("object", func(t *testing.T) { + t.Parallel() + for _, paramType := range []string{"header", "path", "query", "formData"} { + t.Run(paramType, func(t *testing.T) { + // unknown object returns error + assert.Error(t, NewOperation(nil).ParseComment(`@Param some_object `+paramType+` main.Object true "Some Object"`, nil)) + + // verify objects are supported here + o := NewOperation(nil) + o.parser.addTestType("main.TestObject") + err := o.ParseComment(`@Param some_object `+paramType+` main.TestObject true "Some Object"`, nil) + assert.NoError(t, err) + }) + } + }) +} + +// Test ParseParamComment Query Params +func TestParseParamCommentBodyArray(t *testing.T) { + t.Parallel() + + comment := `@Param names body []string true "Users List"` + o := NewOperation(nil) + err := o.ParseComment(comment, nil) + assert.NoError(t, err) + assert.Equal(t, o.Parameters, []spec.Parameter{{ + ParamProps: spec.ParamProps{ + Name: "names", + Description: "Users List", + In: "body", + Required: true, + Schema: &spec.Schema{ + SchemaProps: spec.SchemaProps{ + Type: []string{"array"}, + Items: &spec.SchemaOrArray{ + Schema: &spec.Schema{ + SchemaProps: spec.SchemaProps{ + Type: []string{"string"}, + }, + }, + }, + }, + }, + }, + }}) +} + +// Test ParseParamComment Params +func TestParseParamCommentArray(t *testing.T) { + paramTypes := []string{"header", "path", "query"} + + for _, paramType := range paramTypes { + t.Run(paramType, func(t *testing.T) { + operation := NewOperation(nil) + err := operation.ParseComment(`@Param names `+paramType+` []string true "Users List"`, nil) + + assert.NoError(t, err) + + b, _ := json.MarshalIndent(operation.Parameters, "", " ") + expected := `[ + { + "type": "array", + "items": { + "type": "string" + }, + "description": "Users List", + "name": "names", + "in": "` + paramType + `", + "required": true + } +]` + assert.Equal(t, expected, string(b)) + + err = operation.ParseComment(`@Param names `+paramType+` []model.User true "Users List"`, nil) + assert.Error(t, err) + }) + } +} + +// Test TestParseParamCommentDefaultValue Query Params +func TestParseParamCommentDefaultValue(t *testing.T) { + t.Parallel() + + operation := NewOperation(nil) + err := operation.ParseComment(`@Param names query string true "Users List" default(test)`, nil) + assert.NoError(t, err) + + b, _ := json.MarshalIndent(operation.Parameters, "", " ") + expected := `[ + { + "type": "string", + "default": "test", + "description": "Users List", + "name": "names", + "in": "query", + "required": true + } +]` + assert.Equal(t, expected, string(b)) +} + +// Test ParseParamComment Query Params +func TestParseParamCommentQueryArrayFormat(t *testing.T) { + t.Parallel() + + comment := `@Param names query []string true "Users List" collectionFormat(multi)` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + + assert.NoError(t, err) + b, _ := json.MarshalIndent(operation.Parameters, "", " ") + expected := `[ + { + "type": "array", + "items": { + "type": "string" + }, + "collectionFormat": "multi", + "description": "Users List", + "name": "names", + "in": "query", + "required": true + } +]` + assert.Equal(t, expected, string(b)) +} + +func TestParseParamCommentByID(t *testing.T) { + t.Parallel() + + comment := `@Param unsafe_id[lte] query int true "Unsafe query param"` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + + assert.NoError(t, err) + b, _ := json.MarshalIndent(operation.Parameters, "", " ") + expected := `[ + { + "type": "integer", + "description": "Unsafe query param", + "name": "unsafe_id[lte]", + "in": "query", + "required": true + } +]` + assert.Equal(t, expected, string(b)) +} + +func TestParseParamCommentWithMultilineDescriptions(t *testing.T) { + t.Parallel() + + comment := `@Param some_id query int true "First line\nSecond line\nThird line"` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + + assert.NoError(t, err) + b, _ := json.MarshalIndent(operation.Parameters, "", " ") + expected := `[ + { + "type": "integer", + "description": "First line\nSecond line\nThird line", + "name": "some_id", + "in": "query", + "required": true + } +]` + assert.Equal(t, expected, string(b)) +} + +func TestParseParamCommentByQueryType(t *testing.T) { + t.Parallel() + + comment := `@Param some_id query int true "Some ID"` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + + assert.NoError(t, err) + b, _ := json.MarshalIndent(operation.Parameters, "", " ") + expected := `[ + { + "type": "integer", + "description": "Some ID", + "name": "some_id", + "in": "query", + "required": true + } +]` + assert.Equal(t, expected, string(b)) +} + +func TestParseParamCommentByBodyType(t *testing.T) { + t.Parallel() + + comment := `@Param some_id body model.OrderRow true "Some ID"` + operation := NewOperation(nil) + + operation.parser.addTestType("model.OrderRow") + err := operation.ParseComment(comment, nil) + + assert.NoError(t, err) + b, _ := json.MarshalIndent(operation.Parameters, "", " ") + expected := `[ + { + "description": "Some ID", + "name": "some_id", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/model.OrderRow" + } + } +]` + assert.Equal(t, expected, string(b)) +} + +func TestParseParamCommentByBodyTextPlain(t *testing.T) { + t.Parallel() + + comment := `@Param text body string true "Text to process"` + operation := NewOperation(nil) + + err := operation.ParseComment(comment, nil) + + assert.NoError(t, err) + b, _ := json.MarshalIndent(operation.Parameters, "", " ") + expected := `[ + { + "description": "Text to process", + "name": "text", + "in": "body", + "required": true, + "schema": { + "type": "string" + } + } +]` + assert.Equal(t, expected, string(b)) +} + +// TODO: fix this +func TestParseParamCommentByBodyEnumsText(t *testing.T) { + t.Parallel() + + comment := `@Param text body string true "description" Enums(ENUM1, ENUM2, ENUM3)` + operation := NewOperation(nil) + + err := operation.ParseComment(comment, nil) + + assert.NoError(t, err) + b, _ := json.MarshalIndent(operation.Parameters, "", " ") + expected := `[ + { + "description": "description", + "name": "text", + "in": "body", + "required": true, + "schema": { + "type": "string", + "enum": [ + "ENUM1", + "ENUM2", + "ENUM3" + ] + } + } +]` + assert.Equal(t, expected, string(b)) +} + +func TestParseParamCommentByBodyTypeWithDeepNestedFields(t *testing.T) { + t.Parallel() + + comment := `@Param body body model.CommonHeader{data=string,data2=int} true "test deep"` + operation := NewOperation(nil) + + operation.parser.addTestType("model.CommonHeader") + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + assert.Len(t, operation.Parameters, 1) + assert.Equal(t, "test deep", operation.Parameters[0].Description) + assert.True(t, operation.Parameters[0].Required) + + b, err := json.MarshalIndent(operation.Parameters, "", " ") + assert.NoError(t, err) + expected := `[ + { + "description": "test deep", + "name": "body", + "in": "body", + "required": true, + "schema": { + "allOf": [ + { + "$ref": "#/definitions/model.CommonHeader" + }, + { + "type": "object", + "properties": { + "data": { + "type": "string" + }, + "data2": { + "type": "integer" + } + } + } + ] + } + } +]` + assert.Equal(t, expected, string(b)) +} + +func TestParseParamCommentByBodyTypeArrayOfPrimitiveGo(t *testing.T) { + t.Parallel() + + comment := `@Param some_id body []int true "Some ID"` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + + assert.NoError(t, err) + b, _ := json.MarshalIndent(operation.Parameters, "", " ") + expected := `[ + { + "description": "Some ID", + "name": "some_id", + "in": "body", + "required": true, + "schema": { + "type": "array", + "items": { + "type": "integer" + } + } + } +]` + assert.Equal(t, expected, string(b)) +} + +func TestParseParamCommentByBodyTypeArrayOfPrimitiveGoWithDeepNestedFields(t *testing.T) { + t.Parallel() + + comment := `@Param body body []model.CommonHeader{data=string,data2=int} true "test deep"` + operation := NewOperation(nil) + operation.parser.addTestType("model.CommonHeader") + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + assert.Len(t, operation.Parameters, 1) + assert.Equal(t, "test deep", operation.Parameters[0].Description) + assert.True(t, operation.Parameters[0].Required) + + b, err := json.MarshalIndent(operation.Parameters, "", " ") + assert.NoError(t, err) + expected := `[ + { + "description": "test deep", + "name": "body", + "in": "body", + "required": true, + "schema": { + "type": "array", + "items": { + "allOf": [ + { + "$ref": "#/definitions/model.CommonHeader" + }, + { + "type": "object", + "properties": { + "data": { + "type": "string" + }, + "data2": { + "type": "integer" + } + } + } + ] + } + } + } +]` + assert.Equal(t, expected, string(b)) +} + +func TestParseParamCommentByBodyTypeErr(t *testing.T) { + t.Parallel() + + comment := `@Param some_id body model.OrderRow true "Some ID"` + operation := NewOperation(nil) + operation.parser.addTestType("model.notexist") + err := operation.ParseComment(comment, nil) + + assert.Error(t, err) +} + +func TestParseParamCommentByFormDataType(t *testing.T) { + t.Parallel() + + comment := `@Param file formData file true "this is a test file"` + operation := NewOperation(nil) + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + + b, _ := json.MarshalIndent(operation.Parameters, "", " ") + expected := `[ + { + "type": "file", + "description": "this is a test file", + "name": "file", + "in": "formData", + "required": true + } +]` + assert.Equal(t, expected, string(b)) +} + +func TestParseParamCommentByFormDataTypeUint64(t *testing.T) { + t.Parallel() + + comment := `@Param file formData uint64 true "this is a test file"` + operation := NewOperation(nil) + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + + b, _ := json.MarshalIndent(operation.Parameters, "", " ") + expected := `[ + { + "type": "integer", + "format": "int64", + "description": "this is a test file", + "name": "file", + "in": "formData", + "required": true + } +]` + assert.Equal(t, expected, string(b)) +} + +func TestParseParamCommentByNotSupportedType(t *testing.T) { + t.Parallel() + + comment := `@Param some_id not_supported int true "Some ID"` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + + assert.Error(t, err) +} + +func TestParseParamCommentNotMatch(t *testing.T) { + t.Parallel() + + comment := `@Param some_id body mock true` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + + assert.Error(t, err) +} + +func TestParseParamCommentByEnums(t *testing.T) { + t.Parallel() + + comment := `@Param some_id query string true "Some ID" Enums(A, B, C)` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + + assert.NoError(t, err) + b, _ := json.MarshalIndent(operation.Parameters, "", " ") + expected := `[ + { + "enum": [ + "A", + "B", + "C" + ], + "type": "string", + "description": "Some ID", + "name": "some_id", + "in": "query", + "required": true + } +]` + assert.Equal(t, expected, string(b)) + + comment = `@Param some_id query int true "Some ID" Enums(1, 2, 3)` + operation = NewOperation(nil) + err = operation.ParseComment(comment, nil) + + assert.NoError(t, err) + b, _ = json.MarshalIndent(operation.Parameters, "", " ") + expected = `[ + { + "enum": [ + 1, + 2, + 3 + ], + "type": "integer", + "description": "Some ID", + "name": "some_id", + "in": "query", + "required": true + } +]` + assert.Equal(t, expected, string(b)) + + comment = `@Param some_id query number true "Some ID" Enums(1.1, 2.2, 3.3)` + operation = NewOperation(nil) + err = operation.ParseComment(comment, nil) + + assert.NoError(t, err) + b, _ = json.MarshalIndent(operation.Parameters, "", " ") + expected = `[ + { + "enum": [ + 1.1, + 2.2, + 3.3 + ], + "type": "number", + "description": "Some ID", + "name": "some_id", + "in": "query", + "required": true + } +]` + assert.Equal(t, expected, string(b)) + + comment = `@Param some_id query bool true "Some ID" Enums(true, false)` + operation = NewOperation(nil) + err = operation.ParseComment(comment, nil) + + assert.NoError(t, err) + b, _ = json.MarshalIndent(operation.Parameters, "", " ") + expected = `[ + { + "enum": [ + true, + false + ], + "type": "boolean", + "description": "Some ID", + "name": "some_id", + "in": "query", + "required": true + } +]` + assert.Equal(t, expected, string(b)) + + operation = NewOperation(nil) + + comment = `@Param some_id query int true "Some ID" Enums(A, B, C)` + assert.Error(t, operation.ParseComment(comment, nil)) + + comment = `@Param some_id query number true "Some ID" Enums(A, B, C)` + assert.Error(t, operation.ParseComment(comment, nil)) + + comment = `@Param some_id query boolean true "Some ID" Enums(A, B, C)` + assert.Error(t, operation.ParseComment(comment, nil)) + + comment = `@Param some_id query Document true "Some ID" Enums(A, B, C)` + assert.Error(t, operation.ParseComment(comment, nil)) +} + +func TestParseParamCommentByMaxLength(t *testing.T) { + t.Parallel() + + comment := `@Param some_id query string true "Some ID" MaxLength(10)` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + + assert.NoError(t, err) + b, _ := json.MarshalIndent(operation.Parameters, "", " ") + expected := `[ + { + "maxLength": 10, + "type": "string", + "description": "Some ID", + "name": "some_id", + "in": "query", + "required": true + } +]` + assert.Equal(t, expected, string(b)) + + comment = `@Param some_id query int true "Some ID" MaxLength(10)` + assert.Error(t, operation.ParseComment(comment, nil)) + + comment = `@Param some_id query string true "Some ID" MaxLength(Goopher)` + assert.Error(t, operation.ParseComment(comment, nil)) +} + +func TestParseParamCommentByMinLength(t *testing.T) { + t.Parallel() + + comment := `@Param some_id query string true "Some ID" MinLength(10)` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + + assert.NoError(t, err) + b, _ := json.MarshalIndent(operation.Parameters, "", " ") + expected := `[ + { + "minLength": 10, + "type": "string", + "description": "Some ID", + "name": "some_id", + "in": "query", + "required": true + } +]` + assert.Equal(t, expected, string(b)) + + comment = `@Param some_id query int true "Some ID" MinLength(10)` + assert.Error(t, operation.ParseComment(comment, nil)) + + comment = `@Param some_id query string true "Some ID" MinLength(Goopher)` + assert.Error(t, operation.ParseComment(comment, nil)) +} + +func TestParseParamCommentByMinimum(t *testing.T) { + t.Parallel() + + comment := `@Param some_id query int true "Some ID" Minimum(10)` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + + assert.NoError(t, err) + b, _ := json.MarshalIndent(operation.Parameters, "", " ") + expected := `[ + { + "minimum": 10, + "type": "integer", + "description": "Some ID", + "name": "some_id", + "in": "query", + "required": true + } +]` + assert.Equal(t, expected, string(b)) + + comment = `@Param some_id query int true "Some ID" Mininum(10)` + assert.NoError(t, operation.ParseComment(comment, nil)) + + comment = `@Param some_id query string true "Some ID" Minimum(10)` + assert.Error(t, operation.ParseComment(comment, nil)) + + comment = `@Param some_id query integer true "Some ID" Minimum(Goopher)` + assert.Error(t, operation.ParseComment(comment, nil)) +} + +func TestParseParamCommentByMaximum(t *testing.T) { + t.Parallel() + + comment := `@Param some_id query int true "Some ID" Maximum(10)` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + + assert.NoError(t, err) + b, _ := json.MarshalIndent(operation.Parameters, "", " ") + expected := `[ + { + "maximum": 10, + "type": "integer", + "description": "Some ID", + "name": "some_id", + "in": "query", + "required": true + } +]` + assert.Equal(t, expected, string(b)) + + comment = `@Param some_id query int true "Some ID" Maxinum(10)` + assert.NoError(t, operation.ParseComment(comment, nil)) + + comment = `@Param some_id query string true "Some ID" Maximum(10)` + assert.Error(t, operation.ParseComment(comment, nil)) + + comment = `@Param some_id query integer true "Some ID" Maximum(Goopher)` + assert.Error(t, operation.ParseComment(comment, nil)) +} + +func TestParseParamCommentByDefault(t *testing.T) { + t.Parallel() + + comment := `@Param some_id query int true "Some ID" Default(10)` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + + assert.NoError(t, err) + b, _ := json.MarshalIndent(operation.Parameters, "", " ") + expected := `[ + { + "type": "integer", + "default": 10, + "description": "Some ID", + "name": "some_id", + "in": "query", + "required": true + } +]` + assert.Equal(t, expected, string(b)) +} + +func TestParseParamCommentByExampleInt(t *testing.T) { + t.Parallel() + + comment := `@Param some_id query int true "Some ID" Example(10)` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + + assert.NoError(t, err) + b, _ := json.MarshalIndent(operation.Parameters, "", " ") + expected := `[ + { + "type": "integer", + "example": 10, + "description": "Some ID", + "name": "some_id", + "in": "query", + "required": true + } +]` + assert.Equal(t, expected, string(b)) +} + +func TestParseParamCommentByExampleString(t *testing.T) { + t.Parallel() + + comment := `@Param some_id query string true "Some ID" Example(True feelings)` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + + assert.NoError(t, err) + b, _ := json.MarshalIndent(operation.Parameters, "", " ") + expected := `[ + { + "type": "string", + "example": "True feelings", + "description": "Some ID", + "name": "some_id", + "in": "query", + "required": true + } +]` + assert.Equal(t, expected, string(b)) +} + +func TestParseParamCommentByExampleUnsupportedType(t *testing.T) { + t.Parallel() + var param spec.Parameter + + setExample(¶m, "something", "random value") + assert.Equal(t, param.Example, nil) + + setExample(¶m, STRING, "string value") + assert.Equal(t, param.Example, "string value") + + setExample(¶m, INTEGER, "10") + assert.Equal(t, param.Example, 10) + + setExample(¶m, NUMBER, "10") + assert.Equal(t, param.Example, float64(10)) +} + +func TestParseParamCommentBySchemaExampleString(t *testing.T) { + t.Parallel() + + comment := `@Param some_id body string true "Some ID" SchemaExample(True feelings)` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + + assert.NoError(t, err) + b, _ := json.MarshalIndent(operation.Parameters, "", " ") + expected := `[ + { + "description": "Some ID", + "name": "some_id", + "in": "body", + "required": true, + "schema": { + "type": "string", + "example": "True feelings" + } + } +]` + assert.Equal(t, expected, string(b)) +} + +func TestParseParamCommentBySchemaExampleUnsupportedType(t *testing.T) { + t.Parallel() + var param spec.Parameter + + setSchemaExample(¶m, "something", "random value") + assert.Nil(t, param.Schema) + + setSchemaExample(¶m, STRING, "string value") + assert.Nil(t, param.Schema) + + param.Schema = &spec.Schema{} + setSchemaExample(¶m, STRING, "string value") + assert.Equal(t, "string value", param.Schema.Example) + + setSchemaExample(¶m, INTEGER, "10") + assert.Equal(t, 10, param.Schema.Example) + + setSchemaExample(¶m, NUMBER, "10") + assert.Equal(t, float64(10), param.Schema.Example) + + setSchemaExample(¶m, STRING, "string \\r\\nvalue") + assert.Equal(t, "string \r\nvalue", param.Schema.Example) +} + +func TestParseParamArrayWithEnums(t *testing.T) { + t.Parallel() + + comment := `@Param field query []string true "An enum collection" collectionFormat(csv) enums(also,valid)` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + + assert.NoError(t, err) + b, _ := json.MarshalIndent(operation.Parameters, "", " ") + expected := `[ + { + "type": "array", + "items": { + "enum": [ + "also", + "valid" + ], + "type": "string" + }, + "collectionFormat": "csv", + "description": "An enum collection", + "name": "field", + "in": "query", + "required": true + } +]` + assert.Equal(t, expected, string(b)) +} + +func TestParseAndExtractionParamAttribute(t *testing.T) { + t.Parallel() + + op := NewOperation(nil) + numberParam := spec.Parameter{} + err := op.parseParamAttribute( + " default(1) maximum(100) minimum(0) format(csv)", + "", + NUMBER, + "", + &numberParam, + ) + assert.NoError(t, err) + assert.Equal(t, float64(0), *numberParam.Minimum) + assert.Equal(t, float64(100), *numberParam.Maximum) + assert.Equal(t, "csv", numberParam.SimpleSchema.Format) + assert.Equal(t, float64(1), numberParam.Default) + + err = op.parseParamAttribute(" minlength(1)", "", NUMBER, "", nil) + assert.Error(t, err) + + err = op.parseParamAttribute(" maxlength(1)", "", NUMBER, "", nil) + assert.Error(t, err) + + stringParam := spec.Parameter{} + err = op.parseParamAttribute( + " default(test) maxlength(100) minlength(0) format(csv)", + "", + STRING, + "", + &stringParam, + ) + assert.NoError(t, err) + assert.Equal(t, int64(0), *stringParam.MinLength) + assert.Equal(t, int64(100), *stringParam.MaxLength) + assert.Equal(t, "csv", stringParam.SimpleSchema.Format) + err = op.parseParamAttribute(" minimum(0)", "", STRING, "", nil) + assert.Error(t, err) + + err = op.parseParamAttribute(" maximum(0)", "", STRING, "", nil) + assert.Error(t, err) + + arrayParram := spec.Parameter{} + err = op.parseParamAttribute(" collectionFormat(tsv)", ARRAY, STRING, "", &arrayParram) + assert.Equal(t, "tsv", arrayParram.CollectionFormat) + assert.NoError(t, err) + + err = op.parseParamAttribute(" collectionFormat(tsv)", STRING, STRING, "", nil) + assert.Error(t, err) + + err = op.parseParamAttribute(" default(0)", "", ARRAY, "", nil) + assert.NoError(t, err) +} + +func TestParseParamCommentByExtensions(t *testing.T) { + comment := `@Param some_id path int true "Some ID" extensions(x-example=test,x-custom=Goopher,x-custom2)` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + + assert.NoError(t, err) + b, _ := json.MarshalIndent(operation.Parameters, "", " ") + expected := `[ + { + "type": "integer", + "x-custom": "Goopher", + "x-custom2": true, + "x-example": "test", + "description": "Some ID", + "name": "some_id", + "in": "path", + "required": true + } +]` + assert.Equal(t, expected, string(b)) +} + +func TestParseParamStructCodeExample(t *testing.T) { + t.Parallel() + + fset := token.NewFileSet() + ast, err := goparser.ParseFile(fset, "operation_test.go", `package swag + import structs "git.ipao.vip/rogeecn/atomctl/pkg/swag/testdata/param_structs" + `, goparser.ParseComments) + assert.NoError(t, err) + + parser := New() + err = parser.parseFile("git.ipao.vip/rogeecn/atomctl/pkg/swag/testdata/param_structs", "testdata/param_structs/structs.go", nil, ParseModels) + assert.NoError(t, err) + _, err = parser.packages.ParseTypes() + assert.NoError(t, err) + + validateParameters := func(operation *Operation, params ...spec.Parameter) { + assert.Equal(t, len(params), len(operation.Parameters)) + + for _, param := range params { + found := false + for _, p := range operation.Parameters { + if p.Name == param.Name { + assert.Equal(t, param.ParamProps, p.ParamProps) + assert.Equal(t, param.CommonValidations, p.CommonValidations) + assert.Equal(t, param.SimpleSchema, p.SimpleSchema) + found = true + break + } + } + assert.True(t, found, "found parameter %s", param.Name) + } + } + + // values used in validation checks + max := float64(10) + maxLen := int64(10) + min := float64(0) + + // query and form behave the same + for _, param := range []string{"query", "formData"} { + t.Run(param+" struct", func(t *testing.T) { + operation := NewOperation(parser) + comment := fmt.Sprintf(`@Param model %s structs.FormModel true "query params"`, param) + err = operation.ParseComment(comment, ast) + assert.NoError(t, err) + + validateParameters(operation, + spec.Parameter{ + ParamProps: spec.ParamProps{ + Name: "f", + Description: "", + In: param, + Required: true, + }, + CommonValidations: spec.CommonValidations{ + MaxLength: &maxLen, + }, + SimpleSchema: spec.SimpleSchema{ + Type: "string", + }, + }, + spec.Parameter{ + ParamProps: spec.ParamProps{ + Name: "b", + Description: "B is another field", + In: param, + }, + SimpleSchema: spec.SimpleSchema{ + Type: "boolean", + }, + }) + }) + } + + t.Run("header struct", func(t *testing.T) { + operation := NewOperation(parser) + comment := `@Param auth header structs.AuthHeader true "auth header"` + err = operation.ParseComment(comment, ast) + assert.NoError(t, err) + + validateParameters(operation, + spec.Parameter{ + ParamProps: spec.ParamProps{ + Name: "X-Auth-Token", + Description: "Token is the auth token", + In: "header", + Required: true, + }, + SimpleSchema: spec.SimpleSchema{ + Type: "string", + }, + }, spec.Parameter{ + ParamProps: spec.ParamProps{ + Name: "anotherHeader", + Description: "AnotherHeader is another header", + In: "header", + }, + CommonValidations: spec.CommonValidations{ + Maximum: &max, + Minimum: &min, + }, + SimpleSchema: spec.SimpleSchema{ + Type: "integer", + }, + }) + }) + + t.Run("path struct", func(t *testing.T) { + operation := NewOperation(parser) + comment := `@Param path path structs.PathModel true "path params"` + err = operation.ParseComment(comment, ast) + assert.NoError(t, err) + + validateParameters(operation, + spec.Parameter{ + ParamProps: spec.ParamProps{ + Name: "id", + Description: "ID is the id", + In: "path", + Required: true, + }, + SimpleSchema: spec.SimpleSchema{ + Type: "integer", + }, + }, spec.Parameter{ + ParamProps: spec.ParamProps{ + Name: "name", + Description: "", + In: "path", + }, + CommonValidations: spec.CommonValidations{ + MaxLength: &maxLen, + }, + SimpleSchema: spec.SimpleSchema{ + Type: "string", + }, + }) + }) +} + +func TestParseIdComment(t *testing.T) { + t.Parallel() + + comment := `@Id myOperationId` + operation := NewOperation(nil) + err := operation.ParseComment(comment, nil) + + assert.NoError(t, err) + assert.Equal(t, "myOperationId", operation.ID) +} + +func TestFindTypeDefCoreLib(t *testing.T) { + t.Parallel() + + s, err := findTypeDef("net/http", "Request") + assert.NoError(t, err) + assert.NotNil(t, s) +} + +func TestFindTypeDefExternalPkg(t *testing.T) { + t.Parallel() + + s, err := findTypeDef("github.com/KyleBanks/depth", "Tree") + assert.NoError(t, err) + assert.NotNil(t, s) +} + +func TestFindTypeDefInvalidPkg(t *testing.T) { + t.Parallel() + + s, err := findTypeDef("does-not-exist", "foo") + assert.Error(t, err) + assert.Nil(t, s) +} + +func TestParseSecurityComment(t *testing.T) { + t.Parallel() + + comment := `@Security OAuth2Implicit[read, write]` + operation := NewOperation(nil) + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + + assert.Equal(t, operation.Security, []map[string][]string{ + { + "OAuth2Implicit": {"read", "write"}, + }, + }) +} + +func TestParseSecurityCommentSimple(t *testing.T) { + t.Parallel() + + comment := `@Security ApiKeyAuth` + operation := NewOperation(nil) + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + + assert.Equal(t, operation.Security, []map[string][]string{ + { + "ApiKeyAuth": {}, + }, + }) +} + +func TestParseSecurityCommentAnd(t *testing.T) { + t.Parallel() + + comment := `@Security OAuth2Implicit[read, write] && Firebase[]` + operation := NewOperation(nil) + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + + expect := []map[string][]string{ + { + "OAuth2Implicit": {"read", "write"}, + "Firebase": {""}, + }, + } + assert.Equal(t, operation.Security, expect) + + oldVersionComment := `@Security OAuth2Implicit[read, write] || Firebase[]` + operation = NewOperation(nil) + err = operation.ParseComment(oldVersionComment, nil) + assert.NoError(t, err) + assert.Equal(t, operation.Security, expect) +} + +func TestParseMultiDescription(t *testing.T) { + t.Parallel() + + comment := `@Description line one` + operation := NewOperation(nil) + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + + comment = `@Tags multi` + err = operation.ParseComment(comment, nil) + assert.NoError(t, err) + + comment = `@Description line two x` + err = operation.ParseComment(comment, nil) + assert.NoError(t, err) + + b, _ := json.MarshalIndent(operation, "", " ") + + expected := `"description": "line one\nline two x"` + assert.Contains(t, string(b), expected) +} + +func TestParseDescriptionMarkdown(t *testing.T) { + t.Parallel() + + operation := NewOperation(nil) + operation.parser.markdownFileDir = "example/markdown" + + comment := `@description.markdown admin.md` + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + + comment = `@description.markdown missing.md` + + err = operation.ParseComment(comment, nil) + assert.Error(t, err) +} + +func TestParseSummary(t *testing.T) { + t.Parallel() + + comment := `@summary line one` + operation := NewOperation(nil) + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + + comment = `@Summary line one` + err = operation.ParseComment(comment, nil) + assert.NoError(t, err) +} + +func TestParseDeprecationDescription(t *testing.T) { + t.Parallel() + + comment := `@Deprecated` + operation := NewOperation(nil) + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + + if !operation.Deprecated { + t.Error("Failed to parse @deprecated comment") + } +} + +func TestParseExtentions(t *testing.T) { + t.Parallel() + // Fail if there are no args for attributes. + { + comment := `@x-amazon-apigateway-integration` + operation := NewOperation(nil) + + err := operation.ParseComment(comment, nil) + assert.EqualError(t, err, "annotation @x-amazon-apigateway-integration need a value") + } + + // Fail if args of attributes are broken. + { + comment := `@x-amazon-apigateway-integration ["broken"}]` + operation := NewOperation(nil) + + err := operation.ParseComment(comment, nil) + assert.EqualError(t, err, "annotation @x-amazon-apigateway-integration need a valid json value") + } + + // OK + { + comment := `@x-amazon-apigateway-integration {"uri": "${some_arn}", "passthroughBehavior": "when_no_match", "httpMethod": "POST", "type": "aws_proxy"}` + operation := NewOperation(nil) + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + assert.Equal(t, operation.Extensions["x-amazon-apigateway-integration"], + map[string]interface{}{ + "httpMethod": "POST", + "passthroughBehavior": "when_no_match", + "type": "aws_proxy", + "uri": "${some_arn}", + }) + } + + // Test x-tagGroups + { + comment := `@x-tagGroups [{"name":"Natural Persons","tags":["Person","PersonRisk","PersonDocuments"]}]` + operation := NewOperation(nil) + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err) + assert.Equal(t, operation.Extensions["x-tagGroups"], + []interface{}{map[string]interface{}{ + "name": "Natural Persons", + "tags": []interface{}{ + "Person", + "PersonRisk", + "PersonDocuments", + }, + }}) + } +} + +func TestFindInSlice(t *testing.T) { + t.Parallel() + + assert.True(t, findInSlice([]string{"one", "two", "tree"}, "one")) + assert.True(t, findInSlice([]string{"tree", "two", "one"}, "one")) + assert.True(t, findInSlice([]string{"two", "one", "tree"}, "one")) + assert.False(t, findInSlice([]string{"one", "two", "tree"}, "four")) +} + +func TestParseResponseHeaderComment(t *testing.T) { + t.Parallel() + + operation := NewOperation(nil) + operation.Responses = &spec.Responses{} + err := operation.ParseResponseComment(`default {string} string "other error"`, nil) + assert.NoError(t, err) + err = operation.ParseResponseHeaderComment(`all {string} Token "qwerty"`, nil) + assert.NoError(t, err) +} + +func TestParseObjectSchema(t *testing.T) { + t.Parallel() + + operation := NewOperation(nil) + + schema, err := operation.parseObjectSchema("interface{}", nil) + assert.NoError(t, err) + assert.Equal(t, schema, &spec.Schema{}) + + schema, err = operation.parseObjectSchema("any", nil) + assert.NoError(t, err) + assert.Equal(t, schema, &spec.Schema{}) + + schema, err = operation.parseObjectSchema("any{data=string}", nil) + assert.NoError(t, err) + assert.Equal(t, schema, + (&spec.Schema{}).WithAllOf(spec.Schema{}, *PrimitiveSchema(OBJECT).SetProperty("data", *PrimitiveSchema("string")))) + + schema, err = operation.parseObjectSchema("int", nil) + assert.NoError(t, err) + assert.Equal(t, schema, PrimitiveSchema(INTEGER)) + + schema, err = operation.parseObjectSchema("[]string", nil) + assert.NoError(t, err) + assert.Equal(t, schema, spec.ArrayProperty(PrimitiveSchema(STRING))) + + schema, err = operation.parseObjectSchema("[]int", nil) + assert.NoError(t, err) + assert.Equal(t, schema, spec.ArrayProperty(PrimitiveSchema(INTEGER))) + + _, err = operation.parseObjectSchema("[]bleah", nil) + assert.Error(t, err) + + schema, err = operation.parseObjectSchema("map[]string", nil) + assert.NoError(t, err) + assert.Equal(t, schema, spec.MapProperty(PrimitiveSchema(STRING))) + + schema, err = operation.parseObjectSchema("map[]int", nil) + assert.NoError(t, err) + assert.Equal(t, schema, spec.MapProperty(PrimitiveSchema(INTEGER))) + + schema, err = operation.parseObjectSchema("map[]interface{}", nil) + assert.NoError(t, err) + assert.Equal(t, schema, spec.MapProperty(nil)) + + _, err = operation.parseObjectSchema("map[string", nil) + assert.Error(t, err) + + _, err = operation.parseObjectSchema("map[]bleah", nil) + assert.Error(t, err) + + operation.parser = New() + operation.parser.packages = &PackagesDefinitions{ + uniqueDefinitions: map[string]*TypeSpecDef{ + "model.User": { + File: &ast.File{ + Name: &ast.Ident{ + Name: "user.go", + }, + }, + TypeSpec: &ast.TypeSpec{ + Name: &ast.Ident{ + Name: "User", + }, + }, + }, + }, + } + _, err = operation.parseObjectSchema("model.User", nil) + assert.NoError(t, err) + + operation.parser = nil + schema, err = operation.parseObjectSchema("user.Model", nil) + assert.NoError(t, err) + assert.Equal(t, schema, RefSchema("user.Model")) +} + +func TestParseCodeSamples(t *testing.T) { + t.Parallel() + const comment = `@x-codeSamples file` + t.Run("Find sample by file", func(t *testing.T) { + operation := NewOperation(nil, SetCodeExampleFilesDirectory("testdata/code_examples")) + operation.Summary = "example" + + err := operation.ParseComment(comment, nil) + assert.NoError(t, err, "no error should be thrown") + assert.Equal(t, operation.Summary, "example") + assert.Equal(t, operation.Extensions["x-codeSamples"], + map[string]interface{}{"lang": "JavaScript", "source": "console.log('Hello World');"}) + }) + + t.Run("With broken file sample", func(t *testing.T) { + operation := NewOperation(nil, SetCodeExampleFilesDirectory("testdata/code_examples")) + operation.Summary = "broken" + + err := operation.ParseComment(comment, nil) + assert.Error(t, err, "no error should be thrown") + }) + + t.Run("Example file not found", func(t *testing.T) { + operation := NewOperation(nil, SetCodeExampleFilesDirectory("testdata/code_examples")) + operation.Summary = "badExample" + + err := operation.ParseComment(comment, nil) + assert.Error(t, err, "error was expected, as file does not exist") + }) + + t.Run("Without line reminder", func(t *testing.T) { + comment := `@x-codeSamples` + operation := NewOperation(nil, SetCodeExampleFilesDirectory("testdata/code_examples")) + operation.Summary = "example" + + err := operation.ParseComment(comment, nil) + assert.Error(t, err, "no error should be thrown") + }) + + t.Run(" broken dir", func(t *testing.T) { + operation := NewOperation(nil, SetCodeExampleFilesDirectory("testdata/fake_examples")) + operation.Summary = "code" + + err := operation.ParseComment(comment, nil) + assert.Error(t, err, "no error should be thrown") + }) +} + +func TestParseDeprecatedRouter(t *testing.T) { + p := New() + searchDir := "./testdata/deprecated_router" + if err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth); err != nil { + t.Error("Failed to parse api: " + err.Error()) + } + + b, _ := json.MarshalIndent(p.swagger, "", " ") + expected, err := os.ReadFile(filepath.Join(searchDir, "expected.json")) + assert.NoError(t, err) + assert.Equal(t, expected, b) +} diff --git a/pkg/swag/package.go b/pkg/swag/package.go new file mode 100644 index 0000000..6c3129e --- /dev/null +++ b/pkg/swag/package.go @@ -0,0 +1,187 @@ +package swag + +import ( + "go/ast" + "go/token" + "reflect" + "strconv" + "strings" +) + +// PackageDefinitions files and definition in a package. +type PackageDefinitions struct { + // files in this package, map key is file's relative path starting package path + Files map[string]*ast.File + + // definitions in this package, map key is typeName + TypeDefinitions map[string]*TypeSpecDef + + // const variables in this package, map key is the name + ConstTable map[string]*ConstVariable + + // const variables in order in this package + OrderedConst []*ConstVariable + + // package name + Name string + + // package path + Path string +} + +// ConstVariableGlobalEvaluator an interface used to evaluate enums across packages +type ConstVariableGlobalEvaluator interface { + EvaluateConstValue(pkg *PackageDefinitions, cv *ConstVariable, recursiveStack map[string]struct{}) (interface{}, ast.Expr) + EvaluateConstValueByName(file *ast.File, pkgPath, constVariableName string, recursiveStack map[string]struct{}) (interface{}, ast.Expr) + FindTypeSpec(typeName string, file *ast.File) *TypeSpecDef +} + +// NewPackageDefinitions new a PackageDefinitions object +func NewPackageDefinitions(name, pkgPath string) *PackageDefinitions { + return &PackageDefinitions{ + Name: name, + Path: pkgPath, + Files: make(map[string]*ast.File), + TypeDefinitions: make(map[string]*TypeSpecDef), + ConstTable: make(map[string]*ConstVariable), + } +} + +// AddFile add a file +func (pkg *PackageDefinitions) AddFile(pkgPath string, file *ast.File) *PackageDefinitions { + pkg.Files[pkgPath] = file + return pkg +} + +// AddTypeSpec add a type spec. +func (pkg *PackageDefinitions) AddTypeSpec(name string, typeSpec *TypeSpecDef) *PackageDefinitions { + pkg.TypeDefinitions[name] = typeSpec + return pkg +} + +// AddConst add a const variable. +func (pkg *PackageDefinitions) AddConst(astFile *ast.File, valueSpec *ast.ValueSpec) *PackageDefinitions { + for i := 0; i < len(valueSpec.Names) && i < len(valueSpec.Values); i++ { + variable := &ConstVariable{ + Name: valueSpec.Names[i], + Type: valueSpec.Type, + Value: valueSpec.Values[i], + Comment: valueSpec.Comment, + File: astFile, + } + pkg.ConstTable[valueSpec.Names[i].Name] = variable + pkg.OrderedConst = append(pkg.OrderedConst, variable) + } + return pkg +} + +func (pkg *PackageDefinitions) evaluateConstValue(file *ast.File, iota int, expr ast.Expr, globalEvaluator ConstVariableGlobalEvaluator, recursiveStack map[string]struct{}) (interface{}, ast.Expr) { + switch valueExpr := expr.(type) { + case *ast.Ident: + if valueExpr.Name == "iota" { + return iota, nil + } + if pkg.ConstTable != nil { + if cv, ok := pkg.ConstTable[valueExpr.Name]; ok { + return globalEvaluator.EvaluateConstValue(pkg, cv, recursiveStack) + } + } + case *ast.SelectorExpr: + pkgIdent, ok := valueExpr.X.(*ast.Ident) + if !ok { + return nil, nil + } + return globalEvaluator.EvaluateConstValueByName(file, pkgIdent.Name, valueExpr.Sel.Name, recursiveStack) + case *ast.BasicLit: + switch valueExpr.Kind { + case token.INT: + // handle underscored number, such as 1_000_000 + if strings.ContainsRune(valueExpr.Value, '_') { + valueExpr.Value = strings.Replace(valueExpr.Value, "_", "", -1) + } + if len(valueExpr.Value) >= 2 && valueExpr.Value[0] == '0' { + var start, base = 2, 8 + switch valueExpr.Value[1] { + case 'x', 'X': + //hex + base = 16 + case 'b', 'B': + //binary + base = 2 + default: + //octet + start = 1 + } + if x, err := strconv.ParseInt(valueExpr.Value[start:], base, 64); err == nil { + return int(x), nil + } else if x, err := strconv.ParseUint(valueExpr.Value[start:], base, 64); err == nil { + return x, nil + } else { + panic(err) + } + } + + //a basic literal integer is int type in default, or must have an explicit converting type in front + if x, err := strconv.ParseInt(valueExpr.Value, 10, 64); err == nil { + return int(x), nil + } else if x, err := strconv.ParseUint(valueExpr.Value, 10, 64); err == nil { + return x, nil + } else { + panic(err) + } + case token.STRING: + if valueExpr.Value[0] == '`' { + return valueExpr.Value[1 : len(valueExpr.Value)-1], nil + } + return EvaluateEscapedString(valueExpr.Value[1 : len(valueExpr.Value)-1]), nil + case token.CHAR: + return EvaluateEscapedChar(valueExpr.Value[1 : len(valueExpr.Value)-1]), nil + } + case *ast.UnaryExpr: + x, evalType := pkg.evaluateConstValue(file, iota, valueExpr.X, globalEvaluator, recursiveStack) + if x == nil { + return x, evalType + } + return EvaluateUnary(x, valueExpr.Op, evalType) + case *ast.BinaryExpr: + x, evalTypex := pkg.evaluateConstValue(file, iota, valueExpr.X, globalEvaluator, recursiveStack) + y, evalTypey := pkg.evaluateConstValue(file, iota, valueExpr.Y, globalEvaluator, recursiveStack) + if x == nil || y == nil { + return nil, nil + } + return EvaluateBinary(x, y, valueExpr.Op, evalTypex, evalTypey) + case *ast.ParenExpr: + return pkg.evaluateConstValue(file, iota, valueExpr.X, globalEvaluator, recursiveStack) + case *ast.CallExpr: + //data conversion + if len(valueExpr.Args) != 1 { + return nil, nil + } + arg := valueExpr.Args[0] + if ident, ok := valueExpr.Fun.(*ast.Ident); ok { + name := ident.Name + if name == "uintptr" { + name = "uint" + } + value, _ := pkg.evaluateConstValue(file, iota, arg, globalEvaluator, recursiveStack) + if IsGolangPrimitiveType(name) { + value = EvaluateDataConversion(value, name) + return value, nil + } else if name == "len" { + return reflect.ValueOf(value).Len(), nil + } + typeDef := globalEvaluator.FindTypeSpec(name, file) + if typeDef == nil { + return nil, nil + } + return value, valueExpr.Fun + } else if selector, ok := valueExpr.Fun.(*ast.SelectorExpr); ok { + typeDef := globalEvaluator.FindTypeSpec(fullTypeName(selector.X.(*ast.Ident).Name, selector.Sel.Name), file) + if typeDef == nil { + return nil, nil + } + return arg, typeDef.TypeSpec.Type + } + } + return nil, nil +} diff --git a/pkg/swag/packages.go b/pkg/swag/packages.go new file mode 100644 index 0000000..f63b112 --- /dev/null +++ b/pkg/swag/packages.go @@ -0,0 +1,613 @@ +package swag + +import ( + "fmt" + "go/ast" + goparser "go/parser" + "go/token" + "os" + "path/filepath" + "runtime" + "sort" + "strings" + + "golang.org/x/tools/go/loader" +) + +// PackagesDefinitions map[package import path]*PackageDefinitions. +type PackagesDefinitions struct { + files map[*ast.File]*AstFileInfo + packages map[string]*PackageDefinitions + uniqueDefinitions map[string]*TypeSpecDef + parseDependency ParseFlag + debug Debugger +} + +// NewPackagesDefinitions create object PackagesDefinitions. +func NewPackagesDefinitions() *PackagesDefinitions { + return &PackagesDefinitions{ + files: make(map[*ast.File]*AstFileInfo), + packages: make(map[string]*PackageDefinitions), + uniqueDefinitions: make(map[string]*TypeSpecDef), + } +} + +// ParseFile parse a source file. +func (pkgDefs *PackagesDefinitions) ParseFile(packageDir, path string, src interface{}, flag ParseFlag) error { + // positions are relative to FileSet + fileSet := token.NewFileSet() + astFile, err := goparser.ParseFile(fileSet, path, src, goparser.ParseComments) + if err != nil { + return fmt.Errorf("failed to parse file %s, error:%+v", path, err) + } + return pkgDefs.collectAstFile(fileSet, packageDir, path, astFile, flag) +} + +// collectAstFile collect ast.file. +func (pkgDefs *PackagesDefinitions) collectAstFile(fileSet *token.FileSet, packageDir, path string, astFile *ast.File, flag ParseFlag) error { + if pkgDefs.files == nil { + pkgDefs.files = make(map[*ast.File]*AstFileInfo) + } + + if pkgDefs.packages == nil { + pkgDefs.packages = make(map[string]*PackageDefinitions) + } + + // return without storing the file if we lack a packageDir + if packageDir == "" { + return nil + } + + path, err := filepath.Abs(path) + if err != nil { + return err + } + + dependency, ok := pkgDefs.packages[packageDir] + if ok { + // return without storing the file if it already exists + _, exists := dependency.Files[path] + if exists { + return nil + } + + dependency.Files[path] = astFile + } else { + pkgDefs.packages[packageDir] = NewPackageDefinitions(astFile.Name.Name, packageDir).AddFile(path, astFile) + } + + pkgDefs.files[astFile] = &AstFileInfo{ + FileSet: fileSet, + File: astFile, + Path: path, + PackagePath: packageDir, + ParseFlag: flag, + } + + return nil +} + +// RangeFiles for range the collection of ast.File in alphabetic order. +func (pkgDefs *PackagesDefinitions) RangeFiles(handle func(info *AstFileInfo) error) error { + sortedFiles := make([]*AstFileInfo, 0, len(pkgDefs.files)) + for _, info := range pkgDefs.files { + // ignore package path prefix with 'vendor' or $GOROOT, + // because the router info of api will not be included these files. + if strings.HasPrefix(info.PackagePath, "vendor") || (runtime.GOROOT() != "" && strings.HasPrefix(info.Path, runtime.GOROOT()+string(filepath.Separator))) { + continue + } + sortedFiles = append(sortedFiles, info) + } + + sort.Slice(sortedFiles, func(i, j int) bool { + return strings.Compare(sortedFiles[i].Path, sortedFiles[j].Path) < 0 + }) + + for _, info := range sortedFiles { + err := handle(info) + if err != nil { + return err + } + } + + return nil +} + +// ParseTypes parse types +// @Return parsed definitions. +func (pkgDefs *PackagesDefinitions) ParseTypes() (map[*TypeSpecDef]*Schema, error) { + parsedSchemas := make(map[*TypeSpecDef]*Schema) + for astFile, info := range pkgDefs.files { + pkgDefs.parseTypesFromFile(astFile, info.PackagePath, parsedSchemas) + pkgDefs.parseFunctionScopedTypesFromFile(astFile, info.PackagePath, parsedSchemas) + } + pkgDefs.removeAllNotUniqueTypes() + pkgDefs.evaluateAllConstVariables() + pkgDefs.collectConstEnums(parsedSchemas) + return parsedSchemas, nil +} + +func (pkgDefs *PackagesDefinitions) parseTypesFromFile(astFile *ast.File, packagePath string, parsedSchemas map[*TypeSpecDef]*Schema) { + for _, astDeclaration := range astFile.Decls { + generalDeclaration, ok := astDeclaration.(*ast.GenDecl) + if !ok { + continue + } + if generalDeclaration.Tok == token.TYPE { + for _, astSpec := range generalDeclaration.Specs { + if typeSpec, ok := astSpec.(*ast.TypeSpec); ok { + typeSpecDef := &TypeSpecDef{ + PkgPath: packagePath, + File: astFile, + TypeSpec: typeSpec, + } + + if idt, ok := typeSpec.Type.(*ast.Ident); ok && IsGolangPrimitiveType(idt.Name) && parsedSchemas != nil { + parsedSchemas[typeSpecDef] = &Schema{ + PkgPath: typeSpecDef.PkgPath, + Name: astFile.Name.Name, + Schema: TransToValidPrimitiveSchema(idt.Name), + } + } + + if pkgDefs.uniqueDefinitions == nil { + pkgDefs.uniqueDefinitions = make(map[string]*TypeSpecDef) + } + + fullName := typeSpecDef.TypeName() + + anotherTypeDef, ok := pkgDefs.uniqueDefinitions[fullName] + if ok { + if anotherTypeDef == nil { + typeSpecDef.NotUnique = true + fullName = typeSpecDef.TypeName() + pkgDefs.uniqueDefinitions[fullName] = typeSpecDef + } else if typeSpecDef.PkgPath != anotherTypeDef.PkgPath { + pkgDefs.uniqueDefinitions[fullName] = nil + anotherTypeDef.NotUnique = true + pkgDefs.uniqueDefinitions[anotherTypeDef.TypeName()] = anotherTypeDef + anotherTypeDef.SetSchemaName() + + typeSpecDef.NotUnique = true + fullName = typeSpecDef.TypeName() + pkgDefs.uniqueDefinitions[fullName] = typeSpecDef + } + } else { + pkgDefs.uniqueDefinitions[fullName] = typeSpecDef + } + + typeSpecDef.SetSchemaName() + + if pkgDefs.packages[typeSpecDef.PkgPath] == nil { + pkgDefs.packages[typeSpecDef.PkgPath] = NewPackageDefinitions(astFile.Name.Name, typeSpecDef.PkgPath).AddTypeSpec(typeSpecDef.Name(), typeSpecDef) + } else if _, ok = pkgDefs.packages[typeSpecDef.PkgPath].TypeDefinitions[typeSpecDef.Name()]; !ok { + pkgDefs.packages[typeSpecDef.PkgPath].AddTypeSpec(typeSpecDef.Name(), typeSpecDef) + } + } + } + } else if generalDeclaration.Tok == token.CONST { + // collect consts + pkgDefs.collectConstVariables(astFile, packagePath, generalDeclaration) + } + } +} + +func (pkgDefs *PackagesDefinitions) parseFunctionScopedTypesFromFile(astFile *ast.File, packagePath string, parsedSchemas map[*TypeSpecDef]*Schema) { + for _, astDeclaration := range astFile.Decls { + funcDeclaration, ok := astDeclaration.(*ast.FuncDecl) + if ok && funcDeclaration.Body != nil { + functionScopedTypes := make(map[string]*TypeSpecDef) + for _, stmt := range funcDeclaration.Body.List { + if declStmt, ok := (stmt).(*ast.DeclStmt); ok { + if genDecl, ok := (declStmt.Decl).(*ast.GenDecl); ok && genDecl.Tok == token.TYPE { + for _, astSpec := range genDecl.Specs { + if typeSpec, ok := astSpec.(*ast.TypeSpec); ok { + typeSpecDef := &TypeSpecDef{ + PkgPath: packagePath, + File: astFile, + TypeSpec: typeSpec, + ParentSpec: astDeclaration, + } + + if idt, ok := typeSpec.Type.(*ast.Ident); ok && IsGolangPrimitiveType(idt.Name) && parsedSchemas != nil { + parsedSchemas[typeSpecDef] = &Schema{ + PkgPath: typeSpecDef.PkgPath, + Name: astFile.Name.Name, + Schema: TransToValidPrimitiveSchema(idt.Name), + } + } + + fullName := typeSpecDef.TypeName() + if structType, ok := typeSpecDef.TypeSpec.Type.(*ast.StructType); ok { + for _, field := range structType.Fields.List { + var idt *ast.Ident + var ok bool + switch field.Type.(type) { + case *ast.Ident: + idt, ok = field.Type.(*ast.Ident) + case *ast.StarExpr: + idt, ok = field.Type.(*ast.StarExpr).X.(*ast.Ident) + case *ast.ArrayType: + idt, ok = field.Type.(*ast.ArrayType).Elt.(*ast.Ident) + } + if ok && !IsGolangPrimitiveType(idt.Name) { + if functype, ok := functionScopedTypes[idt.Name]; ok { + idt.Name = functype.TypeName() + } + } + } + } + + if pkgDefs.uniqueDefinitions == nil { + pkgDefs.uniqueDefinitions = make(map[string]*TypeSpecDef) + } + + anotherTypeDef, ok := pkgDefs.uniqueDefinitions[fullName] + if ok { + if anotherTypeDef == nil { + typeSpecDef.NotUnique = true + fullName = typeSpecDef.TypeName() + pkgDefs.uniqueDefinitions[fullName] = typeSpecDef + } else if typeSpecDef.PkgPath != anotherTypeDef.PkgPath { + pkgDefs.uniqueDefinitions[fullName] = nil + anotherTypeDef.NotUnique = true + pkgDefs.uniqueDefinitions[anotherTypeDef.TypeName()] = anotherTypeDef + anotherTypeDef.SetSchemaName() + + typeSpecDef.NotUnique = true + fullName = typeSpecDef.TypeName() + pkgDefs.uniqueDefinitions[fullName] = typeSpecDef + } + } else { + pkgDefs.uniqueDefinitions[fullName] = typeSpecDef + functionScopedTypes[typeSpec.Name.Name] = typeSpecDef + } + + typeSpecDef.SetSchemaName() + + if pkgDefs.packages[typeSpecDef.PkgPath] == nil { + pkgDefs.packages[typeSpecDef.PkgPath] = NewPackageDefinitions(astFile.Name.Name, typeSpecDef.PkgPath).AddTypeSpec(fullName, typeSpecDef) + } else if _, ok = pkgDefs.packages[typeSpecDef.PkgPath].TypeDefinitions[fullName]; !ok { + pkgDefs.packages[typeSpecDef.PkgPath].AddTypeSpec(fullName, typeSpecDef) + } + } + } + + } + } + } + } + } +} + +func (pkgDefs *PackagesDefinitions) collectConstVariables(astFile *ast.File, packagePath string, generalDeclaration *ast.GenDecl) { + pkg, ok := pkgDefs.packages[packagePath] + if !ok { + pkg = NewPackageDefinitions(astFile.Name.Name, packagePath) + pkgDefs.packages[packagePath] = pkg + } + + var lastValueSpec *ast.ValueSpec + for _, astSpec := range generalDeclaration.Specs { + valueSpec, ok := astSpec.(*ast.ValueSpec) + if !ok { + continue + } + if len(valueSpec.Names) == 1 && len(valueSpec.Values) == 1 { + lastValueSpec = valueSpec + } else if len(valueSpec.Names) == 1 && len(valueSpec.Values) == 0 && valueSpec.Type == nil && lastValueSpec != nil { + valueSpec.Type = lastValueSpec.Type + valueSpec.Values = lastValueSpec.Values + } + pkg.AddConst(astFile, valueSpec) + } +} + +func (pkgDefs *PackagesDefinitions) evaluateAllConstVariables() { + for _, pkg := range pkgDefs.packages { + for _, constVar := range pkg.OrderedConst { + pkgDefs.EvaluateConstValue(pkg, constVar, nil) + } + } +} + +// EvaluateConstValue evaluate a const variable. +func (pkgDefs *PackagesDefinitions) EvaluateConstValue(pkg *PackageDefinitions, cv *ConstVariable, recursiveStack map[string]struct{}) (interface{}, ast.Expr) { + if expr, ok := cv.Value.(ast.Expr); ok { + defer func() { + if err := recover(); err != nil { + if fi, ok := pkgDefs.files[cv.File]; ok { + pos := fi.FileSet.Position(cv.Name.NamePos) + pkgDefs.debug.Printf("warning: failed to evaluate const %s at %s:%d:%d, %v", cv.Name.Name, fi.Path, pos.Line, pos.Column, err) + } + } + }() + if recursiveStack == nil { + recursiveStack = make(map[string]struct{}) + } + fullConstName := fullTypeName(pkg.Path, cv.Name.Name) + if _, ok = recursiveStack[fullConstName]; ok { + return nil, nil + } + recursiveStack[fullConstName] = struct{}{} + + value, evalType := pkg.evaluateConstValue(cv.File, cv.Name.Obj.Data.(int), expr, pkgDefs, recursiveStack) + if cv.Type == nil && evalType != nil { + cv.Type = evalType + } + if value != nil { + cv.Value = value + } + return value, cv.Type + } + return cv.Value, cv.Type +} + +// EvaluateConstValueByName evaluate a const variable by name. +func (pkgDefs *PackagesDefinitions) EvaluateConstValueByName(file *ast.File, pkgName, constVariableName string, recursiveStack map[string]struct{}) (interface{}, ast.Expr) { + matchedPkgPaths, externalPkgPaths := pkgDefs.findPackagePathFromImports(pkgName, file) + for _, pkgPath := range matchedPkgPaths { + if pkg, ok := pkgDefs.packages[pkgPath]; ok { + if cv, ok := pkg.ConstTable[constVariableName]; ok { + return pkgDefs.EvaluateConstValue(pkg, cv, recursiveStack) + } + } + } + if pkgDefs.parseDependency > 0 { + for _, pkgPath := range externalPkgPaths { + if err := pkgDefs.loadExternalPackage(pkgPath); err == nil { + if pkg, ok := pkgDefs.packages[pkgPath]; ok { + if cv, ok := pkg.ConstTable[constVariableName]; ok { + return pkgDefs.EvaluateConstValue(pkg, cv, recursiveStack) + } + } + } + } + } + return nil, nil +} + +func (pkgDefs *PackagesDefinitions) collectConstEnums(parsedSchemas map[*TypeSpecDef]*Schema) { + for _, pkg := range pkgDefs.packages { + for _, constVar := range pkg.OrderedConst { + if constVar.Type == nil { + continue + } + ident, ok := constVar.Type.(*ast.Ident) + if !ok || IsGolangPrimitiveType(ident.Name) { + continue + } + typeDef, ok := pkg.TypeDefinitions[ident.Name] + if !ok { + continue + } + + // delete it from parsed schemas, and will parse it again + if _, ok = parsedSchemas[typeDef]; ok { + delete(parsedSchemas, typeDef) + } + + if typeDef.Enums == nil { + typeDef.Enums = make([]EnumValue, 0) + } + + name := constVar.Name.Name + if _, ok = constVar.Value.(ast.Expr); ok { + continue + } + + enumValue := EnumValue{ + key: name, + Value: constVar.Value, + } + if constVar.Comment != nil && len(constVar.Comment.List) > 0 { + enumValue.Comment = constVar.Comment.List[0].Text + enumValue.Comment = strings.TrimPrefix(enumValue.Comment, "//") + enumValue.Comment = strings.TrimPrefix(enumValue.Comment, "/*") + enumValue.Comment = strings.TrimSuffix(enumValue.Comment, "*/") + enumValue.Comment = strings.TrimSpace(enumValue.Comment) + } + typeDef.Enums = append(typeDef.Enums, enumValue) + } + } +} + +func (pkgDefs *PackagesDefinitions) removeAllNotUniqueTypes() { + for key, ud := range pkgDefs.uniqueDefinitions { + if ud == nil { + delete(pkgDefs.uniqueDefinitions, key) + } + } +} + +func (pkgDefs *PackagesDefinitions) findTypeSpec(pkgPath string, typeName string) *TypeSpecDef { + if pkgDefs.packages == nil { + return nil + } + + pd, found := pkgDefs.packages[pkgPath] + if found { + typeSpec, ok := pd.TypeDefinitions[typeName] + if ok { + return typeSpec + } + } + + return nil +} + +func (pkgDefs *PackagesDefinitions) loadExternalPackage(importPath string) error { + cwd, err := os.Getwd() + if err != nil { + return err + } + + conf := loader.Config{ + ParserMode: goparser.ParseComments, + Cwd: cwd, + } + + conf.Import(importPath) + + loaderProgram, err := conf.Load() + if err != nil { + return err + } + + for _, info := range loaderProgram.AllPackages { + pkgPath := strings.TrimPrefix(info.Pkg.Path(), "vendor/") + for _, astFile := range info.Files { + pkgDefs.parseTypesFromFile(astFile, pkgPath, nil) + } + } + + return nil +} + +// findPackagePathFromImports finds out the package path of a package via ranging imports of an ast.File +// @pkg the name of the target package +// @file current ast.File in which to search imports +// @return the package paths of a package of @pkg. +func (pkgDefs *PackagesDefinitions) findPackagePathFromImports(pkg string, file *ast.File) (matchedPkgPaths, externalPkgPaths []string) { + if file == nil { + return + } + + if strings.ContainsRune(pkg, '.') { + pkg = strings.Split(pkg, ".")[0] + } + + matchLastPathPart := func(pkgPath string) bool { + paths := strings.Split(pkgPath, "/") + return paths[len(paths)-1] == pkg + } + + // prior to match named package + for _, imp := range file.Imports { + path := strings.Trim(imp.Path.Value, `"`) + if imp.Name != nil { + if imp.Name.Name == pkg { + // if name match, break loop and return + _, ok := pkgDefs.packages[path] + if ok { + matchedPkgPaths = []string{path} + externalPkgPaths = nil + } else { + externalPkgPaths = []string{path} + matchedPkgPaths = nil + } + break + } else if imp.Name.Name == "_" && len(pkg) > 0 { + // for unused types + pd, ok := pkgDefs.packages[path] + if ok { + if pd.Name == pkg { + matchedPkgPaths = append(matchedPkgPaths, path) + } + } else if matchLastPathPart(path) { + externalPkgPaths = append(externalPkgPaths, path) + } + } else if imp.Name.Name == "." && len(pkg) == 0 { + _, ok := pkgDefs.packages[path] + if ok { + matchedPkgPaths = append(matchedPkgPaths, path) + } else if len(pkg) == 0 || matchLastPathPart(path) { + externalPkgPaths = append(externalPkgPaths, path) + } + } + } else if pkgDefs.packages != nil && len(pkg) > 0 { + pd, ok := pkgDefs.packages[path] + if ok { + if pd.Name == pkg { + matchedPkgPaths = append(matchedPkgPaths, path) + } + } else if matchLastPathPart(path) { + externalPkgPaths = append(externalPkgPaths, path) + } + } + } + + if len(pkg) == 0 || file.Name.Name == pkg { + matchedPkgPaths = append(matchedPkgPaths, pkgDefs.files[file].PackagePath) + } + + return +} + +func (pkgDefs *PackagesDefinitions) findTypeSpecFromPackagePaths(matchedPkgPaths, externalPkgPaths []string, name string) (typeDef *TypeSpecDef) { + if pkgDefs.parseDependency > 0 { + for _, pkgPath := range externalPkgPaths { + if err := pkgDefs.loadExternalPackage(pkgPath); err == nil { + typeDef = pkgDefs.findTypeSpec(pkgPath, name) + if typeDef != nil { + return typeDef + } + } + } + } + + for _, pkgPath := range matchedPkgPaths { + typeDef = pkgDefs.findTypeSpec(pkgPath, name) + if typeDef != nil { + return typeDef + } + } + + return typeDef +} + +// FindTypeSpec finds out TypeSpecDef of a type by typeName +// @typeName the name of the target type, if it starts with a package name, find its own package path from imports on top of @file +// @file the ast.file in which @typeName is used +// @pkgPath the package path of @file. +func (pkgDefs *PackagesDefinitions) FindTypeSpec(typeName string, file *ast.File) *TypeSpecDef { + if IsGolangPrimitiveType(typeName) { + return nil + } + + if file == nil { // for test + return pkgDefs.uniqueDefinitions[typeName] + } + + parts := strings.Split(strings.Split(typeName, "[")[0], ".") + if len(parts) > 1 { + pkgPaths, externalPkgPaths := pkgDefs.findPackagePathFromImports(parts[0], file) + if len(externalPkgPaths) == 0 || pkgDefs.parseDependency == ParseNone { + typeDef, ok := pkgDefs.uniqueDefinitions[typeName] + if ok { + return typeDef + } + } + typeDef := pkgDefs.findTypeSpecFromPackagePaths(pkgPaths, externalPkgPaths, parts[1]) + return pkgDefs.parametrizeGenericType(file, typeDef, typeName) + } + + typeDef, ok := pkgDefs.uniqueDefinitions[fullTypeName(file.Name.Name, typeName)] + if ok { + return typeDef + } + + name := parts[0] + typeDef, ok = pkgDefs.uniqueDefinitions[fullTypeName(file.Name.Name, name)] + if !ok { + pkgPaths, externalPkgPaths := pkgDefs.findPackagePathFromImports("", file) + typeDef = pkgDefs.findTypeSpecFromPackagePaths(pkgPaths, externalPkgPaths, name) + } + + if typeDef != nil { + return pkgDefs.parametrizeGenericType(file, typeDef, typeName) + } + + // in case that comment //@name renamed the type with a name without a dot + for k, v := range pkgDefs.uniqueDefinitions { + if v == nil { + pkgDefs.debug.Printf("%s TypeSpecDef is nil", k) + continue + } + if v.SchemaName == typeName { + return v + } + } + + return nil +} diff --git a/pkg/swag/packages_test.go b/pkg/swag/packages_test.go new file mode 100644 index 0000000..fa576bd --- /dev/null +++ b/pkg/swag/packages_test.go @@ -0,0 +1,262 @@ +package swag + +import ( + "go/ast" + "go/token" + "path/filepath" + "runtime" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPackagesDefinitions_ParseFile(t *testing.T) { + pd := PackagesDefinitions{} + packageDir := "git.ipao.vip/rogeecn/atomctl/pkg/swag/testdata/simple" + assert.NoError(t, pd.ParseFile(packageDir, "testdata/simple/main.go", nil, ParseAll)) + assert.Equal(t, 1, len(pd.packages)) + assert.Equal(t, 1, len(pd.files)) +} + +func TestPackagesDefinitions_collectAstFile(t *testing.T) { + pd := PackagesDefinitions{} + fileSet := token.NewFileSet() + assert.NoError(t, pd.collectAstFile(fileSet, "", "", nil, ParseAll)) + + firstFile := &ast.File{ + Name: &ast.Ident{Name: "main.go"}, + } + + packageDir := "git.ipao.vip/rogeecn/atomctl/pkg/swag/testdata/simple" + assert.NoError(t, pd.collectAstFile(fileSet, packageDir, "testdata/simple/"+firstFile.Name.String(), firstFile, ParseAll)) + assert.NotEmpty(t, pd.packages[packageDir]) + + absPath, _ := filepath.Abs("testdata/simple/" + firstFile.Name.String()) + astFileInfo := &AstFileInfo{ + FileSet: fileSet, + File: firstFile, + Path: absPath, + PackagePath: packageDir, + ParseFlag: ParseAll, + } + assert.Equal(t, pd.files[firstFile], astFileInfo) + + // Override + assert.NoError(t, pd.collectAstFile(fileSet, packageDir, "testdata/simple/"+firstFile.Name.String(), firstFile, ParseAll)) + assert.Equal(t, pd.files[firstFile], astFileInfo) + + // Another file + secondFile := &ast.File{ + Name: &ast.Ident{Name: "api.go"}, + } + assert.NoError(t, pd.collectAstFile(fileSet, packageDir, "testdata/simple/"+secondFile.Name.String(), secondFile, ParseAll)) +} + +func TestPackagesDefinitions_rangeFiles(t *testing.T) { + pd := PackagesDefinitions{ + files: map[*ast.File]*AstFileInfo{ + { + Name: &ast.Ident{Name: "main.go"}, + }: { + File: &ast.File{Name: &ast.Ident{Name: "main.go"}}, + Path: "testdata/simple/main.go", + PackagePath: "main", + }, + { + Name: &ast.Ident{Name: "api.go"}, + }: { + File: &ast.File{Name: &ast.Ident{Name: "api.go"}}, + Path: "testdata/simple/api/api.go", + PackagePath: "api", + }, + }, + } + + i, expect := 0, []string{"testdata/simple/api/api.go", "testdata/simple/main.go"} + _ = pd.RangeFiles(func(fileInfo *AstFileInfo) error { + assert.Equal(t, expect[i], fileInfo.Path) + i++ + return nil + }) +} + +func TestPackagesDefinitions_ParseTypes(t *testing.T) { + absPath, _ := filepath.Abs("") + + mainAST := ast.File{ + Name: &ast.Ident{Name: "main.go"}, + Decls: []ast.Decl{ + &ast.GenDecl{ + Tok: token.TYPE, + Specs: []ast.Spec{ + &ast.TypeSpec{ + Name: &ast.Ident{Name: "Test"}, + Type: &ast.Ident{ + Name: "string", + }, + }, + }, + }, + }, + } + + pd := PackagesDefinitions{ + files: map[*ast.File]*AstFileInfo{ + &mainAST: { + File: &mainAST, + Path: filepath.Join(absPath, "testdata/simple/main.go"), + PackagePath: "main", + }, + { + Name: &ast.Ident{Name: "api.go"}, + }: { + File: &ast.File{Name: &ast.Ident{Name: "api.go"}}, + Path: filepath.Join(absPath, "testdata/simple/api/api.go"), + PackagePath: "api", + }, + }, + packages: make(map[string]*PackageDefinitions), + } + + _, err := pd.ParseTypes() + assert.NoError(t, err) +} + +func TestPackagesDefinitions_parseFunctionScopedTypesFromFile(t *testing.T) { + mainAST := &ast.File{ + Name: &ast.Ident{Name: "main.go"}, + Decls: []ast.Decl{ + &ast.FuncDecl{ + Name: ast.NewIdent("TestFuncDecl"), + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.DeclStmt{ + Decl: &ast.GenDecl{ + Tok: token.TYPE, + Specs: []ast.Spec{ + &ast.TypeSpec{ + Name: ast.NewIdent("response"), + Type: ast.NewIdent("struct"), + }, + &ast.TypeSpec{ + Name: ast.NewIdent("stringResponse"), + Type: ast.NewIdent("string"), + }, + }, + }, + }, + }, + }, + }, + }, + } + + pd := PackagesDefinitions{ + packages: make(map[string]*PackageDefinitions), + } + + parsedSchema := make(map[*TypeSpecDef]*Schema) + pd.parseFunctionScopedTypesFromFile(mainAST, "main", parsedSchema) + + assert.Len(t, parsedSchema, 1) + + _, ok := pd.uniqueDefinitions["main.go.TestFuncDecl.response"] + assert.True(t, ok) + + _, ok = pd.packages["main"].TypeDefinitions["main.go.TestFuncDecl.response"] + assert.True(t, ok) +} + +func TestPackagesDefinitions_FindTypeSpec(t *testing.T) { + userDef := TypeSpecDef{ + File: &ast.File{ + Name: &ast.Ident{Name: "user.go"}, + }, + TypeSpec: &ast.TypeSpec{ + Name: ast.NewIdent("User"), + }, + PkgPath: "user", + } + pkg := PackagesDefinitions{ + uniqueDefinitions: map[string]*TypeSpecDef{ + "user.Model": &userDef, + }, + } + + var nilDef *TypeSpecDef + assert.Equal(t, nilDef, pkg.FindTypeSpec("int", nil)) + assert.Equal(t, nilDef, pkg.FindTypeSpec("bool", nil)) + assert.Equal(t, nilDef, pkg.FindTypeSpec("string", nil)) + + assert.Equal(t, &userDef, pkg.FindTypeSpec("user.Model", nil)) + assert.Equal(t, nilDef, pkg.FindTypeSpec("Model", nil)) +} + +func TestPackage_rangeFiles(t *testing.T) { + pd := NewPackagesDefinitions() + pd.files = map[*ast.File]*AstFileInfo{ + { + Name: &ast.Ident{Name: "main.go"}, + }: { + File: &ast.File{Name: &ast.Ident{Name: "main.go"}}, + Path: "testdata/simple/main.go", + PackagePath: "main", + }, + { + Name: &ast.Ident{Name: "api.go"}, + }: { + File: &ast.File{Name: &ast.Ident{Name: "api.go"}}, + Path: "testdata/simple/api/api.go", + PackagePath: "api", + }, + { + Name: &ast.Ident{Name: "foo.go"}, + }: { + File: &ast.File{Name: &ast.Ident{Name: "foo.go"}}, + Path: "vendor/foo/foo.go", + PackagePath: "vendor/foo", + }, + { + Name: &ast.Ident{Name: "bar.go"}, + }: { + File: &ast.File{Name: &ast.Ident{Name: "bar.go"}}, + Path: filepath.Join(runtime.GOROOT(), "bar.go"), + PackagePath: "bar", + }, + } + + var sorted []string + processor := func(fileInfo *AstFileInfo) error { + sorted = append(sorted, fileInfo.Path) + return nil + } + assert.NoError(t, pd.RangeFiles(processor)) + assert.Equal(t, []string{"testdata/simple/api/api.go", "testdata/simple/main.go"}, sorted) + + assert.Error(t, pd.RangeFiles(func(fileInfo *AstFileInfo) error { + return ErrFuncTypeField + })) +} + +func TestPackagesDefinitions_findTypeSpec(t *testing.T) { + pd := PackagesDefinitions{} + var nilTypeSpec *TypeSpecDef + assert.Equal(t, nilTypeSpec, pd.findTypeSpec("model", "User")) + + userTypeSpec := TypeSpecDef{ + File: &ast.File{}, + TypeSpec: &ast.TypeSpec{}, + PkgPath: "model", + } + pd = PackagesDefinitions{ + packages: map[string]*PackageDefinitions{ + "model": { + TypeDefinitions: map[string]*TypeSpecDef{ + "User": &userTypeSpec, + }, + }, + }, + } + assert.Equal(t, &userTypeSpec, pd.findTypeSpec("model", "User")) + assert.Equal(t, nilTypeSpec, pd.findTypeSpec("others", "User")) +} diff --git a/pkg/swag/parser.go b/pkg/swag/parser.go new file mode 100644 index 0000000..7a0b59e --- /dev/null +++ b/pkg/swag/parser.go @@ -0,0 +1,1963 @@ +package swag + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "go/ast" + "go/build" + goparser "go/parser" + "go/token" + "log" + "net/http" + "os" + "os/exec" + "path/filepath" + "reflect" + "sort" + "strconv" + "strings" + + "github.com/KyleBanks/depth" + "github.com/go-openapi/spec" +) + +const ( + // CamelCase indicates using CamelCase strategy for struct field. + CamelCase = "camelcase" + + // PascalCase indicates using PascalCase strategy for struct field. + PascalCase = "pascalcase" + + // SnakeCase indicates using SnakeCase strategy for struct field. + SnakeCase = "snakecase" + + idAttr = "@id" + acceptAttr = "@accept" + produceAttr = "@produce" + paramAttr = "@param" + successAttr = "@success" + failureAttr = "@failure" + responseAttr = "@response" + headerAttr = "@header" + tagsAttr = "@tags" + routerAttr = "@router" + deprecatedRouterAttr = "@deprecatedrouter" + summaryAttr = "@summary" + deprecatedAttr = "@deprecated" + securityAttr = "@security" + titleAttr = "@title" + conNameAttr = "@contact.name" + conURLAttr = "@contact.url" + conEmailAttr = "@contact.email" + licNameAttr = "@license.name" + licURLAttr = "@license.url" + versionAttr = "@version" + descriptionAttr = "@description" + descriptionMarkdownAttr = "@description.markdown" + secBasicAttr = "@securitydefinitions.basic" + secAPIKeyAttr = "@securitydefinitions.apikey" + secApplicationAttr = "@securitydefinitions.oauth2.application" + secImplicitAttr = "@securitydefinitions.oauth2.implicit" + secPasswordAttr = "@securitydefinitions.oauth2.password" + secAccessCodeAttr = "@securitydefinitions.oauth2.accesscode" + tosAttr = "@termsofservice" + extDocsDescAttr = "@externaldocs.description" + extDocsURLAttr = "@externaldocs.url" + xCodeSamplesAttr = "@x-codesamples" + scopeAttrPrefix = "@scope." + stateAttr = "@state" +) + +// ParseFlag determine what to parse +type ParseFlag int + +const ( + // ParseNone parse nothing + ParseNone ParseFlag = 0x00 + // ParseModels parse models + ParseModels = 0x01 + // ParseOperations parse operations + ParseOperations = 0x02 + // ParseAll parse operations and models + ParseAll = ParseOperations | ParseModels +) + +var ( + // ErrRecursiveParseStruct recursively parsing struct. + ErrRecursiveParseStruct = errors.New("recursively parsing struct") + + // ErrFuncTypeField field type is func. + ErrFuncTypeField = errors.New("field type is func") + + // ErrFailedConvertPrimitiveType Failed to convert for swag to interpretable type. + ErrFailedConvertPrimitiveType = errors.New("swag property: failed convert primitive type") + + // ErrSkippedField .swaggo specifies field should be skipped. + ErrSkippedField = errors.New("field is skipped by global overrides") +) + +var allMethod = map[string]struct{}{ + http.MethodGet: {}, + http.MethodPut: {}, + http.MethodPost: {}, + http.MethodDelete: {}, + http.MethodOptions: {}, + http.MethodHead: {}, + http.MethodPatch: {}, +} + +// Parser implements a parser for Go source files. +type Parser struct { + // swagger represents the root document object for the API specification + swagger *spec.Swagger + + // packages store entities of APIs, definitions, file, package path etc. and their relations + packages *PackagesDefinitions + + // parsedSchemas store schemas which have been parsed from ast.TypeSpec + parsedSchemas map[*TypeSpecDef]*Schema + + // outputSchemas store schemas which will be export to swagger + outputSchemas map[*TypeSpecDef]*Schema + + // PropNamingStrategy naming strategy + PropNamingStrategy string + + // ParseVendor parse vendor folder + ParseVendor bool + + // ParseDependencies whether swag should be parse outside dependency folder: 0 none, 1 models, 2 operations, 3 all + ParseDependency ParseFlag + + // ParseInternal whether swag should parse internal packages + ParseInternal bool + + // Strict whether swag should error or warn when it detects cases which are most likely user errors + Strict bool + + // RequiredByDefault set validation required for all fields by default + RequiredByDefault bool + + // structStack stores full names of the structures that were already parsed or are being parsed now + structStack []*TypeSpecDef + + // markdownFileDir holds the path to the folder, where markdown files are stored + markdownFileDir string + + // codeExampleFilesDir holds path to the folder, where code example files are stored + codeExampleFilesDir string + + // collectionFormatInQuery set the default collectionFormat otherwise then 'csv' for array in query params + collectionFormatInQuery string + + // excludes excludes dirs and files in SearchDir + excludes map[string]struct{} + + // packagePrefix is a list of package path prefixes, packages that do not + // match any one of them will be excluded when searching. + packagePrefix []string + + // tells parser to include only specific extension + parseExtension string + + // debugging output goes here + debug Debugger + + // fieldParserFactory create FieldParser + fieldParserFactory FieldParserFactory + + // Overrides allows global replacements of types. A blank replacement will be skipped. + Overrides map[string]string + + // parseGoList whether swag use go list to parse dependency + parseGoList bool + + // tags to filter the APIs after + tags map[string]struct{} + + // HostState is the state of the host + HostState string + + // ParseFuncBody whether swag should parse api info inside of funcs + ParseFuncBody bool +} + +// FieldParserFactory create FieldParser. +type FieldParserFactory func(ps *Parser, field *ast.Field) FieldParser + +// FieldParser parse struct field. +type FieldParser interface { + ShouldSkip() bool + FieldNames() ([]string, error) + FormName() string + HeaderName() string + PathName() string + CustomSchema() (*spec.Schema, error) + ComplementSchema(schema *spec.Schema) error + IsRequired() (bool, error) +} + +// Debugger is the interface that wraps the basic Printf method. +type Debugger interface { + Printf(format string, v ...interface{}) +} + +// New creates a new Parser with default properties. +func New(options ...func(*Parser)) *Parser { + parser := &Parser{ + swagger: &spec.Swagger{ + SwaggerProps: spec.SwaggerProps{ + Info: &spec.Info{ + InfoProps: spec.InfoProps{ + Contact: &spec.ContactInfo{}, + License: nil, + }, + VendorExtensible: spec.VendorExtensible{ + Extensions: spec.Extensions{}, + }, + }, + Paths: &spec.Paths{ + Paths: make(map[string]spec.PathItem), + VendorExtensible: spec.VendorExtensible{ + Extensions: nil, + }, + }, + Definitions: make(map[string]spec.Schema), + SecurityDefinitions: make(map[string]*spec.SecurityScheme), + }, + VendorExtensible: spec.VendorExtensible{ + Extensions: nil, + }, + }, + packages: NewPackagesDefinitions(), + debug: log.New(os.Stdout, "", log.LstdFlags), + parsedSchemas: make(map[*TypeSpecDef]*Schema), + outputSchemas: make(map[*TypeSpecDef]*Schema), + excludes: make(map[string]struct{}), + tags: make(map[string]struct{}), + fieldParserFactory: newTagBaseFieldParser, + Overrides: make(map[string]string), + } + + for _, option := range options { + option(parser) + } + + parser.packages.debug = parser.debug + + return parser +} + +// SetParseDependency sets whether to parse the dependent packages. +func SetParseDependency(parseDependency int) func(*Parser) { + return func(p *Parser) { + p.ParseDependency = ParseFlag(parseDependency) + if p.packages != nil { + p.packages.parseDependency = p.ParseDependency + } + } +} + +// SetMarkdownFileDirectory sets the directory to search for markdown files. +func SetMarkdownFileDirectory(directoryPath string) func(*Parser) { + return func(p *Parser) { + p.markdownFileDir = directoryPath + } +} + +// SetCodeExamplesDirectory sets the directory to search for code example files. +func SetCodeExamplesDirectory(directoryPath string) func(*Parser) { + return func(p *Parser) { + p.codeExampleFilesDir = directoryPath + } +} + +// SetExcludedDirsAndFiles sets directories and files to be excluded when searching. +func SetExcludedDirsAndFiles(excludes string) func(*Parser) { + return func(p *Parser) { + for _, f := range strings.Split(excludes, ",") { + f = strings.TrimSpace(f) + if f != "" { + f = filepath.Clean(f) + p.excludes[f] = struct{}{} + } + } + } +} + +// SetPackagePrefix sets a list of package path prefixes from a comma-separated +// string, packages that do not match any one of them will be excluded when +// searching. +func SetPackagePrefix(packagePrefix string) func(*Parser) { + return func(p *Parser) { + for _, f := range strings.Split(packagePrefix, ",") { + f = strings.TrimSpace(f) + if f != "" { + p.packagePrefix = append(p.packagePrefix, f) + } + } + } +} + +// SetTags sets the tags to be included +func SetTags(include string) func(*Parser) { + return func(p *Parser) { + for _, f := range strings.Split(include, ",") { + f = strings.TrimSpace(f) + if f != "" { + p.tags[f] = struct{}{} + } + } + } +} + +// SetParseExtension parses only those operations which match given extension +func SetParseExtension(parseExtension string) func(*Parser) { + return func(p *Parser) { + p.parseExtension = parseExtension + } +} + +// SetStrict sets whether swag should error or warn when it detects cases which are most likely user errors. +func SetStrict(strict bool) func(*Parser) { + return func(p *Parser) { + p.Strict = strict + } +} + +// SetDebugger allows the use of user-defined implementations. +func SetDebugger(logger Debugger) func(parser *Parser) { + return func(p *Parser) { + if logger != nil { + p.debug = logger + } + } +} + +// SetFieldParserFactory allows the use of user-defined implementations. +func SetFieldParserFactory(factory FieldParserFactory) func(parser *Parser) { + return func(p *Parser) { + p.fieldParserFactory = factory + } +} + +// SetOverrides allows the use of user-defined global type overrides. +func SetOverrides(overrides map[string]string) func(parser *Parser) { + return func(p *Parser) { + for k, v := range overrides { + p.Overrides[k] = v + } + } +} + +// SetCollectionFormat set default collection format +func SetCollectionFormat(collectionFormat string) func(*Parser) { + return func(p *Parser) { + p.collectionFormatInQuery = collectionFormat + } +} + +// ParseUsingGoList sets whether swag use go list to parse dependency +func ParseUsingGoList(enabled bool) func(parser *Parser) { + return func(p *Parser) { + p.parseGoList = enabled + } +} + +// ParseAPI parses general api info for given searchDir and mainAPIFile. +func (parser *Parser) ParseAPI(searchDir string, mainAPIFile string, parseDepth int) error { + return parser.ParseAPIMultiSearchDir([]string{searchDir}, mainAPIFile, parseDepth) +} + +// skipPackageByPrefix returns true the given pkgpath does not match +// any user-defined package path prefixes. +func (parser *Parser) skipPackageByPrefix(pkgpath string) bool { + if len(parser.packagePrefix) == 0 { + return false + } + for _, prefix := range parser.packagePrefix { + if strings.HasPrefix(pkgpath, prefix) { + return false + } + } + return true +} + +// ParseAPIMultiSearchDir is like ParseAPI but for multiple search dirs. +func (parser *Parser) ParseAPIMultiSearchDir(searchDirs []string, mainAPIFile string, parseDepth int) error { + for _, searchDir := range searchDirs { + parser.debug.Printf("Generate general API Info, search dir:%s", searchDir) + + packageDir, err := getPkgName(searchDir) + if err != nil { + parser.debug.Printf("warning: failed to get package name in dir: %s, error: %s", searchDir, err.Error()) + } + + err = parser.getAllGoFileInfo(packageDir, searchDir) + if err != nil { + return err + } + } + + absMainAPIFilePath, err := filepath.Abs(filepath.Join(searchDirs[0], mainAPIFile)) + if err != nil { + return err + } + + // Use 'go list' command instead of depth.Resolve() + if parser.ParseDependency > 0 { + if parser.parseGoList { + pkgs, err := listPackages(context.Background(), filepath.Dir(absMainAPIFilePath), nil, "-deps") + if err != nil { + return fmt.Errorf("pkg %s cannot find all dependencies, %s", filepath.Dir(absMainAPIFilePath), err) + } + + length := len(pkgs) + for i := 0; i < length; i++ { + err := parser.getAllGoFileInfoFromDepsByList(pkgs[i], parser.ParseDependency) + if err != nil { + return err + } + } + } else { + var t depth.Tree + t.ResolveInternal = true + t.MaxDepth = parseDepth + + pkgName, err := getPkgName(filepath.Dir(absMainAPIFilePath)) + if err != nil { + return err + } + + err = t.Resolve(pkgName) + if err != nil { + return fmt.Errorf("pkg %s cannot find all dependencies, %s", pkgName, err) + } + for i := 0; i < len(t.Root.Deps); i++ { + err := parser.getAllGoFileInfoFromDeps(&t.Root.Deps[i], parser.ParseDependency) + if err != nil { + return err + } + } + } + } + + err = parser.ParseGeneralAPIInfo(absMainAPIFilePath) + if err != nil { + return err + } + + parser.parsedSchemas, err = parser.packages.ParseTypes() + if err != nil { + return err + } + + err = parser.packages.RangeFiles(parser.ParseRouterAPIInfo) + if err != nil { + return err + } + + return parser.checkOperationIDUniqueness() +} + +func getPkgName(searchDir string) (string, error) { + cmd := exec.Command("go", "list", "-f={{.ImportPath}}") + cmd.Dir = searchDir + + var stdout, stderr strings.Builder + + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("execute go list command, %s, stdout:%s, stderr:%s", err, stdout.String(), stderr.String()) + } + + outStr, _ := stdout.String(), stderr.String() + + if outStr[0] == '_' { // will shown like _/{GOPATH}/src/{YOUR_PACKAGE} when NOT enable GO MODULE. + outStr = strings.TrimPrefix(outStr, "_"+build.Default.GOPATH+"/src/") + } + + f := strings.Split(outStr, "\n") + + outStr = f[0] + + return outStr, nil +} + +// ParseGeneralAPIInfo parses general api info for given mainAPIFile path. +func (parser *Parser) ParseGeneralAPIInfo(mainAPIFile string) error { + fileTree, err := goparser.ParseFile(token.NewFileSet(), mainAPIFile, nil, goparser.ParseComments) + if err != nil { + return fmt.Errorf("cannot parse source files %s: %s", mainAPIFile, err) + } + + parser.swagger.Swagger = "2.0" + + for _, comment := range fileTree.Comments { + comments := strings.Split(comment.Text(), "\n") + if !isGeneralAPIComment(comments) { + continue + } + + err = parseGeneralAPIInfo(parser, comments) + if err != nil { + return err + } + } + + return nil +} + +func parseGeneralAPIInfo(parser *Parser, comments []string) error { + previousAttribute := "" + var tag *spec.Tag + // parsing classic meta data model + for line := 0; line < len(comments); line++ { + commentLine := comments[line] + commentLine = strings.TrimSpace(commentLine) + if len(commentLine) == 0 { + continue + } + fields := FieldsByAnySpace(commentLine, 2) + + attribute := fields[0] + var value string + if len(fields) > 1 { + value = fields[1] + } + + switch attr := strings.ToLower(attribute); attr { + case versionAttr, titleAttr, tosAttr, licNameAttr, licURLAttr, conNameAttr, conURLAttr, conEmailAttr: + setSwaggerInfo(parser.swagger, attr, value) + case descriptionAttr: + if previousAttribute == attribute { + parser.swagger.Info.Description += "\n" + value + + continue + } + + setSwaggerInfo(parser.swagger, attr, value) + case descriptionMarkdownAttr: + commentInfo, err := getMarkdownForTag("api", parser.markdownFileDir) + if err != nil { + return err + } + + setSwaggerInfo(parser.swagger, descriptionAttr, string(commentInfo)) + + case "@host": + parser.swagger.Host = value + case "@hoststate": + fields = FieldsByAnySpace(commentLine, 3) + if len(fields) != 3 { + return fmt.Errorf("%s needs 3 arguments", attribute) + } + if parser.HostState == fields[1] { + parser.swagger.Host = fields[2] + } + case "@basepath": + parser.swagger.BasePath = value + + case acceptAttr: + err := parser.ParseAcceptComment(value) + if err != nil { + return err + } + case produceAttr: + err := parser.ParseProduceComment(value) + if err != nil { + return err + } + case "@schemes": + parser.swagger.Schemes = strings.Split(value, " ") + case "@tag.name": + if parser.matchTag(value) { + parser.swagger.Tags = append(parser.swagger.Tags, spec.Tag{ + TagProps: spec.TagProps{ + Name: value, + }, + }) + tag = &parser.swagger.Tags[len(parser.swagger.Tags)-1] + } else { + tag = nil + } + case "@tag.description": + if tag != nil { + tag.TagProps.Description = value + } + case "@tag.description.markdown": + if tag != nil { + commentInfo, err := getMarkdownForTag(tag.TagProps.Name, parser.markdownFileDir) + if err != nil { + return err + } + + tag.TagProps.Description = string(commentInfo) + } + case "@tag.docs.url": + if tag != nil { + tag.TagProps.ExternalDocs = &spec.ExternalDocumentation{ + URL: value, + } + } + case "@tag.docs.description": + if tag != nil { + if tag.TagProps.ExternalDocs == nil { + return fmt.Errorf("%s needs to come after a @tags.docs.url", attribute) + } + + tag.TagProps.ExternalDocs.Description = value + } + case secBasicAttr, secAPIKeyAttr, secApplicationAttr, secImplicitAttr, secPasswordAttr, secAccessCodeAttr: + scheme, err := parseSecAttributes(attribute, comments, &line) + if err != nil { + return err + } + + parser.swagger.SecurityDefinitions[value] = scheme + + case securityAttr: + parser.swagger.Security = append(parser.swagger.Security, parseSecurity(value)) + + case "@query.collection.format": + parser.collectionFormatInQuery = TransToValidCollectionFormat(value) + + case extDocsDescAttr, extDocsURLAttr: + if parser.swagger.ExternalDocs == nil { + parser.swagger.ExternalDocs = new(spec.ExternalDocumentation) + } + switch attr { + case extDocsDescAttr: + parser.swagger.ExternalDocs.Description = value + case extDocsURLAttr: + parser.swagger.ExternalDocs.URL = value + } + + default: + if strings.HasPrefix(attribute, "@x-") { + extensionName := attribute[1:] + + extExistsInSecurityDef := false + // for each security definition + for _, v := range parser.swagger.SecurityDefinitions { + // check if extension exists + _, extExistsInSecurityDef = v.VendorExtensible.Extensions.GetString(extensionName) + // if it exists in at least one, then we stop iterating + if extExistsInSecurityDef { + break + } + } + + // if it is present on security def, don't add it again + if extExistsInSecurityDef { + break + } + + if len(value) == 0 { + return fmt.Errorf("annotation %s need a value", attribute) + } + + var valueJSON interface{} + err := json.Unmarshal([]byte(value), &valueJSON) + if err != nil { + return fmt.Errorf("annotation %s need a valid json value", attribute) + } + + if strings.Contains(extensionName, "logo") { + parser.swagger.Info.Extensions.Add(extensionName, valueJSON) + } else { + if parser.swagger.Extensions == nil { + parser.swagger.Extensions = make(map[string]interface{}) + } + + parser.swagger.Extensions[attribute[1:]] = valueJSON + } + } else if strings.HasPrefix(attribute, "@tag.x-") { + extensionName := attribute[5:] + + if len(value) == 0 { + return fmt.Errorf("annotation %s need a value", attribute) + } + + if tag.Extensions == nil { + tag.Extensions = make(map[string]interface{}) + } + + // tag.Extensions.Add(extensionName, value) works wrong (transforms extensionName to lower case) + // needed to save case for ReDoc + // https://redocly.com/docs/api-reference-docs/specification-extensions/x-display-name/ + tag.Extensions[extensionName] = value + } + } + + previousAttribute = attribute + } + + return nil +} + +func setSwaggerInfo(swagger *spec.Swagger, attribute, value string) { + switch attribute { + case versionAttr: + swagger.Info.Version = value + case titleAttr: + swagger.Info.Title = value + case tosAttr: + swagger.Info.TermsOfService = value + case descriptionAttr: + swagger.Info.Description = value + case conNameAttr: + swagger.Info.Contact.Name = value + case conEmailAttr: + swagger.Info.Contact.Email = value + case conURLAttr: + swagger.Info.Contact.URL = value + case licNameAttr: + swagger.Info.License = initIfEmpty(swagger.Info.License) + swagger.Info.License.Name = value + case licURLAttr: + swagger.Info.License = initIfEmpty(swagger.Info.License) + swagger.Info.License.URL = value + } +} + +func parseSecAttributes(context string, lines []string, index *int) (*spec.SecurityScheme, error) { + const ( + in = "@in" + name = "@name" + descriptionAttr = "@description" + tokenURL = "@tokenurl" + authorizationURL = "@authorizationurl" + ) + + var search []string + + attribute := strings.ToLower(FieldsByAnySpace(lines[*index], 2)[0]) + switch attribute { + case secBasicAttr: + return spec.BasicAuth(), nil + case secAPIKeyAttr: + search = []string{in, name} + case secApplicationAttr, secPasswordAttr: + search = []string{tokenURL} + case secImplicitAttr: + search = []string{authorizationURL} + case secAccessCodeAttr: + search = []string{tokenURL, authorizationURL} + } + + // For the first line we get the attributes in the context parameter, so we skip to the next one + *index++ + + attrMap, scopes := make(map[string]string), make(map[string]string) + extensions, description := make(map[string]interface{}), "" + +loopline: + for ; *index < len(lines); *index++ { + v := strings.TrimSpace(lines[*index]) + if len(v) == 0 { + continue + } + + fields := FieldsByAnySpace(v, 2) + securityAttr := strings.ToLower(fields[0]) + var value string + if len(fields) > 1 { + value = fields[1] + } + + for _, findterm := range search { + if securityAttr == findterm { + attrMap[securityAttr] = value + continue loopline + } + } + + if isExists, err := isExistsScope(securityAttr); err != nil { + return nil, err + } else if isExists { + scopes[securityAttr[len(scopeAttrPrefix):]] = value + continue + } + + if strings.HasPrefix(securityAttr, "@x-") { + // Add the custom attribute without the @ + extensions[securityAttr[1:]] = value + continue + } + + // Not mandatory field + if securityAttr == descriptionAttr { + if description != "" { + description += "\n" + } + description += value + } + + // next securityDefinitions + if strings.Index(securityAttr, "@securitydefinitions.") == 0 { + // Go back to the previous line and break + *index-- + + break + } + } + + if len(attrMap) != len(search) { + return nil, fmt.Errorf("%s is %v required", context, search) + } + + var scheme *spec.SecurityScheme + + switch attribute { + case secAPIKeyAttr: + scheme = spec.APIKeyAuth(attrMap[name], attrMap[in]) + case secApplicationAttr: + scheme = spec.OAuth2Application(attrMap[tokenURL]) + case secImplicitAttr: + scheme = spec.OAuth2Implicit(attrMap[authorizationURL]) + case secPasswordAttr: + scheme = spec.OAuth2Password(attrMap[tokenURL]) + case secAccessCodeAttr: + scheme = spec.OAuth2AccessToken(attrMap[authorizationURL], attrMap[tokenURL]) + } + + scheme.Description = description + + for extKey, extValue := range extensions { + scheme.AddExtension(extKey, extValue) + } + + for scope, scopeDescription := range scopes { + scheme.AddScope(scope, scopeDescription) + } + + return scheme, nil +} + +func parseSecurity(commentLine string) map[string][]string { + securityMap := make(map[string][]string) + + for _, securityOption := range securityPairSepPattern.Split(commentLine, -1) { + securityOption = strings.TrimSpace(securityOption) + + left, right := strings.Index(securityOption, "["), strings.Index(securityOption, "]") + + if !(left == -1 && right == -1) { + scopes := securityOption[left+1 : right] + + var options []string + + for _, scope := range strings.Split(scopes, ",") { + options = append(options, strings.TrimSpace(scope)) + } + + securityKey := securityOption[0:left] + securityMap[securityKey] = append(securityMap[securityKey], options...) + } else { + securityKey := strings.TrimSpace(securityOption) + securityMap[securityKey] = []string{} + } + } + + return securityMap +} + +func initIfEmpty(license *spec.License) *spec.License { + if license == nil { + return new(spec.License) + } + + return license +} + +// ParseAcceptComment parses comment for given `accept` comment string. +func (parser *Parser) ParseAcceptComment(commentLine string) error { + return parseMimeTypeList(commentLine, &parser.swagger.Consumes, "%v accept type can't be accepted") +} + +// ParseProduceComment parses comment for given `produce` comment string. +func (parser *Parser) ParseProduceComment(commentLine string) error { + return parseMimeTypeList(commentLine, &parser.swagger.Produces, "%v produce type can't be accepted") +} + +func isGeneralAPIComment(comments []string) bool { + for _, commentLine := range comments { + commentLine = strings.TrimSpace(commentLine) + if len(commentLine) == 0 { + continue + } + attribute := strings.ToLower(FieldsByAnySpace(commentLine, 2)[0]) + switch attribute { + // The @summary, @router, @success, @failure annotation belongs to Operation + case summaryAttr, routerAttr, successAttr, failureAttr, responseAttr: + return false + } + } + + return true +} + +func getMarkdownForTag(tagName string, dirPath string) ([]byte, error) { + if tagName == "" { + // this happens when parsing the @description.markdown attribute + // it will be called properly another time with tagName="api" + // so we can safely return an empty byte slice here + return make([]byte, 0), nil + } + + dirEntries, err := os.ReadDir(dirPath) + if err != nil { + return nil, err + } + + for _, entry := range dirEntries { + if entry.IsDir() { + continue + } + + fileName := entry.Name() + + expectedFileName := tagName + if !strings.HasSuffix(tagName, ".md") { + expectedFileName = tagName + ".md" + } + + if fileName == expectedFileName { + fullPath := filepath.Join(dirPath, fileName) + + commentInfo, err := os.ReadFile(fullPath) + if err != nil { + return nil, fmt.Errorf("Failed to read markdown file %s error: %s ", fullPath, err) + } + + return commentInfo, nil + } + } + + return nil, fmt.Errorf("Unable to find markdown file for tag %s in the given directory", tagName) +} + +func isExistsScope(scope string) (bool, error) { + s := strings.Fields(scope) + for _, v := range s { + if strings.HasPrefix(v, scopeAttrPrefix) { + if strings.Contains(v, ",") { + return false, fmt.Errorf("@scope can't use comma(,) get=" + v) + } + } + } + + return strings.HasPrefix(scope, scopeAttrPrefix), nil +} + +func getTagsFromComment(comment string) (tags []string) { + commentLine := strings.TrimSpace(strings.TrimLeft(comment, "/")) + if len(commentLine) == 0 { + return nil + } + + attribute := strings.Fields(commentLine)[0] + lineRemainder, lowerAttribute := strings.TrimSpace(commentLine[len(attribute):]), strings.ToLower(attribute) + + if lowerAttribute == tagsAttr { + for _, tag := range strings.Split(lineRemainder, ",") { + tags = append(tags, strings.TrimSpace(tag)) + } + } + return + +} + +func (parser *Parser) matchTag(tag string) bool { + if len(parser.tags) == 0 { + return true + } + + if _, has := parser.tags["!"+tag]; has { + return false + } + if _, has := parser.tags[tag]; has { + return true + } + + // If all tags are negation then we should return true + for key := range parser.tags { + if key[0] != '!' { + return false + } + } + return true +} + +func (parser *Parser) matchTags(comments []*ast.Comment) (match bool) { + if len(parser.tags) == 0 { + return true + } + + match = false + for _, comment := range comments { + for _, tag := range getTagsFromComment(comment.Text) { + if _, has := parser.tags["!"+tag]; has { + return false + } + if _, has := parser.tags[tag]; has { + match = true // keep iterating as it may contain a tag that is excluded + } + } + } + + if !match { + // If all tags are negation then we should return true + for key := range parser.tags { + if key[0] != '!' { + return false + } + } + } + return true +} + +func matchExtension(extensionToMatch string, comments []*ast.Comment) (match bool) { + if len(extensionToMatch) != 0 { + for _, comment := range comments { + commentLine := strings.TrimSpace(strings.TrimLeft(comment.Text, "/")) + fields := FieldsByAnySpace(commentLine, 2) + if len(fields) > 0 { + lowerAttribute := strings.ToLower(fields[0]) + + if lowerAttribute == fmt.Sprintf("@x-%s", strings.ToLower(extensionToMatch)) { + return true + } + } + } + return false + } + return true +} + +func getFuncDoc(decl any) (*ast.CommentGroup, bool) { + switch astDecl := decl.(type) { + case *ast.FuncDecl: // func name() {} + return astDecl.Doc, true + case *ast.GenDecl: // var name = namePointToFuncDirectlyOrIndirectly + if astDecl.Tok != token.VAR { + return nil, false + } + varSpec, ok := astDecl.Specs[0].(*ast.ValueSpec) + if !ok || len(varSpec.Values) != 1 { + return nil, false + } + _, ok = getFuncDoc(varSpec) + return astDecl.Doc, ok + case *ast.ValueSpec: + value, ok := astDecl.Values[0].(*ast.Ident) + if !ok || value == nil { + return nil, false + } + _, ok = getFuncDoc(value.Obj.Decl) + return astDecl.Doc, ok + } + return nil, false +} + +// ParseRouterAPIInfo parses router api info for given astFile. +func (parser *Parser) ParseRouterAPIInfo(fileInfo *AstFileInfo) error { + if (fileInfo.ParseFlag & ParseOperations) == ParseNone { + return nil + } + + // parse File.Comments instead of File.Decls.Doc if ParseFuncBody flag set to "true" + if parser.ParseFuncBody { + for _, astComments := range fileInfo.File.Comments { + if astComments.List != nil { + if err := parser.parseRouterAPIInfoComment(astComments.List, fileInfo); err != nil { + return err + } + } + } + + return nil + } + + for _, decl := range fileInfo.File.Decls { + funcDoc, ok := getFuncDoc(decl) + if ok && funcDoc != nil && funcDoc.List != nil { + if err := parser.parseRouterAPIInfoComment(funcDoc.List, fileInfo); err != nil { + return err + } + } + } + + return nil +} + +func (parser *Parser) parseRouterAPIInfoComment(comments []*ast.Comment, fileInfo *AstFileInfo) error { + if parser.matchTags(comments) && matchExtension(parser.parseExtension, comments) { + // for per 'function' comment, create a new 'Operation' object + operation := NewOperation(parser, SetCodeExampleFilesDirectory(parser.codeExampleFilesDir)) + for _, comment := range comments { + err := operation.ParseComment(comment.Text, fileInfo.File) + if err != nil { + return fmt.Errorf("ParseComment error in file %s for comment: '%s': %+v", fileInfo.Path, comment.Text, err) + } + if operation.State != "" && operation.State != parser.HostState { + return nil + } + } + err := processRouterOperation(parser, operation) + if err != nil { + return err + } + } + + return nil +} + +func refRouteMethodOp(item *spec.PathItem, method string) (op **spec.Operation) { + switch method { + case http.MethodGet: + op = &item.Get + case http.MethodPost: + op = &item.Post + case http.MethodDelete: + op = &item.Delete + case http.MethodPut: + op = &item.Put + case http.MethodPatch: + op = &item.Patch + case http.MethodHead: + op = &item.Head + case http.MethodOptions: + op = &item.Options + } + + return +} + +func processRouterOperation(parser *Parser, operation *Operation) error { + for _, routeProperties := range operation.RouterProperties { + var ( + pathItem spec.PathItem + ok bool + ) + + pathItem, ok = parser.swagger.Paths.Paths[routeProperties.Path] + if !ok { + pathItem = spec.PathItem{} + } + + op := refRouteMethodOp(&pathItem, routeProperties.HTTPMethod) + + // check if we already have an operation for this path and method + if *op != nil { + err := fmt.Errorf("route %s %s is declared multiple times", routeProperties.HTTPMethod, routeProperties.Path) + if parser.Strict { + return err + } + + parser.debug.Printf("warning: %s\n", err) + } + + if len(operation.RouterProperties) > 1 { + newOp := *operation + var validParams []spec.Parameter + for _, param := range newOp.Operation.OperationProps.Parameters { + if param.In == "path" && !strings.Contains(routeProperties.Path, param.Name) { + // This path param is not actually contained in the path, skip adding it to the final params + continue + } + validParams = append(validParams, param) + } + newOp.Operation.OperationProps.Parameters = validParams + *op = &newOp.Operation + } else { + *op = &operation.Operation + } + + if routeProperties.Deprecated { + (*op).Deprecated = routeProperties.Deprecated + } + + parser.swagger.Paths.Paths[routeProperties.Path] = pathItem + } + + return nil +} + +func convertFromSpecificToPrimitive(typeName string) (string, error) { + name := typeName + if strings.ContainsRune(name, '.') { + name = strings.Split(name, ".")[1] + } + + switch strings.ToUpper(name) { + case "TIME", "OBJECTID", "UUID": + return STRING, nil + case "DECIMAL": + return NUMBER, nil + } + + return typeName, ErrFailedConvertPrimitiveType +} + +func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) (*spec.Schema, error) { + if override, ok := parser.Overrides[typeName]; ok { + parser.debug.Printf("Override detected for %s: using %s instead", typeName, override) + return parseObjectSchema(parser, override, file) + } + + if IsInterfaceLike(typeName) { + return &spec.Schema{}, nil + } + if IsGolangPrimitiveType(typeName) { + return TransToValidPrimitiveSchema(typeName), nil + } + + schemaType, err := convertFromSpecificToPrimitive(typeName) + if err == nil { + return PrimitiveSchema(schemaType), nil + } + + typeSpecDef := parser.packages.FindTypeSpec(typeName, file) + if typeSpecDef == nil { + return nil, fmt.Errorf("cannot find type definition: %s", typeName) + } + + if override, ok := parser.Overrides[typeSpecDef.FullPath()]; ok { + if override == "" { + parser.debug.Printf("Override detected for %s: ignoring", typeSpecDef.FullPath()) + + return nil, ErrSkippedField + } + + parser.debug.Printf("Override detected for %s: using %s instead", typeSpecDef.FullPath(), override) + + separator := strings.LastIndex(override, ".") + if separator == -1 { + // treat as a swaggertype tag + parts := strings.Split(override, ",") + + return BuildCustomSchema(parts) + } + + typeSpecDef = parser.packages.findTypeSpec(override[0:separator], override[separator+1:]) + } + + schema, ok := parser.parsedSchemas[typeSpecDef] + if !ok { + var err error + + schema, err = parser.ParseDefinition(typeSpecDef) + if err != nil { + if err == ErrRecursiveParseStruct && ref { + return parser.getRefTypeSchema(typeSpecDef, schema), nil + } + return nil, fmt.Errorf("%s: %w", typeName, err) + } + } + + if ref { + if IsComplexSchema(schema.Schema) { + return parser.getRefTypeSchema(typeSpecDef, schema), nil + } + // if it is a simple schema, just return a copy + newSchema := *schema.Schema + return &newSchema, nil + } + + return schema.Schema, nil +} + +func (parser *Parser) getRefTypeSchema(typeSpecDef *TypeSpecDef, schema *Schema) *spec.Schema { + _, ok := parser.outputSchemas[typeSpecDef] + if !ok { + parser.swagger.Definitions[schema.Name] = spec.Schema{} + + if schema.Schema != nil { + parser.swagger.Definitions[schema.Name] = *schema.Schema + } + + parser.outputSchemas[typeSpecDef] = schema + } + + refSchema := RefSchema(schema.Name) + + return refSchema +} + +func (parser *Parser) isInStructStack(typeSpecDef *TypeSpecDef) bool { + for _, specDef := range parser.structStack { + if typeSpecDef == specDef { + return true + } + } + + return false +} + +// ParseDefinition parses given type spec that corresponds to the type under +// given name and package, and populates swagger schema definitions registry +// with a schema for the given type +func (parser *Parser) ParseDefinition(typeSpecDef *TypeSpecDef) (*Schema, error) { + typeName := typeSpecDef.TypeName() + schema, found := parser.parsedSchemas[typeSpecDef] + if found { + parser.debug.Printf("Skipping '%s', already parsed.", typeName) + + return schema, nil + } + + if parser.isInStructStack(typeSpecDef) { + parser.debug.Printf("Skipping '%s', recursion detected.", typeName) + + return &Schema{ + Name: typeSpecDef.SchemaName, + PkgPath: typeSpecDef.PkgPath, + Schema: PrimitiveSchema(OBJECT), + }, + ErrRecursiveParseStruct + } + + parser.structStack = append(parser.structStack, typeSpecDef) + + parser.debug.Printf("Generating %s", typeName) + + definition, err := parser.parseTypeExpr(typeSpecDef.File, typeSpecDef.TypeSpec.Type, false) + if err != nil { + parser.debug.Printf("Error parsing type definition '%s': %s", typeName, err) + return nil, err + } + + if definition.Description == "" { + err = parser.fillDefinitionDescription(definition, typeSpecDef.File, typeSpecDef) + if err != nil { + return nil, err + } + } + + if len(typeSpecDef.Enums) > 0 { + var varnames []string + var enumComments = make(map[string]string) + var enumDescriptions = make([]string, 0, len(typeSpecDef.Enums)) + for _, value := range typeSpecDef.Enums { + definition.Enum = append(definition.Enum, value.Value) + varnames = append(varnames, value.key) + if len(value.Comment) > 0 { + enumComments[value.key] = value.Comment + enumDescriptions = append(enumDescriptions, value.Comment) + } + } + if definition.Extensions == nil { + definition.Extensions = make(spec.Extensions) + } + definition.Extensions[enumVarNamesExtension] = varnames + if len(enumComments) > 0 { + definition.Extensions[enumCommentsExtension] = enumComments + definition.Extensions[enumDescriptionsExtension] = enumDescriptions + } + } + + schemaName := typeName + + if typeSpecDef.SchemaName != "" { + schemaName = typeSpecDef.SchemaName + } + + sch := Schema{ + Name: schemaName, + PkgPath: typeSpecDef.PkgPath, + Schema: definition, + } + parser.parsedSchemas[typeSpecDef] = &sch + + // update an empty schema as a result of recursion + s2, found := parser.outputSchemas[typeSpecDef] + if found { + parser.swagger.Definitions[s2.Name] = *definition + } + + return &sch, nil +} + +func fullTypeName(parts ...string) string { + return strings.Join(parts, ".") +} + +// fillDefinitionDescription additionally fills fields in definition (spec.Schema) +// TODO: If .go file contains many types, it may work for a long time +func (parser *Parser) fillDefinitionDescription(definition *spec.Schema, file *ast.File, typeSpecDef *TypeSpecDef) (err error) { + if file == nil { + return + } + for _, astDeclaration := range file.Decls { + generalDeclaration, ok := astDeclaration.(*ast.GenDecl) + if !ok || generalDeclaration.Tok != token.TYPE { + continue + } + + for _, astSpec := range generalDeclaration.Specs { + typeSpec, ok := astSpec.(*ast.TypeSpec) + if !ok || typeSpec != typeSpecDef.TypeSpec { + continue + } + var typeName string + if typeSpec.Name != nil { + typeName = typeSpec.Name.Name + } + definition.Description, err = + parser.extractDeclarationDescription(typeName, typeSpec.Doc, typeSpec.Comment, generalDeclaration.Doc) + if err != nil { + return + } + } + } + return nil +} + +// extractDeclarationDescription gets first description +// from attribute descriptionAttr in commentGroups (ast.CommentGroup) +func (parser *Parser) extractDeclarationDescription(typeName string, commentGroups ...*ast.CommentGroup) (string, error) { + var description string + + for _, commentGroup := range commentGroups { + if commentGroup == nil { + continue + } + + isHandlingDescription := false + + for _, comment := range commentGroup.List { + commentText := strings.TrimSpace(strings.TrimLeft(comment.Text, "/")) + if len(commentText) == 0 { + continue + } + fields := FieldsByAnySpace(commentText, 2) + attribute := fields[0] + + if attr := strings.ToLower(attribute); attr == descriptionMarkdownAttr { + if len(fields) > 1 { + typeName = fields[1] + } + if typeName == "" { + continue + } + desc, err := getMarkdownForTag(typeName, parser.markdownFileDir) + if err != nil { + return "", err + } + // if found markdown description, we will only use the markdown file content + return string(desc), nil + } else if attr != descriptionAttr { + if !isHandlingDescription { + continue + } + + break + } + + isHandlingDescription = true + description += " " + strings.TrimSpace(commentText[len(attribute):]) + } + } + + return strings.TrimLeft(description, " "), nil +} + +// parseTypeExpr parses given type expression that corresponds to the type under +// given name and package, and returns swagger schema for it. +func (parser *Parser) parseTypeExpr(file *ast.File, typeExpr ast.Expr, ref bool) (*spec.Schema, error) { + switch expr := typeExpr.(type) { + // type Foo interface{} + case *ast.InterfaceType: + return &spec.Schema{}, nil + + // type Foo struct {...} + case *ast.StructType: + return parser.parseStruct(file, expr.Fields) + + // type Foo Baz + case *ast.Ident: + return parser.getTypeSchema(expr.Name, file, ref) + + // type Foo *Baz + case *ast.StarExpr: + return parser.parseTypeExpr(file, expr.X, ref) + + // type Foo pkg.Bar + case *ast.SelectorExpr: + if xIdent, ok := expr.X.(*ast.Ident); ok { + return parser.getTypeSchema(fullTypeName(xIdent.Name, expr.Sel.Name), file, ref) + } + // type Foo []Baz + case *ast.ArrayType: + itemSchema, err := parser.parseTypeExpr(file, expr.Elt, true) + if err != nil { + return nil, err + } + + return spec.ArrayProperty(itemSchema), nil + // type Foo map[string]Bar + case *ast.MapType: + if _, ok := expr.Value.(*ast.InterfaceType); ok { + return spec.MapProperty(nil), nil + } + schema, err := parser.parseTypeExpr(file, expr.Value, true) + if err != nil { + return nil, err + } + + return spec.MapProperty(schema), nil + + case *ast.FuncType: + return nil, ErrFuncTypeField + // ... + } + + return parser.parseGenericTypeExpr(file, typeExpr) +} + +func (parser *Parser) parseStruct(file *ast.File, fields *ast.FieldList) (*spec.Schema, error) { + required, properties := make([]string, 0), make(map[string]spec.Schema) + + for _, field := range fields.List { + fieldProps, requiredFromAnon, err := parser.parseStructField(file, field) + if err != nil { + if errors.Is(err, ErrFuncTypeField) || errors.Is(err, ErrSkippedField) { + continue + } + + return nil, err + } + + if len(fieldProps) == 0 { + continue + } + + required = append(required, requiredFromAnon...) + + for k, v := range fieldProps { + properties[k] = v + } + } + + sort.Strings(required) + + return &spec.Schema{ + SchemaProps: spec.SchemaProps{ + Type: []string{OBJECT}, + Properties: properties, + Required: required, + }, + }, nil +} + +func (parser *Parser) parseStructField(file *ast.File, field *ast.Field) (map[string]spec.Schema, []string, error) { + if field.Tag != nil { + skip, ok := reflect.StructTag(strings.ReplaceAll(field.Tag.Value, "`", "")).Lookup("swaggerignore") + if ok && strings.EqualFold(skip, "true") { + return nil, nil, nil + } + } + + ps := parser.fieldParserFactory(parser, field) + + if ps.ShouldSkip() { + return nil, nil, nil + } + + fieldNames, err := ps.FieldNames() + if err != nil { + return nil, nil, err + } + + if len(fieldNames) == 0 { + typeName, err := getFieldType(file, field.Type, nil) + if err != nil { + return nil, nil, err + } + + schema, err := parser.getTypeSchema(typeName, file, false) + if err != nil { + return nil, nil, err + } + + if len(schema.Type) > 0 && schema.Type[0] == OBJECT { + if len(schema.Properties) == 0 { + return nil, nil, nil + } + + properties := map[string]spec.Schema{} + for k, v := range schema.Properties { + properties[k] = v + } + + return properties, schema.SchemaProps.Required, nil + } + // for alias type of non-struct types ,such as array,map, etc. ignore field tag. + return map[string]spec.Schema{typeName: *schema}, nil, nil + + } + + schema, err := ps.CustomSchema() + if err != nil { + return nil, nil, fmt.Errorf("%v: %w", fieldNames, err) + } + + if schema == nil { + typeName, err := getFieldType(file, field.Type, nil) + if err == nil { + // named type + schema, err = parser.getTypeSchema(typeName, file, true) + } else { + // unnamed type + schema, err = parser.parseTypeExpr(file, field.Type, false) + } + + if err != nil { + return nil, nil, fmt.Errorf("%v: %w", fieldNames, err) + } + } + + err = ps.ComplementSchema(schema) + if err != nil { + return nil, nil, fmt.Errorf("%v: %w", fieldNames, err) + } + + var tagRequired []string + + required, err := ps.IsRequired() + if err != nil { + return nil, nil, fmt.Errorf("%v: %w", fieldNames, err) + } + + if required { + tagRequired = append(tagRequired, fieldNames...) + } + + if schema.Extensions == nil { + schema.Extensions = make(spec.Extensions) + } + if formName := ps.FormName(); len(formName) > 0 { + schema.Extensions["formData"] = formName + } + if headerName := ps.HeaderName(); len(headerName) > 0 { + schema.Extensions["header"] = headerName + } + if pathName := ps.PathName(); len(pathName) > 0 { + schema.Extensions["path"] = pathName + } + fields := make(map[string]spec.Schema) + for _, name := range fieldNames { + fields[name] = *schema + } + return fields, tagRequired, nil +} + +func getFieldType(file *ast.File, field ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec) (string, error) { + switch fieldType := field.(type) { + case *ast.Ident: + return fieldType.Name, nil + case *ast.SelectorExpr: + packageName, err := getFieldType(file, fieldType.X, genericParamTypeDefs) + if err != nil { + return "", err + } + + return fullTypeName(packageName, fieldType.Sel.Name), nil + case *ast.StarExpr: + fullName, err := getFieldType(file, fieldType.X, genericParamTypeDefs) + if err != nil { + return "", err + } + + return fullName, nil + default: + return getGenericFieldType(file, field, genericParamTypeDefs) + } +} + +func (parser *Parser) getUnderlyingSchema(schema *spec.Schema) *spec.Schema { + if schema == nil { + return nil + } + + if url := schema.Ref.GetURL(); url != nil { + if pos := strings.LastIndexByte(url.Fragment, '/'); pos >= 0 { + name := url.Fragment[pos+1:] + if schema, ok := parser.swagger.Definitions[name]; ok { + return &schema + } + } + } + + if len(schema.AllOf) > 0 { + merged := &spec.Schema{} + MergeSchema(merged, schema) + for _, s := range schema.AllOf { + MergeSchema(merged, parser.getUnderlyingSchema(&s)) + } + return merged + } + return nil +} + +// GetSchemaTypePath get path of schema type. +func (parser *Parser) GetSchemaTypePath(schema *spec.Schema, depth int) []string { + if schema == nil || depth == 0 { + return nil + } + + if underlying := parser.getUnderlyingSchema(schema); underlying != nil { + return parser.GetSchemaTypePath(underlying, depth) + } + + if len(schema.Type) > 0 { + switch schema.Type[0] { + case ARRAY: + depth-- + + s := []string{schema.Type[0]} + + return append(s, parser.GetSchemaTypePath(schema.Items.Schema, depth)...) + case OBJECT: + if schema.AdditionalProperties != nil && schema.AdditionalProperties.Schema != nil { + // for map + depth-- + + s := []string{schema.Type[0]} + + return append(s, parser.GetSchemaTypePath(schema.AdditionalProperties.Schema, depth)...) + } + } + + return []string{schema.Type[0]} + } + + return []string{ANY} +} + +// defineTypeOfExample example value define the type (object and array unsupported). +func defineTypeOfExample(schemaType, arrayType, exampleValue string) (interface{}, error) { + switch schemaType { + case STRING: + return exampleValue, nil + case NUMBER: + v, err := strconv.ParseFloat(exampleValue, 64) + if err != nil { + return nil, fmt.Errorf("example value %s can't convert to %s err: %s", exampleValue, schemaType, err) + } + + return v, nil + case INTEGER: + v, err := strconv.Atoi(exampleValue) + if err != nil { + return nil, fmt.Errorf("example value %s can't convert to %s err: %s", exampleValue, schemaType, err) + } + + return v, nil + case BOOLEAN: + v, err := strconv.ParseBool(exampleValue) + if err != nil { + return nil, fmt.Errorf("example value %s can't convert to %s err: %s", exampleValue, schemaType, err) + } + + return v, nil + case ARRAY: + values := strings.Split(exampleValue, ",") + result := make([]interface{}, 0) + for _, value := range values { + v, err := defineTypeOfExample(arrayType, "", value) + if err != nil { + return nil, err + } + + result = append(result, v) + } + + return result, nil + case OBJECT: + if arrayType == "" { + return nil, fmt.Errorf("%s is unsupported type in example value `%s`", schemaType, exampleValue) + } + + values := strings.Split(exampleValue, ",") + + result := map[string]interface{}{} + + for _, value := range values { + mapData := strings.SplitN(value, ":", 2) + + if len(mapData) == 2 { + v, err := defineTypeOfExample(arrayType, "", mapData[1]) + if err != nil { + return nil, err + } + + result[mapData[0]] = v + + continue + } + + return nil, fmt.Errorf("example value %s should format: key:value", exampleValue) + } + + return result, nil + } + + return nil, fmt.Errorf("%s is unsupported type in example value %s", schemaType, exampleValue) +} + +// GetAllGoFileInfo gets all Go source files information for given searchDir. +func (parser *Parser) getAllGoFileInfo(packageDir, searchDir string) error { + if parser.skipPackageByPrefix(packageDir) { + return nil // ignored by user-defined package path prefixes + } + return filepath.Walk(searchDir, func(path string, f os.FileInfo, _ error) error { + err := parser.Skip(path, f) + if err != nil { + return err + } + + if f.IsDir() { + return nil + } + + relPath, err := filepath.Rel(searchDir, path) + if err != nil { + return err + } + + return parser.parseFile(filepath.ToSlash(filepath.Dir(filepath.Clean(filepath.Join(packageDir, relPath)))), path, nil, ParseAll) + }) +} + +func (parser *Parser) getAllGoFileInfoFromDeps(pkg *depth.Pkg, parseFlag ParseFlag) error { + ignoreInternal := pkg.Internal && !parser.ParseInternal + if ignoreInternal || !pkg.Resolved { // ignored internal and not resolved dependencies + return nil + } + + if pkg.Raw != nil && parser.skipPackageByPrefix(pkg.Raw.ImportPath) { + return nil // ignored by user-defined package path prefixes + } + + // Skip cgo + if pkg.Raw == nil && pkg.Name == "C" { + return nil + } + + srcDir := pkg.Raw.Dir + + files, err := os.ReadDir(srcDir) // only parsing files in the dir(don't contain sub dir files) + if err != nil { + return err + } + + for _, f := range files { + if f.IsDir() { + continue + } + + path := filepath.Join(srcDir, f.Name()) + if err := parser.parseFile(pkg.Name, path, nil, parseFlag); err != nil { + return err + } + } + + for i := 0; i < len(pkg.Deps); i++ { + if err := parser.getAllGoFileInfoFromDeps(&pkg.Deps[i], parseFlag); err != nil { + return err + } + } + + return nil +} + +func (parser *Parser) parseFile(packageDir, path string, src interface{}, flag ParseFlag) error { + if strings.HasSuffix(strings.ToLower(path), "_test.go") || filepath.Ext(path) != ".go" { + return nil + } + + return parser.packages.ParseFile(packageDir, path, src, flag) +} + +func (parser *Parser) checkOperationIDUniqueness() error { + // operationsIds contains all operationId annotations to check it's unique + operationsIds := make(map[string]string) + + for path, item := range parser.swagger.Paths.Paths { + var method, id string + + for method = range allMethod { + op := refRouteMethodOp(&item, method) + if *op != nil { + id = (**op).ID + + break + } + } + + if id == "" { + continue + } + + current := fmt.Sprintf("%s %s", method, path) + + previous, ok := operationsIds[id] + if ok { + return fmt.Errorf( + "duplicated @id annotation '%s' found in '%s', previously declared in: '%s'", + id, current, previous) + } + + operationsIds[id] = current + } + + return nil +} + +// Skip returns filepath.SkipDir error if match vendor and hidden folder. +func (parser *Parser) Skip(path string, f os.FileInfo) error { + return walkWith(parser.excludes, parser.ParseVendor)(path, f) +} + +func walkWith(excludes map[string]struct{}, parseVendor bool) func(path string, fileInfo os.FileInfo) error { + return func(path string, f os.FileInfo) error { + if f.IsDir() { + if !parseVendor && f.Name() == "vendor" || // ignore "vendor" + f.Name() == "docs" || // exclude docs + len(f.Name()) > 1 && f.Name()[0] == '.' && f.Name() != ".." { // exclude all hidden folder + return filepath.SkipDir + } + + if excludes != nil { + if _, ok := excludes[path]; ok { + return filepath.SkipDir + } + } + } + + return nil + } +} + +// GetSwagger returns *spec.Swagger which is the root document object for the API specification. +func (parser *Parser) GetSwagger() *spec.Swagger { + return parser.swagger +} + +// addTestType just for tests. +func (parser *Parser) addTestType(typename string) { + typeDef := &TypeSpecDef{} + parser.packages.uniqueDefinitions[typename] = typeDef + parser.parsedSchemas[typeDef] = &Schema{ + PkgPath: "", + Name: typename, + Schema: PrimitiveSchema(OBJECT), + } +} diff --git a/pkg/swag/parser_test.go b/pkg/swag/parser_test.go new file mode 100644 index 0000000..e986b92 --- /dev/null +++ b/pkg/swag/parser_test.go @@ -0,0 +1,4436 @@ +package swag + +import ( + "bytes" + "encoding/json" + "errors" + "go/ast" + goparser "go/parser" + "go/token" + "log" + "os" + "path/filepath" + "reflect" + "strings" + "testing" + + "github.com/go-openapi/spec" + "github.com/stretchr/testify/assert" +) + +const defaultParseDepth = 100 + +const mainAPIFile = "main.go" + +func TestNew(t *testing.T) { + t.Run("SetMarkdownFileDirectory", func(t *testing.T) { + t.Parallel() + + expected := "docs/markdown" + p := New(SetMarkdownFileDirectory(expected)) + assert.Equal(t, expected, p.markdownFileDir) + }) + + t.Run("SetCodeExamplesDirectory", func(t *testing.T) { + t.Parallel() + + expected := "docs/examples" + p := New(SetCodeExamplesDirectory(expected)) + assert.Equal(t, expected, p.codeExampleFilesDir) + }) + + t.Run("SetStrict", func(t *testing.T) { + t.Parallel() + + p := New() + assert.Equal(t, false, p.Strict) + + p = New(SetStrict(true)) + assert.Equal(t, true, p.Strict) + }) + + t.Run("SetDebugger", func(t *testing.T) { + t.Parallel() + + logger := log.New(&bytes.Buffer{}, "", log.LstdFlags) + + p := New(SetDebugger(logger)) + assert.Equal(t, logger, p.debug) + }) + + t.Run("SetFieldParserFactory", func(t *testing.T) { + t.Parallel() + + p := New(SetFieldParserFactory(nil)) + assert.Nil(t, p.fieldParserFactory) + }) +} + +func TestSetOverrides(t *testing.T) { + t.Parallel() + + overrides := map[string]string{ + "foo": "bar", + } + + p := New(SetOverrides(overrides)) + assert.Equal(t, overrides, p.Overrides) +} + +func TestOverrides_getTypeSchema(t *testing.T) { + t.Parallel() + + overrides := map[string]string{ + "sql.NullString": "string", + } + + p := New(SetOverrides(overrides)) + + t.Run("Override sql.NullString by string", func(t *testing.T) { + t.Parallel() + + s, err := p.getTypeSchema("sql.NullString", nil, false) + if assert.NoError(t, err) { + assert.Truef(t, s.Type.Contains("string"), "type sql.NullString should be overridden by string") + } + }) + + t.Run("Missing Override for sql.NullInt64", func(t *testing.T) { + t.Parallel() + + _, err := p.getTypeSchema("sql.NullInt64", nil, false) + if assert.Error(t, err) { + assert.Equal(t, "cannot find type definition: sql.NullInt64", err.Error()) + } + }) +} + +func TestParser_ParseDefinition(t *testing.T) { + p := New() + + // Parsing existing type + definition := &TypeSpecDef{ + PkgPath: "github.com/swagger/swag", + File: &ast.File{ + Name: &ast.Ident{ + Name: "swag", + }, + }, + TypeSpec: &ast.TypeSpec{ + Name: &ast.Ident{ + Name: "Test", + }, + }, + } + + expected := &Schema{} + p.parsedSchemas[definition] = expected + + schema, err := p.ParseDefinition(definition) + assert.NoError(t, err) + assert.Equal(t, expected, schema) + + // Parsing *ast.FuncType + definition = &TypeSpecDef{ + PkgPath: "github.com/swagger/swag/model", + File: &ast.File{ + Name: &ast.Ident{ + Name: "model", + }, + }, + TypeSpec: &ast.TypeSpec{ + Name: &ast.Ident{ + Name: "Test", + }, + Type: &ast.FuncType{}, + }, + } + _, err = p.ParseDefinition(definition) + assert.Error(t, err) + + // Parsing *ast.FuncType with parent spec + definition = &TypeSpecDef{ + PkgPath: "github.com/swagger/swag/model", + File: &ast.File{ + Name: &ast.Ident{ + Name: "model", + }, + }, + TypeSpec: &ast.TypeSpec{ + Name: &ast.Ident{ + Name: "Test", + }, + Type: &ast.FuncType{}, + }, + ParentSpec: &ast.FuncDecl{ + Name: ast.NewIdent("TestFuncDecl"), + }, + } + _, err = p.ParseDefinition(definition) + assert.Error(t, err) + assert.Equal(t, "model.TestFuncDecl.Test", definition.TypeName()) +} + +func TestParser_ParseGeneralApiInfo(t *testing.T) { + t.Parallel() + + expected := `{ + "schemes": [ + "http", + "https" + ], + "swagger": "2.0", + "info": { + "description": "This is a sample server Petstore server.\nIt has a lot of beautiful features.", + "title": "Swagger Example API", + "termsOfService": "http://swagger.io/terms/", + "contact": { + "name": "API Support", + "url": "http://www.swagger.io/support", + "email": "support@swagger.io" + }, + "license": { + "name": "Apache 2.0", + "url": "http://www.apache.org/licenses/LICENSE-2.0.html" + }, + "version": "1.0", + "x-logo": { + "altText": "Petstore logo", + "backgroundColor": "#FFFFFF", + "url": "https://redocly.github.io/redoc/petstore-logo.png" + } + }, + "host": "petstore.swagger.io", + "basePath": "/v2", + "paths": {}, + "securityDefinitions": { + "ApiKeyAuth": { + "description": "some description", + "type": "apiKey", + "name": "Authorization", + "in": "header" + }, + "BasicAuth": { + "type": "basic" + }, + "OAuth2AccessCode": { + "type": "oauth2", + "flow": "accessCode", + "authorizationUrl": "https://example.com/oauth/authorize", + "tokenUrl": "https://example.com/oauth/token", + "scopes": { + "admin": "Grants read and write access to administrative information" + }, + "x-tokenname": "id_token" + }, + "OAuth2Application": { + "type": "oauth2", + "flow": "application", + "tokenUrl": "https://example.com/oauth/token", + "scopes": { + "admin": "Grants read and write access to administrative information", + "write": "Grants write access" + } + }, + "OAuth2Implicit": { + "type": "oauth2", + "flow": "implicit", + "authorizationUrl": "https://example.com/oauth/authorize", + "scopes": { + "admin": "Grants read and write access to administrative information", + "write": "Grants write access" + }, + "x-google-audiences": "some_audience.google.com" + }, + "OAuth2Password": { + "type": "oauth2", + "flow": "password", + "tokenUrl": "https://example.com/oauth/token", + "scopes": { + "admin": "Grants read and write access to administrative information", + "read": "Grants read access", + "write": "Grants write access" + } + } + }, + "externalDocs": { + "description": "OpenAPI", + "url": "https://swagger.io/resources/open-api" + }, + "x-google-endpoints": [ + { + "allowCors": true, + "name": "name.endpoints.environment.cloud.goog" + } + ], + "x-google-marks": "marks values" +}` + gopath := os.Getenv("GOPATH") + assert.NotNil(t, gopath) + + p := New() + + err := p.ParseGeneralAPIInfo("testdata/main.go") + assert.NoError(t, err) + + b, _ := json.MarshalIndent(p.swagger, "", " ") + assert.Equal(t, expected, string(b)) +} + +func TestParser_ParseGeneralApiInfoTemplated(t *testing.T) { + t.Parallel() + + expected := `{ + "swagger": "2.0", + "info": { + "termsOfService": "http://swagger.io/terms/", + "contact": { + "name": "API Support", + "url": "http://www.swagger.io/support", + "email": "support@swagger.io" + }, + "license": { + "name": "Apache 2.0", + "url": "http://www.apache.org/licenses/LICENSE-2.0.html" + } + }, + "paths": {}, + "securityDefinitions": { + "ApiKeyAuth": { + "type": "apiKey", + "name": "Authorization", + "in": "header" + }, + "BasicAuth": { + "type": "basic" + }, + "OAuth2AccessCode": { + "type": "oauth2", + "flow": "accessCode", + "authorizationUrl": "https://example.com/oauth/authorize", + "tokenUrl": "https://example.com/oauth/token", + "scopes": { + "admin": "Grants read and write access to administrative information" + } + }, + "OAuth2Application": { + "type": "oauth2", + "flow": "application", + "tokenUrl": "https://example.com/oauth/token", + "scopes": { + "admin": "Grants read and write access to administrative information", + "write": "Grants write access" + } + }, + "OAuth2Implicit": { + "type": "oauth2", + "flow": "implicit", + "authorizationUrl": "https://example.com/oauth/authorize", + "scopes": { + "admin": "Grants read and write access to administrative information", + "write": "Grants write access" + } + }, + "OAuth2Password": { + "type": "oauth2", + "flow": "password", + "tokenUrl": "https://example.com/oauth/token", + "scopes": { + "admin": "Grants read and write access to administrative information", + "read": "Grants read access", + "write": "Grants write access" + } + } + }, + "externalDocs": { + "description": "OpenAPI", + "url": "https://swagger.io/resources/open-api" + }, + "x-google-endpoints": [ + { + "allowCors": true, + "name": "name.endpoints.environment.cloud.goog" + } + ], + "x-google-marks": "marks values" +}` + gopath := os.Getenv("GOPATH") + assert.NotNil(t, gopath) + + p := New() + + err := p.ParseGeneralAPIInfo("testdata/templated.go") + assert.NoError(t, err) + + b, _ := json.MarshalIndent(p.swagger, "", " ") + assert.Equal(t, expected, string(b)) +} + +func TestParser_ParseGeneralApiInfoExtensions(t *testing.T) { + // should return an error because extension value is not a valid json + t.Run("Test invalid extension value", func(t *testing.T) { + t.Parallel() + + expected := "annotation @x-google-endpoints need a valid json value" + gopath := os.Getenv("GOPATH") + assert.NotNil(t, gopath) + + p := New() + + err := p.ParseGeneralAPIInfo("testdata/extensionsFail1.go") + if assert.Error(t, err) { + assert.Equal(t, expected, err.Error()) + } + }) + + // should return an error because extension don't have a value + t.Run("Test missing extension value", func(t *testing.T) { + t.Parallel() + + expected := "annotation @x-google-endpoints need a value" + gopath := os.Getenv("GOPATH") + assert.NotNil(t, gopath) + + p := New() + + err := p.ParseGeneralAPIInfo("testdata/extensionsFail2.go") + if assert.Error(t, err) { + assert.Equal(t, expected, err.Error()) + } + }) +} + +func TestParser_ParseGeneralApiInfoWithOpsInSameFile(t *testing.T) { + t.Parallel() + + expected := `{ + "swagger": "2.0", + "info": { + "description": "This is a sample server Petstore server.\nIt has a lot of beautiful features.", + "title": "Swagger Example API", + "termsOfService": "http://swagger.io/terms/", + "contact": {}, + "version": "1.0" + }, + "paths": {} +}` + + gopath := os.Getenv("GOPATH") + assert.NotNil(t, gopath) + + p := New() + + err := p.ParseGeneralAPIInfo("testdata/single_file_api/main.go") + assert.NoError(t, err) + + b, _ := json.MarshalIndent(p.swagger, "", " ") + assert.Equal(t, expected, string(b)) +} + +func TestParser_ParseGeneralAPIInfoMarkdown(t *testing.T) { + t.Parallel() + + p := New(SetMarkdownFileDirectory("testdata")) + mainAPIFile := "testdata/markdown.go" + err := p.ParseGeneralAPIInfo(mainAPIFile) + assert.NoError(t, err) + + expected := `{ + "swagger": "2.0", + "info": { + "description": "Swagger Example API Markdown Description", + "title": "Swagger Example API", + "termsOfService": "http://swagger.io/terms/", + "contact": {}, + "version": "1.0" + }, + "paths": {}, + "tags": [ + { + "description": "Users Tag Markdown Description", + "name": "users" + } + ] +}` + b, _ := json.MarshalIndent(p.swagger, "", " ") + assert.Equal(t, expected, string(b)) + + p = New() + + err = p.ParseGeneralAPIInfo(mainAPIFile) + assert.Error(t, err) +} + +func TestParser_ParseGeneralApiInfoFailed(t *testing.T) { + t.Parallel() + + gopath := os.Getenv("GOPATH") + assert.NotNil(t, gopath) + p := New() + assert.Error(t, p.ParseGeneralAPIInfo("testdata/noexist.go")) +} + +func TestParser_ParseAcceptComment(t *testing.T) { + t.Parallel() + + expected := []string{ + "application/json", + "text/xml", + "text/plain", + "text/html", + "multipart/form-data", + "application/x-www-form-urlencoded", + "application/vnd.api+json", + "application/x-json-stream", + "application/octet-stream", + "image/png", + "image/jpeg", + "image/gif", + "application/xhtml+xml", + "application/health+json", + } + + comment := `@Accept json,xml,plain,html,mpfd,x-www-form-urlencoded,json-api,json-stream,octet-stream,png,jpeg,gif,application/xhtml+xml,application/health+json` + + parser := New() + assert.NoError(t, parseGeneralAPIInfo(parser, []string{comment})) + assert.Equal(t, parser.swagger.Consumes, expected) + + assert.Error(t, parseGeneralAPIInfo(parser, []string{`@Accept cookies,candies`})) + + parser = New() + assert.NoError(t, parser.ParseAcceptComment(comment[len(acceptAttr)+1:])) + assert.Equal(t, parser.swagger.Consumes, expected) +} + +func TestParser_ParseProduceComment(t *testing.T) { + t.Parallel() + + expected := []string{ + "application/json", + "text/xml", + "text/plain", + "text/html", + "multipart/form-data", + "application/x-www-form-urlencoded", + "application/vnd.api+json", + "application/x-json-stream", + "application/octet-stream", + "image/png", + "image/jpeg", + "image/gif", + "application/xhtml+xml", + "application/health+json", + } + + comment := `@Produce json,xml,plain,html,mpfd,x-www-form-urlencoded,json-api,json-stream,octet-stream,png,jpeg,gif,application/xhtml+xml,application/health+json` + + parser := New() + assert.NoError(t, parseGeneralAPIInfo(parser, []string{comment})) + assert.Equal(t, parser.swagger.Produces, expected) + + assert.Error(t, parseGeneralAPIInfo(parser, []string{`@Produce cookies,candies`})) + + parser = New() + assert.NoError(t, parser.ParseProduceComment(comment[len(produceAttr)+1:])) + assert.Equal(t, parser.swagger.Produces, expected) +} + +func TestParser_ParseGeneralAPIInfoCollectionFormat(t *testing.T) { + t.Parallel() + + parser := New() + assert.NoError(t, parseGeneralAPIInfo(parser, []string{ + "@query.collection.format csv", + })) + assert.Equal(t, parser.collectionFormatInQuery, "csv") + + assert.NoError(t, parseGeneralAPIInfo(parser, []string{ + "@query.collection.format tsv", + })) + assert.Equal(t, parser.collectionFormatInQuery, "tsv") +} + +func TestParser_ParseGeneralAPITagGroups(t *testing.T) { + t.Parallel() + + parser := New() + assert.NoError(t, parseGeneralAPIInfo(parser, []string{ + "@x-tagGroups [{\"name\":\"General\",\"tags\":[\"lanes\",\"video-recommendations\"]}]", + })) + + expected := []interface{}{map[string]interface{}{"name": "General", "tags": []interface{}{"lanes", "video-recommendations"}}} + assert.Equal(t, parser.swagger.Extensions["x-tagGroups"], expected) +} + +func TestParser_ParseGeneralAPITagDocs(t *testing.T) { + t.Parallel() + + parser := New() + assert.Error(t, parseGeneralAPIInfo(parser, []string{ + "@tag.name Test", + "@tag.docs.description Best example documentation", + })) + + parser = New() + err := parseGeneralAPIInfo(parser, []string{ + "@tag.name test", + "@tag.description A test Tag", + "@tag.docs.url https://example.com", + "@tag.docs.description Best example documentation", + "@tag.x-displayName Test group", + }) + assert.NoError(t, err) + + b, _ := json.MarshalIndent(parser.GetSwagger().Tags, "", " ") + expected := `[ + { + "description": "A test Tag", + "name": "test", + "externalDocs": { + "description": "Best example documentation", + "url": "https://example.com" + }, + "x-displayName": "Test group" + } +]` + assert.Equal(t, expected, string(b)) +} + +func TestParser_ParseGeneralAPITagDocsWithTagFilters(t *testing.T) { + t.Parallel() + + filterTags := []string{"test1", "!test2"} + + comments := []string{ + "@tag.name test1", + "@tag.description A test1 Tag", + "@tag.docs.url https://example1.com", + "@tag.docs.description Best example1 documentation", + "@tag.name test2", + "@tag.description A test2 Tag", + "@tag.docs.url https://example2.com", + "@tag.docs.description Best example2 documentation", + } + + expected := `[ + { + "description": "A test1 Tag", + "name": "test1", + "externalDocs": { + "description": "Best example1 documentation", + "url": "https://example1.com" + } + } +]` + + for _, tag := range filterTags { + parser := New(SetTags(tag)) + err := parseGeneralAPIInfo(parser, comments) + assert.NoError(t, err) + b, _ := json.MarshalIndent(parser.GetSwagger().Tags, "", " ") + assert.Equal(t, expected, string(b)) + } +} + +func TestParser_ParseGeneralAPISecurity(t *testing.T) { + t.Run("ApiKey", func(t *testing.T) { + t.Parallel() + + parser := New() + assert.Error(t, parseGeneralAPIInfo(parser, []string{ + "@securitydefinitions.apikey ApiKey", + })) + + assert.Error(t, parseGeneralAPIInfo(parser, []string{ + "@securitydefinitions.apikey ApiKey", + "@in header", + })) + assert.Error(t, parseGeneralAPIInfo(parser, []string{ + "@securitydefinitions.apikey ApiKey", + "@name X-API-KEY", + })) + + err := parseGeneralAPIInfo(parser, []string{ + "@securitydefinitions.apikey ApiKey", + "@in header", + "@name X-API-KEY", + "@description some", + "", + "@securitydefinitions.oauth2.accessCode OAuth2AccessCode", + "@tokenUrl https://example.com/oauth/token", + "@authorizationUrl https://example.com/oauth/authorize", + "@scope.admin foo", + }) + assert.NoError(t, err) + + b, _ := json.MarshalIndent(parser.GetSwagger().SecurityDefinitions, "", " ") + expected := `{ + "ApiKey": { + "description": "some", + "type": "apiKey", + "name": "X-API-KEY", + "in": "header" + }, + "OAuth2AccessCode": { + "type": "oauth2", + "flow": "accessCode", + "authorizationUrl": "https://example.com/oauth/authorize", + "tokenUrl": "https://example.com/oauth/token", + "scopes": { + "admin": "foo" + } + } +}` + assert.Equal(t, expected, string(b)) + }) + + t.Run("OAuth2Application", func(t *testing.T) { + t.Parallel() + + parser := New() + assert.Error(t, parseGeneralAPIInfo(parser, []string{ + "@securitydefinitions.oauth2.application OAuth2Application", + })) + + err := parseGeneralAPIInfo(parser, []string{ + "@securitydefinitions.oauth2.application OAuth2Application", + "@tokenUrl https://example.com/oauth/token", + }) + assert.NoError(t, err) + b, _ := json.MarshalIndent(parser.GetSwagger().SecurityDefinitions, "", " ") + expected := `{ + "OAuth2Application": { + "type": "oauth2", + "flow": "application", + "tokenUrl": "https://example.com/oauth/token" + } +}` + assert.Equal(t, expected, string(b)) + }) + + t.Run("OAuth2Implicit", func(t *testing.T) { + t.Parallel() + + parser := New() + assert.Error(t, parseGeneralAPIInfo(parser, []string{ + "@securitydefinitions.oauth2.implicit OAuth2Implicit", + })) + + err := parseGeneralAPIInfo(parser, []string{ + "@securitydefinitions.oauth2.implicit OAuth2Implicit", + "@authorizationurl https://example.com/oauth/authorize", + }) + assert.NoError(t, err) + b, _ := json.MarshalIndent(parser.GetSwagger().SecurityDefinitions, "", " ") + expected := `{ + "OAuth2Implicit": { + "type": "oauth2", + "flow": "implicit", + "authorizationUrl": "https://example.com/oauth/authorize" + } +}` + assert.Equal(t, expected, string(b)) + }) + + t.Run("OAuth2Password", func(t *testing.T) { + t.Parallel() + + parser := New() + assert.Error(t, parseGeneralAPIInfo(parser, []string{ + "@securitydefinitions.oauth2.password OAuth2Password", + })) + + err := parseGeneralAPIInfo(parser, []string{ + "@securitydefinitions.oauth2.password OAuth2Password", + "@tokenUrl https://example.com/oauth/token", + }) + assert.NoError(t, err) + b, _ := json.MarshalIndent(parser.GetSwagger().SecurityDefinitions, "", " ") + expected := `{ + "OAuth2Password": { + "type": "oauth2", + "flow": "password", + "tokenUrl": "https://example.com/oauth/token" + } +}` + assert.Equal(t, expected, string(b)) + }) + + t.Run("OAuth2AccessCode", func(t *testing.T) { + t.Parallel() + + parser := New() + assert.Error(t, parseGeneralAPIInfo(parser, []string{ + "@securitydefinitions.oauth2.accessCode OAuth2AccessCode", + })) + + assert.Error(t, parseGeneralAPIInfo(parser, []string{ + "@securitydefinitions.oauth2.accessCode OAuth2AccessCode", + "@tokenUrl https://example.com/oauth/token", + })) + + assert.Error(t, parseGeneralAPIInfo(parser, []string{ + "@securitydefinitions.oauth2.accessCode OAuth2AccessCode", + "@authorizationurl https://example.com/oauth/authorize", + })) + + err := parseGeneralAPIInfo(parser, []string{ + "@securitydefinitions.oauth2.accessCode OAuth2AccessCode", + "@tokenUrl https://example.com/oauth/token", + "@authorizationurl https://example.com/oauth/authorize", + }) + assert.NoError(t, err) + b, _ := json.MarshalIndent(parser.GetSwagger().SecurityDefinitions, "", " ") + expected := `{ + "OAuth2AccessCode": { + "type": "oauth2", + "flow": "accessCode", + "authorizationUrl": "https://example.com/oauth/authorize", + "tokenUrl": "https://example.com/oauth/token" + } +}` + assert.Equal(t, expected, string(b)) + + assert.Error(t, parseGeneralAPIInfo(parser, []string{ + "@securitydefinitions.oauth2.accessCode OAuth2AccessCode", + "@tokenUrl https://example.com/oauth/token", + "@authorizationurl https://example.com/oauth/authorize", + "@scope.read,write Multiple scope", + })) + }) +} + +func TestParser_RefWithOtherPropertiesIsWrappedInAllOf(t *testing.T) { + t.Run("Readonly", func(t *testing.T) { + src := ` +package main + +type Teacher struct { + Name string +} //@name Teacher + +type Student struct { + Name string + Age int ` + "`readonly:\"true\"`" + ` + Teacher Teacher ` + "`readonly:\"true\"`" + ` + OtherTeacher Teacher +} //@name Student + +// @Success 200 {object} Student +// @Router /test [get] +func Fun() { + +} +` + expected := `{ + "info": { + "contact": {} + }, + "paths": { + "/test": { + "get": { + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/Student" + } + } + } + } + } + }, + "definitions": { + "Student": { + "type": "object", + "properties": { + "age": { + "type": "integer", + "readOnly": true + }, + "name": { + "type": "string" + }, + "otherTeacher": { + "$ref": "#/definitions/Teacher" + }, + "teacher": { + "allOf": [ + { + "$ref": "#/definitions/Teacher" + } + ], + "readOnly": true + } + } + }, + "Teacher": { + "type": "object", + "properties": { + "name": { + "type": "string" + } + } + } + } +}` + + p := New() + _ = p.packages.ParseFile("api", "api/api.go", src, ParseAll) + + _, err := p.packages.ParseTypes() + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + b, _ := json.MarshalIndent(p.swagger, "", " ") + assert.Equal(t, expected, string(b)) + }) +} + +func TestGetAllGoFileInfo(t *testing.T) { + t.Parallel() + + searchDir := "testdata/pet" + + p := New() + err := p.getAllGoFileInfo("testdata", searchDir) + + assert.NoError(t, err) + assert.Equal(t, 2, len(p.packages.files)) +} + +func TestParser_ParseType(t *testing.T) { + t.Parallel() + + searchDir := "testdata/simple/" + + p := New() + err := p.getAllGoFileInfo("testdata", searchDir) + assert.NoError(t, err) + + _, err = p.packages.ParseTypes() + + assert.NoError(t, err) + assert.NotNil(t, p.packages.uniqueDefinitions["api.Pet3"]) + assert.NotNil(t, p.packages.uniqueDefinitions["web.Pet"]) + assert.NotNil(t, p.packages.uniqueDefinitions["web.Pet2"]) +} + +func TestParseSimpleApi1(t *testing.T) { + t.Parallel() + + expected, err := os.ReadFile("testdata/simple/expected.json") + assert.NoError(t, err) + searchDir := "testdata/simple" + p := New() + p.PropNamingStrategy = PascalCase + err = p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + + b, _ := json.MarshalIndent(p.swagger, "", " ") + assert.JSONEq(t, string(expected), string(b)) +} + +func TestParseInterfaceAndError(t *testing.T) { + t.Parallel() + + expected, err := os.ReadFile("testdata/error/expected.json") + assert.NoError(t, err) + searchDir := "testdata/error" + p := New() + err = p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + + b, _ := json.MarshalIndent(p.swagger, "", " ") + assert.JSONEq(t, string(expected), string(b)) +} + +func TestParseSimpleApi_ForSnakecase(t *testing.T) { + t.Parallel() + + expected := `{ + "swagger": "2.0", + "info": { + "description": "This is a sample server Petstore server.", + "title": "Swagger Example API", + "termsOfService": "http://swagger.io/terms/", + "contact": { + "name": "API Support", + "url": "http://www.swagger.io/support", + "email": "support@swagger.io" + }, + "license": { + "name": "Apache 2.0", + "url": "http://www.apache.org/licenses/LICENSE-2.0.html" + }, + "version": "1.0" + }, + "host": "petstore.swagger.io", + "basePath": "/v2", + "paths": { + "/file/upload": { + "post": { + "description": "Upload file", + "consumes": [ + "multipart/form-data" + ], + "produces": [ + "application/json" + ], + "summary": "Upload file", + "operationId": "file.upload", + "parameters": [ + { + "type": "file", + "description": "this is a test file", + "name": "file", + "in": "formData", + "required": true + } + ], + "responses": { + "200": { + "description": "ok", + "schema": { + "type": "string" + } + }, + "400": { + "description": "We need ID!!", + "schema": { + "$ref": "#/definitions/web.APIError" + } + }, + "404": { + "description": "Can not find ID", + "schema": { + "$ref": "#/definitions/web.APIError" + } + } + } + } + }, + "/testapi/get-string-by-int/{some_id}": { + "get": { + "description": "get string by ID", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "summary": "Add a new pet to the store", + "operationId": "get-string-by-int", + "parameters": [ + { + "type": "integer", + "format": "int64", + "description": "Some ID", + "name": "some_id", + "in": "path", + "required": true + }, + { + "description": "Some ID", + "name": "some_id", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/web.Pet" + } + } + ], + "responses": { + "200": { + "description": "ok", + "schema": { + "type": "string" + } + }, + "400": { + "description": "We need ID!!", + "schema": { + "$ref": "#/definitions/web.APIError" + } + }, + "404": { + "description": "Can not find ID", + "schema": { + "$ref": "#/definitions/web.APIError" + } + } + } + } + }, + "/testapi/get-struct-array-by-string/{some_id}": { + "get": { + "security": [ + { + "ApiKeyAuth": [] + }, + { + "BasicAuth": [] + }, + { + "OAuth2Application": [ + "write" + ] + }, + { + "OAuth2Implicit": [ + "read", + "admin" + ] + }, + { + "OAuth2AccessCode": [ + "read" + ] + }, + { + "OAuth2Password": [ + "admin" + ] + } + ], + "description": "get struct array by ID", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "operationId": "get-struct-array-by-string", + "parameters": [ + { + "type": "string", + "description": "Some ID", + "name": "some_id", + "in": "path", + "required": true + }, + { + "enum": [ + 1, + 2, + 3 + ], + "type": "integer", + "description": "Category", + "name": "category", + "in": "query", + "required": true + }, + { + "minimum": 0, + "type": "integer", + "default": 0, + "description": "Offset", + "name": "offset", + "in": "query", + "required": true + }, + { + "maximum": 50, + "type": "integer", + "default": 10, + "description": "Limit", + "name": "limit", + "in": "query", + "required": true + }, + { + "maxLength": 50, + "minLength": 1, + "type": "string", + "default": "\"\"", + "description": "q", + "name": "q", + "in": "query", + "required": true + } + ], + "responses": { + "200": { + "description": "ok", + "schema": { + "type": "string" + } + }, + "400": { + "description": "We need ID!!", + "schema": { + "$ref": "#/definitions/web.APIError" + } + }, + "404": { + "description": "Can not find ID", + "schema": { + "$ref": "#/definitions/web.APIError" + } + } + } + } + } + }, + "definitions": { + "web.APIError": { + "type": "object", + "properties": { + "created_at": { + "type": "string" + }, + "error_code": { + "type": "integer" + }, + "error_message": { + "type": "string" + } + } + }, + "web.Pet": { + "type": "object", + "required": [ + "price" + ], + "properties": { + "birthday": { + "type": "integer" + }, + "category": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "example": 1 + }, + "name": { + "type": "string", + "example": "category_name" + }, + "photo_urls": { + "type": "array", + "items": { + "type": "string", + "format": "url" + }, + "example": [ + "http://test/image/1.jpg", + "http://test/image/2.jpg" + ] + }, + "small_category": { + "type": "object", + "required": [ + "name" + ], + "properties": { + "id": { + "type": "integer", + "example": 1 + }, + "name": { + "type": "string", + "example": "detail_category_name" + }, + "photo_urls": { + "type": "array", + "items": { + "type": "string" + }, + "example": [ + "http://test/image/1.jpg", + "http://test/image/2.jpg" + ] + } + } + } + } + }, + "coeffs": { + "type": "array", + "items": { + "type": "number" + } + }, + "custom_string": { + "type": "string" + }, + "custom_string_arr": { + "type": "array", + "items": { + "type": "string" + } + }, + "data": {}, + "decimal": { + "type": "number" + }, + "id": { + "type": "integer", + "format": "int64", + "example": 1 + }, + "is_alive": { + "type": "boolean", + "example": true + }, + "name": { + "type": "string", + "example": "poti" + }, + "null_int": { + "type": "integer" + }, + "pets": { + "type": "array", + "items": { + "$ref": "#/definitions/web.Pet2" + } + }, + "pets2": { + "type": "array", + "items": { + "$ref": "#/definitions/web.Pet2" + } + }, + "photo_urls": { + "type": "array", + "items": { + "type": "string" + }, + "example": [ + "http://test/image/1.jpg", + "http://test/image/2.jpg" + ] + }, + "price": { + "type": "number", + "maximum": 130, + "minimum": 0, + "multipleOf": 0.01, + "example": 3.25 + }, + "status": { + "type": "string" + }, + "tags": { + "type": "array", + "items": { + "$ref": "#/definitions/web.Tag" + } + }, + "uuid": { + "type": "string" + } + } + }, + "web.Pet2": { + "type": "object", + "properties": { + "deleted_at": { + "type": "string" + }, + "id": { + "type": "integer" + }, + "middle_name": { + "type": "string" + } + } + }, + "web.RevValue": { + "type": "object", + "properties": { + "data": { + "type": "integer" + }, + "err": { + "type": "integer", + "format": "int32" + }, + "status": { + "type": "boolean" + } + } + }, + "web.Tag": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "format": "int64" + }, + "name": { + "type": "string" + }, + "pets": { + "type": "array", + "items": { + "$ref": "#/definitions/web.Pet" + } + } + } + } + }, + "securityDefinitions": { + "ApiKeyAuth": { + "type": "apiKey", + "name": "Authorization", + "in": "header" + }, + "BasicAuth": { + "type": "basic" + }, + "OAuth2AccessCode": { + "type": "oauth2", + "flow": "accessCode", + "authorizationUrl": "https://example.com/oauth/authorize", + "tokenUrl": "https://example.com/oauth/token", + "scopes": { + "admin": "Grants read and write access to administrative information" + } + }, + "OAuth2Application": { + "type": "oauth2", + "flow": "application", + "tokenUrl": "https://example.com/oauth/token", + "scopes": { + "admin": "Grants read and write access to administrative information", + "write": "Grants write access" + } + }, + "OAuth2Implicit": { + "type": "oauth2", + "flow": "implicit", + "authorizationUrl": "https://example.com/oauth/authorize", + "scopes": { + "admin": "Grants read and write access to administrative information", + "write": "Grants write access" + } + }, + "OAuth2Password": { + "type": "oauth2", + "flow": "password", + "tokenUrl": "https://example.com/oauth/token", + "scopes": { + "admin": "Grants read and write access to administrative information", + "read": "Grants read access", + "write": "Grants write access" + } + } + } +}` + searchDir := "testdata/simple2" + p := New() + p.PropNamingStrategy = SnakeCase + err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + + b, _ := json.MarshalIndent(p.swagger, "", " ") + assert.Equal(t, expected, string(b)) +} + +func TestParseSimpleApi_ForLowerCamelcase(t *testing.T) { + t.Parallel() + + expected := `{ + "swagger": "2.0", + "info": { + "description": "This is a sample server Petstore server.", + "title": "Swagger Example API", + "termsOfService": "http://swagger.io/terms/", + "contact": { + "name": "API Support", + "url": "http://www.swagger.io/support", + "email": "support@swagger.io" + }, + "license": { + "name": "Apache 2.0", + "url": "http://www.apache.org/licenses/LICENSE-2.0.html" + }, + "version": "1.0" + }, + "host": "petstore.swagger.io", + "basePath": "/v2", + "paths": { + "/file/upload": { + "post": { + "description": "Upload file", + "consumes": [ + "multipart/form-data" + ], + "produces": [ + "application/json" + ], + "summary": "Upload file", + "operationId": "file.upload", + "parameters": [ + { + "type": "file", + "description": "this is a test file", + "name": "file", + "in": "formData", + "required": true + } + ], + "responses": { + "200": { + "description": "ok", + "schema": { + "type": "string" + } + }, + "400": { + "description": "We need ID!!", + "schema": { + "$ref": "#/definitions/web.APIError" + } + }, + "404": { + "description": "Can not find ID", + "schema": { + "$ref": "#/definitions/web.APIError" + } + } + } + } + }, + "/testapi/get-string-by-int/{some_id}": { + "get": { + "description": "get string by ID", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "summary": "Add a new pet to the store", + "operationId": "get-string-by-int", + "parameters": [ + { + "type": "integer", + "format": "int64", + "description": "Some ID", + "name": "some_id", + "in": "path", + "required": true + }, + { + "description": "Some ID", + "name": "some_id", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/web.Pet" + } + } + ], + "responses": { + "200": { + "description": "ok", + "schema": { + "type": "string" + } + }, + "400": { + "description": "We need ID!!", + "schema": { + "$ref": "#/definitions/web.APIError" + } + }, + "404": { + "description": "Can not find ID", + "schema": { + "$ref": "#/definitions/web.APIError" + } + } + } + } + }, + "/testapi/get-struct-array-by-string/{some_id}": { + "get": { + "security": [ + { + "ApiKeyAuth": [] + }, + { + "BasicAuth": [] + }, + { + "OAuth2Application": [ + "write" + ] + }, + { + "OAuth2Implicit": [ + "read", + "admin" + ] + }, + { + "OAuth2AccessCode": [ + "read" + ] + }, + { + "OAuth2Password": [ + "admin" + ] + } + ], + "description": "get struct array by ID", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "operationId": "get-struct-array-by-string", + "parameters": [ + { + "type": "string", + "description": "Some ID", + "name": "some_id", + "in": "path", + "required": true + }, + { + "enum": [ + 1, + 2, + 3 + ], + "type": "integer", + "description": "Category", + "name": "category", + "in": "query", + "required": true + }, + { + "minimum": 0, + "type": "integer", + "default": 0, + "description": "Offset", + "name": "offset", + "in": "query", + "required": true + }, + { + "maximum": 50, + "type": "integer", + "default": 10, + "description": "Limit", + "name": "limit", + "in": "query", + "required": true + }, + { + "maxLength": 50, + "minLength": 1, + "type": "string", + "default": "\"\"", + "description": "q", + "name": "q", + "in": "query", + "required": true + } + ], + "responses": { + "200": { + "description": "ok", + "schema": { + "type": "string" + } + }, + "400": { + "description": "We need ID!!", + "schema": { + "$ref": "#/definitions/web.APIError" + } + }, + "404": { + "description": "Can not find ID", + "schema": { + "$ref": "#/definitions/web.APIError" + } + } + } + } + } + }, + "definitions": { + "web.APIError": { + "type": "object", + "properties": { + "createdAt": { + "type": "string" + }, + "errorCode": { + "type": "integer" + }, + "errorMessage": { + "type": "string" + } + } + }, + "web.Pet": { + "type": "object", + "properties": { + "category": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "example": 1 + }, + "name": { + "type": "string", + "example": "category_name" + }, + "photoURLs": { + "type": "array", + "items": { + "type": "string", + "format": "url" + }, + "example": [ + "http://test/image/1.jpg", + "http://test/image/2.jpg" + ] + }, + "smallCategory": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "example": 1 + }, + "name": { + "type": "string", + "example": "detail_category_name" + }, + "photoURLs": { + "type": "array", + "items": { + "type": "string" + }, + "example": [ + "http://test/image/1.jpg", + "http://test/image/2.jpg" + ] + } + } + } + } + }, + "data": {}, + "decimal": { + "type": "number" + }, + "id": { + "type": "integer", + "format": "int64", + "example": 1 + }, + "isAlive": { + "type": "boolean", + "example": true + }, + "name": { + "type": "string", + "example": "poti" + }, + "pets": { + "type": "array", + "items": { + "$ref": "#/definitions/web.Pet2" + } + }, + "pets2": { + "type": "array", + "items": { + "$ref": "#/definitions/web.Pet2" + } + }, + "photoURLs": { + "type": "array", + "items": { + "type": "string" + }, + "example": [ + "http://test/image/1.jpg", + "http://test/image/2.jpg" + ] + }, + "price": { + "type": "number", + "multipleOf": 0.01, + "example": 3.25 + }, + "status": { + "type": "string" + }, + "tags": { + "type": "array", + "items": { + "$ref": "#/definitions/web.Tag" + } + }, + "uuid": { + "type": "string" + } + } + }, + "web.Pet2": { + "type": "object", + "properties": { + "deletedAt": { + "type": "string" + }, + "id": { + "type": "integer" + }, + "middleName": { + "type": "string" + } + } + }, + "web.RevValue": { + "type": "object", + "properties": { + "data": { + "type": "integer" + }, + "err": { + "type": "integer", + "format": "int32" + }, + "status": { + "type": "boolean" + } + } + }, + "web.Tag": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "format": "int64" + }, + "name": { + "type": "string" + }, + "pets": { + "type": "array", + "items": { + "$ref": "#/definitions/web.Pet" + } + } + } + } + }, + "securityDefinitions": { + "ApiKeyAuth": { + "type": "apiKey", + "name": "Authorization", + "in": "header" + }, + "BasicAuth": { + "type": "basic" + }, + "OAuth2AccessCode": { + "type": "oauth2", + "flow": "accessCode", + "authorizationUrl": "https://example.com/oauth/authorize", + "tokenUrl": "https://example.com/oauth/token", + "scopes": { + "admin": "Grants read and write access to administrative information" + } + }, + "OAuth2Application": { + "type": "oauth2", + "flow": "application", + "tokenUrl": "https://example.com/oauth/token", + "scopes": { + "admin": "Grants read and write access to administrative information", + "write": "Grants write access" + } + }, + "OAuth2Implicit": { + "type": "oauth2", + "flow": "implicit", + "authorizationUrl": "https://example.com/oauth/authorize", + "scopes": { + "admin": "Grants read and write access to administrative information", + "write": "Grants write access" + } + }, + "OAuth2Password": { + "type": "oauth2", + "flow": "password", + "tokenUrl": "https://example.com/oauth/token", + "scopes": { + "admin": "Grants read and write access to administrative information", + "read": "Grants read access", + "write": "Grants write access" + } + } + } +}` + searchDir := "testdata/simple3" + p := New() + err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + + b, _ := json.MarshalIndent(p.swagger, "", " ") + assert.Equal(t, expected, string(b)) +} + +func TestParseStructComment(t *testing.T) { + t.Parallel() + + expected := `{ + "swagger": "2.0", + "info": { + "description": "This is a sample server Petstore server.", + "title": "Swagger Example API", + "contact": {}, + "version": "1.0" + }, + "host": "localhost:4000", + "basePath": "/api", + "paths": { + "/posts/{post_id}": { + "get": { + "description": "get string by ID", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "summary": "Add a new pet to the store", + "parameters": [ + { + "type": "integer", + "format": "int64", + "description": "Some ID", + "name": "post_id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "type": "string" + } + }, + "400": { + "description": "We need ID!!", + "schema": { + "$ref": "#/definitions/web.APIError" + } + }, + "404": { + "description": "Can not find ID", + "schema": { + "$ref": "#/definitions/web.APIError" + } + } + } + } + } + }, + "definitions": { + "web.APIError": { + "description": "API error with information about it", + "type": "object", + "properties": { + "createdAt": { + "description": "Error time", + "type": "string" + }, + "error": { + "description": "Error an Api error", + "type": "string" + }, + "errorCtx": { + "description": "Error ` + "`" + `context` + "`" + ` tick comment", + "type": "string" + }, + "errorNo": { + "description": "Error ` + "`" + `number` + "`" + ` tick comment", + "type": "integer", + "format": "int64" + } + } + } + } +}` + searchDir := "testdata/struct_comment" + p := New() + err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + b, _ := json.MarshalIndent(p.swagger, "", " ") + assert.Equal(t, expected, string(b)) +} + +func TestParseNonExportedJSONFields(t *testing.T) { + t.Parallel() + + expected := `{ + "swagger": "2.0", + "info": { + "description": "This is a sample server.", + "title": "Swagger Example API", + "contact": {}, + "version": "1.0" + }, + "host": "localhost:4000", + "basePath": "/api", + "paths": { + "/so-something": { + "get": { + "description": "Does something, but internal (non-exported) fields inside a struct won't be marshaled into JSON", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "summary": "Call DoSomething", + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/main.MyStruct" + } + } + } + } + } + }, + "definitions": { + "main.MyStruct": { + "type": "object", + "properties": { + "data": { + "description": "Post data", + "type": "object", + "properties": { + "name": { + "description": "Post tag", + "type": "array", + "items": { + "type": "string" + } + } + } + }, + "id": { + "type": "integer", + "format": "int64", + "example": 1 + }, + "name": { + "description": "Post name", + "type": "string", + "example": "poti" + } + } + } + } +}` + + searchDir := "testdata/non_exported_json_fields" + p := New() + err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + b, _ := json.MarshalIndent(p.swagger, "", " ") + assert.Equal(t, expected, string(b)) +} + +func TestParsePetApi(t *testing.T) { + t.Parallel() + + expected := `{ + "schemes": [ + "http", + "https" + ], + "swagger": "2.0", + "info": { + "description": "This is a sample server Petstore server. You can find out more about Swagger at [http://swagger.io](http://swagger.io) or on [irc.freenode.net, #swagger](http://swagger.io/irc/). For this sample, you can use the api key 'special-key' to test the authorization filters.", + "title": "Swagger Petstore", + "termsOfService": "http://swagger.io/terms/", + "contact": { + "email": "apiteam@swagger.io" + }, + "license": { + "name": "Apache 2.0", + "url": "http://www.apache.org/licenses/LICENSE-2.0.html" + }, + "version": "1.0" + }, + "host": "petstore.swagger.io", + "basePath": "/v2", + "paths": {} +}` + searchDir := "testdata/pet" + p := New() + err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + b, _ := json.MarshalIndent(p.swagger, "", " ") + assert.Equal(t, expected, string(b)) +} + +func TestParseModelAsTypeAlias(t *testing.T) { + t.Parallel() + + expected := `{ + "swagger": "2.0", + "info": { + "description": "This is a sample server Petstore server.", + "title": "Swagger Example API", + "termsOfService": "http://swagger.io/terms/", + "contact": { + "name": "API Support", + "url": "http://www.swagger.io/support", + "email": "support@swagger.io" + }, + "license": { + "name": "Apache 2.0", + "url": "http://www.apache.org/licenses/LICENSE-2.0.html" + }, + "version": "1.0" + }, + "host": "petstore.swagger.io", + "basePath": "/v2", + "paths": { + "/testapi/time-as-time-container": { + "get": { + "description": "test container with time and time alias", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "summary": "Get container with time and time alias", + "operationId": "time-as-time-container", + "responses": { + "200": { + "description": "ok", + "schema": { + "$ref": "#/definitions/data.TimeContainer" + } + } + } + } + } + }, + "definitions": { + "data.TimeContainer": { + "type": "object", + "properties": { + "created_at": { + "type": "string" + }, + "name": { + "type": "string" + }, + "timestamp": { + "type": "string" + } + } + } + } +}` + searchDir := "testdata/alias_type" + p := New() + err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + + b, _ := json.MarshalIndent(p.swagger, "", " ") + assert.Equal(t, expected, string(b)) +} + +func TestParseComposition(t *testing.T) { + t.Parallel() + + searchDir := "testdata/composition" + p := New() + err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + + expected, err := os.ReadFile(filepath.Join(searchDir, "expected.json")) + assert.NoError(t, err) + + b, _ := json.MarshalIndent(p.swagger, "", " ") + + // windows will fail: \r\n \n + assert.Equal(t, string(expected), string(b)) +} + +func TestParseImportAliases(t *testing.T) { + t.Parallel() + + searchDir := "testdata/alias_import" + p := New() + err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + + expected, err := os.ReadFile(filepath.Join(searchDir, "expected.json")) + assert.NoError(t, err) + + b, _ := json.MarshalIndent(p.swagger, "", " ") + // windows will fail: \r\n \n + assert.Equal(t, string(expected), string(b)) +} + +func TestParseTypeOverrides(t *testing.T) { + t.Parallel() + + searchDir := "testdata/global_override" + p := New(SetOverrides(map[string]string{ + "git.ipao.vip/rogeecn/atomctl/pkg/swag/testdata/global_override/types.Application": "string", + "git.ipao.vip/rogeecn/atomctl/pkg/swag/testdata/global_override/types.Application2": "git.ipao.vip/rogeecn/atomctl/pkg/swag/testdata/global_override/othertypes.Application", + "git.ipao.vip/rogeecn/atomctl/pkg/swag/testdata/global_override/types.ShouldSkip": "", + })) + err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + + expected, err := os.ReadFile(filepath.Join(searchDir, "expected.json")) + assert.NoError(t, err) + + b, _ := json.MarshalIndent(p.swagger, "", " ") + // windows will fail: \r\n \n + assert.Equal(t, string(expected), string(b)) +} + +func TestGlobalSecurity(t *testing.T) { + t.Parallel() + + searchDir := "testdata/global_security" + p := New() + err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + + expected, err := os.ReadFile(filepath.Join(searchDir, "expected.json")) + assert.NoError(t, err) + + b, _ := json.MarshalIndent(p.swagger, "", " ") + assert.Equal(t, string(expected), string(b)) +} + +func TestParseNested(t *testing.T) { + t.Parallel() + + searchDir := "testdata/nested" + p := New(SetParseDependency(1)) + err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + + expected, err := os.ReadFile(filepath.Join(searchDir, "expected.json")) + assert.NoError(t, err) + + b, _ := json.MarshalIndent(p.swagger, "", " ") + assert.Equal(t, string(expected), string(b)) +} + +func TestParseDuplicated(t *testing.T) { + t.Parallel() + + searchDir := "testdata/duplicated" + p := New(SetParseDependency(1)) + err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.Errorf(t, err, "duplicated @id declarations successfully found") +} + +func TestParseDuplicatedOtherMethods(t *testing.T) { + t.Parallel() + + searchDir := "testdata/duplicated2" + p := New(SetParseDependency(1)) + err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.Errorf(t, err, "duplicated @id declarations successfully found") +} + +func TestParseDuplicatedFunctionScoped(t *testing.T) { + t.Parallel() + + searchDir := "testdata/duplicated_function_scoped" + p := New(SetParseDependency(1)) + err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.Errorf(t, err, "duplicated @id declarations successfully found") +} + +func TestParseConflictSchemaName(t *testing.T) { + t.Parallel() + + searchDir := "testdata/conflict_name" + p := New(SetParseDependency(1)) + err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + b, _ := json.MarshalIndent(p.swagger, "", " ") + expected, err := os.ReadFile(filepath.Join(searchDir, "expected.json")) + assert.NoError(t, err) + assert.Equal(t, string(expected), string(b)) +} + +func TestParseExternalModels(t *testing.T) { + searchDir := "testdata/external_models/main" + mainAPIFile := "main.go" + p := New(SetParseDependency(1)) + err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + b, _ := json.MarshalIndent(p.swagger, "", " ") + // ioutil.WriteFile("./testdata/external_models/main/expected.json",b,0777) + expected, err := os.ReadFile(filepath.Join(searchDir, "expected.json")) + assert.NoError(t, err) + assert.Equal(t, string(expected), string(b)) +} + +func TestParseGoList(t *testing.T) { + mainAPIFile := "main.go" + p := New(ParseUsingGoList(true), SetParseDependency(1)) + go111moduleEnv := os.Getenv("GO111MODULE") + + cases := []struct { + name string + gomodule bool + searchDir string + err error + run func(searchDir string) error + }{ + { + name: "disableGOMODULE", + gomodule: false, + searchDir: "testdata/golist_disablemodule", + run: func(searchDir string) error { + return p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + }, + }, + { + name: "enableGOMODULE", + gomodule: true, + searchDir: "testdata/golist", + run: func(searchDir string) error { + return p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + }, + }, + { + name: "invalid_main", + gomodule: true, + searchDir: "testdata/golist_invalid", + err: errors.New("no such file or directory"), + run: func(searchDir string) error { + return p.ParseAPI(searchDir, "invalid/main.go", defaultParseDepth) + }, + }, + { + name: "internal_invalid_pkg", + gomodule: true, + searchDir: "testdata/golist_invalid", + err: errors.New("expected 'package', found This"), + run: func(searchDir string) error { + mockErrGoFile := "testdata/golist_invalid/err.go" + f, err := os.OpenFile(mockErrGoFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644) + if err != nil { + return err + } + defer f.Close() + _, err = f.Write([]byte(`package invalid + +function a() {}`)) + if err != nil { + return err + } + defer os.Remove(mockErrGoFile) + return p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + }, + }, + { + name: "invalid_pkg", + gomodule: true, + searchDir: "testdata/golist_invalid", + err: errors.New("expected 'package', found This"), + run: func(searchDir string) error { + mockErrGoFile := "testdata/invalid_external_pkg/invalid/err.go" + f, err := os.OpenFile(mockErrGoFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644) + if err != nil { + return err + } + defer f.Close() + _, err = f.Write([]byte(`package invalid + +function a() {}`)) + if err != nil { + return err + } + defer os.Remove(mockErrGoFile) + return p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + }, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if c.gomodule { + os.Setenv("GO111MODULE", "on") + } else { + os.Setenv("GO111MODULE", "off") + } + err := c.run(c.searchDir) + os.Setenv("GO111MODULE", go111moduleEnv) + if c.err == nil { + assert.NoError(t, err) + } else { + assert.Error(t, err) + } + }) + } +} + +func TestParser_ParseStructArrayObject(t *testing.T) { + t.Parallel() + + src := ` +package api + +type Response struct { + Code int + Table [][]string + Data []struct{ + Field1 uint + Field2 string + } +} + +// @Success 200 {object} Response +// @Router /api/{id} [get] +func Test(){ +} +` + expected := `{ + "api.Response": { + "type": "object", + "properties": { + "code": { + "type": "integer" + }, + "data": { + "type": "array", + "items": { + "type": "object", + "properties": { + "field1": { + "type": "integer" + }, + "field2": { + "type": "string" + } + } + } + }, + "table": { + "type": "array", + "items": { + "type": "array", + "items": { + "type": "string" + } + } + } + } + } +}` + + p := New() + _ = p.packages.ParseFile("api", "api/api.go", src, ParseAll) + _, err := p.packages.ParseTypes() + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + out, err := json.MarshalIndent(p.swagger.Definitions, "", " ") + assert.NoError(t, err) + assert.Equal(t, expected, string(out)) +} + +func TestParser_ParseEmbededStruct(t *testing.T) { + t.Parallel() + + src := ` +package api + +type Response struct { + rest.ResponseWrapper +} + +// @Success 200 {object} Response +// @Router /api/{id} [get] +func Test(){ +} +` + restsrc := ` +package rest + +type ResponseWrapper struct { + Status string + Code int + Messages []string + Result interface{} +} +` + expected := `{ + "api.Response": { + "type": "object", + "properties": { + "code": { + "type": "integer" + }, + "messages": { + "type": "array", + "items": { + "type": "string" + } + }, + "result": {}, + "status": { + "type": "string" + } + } + } +}` + parser := New(SetParseDependency(1)) + + _ = parser.packages.ParseFile("api", "api/api.go", src, ParseAll) + + _ = parser.packages.ParseFile("rest", "rest/rest.go", restsrc, ParseAll) + + _, err := parser.packages.ParseTypes() + assert.NoError(t, err) + + err = parser.packages.RangeFiles(parser.ParseRouterAPIInfo) + assert.NoError(t, err) + + out, err := json.MarshalIndent(parser.swagger.Definitions, "", " ") + assert.NoError(t, err) + assert.Equal(t, expected, string(out)) +} + +func TestParser_ParseStructPointerMembers(t *testing.T) { + t.Parallel() + + src := ` +package api + +type Child struct { + Name string +} + +type Parent struct { + Test1 *string //test1 + Test2 *Child //test2 +} + +// @Success 200 {object} Parent +// @Router /api/{id} [get] +func Test(){ +} +` + + expected := `{ + "api.Child": { + "type": "object", + "properties": { + "name": { + "type": "string" + } + } + }, + "api.Parent": { + "type": "object", + "properties": { + "test1": { + "description": "test1", + "type": "string" + }, + "test2": { + "description": "test2", + "allOf": [ + { + "$ref": "#/definitions/api.Child" + } + ] + } + } + } +}` + p := New() + _ = p.packages.ParseFile("api", "api/api.go", src, ParseAll) + _, err := p.packages.ParseTypes() + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + out, err := json.MarshalIndent(p.swagger.Definitions, "", " ") + assert.NoError(t, err) + assert.Equal(t, expected, string(out)) +} + +func TestParser_ParseStructMapMember(t *testing.T) { + t.Parallel() + + src := ` +package api + +type MyMapType map[string]string + +type Child struct { + Name string +} + +type Parent struct { + Test1 map[string]interface{} //test1 + Test2 map[string]string //test2 + Test3 map[string]*string //test3 + Test4 map[string]Child //test4 + Test5 map[string]*Child //test5 + Test6 MyMapType //test6 + Test7 []Child //test7 + Test8 []*Child //test8 + Test9 []map[string]string //test9 +} + +// @Success 200 {object} Parent +// @Router /api/{id} [get] +func Test(){ +} +` + expected := `{ + "api.Child": { + "type": "object", + "properties": { + "name": { + "type": "string" + } + } + }, + "api.MyMapType": { + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "api.Parent": { + "type": "object", + "properties": { + "test1": { + "description": "test1", + "type": "object", + "additionalProperties": true + }, + "test2": { + "description": "test2", + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "test3": { + "description": "test3", + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "test4": { + "description": "test4", + "type": "object", + "additionalProperties": { + "$ref": "#/definitions/api.Child" + } + }, + "test5": { + "description": "test5", + "type": "object", + "additionalProperties": { + "$ref": "#/definitions/api.Child" + } + }, + "test6": { + "description": "test6", + "allOf": [ + { + "$ref": "#/definitions/api.MyMapType" + } + ] + }, + "test7": { + "description": "test7", + "type": "array", + "items": { + "$ref": "#/definitions/api.Child" + } + }, + "test8": { + "description": "test8", + "type": "array", + "items": { + "$ref": "#/definitions/api.Child" + } + }, + "test9": { + "description": "test9", + "type": "array", + "items": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + } + } + } +}` + p := New() + _ = p.packages.ParseFile("api", "api/api.go", src, ParseAll) + + _, err := p.packages.ParseTypes() + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + out, err := json.MarshalIndent(p.swagger.Definitions, "", " ") + assert.NoError(t, err) + assert.Equal(t, expected, string(out)) +} + +func TestParser_ParseRouterApiInfoErr(t *testing.T) { + t.Parallel() + + src := ` +package test + +// @Accept unknown +func Test(){ +} +` + p := New() + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.Error(t, err) +} + +func TestParser_ParseRouterApiGet(t *testing.T) { + t.Parallel() + + src := ` +package test + +// @Router /api/{id} [get] +func Test(){ +} +` + p := New() + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + ps := p.swagger.Paths.Paths + + val, ok := ps["/api/{id}"] + + assert.True(t, ok) + assert.NotNil(t, val.Get) +} + +func TestParser_ParseRouterApiPOST(t *testing.T) { + t.Parallel() + + src := ` +package test + +// @Router /api/{id} [post] +func Test(){ +} +` + p := New() + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + ps := p.swagger.Paths.Paths + + val, ok := ps["/api/{id}"] + + assert.True(t, ok) + assert.NotNil(t, val.Post) +} + +func TestParser_ParseRouterApiDELETE(t *testing.T) { + t.Parallel() + + src := ` +package test + +// @Router /api/{id} [delete] +func Test(){ +} +` + p := New() + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + ps := p.swagger.Paths.Paths + + val, ok := ps["/api/{id}"] + + assert.True(t, ok) + assert.NotNil(t, val.Delete) +} + +func TestParser_ParseRouterApiPUT(t *testing.T) { + t.Parallel() + + src := ` +package test + +// @Router /api/{id} [put] +func Test(){ +} +` + p := New() + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + ps := p.swagger.Paths.Paths + + val, ok := ps["/api/{id}"] + + assert.True(t, ok) + assert.NotNil(t, val.Put) +} + +func TestParser_ParseRouterApiPATCH(t *testing.T) { + t.Parallel() + + src := ` +package test + +// @Router /api/{id} [patch] +func Test(){ +} +` + p := New() + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + ps := p.swagger.Paths.Paths + + val, ok := ps["/api/{id}"] + + assert.True(t, ok) + assert.NotNil(t, val.Patch) +} + +func TestParser_ParseRouterApiHead(t *testing.T) { + t.Parallel() + + src := ` +package test + +// @Router /api/{id} [head] +func Test(){ +} +` + p := New() + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + ps := p.swagger.Paths.Paths + + val, ok := ps["/api/{id}"] + + assert.True(t, ok) + assert.NotNil(t, val.Head) +} + +func TestParser_ParseRouterApiOptions(t *testing.T) { + t.Parallel() + + src := ` +package test + +// @Router /api/{id} [options] +func Test(){ +} +` + p := New() + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + ps := p.swagger.Paths.Paths + + val, ok := ps["/api/{id}"] + + assert.True(t, ok) + assert.NotNil(t, val.Options) +} + +func TestParser_ParseRouterApiMultipleRoutesForSameFunction(t *testing.T) { + t.Parallel() + + src := ` +package test + +// @Router /api/v1/{id} [get] +// @Router /api/v2/{id} [post] +func Test(){ +} +` + p := New() + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + ps := p.swagger.Paths.Paths + + val, ok := ps["/api/v1/{id}"] + + assert.True(t, ok) + assert.NotNil(t, val.Get) + + val, ok = ps["/api/v2/{id}"] + + assert.True(t, ok) + assert.NotNil(t, val.Post) +} + +func TestParser_ParseRouterApiMultiple(t *testing.T) { + t.Parallel() + + src := ` +package test + +// @Router /api/{id} [get] +func Test1(){ +} + +// @Router /api/{id} [patch] +func Test2(){ +} + +// @Router /api/{id} [delete] +func Test3(){ +} +` + p := New() + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + ps := p.swagger.Paths.Paths + + val, ok := ps["/api/{id}"] + + assert.True(t, ok) + assert.NotNil(t, val.Get) + assert.NotNil(t, val.Patch) + assert.NotNil(t, val.Delete) +} + +func TestParser_ParseRouterApiMultiplePathsWithMultipleParams(t *testing.T) { + t.Parallel() + + src := ` +package test + +// @Success 200 +// @Param group_id path int true "Group ID" +// @Param user_id path int true "User ID" +// @Router /examples/groups/{group_id}/user/{user_id}/address [get] +// @Router /examples/user/{user_id}/address [get] +func Test(){ +} +` + p := New() + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + ps := p.swagger.Paths.Paths + + val, ok := ps["/examples/groups/{group_id}/user/{user_id}/address"] + + assert.True(t, ok) + assert.Equal(t, 2, len(val.Get.Parameters)) + + val, ok = ps["/examples/user/{user_id}/address"] + + assert.True(t, ok) + assert.Equal(t, 1, len(val.Get.Parameters)) +} + +// func TestParseDeterministic(t *testing.T) { +// mainAPIFile := "main.go" +// for _, searchDir := range []string{ +// "testdata/simple", +// "testdata/model_not_under_root/cmd", +// } { +// t.Run(searchDir, func(t *testing.T) { +// var expected string + +// // run the same code 100 times and check that the output is the same every time +// for i := 0; i < 100; i++ { +// p := New() +// p.PropNamingStrategy = PascalCase +// err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) +// b, _ := json.MarshalIndent(p.swagger, "", " ") +// assert.NotEqual(t, "", string(b)) + +// if expected == "" { +// expected = string(b) +// } + +// assert.Equal(t, expected, string(b)) +// } +// }) +// } +// } + +func TestParser_ParseRouterApiDuplicateRoute(t *testing.T) { + t.Parallel() + + src := ` +package api + +import ( + "net/http" +) + +// @Router /api/endpoint [get] +func FunctionOne(w http.ResponseWriter, r *http.Request) { + //write your code +} + +// @Router /api/endpoint [get] +func FunctionTwo(w http.ResponseWriter, r *http.Request) { + //write your code +} + +` + p := New(SetStrict(true)) + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.EqualError(t, err, "route GET /api/endpoint is declared multiple times") + + p = New() + err = p.packages.ParseFile("api", "api/api.go", src, ParseAll) + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) +} + +func TestApiParseTag(t *testing.T) { + t.Parallel() + + searchDir := "testdata/tags" + p := New(SetMarkdownFileDirectory(searchDir)) + p.PropNamingStrategy = PascalCase + err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + + if len(p.swagger.Tags) != 3 { + t.Error("Number of tags did not match") + } + + dogs := p.swagger.Tags[0] + if dogs.TagProps.Name != "dogs" || dogs.TagProps.Description != "Dogs are cool" { + t.Error("Failed to parse dogs name or description") + } + + cats := p.swagger.Tags[1] + if cats.TagProps.Name != "cats" || cats.TagProps.Description != "Cats are the devil" { + t.Error("Failed to parse cats name or description") + } + + if cats.TagProps.ExternalDocs.URL != "https://google.de" || cats.TagProps.ExternalDocs.Description != "google is super useful to find out that cats are evil!" { + t.Error("URL: ", cats.TagProps.ExternalDocs.URL) + t.Error("Description: ", cats.TagProps.ExternalDocs.Description) + t.Error("Failed to parse cats external documentation") + } +} + +func TestApiParseTag_NonExistendTag(t *testing.T) { + t.Parallel() + + searchDir := "testdata/tags_nonexistend_tag" + p := New(SetMarkdownFileDirectory(searchDir)) + p.PropNamingStrategy = PascalCase + err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.Error(t, err) +} + +func TestParseTagMarkdownDescription(t *testing.T) { + t.Parallel() + + searchDir := "testdata/tags" + p := New(SetMarkdownFileDirectory(searchDir)) + p.PropNamingStrategy = PascalCase + err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + if err != nil { + t.Error("Failed to parse api description: " + err.Error()) + } + + if len(p.swagger.Tags) != 3 { + t.Error("Number of tags did not match") + } + + apes := p.swagger.Tags[2] + if apes.TagProps.Description == "" { + t.Error("Failed to parse tag description markdown file") + } +} + +func TestParseApiMarkdownDescription(t *testing.T) { + t.Parallel() + + searchDir := "testdata/tags" + p := New(SetMarkdownFileDirectory(searchDir)) + p.PropNamingStrategy = PascalCase + err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + if err != nil { + t.Error("Failed to parse api description: " + err.Error()) + } + + if p.swagger.Info.Description == "" { + t.Error("Failed to parse api description: " + err.Error()) + } +} + +func TestIgnoreInvalidPkg(t *testing.T) { + t.Parallel() + + searchDir := "testdata/deps_having_invalid_pkg" + p := New() + if err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth); err != nil { + t.Error("Failed to ignore valid pkg: " + err.Error()) + } +} + +func TestFixes432(t *testing.T) { + t.Parallel() + + searchDir := "testdata/fixes-432" + mainAPIFile := "cmd/main.go" + + p := New() + if err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth); err != nil { + t.Error("Failed to ignore valid pkg: " + err.Error()) + } +} + +func TestParseOutsideDependencies(t *testing.T) { + t.Parallel() + + searchDir := "testdata/pare_outside_dependencies" + mainAPIFile := "cmd/main.go" + + p := New(SetParseDependency(1)) + if err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth); err != nil { + t.Error("Failed to parse api: " + err.Error()) + } +} + +func TestParseStructParamCommentByQueryType(t *testing.T) { + t.Parallel() + + src := ` +package main + +type Student struct { + Name string + Age int + Teachers []string + SkipField map[string]string +} + +// @Param request query Student true "query params" +// @Success 200 +// @Router /test [get] +func Fun() { + +} +` + expected := `{ + "info": { + "contact": {} + }, + "paths": { + "/test": { + "get": { + "parameters": [ + { + "type": "integer", + "name": "age", + "in": "query" + }, + { + "type": "string", + "name": "name", + "in": "query" + }, + { + "type": "array", + "items": { + "type": "string" + }, + "name": "teachers", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK" + } + } + } + } + } +}` + + p := New() + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) + assert.NoError(t, err) + + _, err = p.packages.ParseTypes() + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + b, _ := json.MarshalIndent(p.swagger, "", " ") + assert.Equal(t, expected, string(b)) +} + +func TestParseParamCommentExtension(t *testing.T) { + t.Parallel() + + src := ` +package main + +// @Param request query string true "query params" extensions(x-example=[0, 9],x-foo=bar) +// @Success 200 +// @Router /test [get] +func Fun() { + +} +` + expected := `{ + "info": { + "contact": {} + }, + "paths": { + "/test": { + "get": { + "parameters": [ + { + "type": "string", + "x-example": "[0, 9]", + "x-foo": "bar", + "description": "query params", + "name": "request", + "in": "query", + "required": true + } + ], + "responses": { + "200": { + "description": "OK" + } + } + } + } + } +}` + + p := New() + _ = p.packages.ParseFile("api", "api/api.go", src, ParseAll) + + _, err := p.packages.ParseTypes() + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + b, _ := json.MarshalIndent(p.swagger, "", " ") + assert.JSONEq(t, expected, string(b)) +} + +func TestParseRenamedStructDefinition(t *testing.T) { + t.Parallel() + + src := ` +package main + +type Child struct { + Name string +}//@name Student + +type Parent struct { + Name string + Child Child +}//@name Teacher + +// @Param request body Parent true "query params" +// @Success 200 {object} Parent +// @Router /test [get] +func Fun() { + +} +` + + p := New() + _ = p.packages.ParseFile("api", "api/api.go", src, ParseAll) + _, err := p.packages.ParseTypes() + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + assert.NoError(t, err) + teacher, ok := p.swagger.Definitions["Teacher"] + assert.True(t, ok) + ref := teacher.Properties["child"].SchemaProps.Ref + assert.Equal(t, "#/definitions/Student", ref.String()) + _, ok = p.swagger.Definitions["Student"] + assert.True(t, ok) + path, ok := p.swagger.Paths.Paths["/test"] + assert.True(t, ok) + assert.Equal(t, "#/definitions/Teacher", path.Get.Parameters[0].Schema.Ref.String()) + ref = path.Get.Responses.ResponsesProps.StatusCodeResponses[200].ResponseProps.Schema.Ref + assert.Equal(t, "#/definitions/Teacher", ref.String()) +} + +func TestParseTabFormattedRenamedStructDefinition(t *testing.T) { + t.Parallel() + + src := "package main\n" + + "\n" + + "type Child struct {\n" + + "\tName string\n" + + "}\t//\t@name\tPupil\n" + + "\n" + + "// @Success 200 {object} Pupil\n" + + "func Fun() { }" + + p := New() + _ = p.packages.ParseFile("api", "api/api.go", src, ParseAll) + _, err := p.packages.ParseTypes() + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + _, ok := p.swagger.Definitions["Pupil"] + assert.True(t, ok) +} + +func TestParseFunctionScopedStructDefinition(t *testing.T) { + t.Parallel() + + src := ` +package main + +// @Param request body main.Fun.request true "query params" +// @Success 200 {object} main.Fun.response +// @Router /test [post] +func Fun() { + type request struct { + Name string + } + + type response struct { + Name string + Child string + } +} +` + p := New() + _ = p.packages.ParseFile("api", "api/api.go", src, ParseAll) + _, err := p.packages.ParseTypes() + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + _, ok := p.swagger.Definitions["main.Fun.response"] + assert.True(t, ok) +} + +func TestParseFunctionScopedComplexStructDefinition(t *testing.T) { + t.Parallel() + + src := ` +package main + +// @Param request body main.Fun.request true "query params" +// @Success 200 {object} main.Fun.response +// @Router /test [post] +func Fun() { + type request struct { + Name string + } + + type grandChild struct { + Name string + } + + type pointerChild struct { + Name string + } + + type arrayChild struct { + Name string + } + + type child struct { + GrandChild grandChild + PointerChild *pointerChild + ArrayChildren []arrayChild + } + + type response struct { + Children []child + } +} +` + p := New() + _ = p.packages.ParseFile("api", "api/api.go", src, ParseAll) + _, err := p.packages.ParseTypes() + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + _, ok := p.swagger.Definitions["main.Fun.response"] + assert.True(t, ok) + _, ok = p.swagger.Definitions["main.Fun.child"] + assert.True(t, ok) + _, ok = p.swagger.Definitions["main.Fun.grandChild"] + assert.True(t, ok) + _, ok = p.swagger.Definitions["main.Fun.pointerChild"] + assert.True(t, ok) + _, ok = p.swagger.Definitions["main.Fun.arrayChild"] + assert.True(t, ok) +} + +func TestParseFunctionScopedStructRequestResponseJSON(t *testing.T) { + t.Parallel() + + src := ` +package main + +// @Param request body main.Fun.request true "query params" +// @Success 200 {object} main.Fun.response +// @Router /test [post] +func Fun() { + type request struct { + Name string + } + + type response struct { + Name string + Child string + } +} +` + expected := `{ + "info": { + "contact": {} + }, + "paths": { + "/test": { + "post": { + "parameters": [ + { + "description": "query params", + "name": "request", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/main.Fun.request" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/main.Fun.response" + } + } + } + } + } + }, + "definitions": { + "main.Fun.request": { + "type": "object", + "properties": { + "name": { + "type": "string" + } + } + }, + "main.Fun.response": { + "type": "object", + "properties": { + "child": { + "type": "string" + }, + "name": { + "type": "string" + } + } + } + } +}` + + p := New() + _ = p.packages.ParseFile("api", "api/api.go", src, ParseAll) + + _, err := p.packages.ParseTypes() + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + b, _ := json.MarshalIndent(p.swagger, "", " ") + assert.Equal(t, expected, string(b)) +} + +func TestParseFunctionScopedComplexStructRequestResponseJSON(t *testing.T) { + t.Parallel() + + src := ` +package main + +type PublicChild struct { + Name string +} + +// @Param request body main.Fun.request true "query params" +// @Success 200 {object} main.Fun.response +// @Router /test [post] +func Fun() { + type request struct { + Name string + } + + type grandChild struct { + Name string + } + + type child struct { + GrandChild grandChild + } + + type response struct { + Children []child + PublicChild PublicChild + } +} +` + expected := `{ + "info": { + "contact": {} + }, + "paths": { + "/test": { + "post": { + "parameters": [ + { + "description": "query params", + "name": "request", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/main.Fun.request" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/main.Fun.response" + } + } + } + } + } + }, + "definitions": { + "main.Fun.child": { + "type": "object", + "properties": { + "grandChild": { + "$ref": "#/definitions/main.Fun.grandChild" + } + } + }, + "main.Fun.grandChild": { + "type": "object", + "properties": { + "name": { + "type": "string" + } + } + }, + "main.Fun.request": { + "type": "object", + "properties": { + "name": { + "type": "string" + } + } + }, + "main.Fun.response": { + "type": "object", + "properties": { + "children": { + "type": "array", + "items": { + "$ref": "#/definitions/main.Fun.child" + } + }, + "publicChild": { + "$ref": "#/definitions/main.PublicChild" + } + } + }, + "main.PublicChild": { + "type": "object", + "properties": { + "name": { + "type": "string" + } + } + } + } +}` + + p := New() + _ = p.packages.ParseFile("api", "api/api.go", src, ParseAll) + + _, err := p.packages.ParseTypes() + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + b, _ := json.MarshalIndent(p.swagger, "", " ") + assert.Equal(t, expected, string(b)) +} + +func TestPackagesDefinitions_CollectAstFileInit(t *testing.T) { + t.Parallel() + + src := ` +package main + +// @Router /test [get] +func Fun() { + +} +` + pkgs := NewPackagesDefinitions() + + // unset the .files and .packages and check that they're re-initialized by collectAstFile + pkgs.packages = nil + pkgs.files = nil + + _ = pkgs.ParseFile("api", "api/api.go", src, ParseAll) + assert.NotNil(t, pkgs.packages) + assert.NotNil(t, pkgs.files) +} + +func TestCollectAstFileMultipleTimes(t *testing.T) { + t.Parallel() + + src := ` +package main + +// @Router /test [get] +func Fun() { + +} +` + + p := New() + _ = p.packages.ParseFile("api", "api/api.go", src, ParseAll) + assert.Equal(t, 1, len(p.packages.files)) + var path string + var file *ast.File + for path, file = range p.packages.packages["api"].Files { + break + } + assert.NotNil(t, file) + assert.NotNil(t, p.packages.files[file]) + + // if we collect the same again nothing should happen + _ = p.packages.ParseFile("api", "api/api.go", src, ParseAll) + assert.Equal(t, 1, len(p.packages.files)) + assert.Equal(t, file, p.packages.packages["api"].Files[path]) + assert.NotNil(t, p.packages.files[file]) +} + +func TestParseJSONFieldString(t *testing.T) { + t.Parallel() + + expected := `{ + "swagger": "2.0", + "info": { + "description": "This is a sample server.", + "title": "Swagger Example API", + "contact": {}, + "version": "1.0" + }, + "host": "localhost:4000", + "basePath": "/", + "paths": { + "/do-something": { + "post": { + "description": "Does something", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "summary": "Call DoSomething", + "parameters": [ + { + "description": "My Struct", + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/main.MyStruct" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/main.MyStruct" + } + }, + "500": { + "description": "Internal Server Error" + } + } + } + } + }, + "definitions": { + "main.MyStruct": { + "type": "object", + "properties": { + "boolvar": { + "description": "boolean as a string", + "type": "string", + "example": "false" + }, + "floatvar": { + "description": "float as a string", + "type": "string", + "example": "0" + }, + "id": { + "type": "integer", + "format": "int64", + "example": 1 + }, + "myint": { + "description": "integer as string", + "type": "string", + "example": "0" + }, + "name": { + "type": "string", + "example": "poti" + }, + "truebool": { + "description": "boolean as a string", + "type": "string", + "example": "true" + }, + "uuids": { + "description": "string array with format", + "type": "array", + "items": { + "type": "string", + "format": "uuid" + } + } + } + } + } +}` + + searchDir := "testdata/json_field_string" + p := New() + err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + b, _ := json.MarshalIndent(p.swagger, "", " ") + assert.Equal(t, expected, string(b)) +} + +func TestParseSwaggerignoreForEmbedded(t *testing.T) { + t.Parallel() + + src := ` +package main + +type Child struct { + ChildName string +}//@name Student + +type Parent struct { + Name string + Child ` + "`swaggerignore:\"true\"`" + ` +}//@name Teacher + +// @Param request body Parent true "query params" +// @Success 200 {object} Parent +// @Router /test [get] +func Fun() { + +} +` + p := New() + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) + assert.NoError(t, err) + _, _ = p.packages.ParseTypes() + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + teacher, ok := p.swagger.Definitions["Teacher"] + assert.True(t, ok) + + name, ok := teacher.Properties["name"] + assert.True(t, ok) + assert.Len(t, name.Type, 1) + assert.Equal(t, "string", name.Type[0]) + + childName, ok := teacher.Properties["childName"] + assert.False(t, ok) + assert.Empty(t, childName) +} + +func TestDefineTypeOfExample(t *testing.T) { + t.Run("String type", func(t *testing.T) { + t.Parallel() + + example, err := defineTypeOfExample("string", "", "example") + assert.NoError(t, err) + assert.Equal(t, example.(string), "example") + }) + + t.Run("Number type", func(t *testing.T) { + t.Parallel() + + example, err := defineTypeOfExample("number", "", "12.34") + assert.NoError(t, err) + assert.Equal(t, example.(float64), 12.34) + + _, err = defineTypeOfExample("number", "", "two") + assert.Error(t, err) + }) + + t.Run("Integer type", func(t *testing.T) { + t.Parallel() + + example, err := defineTypeOfExample("integer", "", "12") + assert.NoError(t, err) + assert.Equal(t, example.(int), 12) + + _, err = defineTypeOfExample("integer", "", "two") + assert.Error(t, err) + }) + + t.Run("Boolean type", func(t *testing.T) { + t.Parallel() + + example, err := defineTypeOfExample("boolean", "", "true") + assert.NoError(t, err) + assert.Equal(t, example.(bool), true) + + _, err = defineTypeOfExample("boolean", "", "!true") + assert.Error(t, err) + }) + + t.Run("Array type", func(t *testing.T) { + t.Parallel() + + example, err := defineTypeOfExample("array", "", "one,two,three") + assert.Error(t, err) + assert.Nil(t, example) + + example, err = defineTypeOfExample("array", "string", "one,two,three") + assert.NoError(t, err) + + var arr []string + + for _, v := range example.([]interface{}) { + arr = append(arr, v.(string)) + } + + assert.Equal(t, arr, []string{"one", "two", "three"}) + }) + + t.Run("Object type", func(t *testing.T) { + t.Parallel() + + example, err := defineTypeOfExample("object", "", "key_one:one,key_two:two,key_three:three") + assert.Error(t, err) + assert.Nil(t, example) + + example, err = defineTypeOfExample("object", "string", "key_one,key_two,key_three") + assert.Error(t, err) + assert.Nil(t, example) + + example, err = defineTypeOfExample("object", "oops", "key_one:one,key_two:two,key_three:three") + assert.Error(t, err) + assert.Nil(t, example) + + example, err = defineTypeOfExample("object", "string", "key_one:one,key_two:two,key_three:three") + assert.NoError(t, err) + obj := map[string]string{} + + for k, v := range example.(map[string]interface{}) { + obj[k] = v.(string) + } + + assert.Equal(t, obj, map[string]string{"key_one": "one", "key_two": "two", "key_three": "three"}) + }) + + t.Run("Invalid type", func(t *testing.T) { + t.Parallel() + + example, err := defineTypeOfExample("oops", "", "") + assert.Error(t, err) + assert.Nil(t, example) + }) +} + +type mockFS struct { + os.FileInfo + FileName string + IsDirectory bool +} + +func (fs *mockFS) Name() string { + return fs.FileName +} + +func (fs *mockFS) IsDir() bool { + return fs.IsDirectory +} + +func TestParser_Skip(t *testing.T) { + t.Parallel() + + parser := New() + parser.ParseVendor = true + + assert.NoError(t, parser.Skip("", &mockFS{FileName: "vendor"})) + assert.NoError(t, parser.Skip("", &mockFS{FileName: "vendor", IsDirectory: true})) + + parser.ParseVendor = false + assert.NoError(t, parser.Skip("", &mockFS{FileName: "vendor"})) + assert.Error(t, parser.Skip("", &mockFS{FileName: "vendor", IsDirectory: true})) + + assert.NoError(t, parser.Skip("", &mockFS{FileName: "models", IsDirectory: true})) + assert.NoError(t, parser.Skip("", &mockFS{FileName: "admin", IsDirectory: true})) + assert.NoError(t, parser.Skip("", &mockFS{FileName: "release", IsDirectory: true})) + assert.NoError(t, parser.Skip("", &mockFS{FileName: "..", IsDirectory: true})) + + parser = New(SetExcludedDirsAndFiles("admin/release,admin/models")) + assert.NoError(t, parser.Skip("admin", &mockFS{IsDirectory: true})) + assert.NoError(t, parser.Skip(filepath.Clean("admin/service"), &mockFS{IsDirectory: true})) + assert.Error(t, parser.Skip(filepath.Clean("admin/models"), &mockFS{IsDirectory: true})) + assert.Error(t, parser.Skip(filepath.Clean("admin/release"), &mockFS{IsDirectory: true})) +} + +func TestGetFieldType(t *testing.T) { + t.Parallel() + + field, err := getFieldType(&ast.File{}, &ast.Ident{Name: "User"}, nil) + assert.NoError(t, err) + assert.Equal(t, "User", field) + + _, err = getFieldType(&ast.File{}, &ast.FuncType{}, nil) + assert.Error(t, err) + + field, err = getFieldType(&ast.File{}, &ast.SelectorExpr{X: &ast.Ident{Name: "models"}, Sel: &ast.Ident{Name: "User"}}, nil) + assert.NoError(t, err) + assert.Equal(t, "models.User", field) + + _, err = getFieldType(&ast.File{}, &ast.SelectorExpr{X: &ast.FuncType{}, Sel: &ast.Ident{Name: "User"}}, nil) + assert.Error(t, err) + + field, err = getFieldType(&ast.File{}, &ast.StarExpr{X: &ast.Ident{Name: "User"}}, nil) + assert.NoError(t, err) + assert.Equal(t, "User", field) + + field, err = getFieldType(&ast.File{}, &ast.StarExpr{X: &ast.FuncType{}}, nil) + assert.Error(t, err) + + field, err = getFieldType(&ast.File{}, &ast.StarExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "models"}, Sel: &ast.Ident{Name: "User"}}}, nil) + assert.NoError(t, err) + assert.Equal(t, "models.User", field) +} + +func TestTryAddDescription(t *testing.T) { + type args struct { + spec *spec.SecurityScheme + extensions map[string]interface{} + } + tests := []struct { + name string + lines []string + args args + want *spec.SecurityScheme + }{ + { + name: "added description", + lines: []string{ + "\t@securitydefinitions.apikey test", + "\t@in header", + "\t@name x-api-key", + "\t@description some description", + }, + want: &spec.SecurityScheme{ + SecuritySchemeProps: spec.SecuritySchemeProps{ + Name: "x-api-key", + Type: "apiKey", + In: "header", + Description: "some description", + }, + }, + }, + { + name: "added description with multiline", + lines: []string{ + "\t@securitydefinitions.apikey test", + "\t@in header", + "\t@name x-api-key", + "\t@description line1", + "\t@description line2", + }, + want: &spec.SecurityScheme{ + SecuritySchemeProps: spec.SecuritySchemeProps{ + Name: "x-api-key", + Type: "apiKey", + In: "header", + Description: "line1\nline2", + }, + }, + }, + { + name: "no description", + lines: []string{ + " @securitydefinitions.oauth2.application swagger", + " @tokenurl https://example.com/oauth/token", + " @not-description some description", + }, + want: &spec.SecurityScheme{ + SecuritySchemeProps: spec.SecuritySchemeProps{ + Type: "oauth2", + Flow: "application", + TokenURL: "https://example.com/oauth/token", + Description: "", + }, + }, + }, + + { + name: "description has invalid format", + lines: []string{ + "@securitydefinitions.oauth2.implicit swagger", + "@authorizationurl https://example.com/oauth/token", + "@description 12345", + }, + + want: &spec.SecurityScheme{ + SecuritySchemeProps: spec.SecuritySchemeProps{ + Type: "oauth2", + Flow: "implicit", + AuthorizationURL: "https://example.com/oauth/token", + Description: "12345", + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + swag := spec.Swagger{ + SwaggerProps: spec.SwaggerProps{ + SecurityDefinitions: make(map[string]*spec.SecurityScheme), + }, + } + line := 0 + commentLine := tt.lines[line] + attribute := strings.Split(commentLine, " ")[0] + value := strings.TrimSpace(commentLine[len(attribute):]) + secAttr, _ := parseSecAttributes(attribute, tt.lines, &line) + if !reflect.DeepEqual(secAttr, tt.want) { + t.Errorf("setSwaggerSecurity() = %#v, want %#v", swag.SecurityDefinitions[value], tt.want) + } + }) + } +} + +func Test_getTagsFromComment(t *testing.T) { + type args struct { + comment string + } + tests := []struct { + name string + args args + wantTags []string + }{ + { + name: "no tags comment", + args: args{ + comment: "//@name Student", + }, + wantTags: nil, + }, + { + name: "empty comment", + args: args{ + comment: "//", + }, + wantTags: nil, + }, + { + name: "tags comment", + args: args{ + comment: "//@Tags tag1,tag2,tag3", + }, + wantTags: []string{"tag1", "tag2", "tag3"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotTags := getTagsFromComment(tt.args.comment); !reflect.DeepEqual(gotTags, tt.wantTags) { + t.Errorf("getTagsFromComment() = %v, want %v", gotTags, tt.wantTags) + } + }) + } +} + +func TestParser_matchTags(t *testing.T) { + type args struct { + comments []*ast.Comment + } + tests := []struct { + name string + parser *Parser + args args + wantMatch bool + }{ + { + name: "no tags filter", + parser: New(), + args: args{comments: []*ast.Comment{{Text: "//@Tags tag1,tag2,tag3"}}}, + wantMatch: true, + }, + { + name: "with tags filter but no match", + parser: New(SetTags("tag4,tag5,!tag1")), + args: args{comments: []*ast.Comment{{Text: "//@Tags tag1,tag2,tag3"}}}, + wantMatch: false, + }, + { + name: "with tags filter but match", + parser: New(SetTags("tag4,tag5,tag1")), + args: args{comments: []*ast.Comment{{Text: "//@Tags tag1,tag2,tag3"}}}, + wantMatch: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotMatch := tt.parser.matchTags(tt.args.comments); gotMatch != tt.wantMatch { + t.Errorf("Parser.matchTags() = %v, want %v", gotMatch, tt.wantMatch) + } + }) + } +} + +func TestParser_parseExtension(t *testing.T) { + packagePath := "testdata/parseExtension" + filePath := packagePath + "/parseExtension.go" + src, err := os.ReadFile(filePath) + assert.NoError(t, err) + + fileSet := token.NewFileSet() + f, err := goparser.ParseFile(fileSet, "", src, goparser.ParseComments) + assert.NoError(t, err) + + tests := []struct { + name string + parser *Parser + expectedPaths map[string]bool + }{ + { + name: "when no flag is set, everything is exported", + parser: New(), + expectedPaths: map[string]bool{"/without-extension": true, "/with-another-extension": true, "/with-correct-extension": true, "/with-empty-comment-line": true}, + }, + { + name: "when nonexistent flag is set, nothing is exported", + parser: New(SetParseExtension("nonexistent-extension-filter")), + expectedPaths: map[string]bool{"/without-extension": false, "/with-another-extension": false, "/with-correct-extension": false, "/with-empty-comment-line": false}, + }, + { + name: "when correct flag is set, only that Path is exported", + parser: New(SetParseExtension("google-backend")), + expectedPaths: map[string]bool{"/without-extension": false, "/with-another-extension": false, "/with-correct-extension": true, "/with-empty-comment-line": false}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err = tt.parser.ParseRouterAPIInfo(&AstFileInfo{ + FileSet: fileSet, + File: f, + Path: filePath, + PackagePath: packagePath, + ParseFlag: ParseAll, + }) + assert.NoError(t, err) + for p, isExpected := range tt.expectedPaths { + _, ok := tt.parser.swagger.Paths.Paths[p] + assert.Equal(t, isExpected, ok) + } + + for p := range tt.parser.swagger.Paths.Paths { + _, isExpected := tt.expectedPaths[p] + assert.Equal(t, isExpected, true) + } + }) + } +} + +func TestParser_collectionFormat(t *testing.T) { + tests := []struct { + name string + parser *Parser + format string + }{ + { + name: "no collectionFormat", + parser: New(), + format: "", + }, + { + name: "multi collectionFormat", + parser: New(SetCollectionFormat("multi")), + format: "multi", + }, + { + name: "ssv collectionFormat", + parser: New(SetCollectionFormat("ssv")), + format: "ssv", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.parser.collectionFormatInQuery != tt.format { + t.Errorf("Parser.collectionFormatInQuery = %s, want %s", tt.parser.collectionFormatInQuery, tt.format) + } + }) + } +} + +func TestParser_skipPackageByPrefix(t *testing.T) { + t.Parallel() + + parser := New() + + assert.False(t, parser.skipPackageByPrefix("git.ipao.vip/rogeecn/atomctl/pkg/swag")) + assert.False(t, parser.skipPackageByPrefix("git.ipao.vip/rogeecn/atomctl/pkg/swag/cmd")) + assert.False(t, parser.skipPackageByPrefix("git.ipao.vip/rogeecn/atomctl/pkg/swag/gen")) + + parser = New(SetPackagePrefix("git.ipao.vip/rogeecn/atomctl/pkg/swag/cmd")) + + assert.True(t, parser.skipPackageByPrefix("git.ipao.vip/rogeecn/atomctl/pkg/swag")) + assert.False(t, parser.skipPackageByPrefix("git.ipao.vip/rogeecn/atomctl/pkg/swag/cmd")) + assert.True(t, parser.skipPackageByPrefix("git.ipao.vip/rogeecn/atomctl/pkg/swag/gen")) + + parser = New(SetPackagePrefix("git.ipao.vip/rogeecn/atomctl/pkg/swag/cmd,git.ipao.vip/rogeecn/atomctl/pkg/swag/gen")) + + assert.True(t, parser.skipPackageByPrefix("git.ipao.vip/rogeecn/atomctl/pkg/swag")) + assert.False(t, parser.skipPackageByPrefix("git.ipao.vip/rogeecn/atomctl/pkg/swag/cmd")) + assert.False(t, parser.skipPackageByPrefix("git.ipao.vip/rogeecn/atomctl/pkg/swag/gen")) +} + +func TestParser_ParseRouterApiInFuncBody(t *testing.T) { + t.Parallel() + + src := ` +package test + +func Test(){ + // @Router /api/{id} [get] + _ = func() { + } +} +` + p := New() + p.ParseFuncBody = true + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + ps := p.swagger.Paths.Paths + + val, ok := ps["/api/{id}"] + + assert.True(t, ok) + assert.NotNil(t, val.Get) +} + +func TestParser_ParseRouterApiInfoInAndOutFuncBody(t *testing.T) { + t.Parallel() + + src := ` +package test + +// @Router /api/outside [get] +func otherRoute(){ +} + +func Test(){ + // @Router /api/inside [get] + _ = func() { + } +} +` + p := New() + p.ParseFuncBody = true + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + ps := p.swagger.Paths.Paths + + val1, ok := ps["/api/outside"] + assert.True(t, ok) + assert.NotNil(t, val1.Get) + + val2, ok := ps["/api/inside"] + assert.True(t, ok) + assert.NotNil(t, val2.Get) +} + +func TestParser_EmbeddedStructAsOtherAliasGoListNested(t *testing.T) { + t.Parallel() + + p := New(SetParseDependency(1), ParseUsingGoList(true)) + + p.parseGoList = true + + searchDir := "testdata/alias_nested" + expected, err := os.ReadFile(filepath.Join(searchDir, "expected.json")) + assert.NoError(t, err) + + err = p.ParseAPI(searchDir, "cmd/main/main.go", 0) + assert.NoError(t, err) + + b, err := json.MarshalIndent(p.swagger, "", " ") + assert.NoError(t, err) + assert.Equal(t, string(expected), string(b)) +} + +func TestParser_genVarDefinedFuncDoc(t *testing.T) { + t.Parallel() + + src := ` +package main +func f() {} +// @Summary generate var-defined functions' doc +// @Router /test [get] +var Func = f +// @Summary generate indirectly pointing +// @Router /test2 [get] +var Func2 = Func +` + p := New() + err := p.packages.ParseFile("api", "api/api.go", src, ParseAll) + assert.NoError(t, err) + _, _ = p.packages.ParseTypes() + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + val, ok := p.swagger.Paths.Paths["/test"] + assert.True(t, ok) + assert.NotNil(t, val.Get) + assert.Equal(t, val.Get.OperationProps.Summary, "generate var-defined functions' doc") + + val2, ok := p.swagger.Paths.Paths["/test2"] + assert.True(t, ok) + assert.NotNil(t, val2.Get) + assert.Equal(t, val2.Get.OperationProps.Summary, "generate indirectly pointing") +} diff --git a/pkg/swag/schema.go b/pkg/swag/schema.go new file mode 100644 index 0000000..234eb3f --- /dev/null +++ b/pkg/swag/schema.go @@ -0,0 +1,331 @@ +package swag + +import ( + "errors" + "fmt" + "github.com/go-openapi/spec" +) + +const ( + // ARRAY represent a array value. + ARRAY = "array" + // OBJECT represent a object value. + OBJECT = "object" + // PRIMITIVE represent a primitive value. + PRIMITIVE = "primitive" + // BOOLEAN represent a boolean value. + BOOLEAN = "boolean" + // INTEGER represent a integer value. + INTEGER = "integer" + // NUMBER represent a number value. + NUMBER = "number" + // STRING represent a string value. + STRING = "string" + // FUNC represent a function value. + FUNC = "func" + // ERROR represent a error value. + ERROR = "error" + // INTERFACE represent a interface value. + INTERFACE = "interface{}" + // ANY represent a any value. + ANY = "any" + // NIL represent a empty value. + NIL = "nil" + + // IgnoreNameOverridePrefix Prepend to model to avoid renaming based on comment. + IgnoreNameOverridePrefix = '$' +) + +// CheckSchemaType checks if typeName is not a name of primitive type. +func CheckSchemaType(typeName string) error { + if !IsPrimitiveType(typeName) { + return fmt.Errorf("%s is not basic types", typeName) + } + + return nil +} + +// IsSimplePrimitiveType determine whether the type name is a simple primitive type. +func IsSimplePrimitiveType(typeName string) bool { + switch typeName { + case STRING, NUMBER, INTEGER, BOOLEAN: + return true + } + + return false +} + +// IsPrimitiveType determine whether the type name is a primitive type. +func IsPrimitiveType(typeName string) bool { + switch typeName { + case STRING, NUMBER, INTEGER, BOOLEAN, ARRAY, OBJECT, FUNC: + return true + } + + return false +} + +// IsInterfaceLike determines whether the swagger type name is an go named interface type like error type. +func IsInterfaceLike(typeName string) bool { + return typeName == ERROR || typeName == ANY +} + +// IsNumericType determines whether the swagger type name is a numeric type. +func IsNumericType(typeName string) bool { + return typeName == INTEGER || typeName == NUMBER +} + +// TransToValidPrimitiveSchema transfer golang basic type to swagger schema with format considered. +func TransToValidPrimitiveSchema(typeName string) *spec.Schema { + switch typeName { + case "int", "uint": + return &spec.Schema{SchemaProps: spec.SchemaProps{Type: []string{INTEGER}}} + case "uint8", "int8", "uint16", "int16", "byte", "int32", "uint32", "rune": + return &spec.Schema{SchemaProps: spec.SchemaProps{Type: []string{INTEGER}, Format: "int32"}} + case "uint64", "int64": + return &spec.Schema{SchemaProps: spec.SchemaProps{Type: []string{INTEGER}, Format: "int64"}} + case "float32", "float64": + return &spec.Schema{SchemaProps: spec.SchemaProps{Type: []string{NUMBER}, Format: typeName}} + case "bool": + return &spec.Schema{SchemaProps: spec.SchemaProps{Type: []string{BOOLEAN}}} + case "string": + return &spec.Schema{SchemaProps: spec.SchemaProps{Type: []string{STRING}}} + } + return &spec.Schema{SchemaProps: spec.SchemaProps{Type: []string{typeName}}} +} + +// TransToValidSchemeTypeWithFormat indicates type will transfer golang basic type to swagger supported type with format. +func TransToValidSchemeTypeWithFormat(typeName string) (string, string) { + switch typeName { + case "int", "uint": + return INTEGER, "" + case "uint8", "int8", "uint16", "int16", "byte", "int32", "uint32", "rune": + return INTEGER, "int32" + case "uint64", "int64": + return INTEGER, "int64" + case "float32", "float64": + return NUMBER, typeName + case "bool": + return BOOLEAN, "" + case "string": + return STRING, "" + } + return typeName, "" +} + +// TransToValidSchemeType indicates type will transfer golang basic type to swagger supported type. +func TransToValidSchemeType(typeName string) string { + switch typeName { + case "uint", "int", "uint8", "int8", "uint16", "int16", "byte": + return INTEGER + case "uint32", "int32", "rune": + return INTEGER + case "uint64", "int64": + return INTEGER + case "float32", "float64": + return NUMBER + case "bool": + return BOOLEAN + case "string": + return STRING + } + + return typeName +} + +// IsGolangPrimitiveType determine whether the type name is a golang primitive type. +func IsGolangPrimitiveType(typeName string) bool { + switch typeName { + case "uint", + "int", + "uint8", + "int8", + "uint16", + "int16", + "byte", + "uint32", + "int32", + "rune", + "uint64", + "int64", + "float32", + "float64", + "bool", + "string": + return true + } + + return false +} + +// TransToValidCollectionFormat determine valid collection format. +func TransToValidCollectionFormat(format string) string { + switch format { + case "csv", "multi", "pipes", "tsv", "ssv": + return format + } + + return "" +} + +func ignoreNameOverride(name string) bool { + return len(name) != 0 && name[0] == IgnoreNameOverridePrefix +} + +// IsComplexSchema whether a schema is complex and should be a ref schema +func IsComplexSchema(schema *spec.Schema) bool { + // a enum type should be complex + if len(schema.Enum) > 0 { + return true + } + + // a deep array type is complex, how to determine deep? here more than 2 ,for example: [][]object,[][][]int + if len(schema.Type) > 2 { + return true + } + + //Object included, such as Object or []Object + for _, st := range schema.Type { + if st == OBJECT { + return true + } + } + return false +} + +// IsRefSchema whether a schema is a reference schema. +func IsRefSchema(schema *spec.Schema) bool { + return schema.Ref.Ref.GetURL() != nil +} + +// RefSchema build a reference schema. +func RefSchema(refType string) *spec.Schema { + return spec.RefSchema("#/definitions/" + refType) +} + +// PrimitiveSchema build a primitive schema. +func PrimitiveSchema(refType string) *spec.Schema { + return &spec.Schema{SchemaProps: spec.SchemaProps{Type: []string{refType}}} +} + +// BuildCustomSchema build custom schema specified by tag swaggertype. +func BuildCustomSchema(types []string) (*spec.Schema, error) { + if len(types) == 0 { + return nil, nil + } + + switch types[0] { + case PRIMITIVE: + if len(types) == 1 { + return nil, errors.New("need primitive type after primitive") + } + + return BuildCustomSchema(types[1:]) + case ARRAY: + if len(types) == 1 { + return nil, errors.New("need array item type after array") + } + + schema, err := BuildCustomSchema(types[1:]) + if err != nil { + return nil, err + } + + return spec.ArrayProperty(schema), nil + case OBJECT: + if len(types) == 1 { + return PrimitiveSchema(types[0]), nil + } + + schema, err := BuildCustomSchema(types[1:]) + if err != nil { + return nil, err + } + + return spec.MapProperty(schema), nil + default: + err := CheckSchemaType(types[0]) + if err != nil { + return nil, err + } + + return PrimitiveSchema(types[0]), nil + } +} + +// MergeSchema merge schemas +func MergeSchema(dst *spec.Schema, src *spec.Schema) *spec.Schema { + if len(src.Type) > 0 { + dst.Type = src.Type + } + if len(src.Properties) > 0 { + dst.Properties = src.Properties + } + if src.Items != nil { + dst.Items = src.Items + } + if src.AdditionalProperties != nil { + dst.AdditionalProperties = src.AdditionalProperties + } + if len(src.Description) > 0 { + dst.Description = src.Description + } + if src.Nullable { + dst.Nullable = src.Nullable + } + if len(src.Format) > 0 { + dst.Format = src.Format + } + if src.Default != nil { + dst.Default = src.Default + } + if src.Example != nil { + dst.Example = src.Example + } + if len(src.Extensions) > 0 { + dst.Extensions = src.Extensions + } + if src.Maximum != nil { + dst.Maximum = src.Maximum + } + if src.Minimum != nil { + dst.Minimum = src.Minimum + } + if src.ExclusiveMaximum { + dst.ExclusiveMaximum = src.ExclusiveMaximum + } + if src.ExclusiveMinimum { + dst.ExclusiveMinimum = src.ExclusiveMinimum + } + if src.MaxLength != nil { + dst.MaxLength = src.MaxLength + } + if src.MinLength != nil { + dst.MinLength = src.MinLength + } + if len(src.Pattern) > 0 { + dst.Pattern = src.Pattern + } + if src.MaxItems != nil { + dst.MaxItems = src.MaxItems + } + if src.MinItems != nil { + dst.MinItems = src.MinItems + } + if src.UniqueItems { + dst.UniqueItems = src.UniqueItems + } + if src.MultipleOf != nil { + dst.MultipleOf = src.MultipleOf + } + if len(src.Enum) > 0 { + dst.Enum = src.Enum + } + if len(src.Extensions) > 0 { + dst.Extensions = src.Extensions + } + if len(src.ExtraProps) > 0 { + dst.ExtraProps = src.ExtraProps + } + return dst +} diff --git a/pkg/swag/schema_test.go b/pkg/swag/schema_test.go new file mode 100644 index 0000000..6589e2e --- /dev/null +++ b/pkg/swag/schema_test.go @@ -0,0 +1,151 @@ +package swag + +import ( + "testing" + + "github.com/go-openapi/spec" + "github.com/stretchr/testify/assert" +) + +func TestValidDataType(t *testing.T) { + t.Parallel() + + assert.NoError(t, CheckSchemaType(STRING)) + assert.NoError(t, CheckSchemaType(NUMBER)) + assert.NoError(t, CheckSchemaType(INTEGER)) + assert.NoError(t, CheckSchemaType(BOOLEAN)) + assert.NoError(t, CheckSchemaType(ARRAY)) + assert.NoError(t, CheckSchemaType(OBJECT)) + + assert.Error(t, CheckSchemaType("oops")) +} + +func TestTransToValidSchemeType(t *testing.T) { + t.Parallel() + + assert.Equal(t, TransToValidSchemeType("uint"), INTEGER) + assert.Equal(t, TransToValidSchemeType("uint32"), INTEGER) + assert.Equal(t, TransToValidSchemeType("uint64"), INTEGER) + assert.Equal(t, TransToValidSchemeType("float32"), NUMBER) + assert.Equal(t, TransToValidSchemeType("bool"), BOOLEAN) + assert.Equal(t, TransToValidSchemeType("string"), STRING) + + // should accept any type, due to user defined types + other := "oops" + assert.Equal(t, TransToValidSchemeType(other), other) +} + +func TestTransToValidCollectionFormat(t *testing.T) { + t.Parallel() + + assert.Equal(t, TransToValidCollectionFormat("csv"), "csv") + assert.Equal(t, TransToValidCollectionFormat("multi"), "multi") + assert.Equal(t, TransToValidCollectionFormat("pipes"), "pipes") + assert.Equal(t, TransToValidCollectionFormat("tsv"), "tsv") + assert.Equal(t, TransToValidSchemeType("string"), STRING) + + // should accept any type, due to user defined types + assert.Equal(t, TransToValidCollectionFormat("oops"), "") +} + +func TestIsGolangPrimitiveType(t *testing.T) { + t.Parallel() + + assert.Equal(t, IsGolangPrimitiveType("uint"), true) + assert.Equal(t, IsGolangPrimitiveType("int"), true) + assert.Equal(t, IsGolangPrimitiveType("uint8"), true) + assert.Equal(t, IsGolangPrimitiveType("uint16"), true) + assert.Equal(t, IsGolangPrimitiveType("int16"), true) + assert.Equal(t, IsGolangPrimitiveType("byte"), true) + assert.Equal(t, IsGolangPrimitiveType("uint32"), true) + assert.Equal(t, IsGolangPrimitiveType("int32"), true) + assert.Equal(t, IsGolangPrimitiveType("rune"), true) + assert.Equal(t, IsGolangPrimitiveType("uint64"), true) + assert.Equal(t, IsGolangPrimitiveType("int64"), true) + assert.Equal(t, IsGolangPrimitiveType("float32"), true) + assert.Equal(t, IsGolangPrimitiveType("float64"), true) + assert.Equal(t, IsGolangPrimitiveType("bool"), true) + assert.Equal(t, IsGolangPrimitiveType("string"), true) + + assert.Equal(t, IsGolangPrimitiveType("oops"), false) +} + +func TestIsSimplePrimitiveType(t *testing.T) { + t.Parallel() + + assert.Equal(t, IsSimplePrimitiveType("string"), true) + assert.Equal(t, IsSimplePrimitiveType("number"), true) + assert.Equal(t, IsSimplePrimitiveType("integer"), true) + assert.Equal(t, IsSimplePrimitiveType("boolean"), true) + + assert.Equal(t, IsSimplePrimitiveType("oops"), false) +} + +func TestBuildCustomSchema(t *testing.T) { + t.Parallel() + + var ( + schema *spec.Schema + err error + ) + + schema, err = BuildCustomSchema([]string{}) + assert.NoError(t, err) + assert.Nil(t, schema) + + schema, err = BuildCustomSchema([]string{"primitive"}) + assert.Error(t, err) + assert.Nil(t, schema) + + schema, err = BuildCustomSchema([]string{"primitive", "oops"}) + assert.Error(t, err) + assert.Nil(t, schema) + + schema, err = BuildCustomSchema([]string{"primitive", "string"}) + assert.NoError(t, err) + assert.Equal(t, schema.SchemaProps.Type, spec.StringOrArray{"string"}) + + schema, err = BuildCustomSchema([]string{"array"}) + assert.Error(t, err) + assert.Nil(t, schema) + + schema, err = BuildCustomSchema([]string{"array", "oops"}) + assert.Error(t, err) + assert.Nil(t, schema) + + schema, err = BuildCustomSchema([]string{"array", "string"}) + assert.NoError(t, err) + assert.Equal(t, schema.SchemaProps.Type, spec.StringOrArray{"array"}) + assert.Equal(t, schema.SchemaProps.Items.Schema.SchemaProps.Type, spec.StringOrArray{"string"}) + + schema, err = BuildCustomSchema([]string{"object"}) + assert.NoError(t, err) + assert.Equal(t, schema.SchemaProps.Type, spec.StringOrArray{"object"}) + + schema, err = BuildCustomSchema([]string{"object", "oops"}) + assert.Error(t, err) + assert.Nil(t, schema) + + schema, err = BuildCustomSchema([]string{"object", "string"}) + assert.NoError(t, err) + assert.Equal(t, schema.SchemaProps.Type, spec.StringOrArray{"object"}) + assert.Equal(t, schema.SchemaProps.AdditionalProperties.Schema.Type, spec.StringOrArray{"string"}) +} + +func TestIsNumericType(t *testing.T) { + t.Parallel() + + assert.Equal(t, IsNumericType(INTEGER), true) + assert.Equal(t, IsNumericType(NUMBER), true) + + assert.Equal(t, IsNumericType(STRING), false) +} + +func TestIsInterfaceLike(t *testing.T) { + t.Parallel() + + assert.Equal(t, IsInterfaceLike(ERROR), true) + assert.Equal(t, IsInterfaceLike(ANY), true) + + assert.Equal(t, IsInterfaceLike(STRING), false) +} diff --git a/pkg/swag/spec.go b/pkg/swag/spec.go new file mode 100644 index 0000000..c18a365 --- /dev/null +++ b/pkg/swag/spec.go @@ -0,0 +1,64 @@ +package swag + +import ( + "bytes" + "encoding/json" + "strings" + "text/template" +) + +// Spec holds exported Swagger Info so clients can modify it. +type Spec struct { + Version string + Host string + BasePath string + Schemes []string + Title string + Description string + InfoInstanceName string + SwaggerTemplate string + LeftDelim string + RightDelim string +} + +// ReadDoc parses SwaggerTemplate into swagger document. +func (i *Spec) ReadDoc() string { + i.Description = strings.ReplaceAll(i.Description, "\n", "\\n") + + tpl := template.New("swagger_info").Funcs(template.FuncMap{ + "marshal": func(v interface{}) string { + a, _ := json.Marshal(v) + + return string(a) + }, + "escape": func(v interface{}) string { + // escape tabs + var str = strings.ReplaceAll(v.(string), "\t", "\\t") + // replace " with \", and if that results in \\", replace that with \\\" + str = strings.ReplaceAll(str, "\"", "\\\"") + + return strings.ReplaceAll(str, "\\\\\"", "\\\\\\\"") + }, + }) + + if i.LeftDelim != "" && i.RightDelim != "" { + tpl = tpl.Delims(i.LeftDelim, i.RightDelim) + } + + parsed, err := tpl.Parse(i.SwaggerTemplate) + if err != nil { + return i.SwaggerTemplate + } + + var doc bytes.Buffer + if err = parsed.Execute(&doc, i); err != nil { + return i.SwaggerTemplate + } + + return doc.String() +} + +// InstanceName returns Spec instance name. +func (i *Spec) InstanceName() string { + return i.InfoInstanceName +} diff --git a/pkg/swag/spec_test.go b/pkg/swag/spec_test.go new file mode 100644 index 0000000..f20d70c --- /dev/null +++ b/pkg/swag/spec_test.go @@ -0,0 +1,188 @@ +package swag + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSpec_InstanceName(t *testing.T) { + type fields struct { + Version string + Host string + BasePath string + Schemes []string + Title string + Description string + InfoInstanceName string + SwaggerTemplate string + } + + tests := []struct { + name string + fields fields + want string + }{ + { + name: "TestInstanceNameCorrect", + fields: fields{ + Version: "1.0", + Host: "localhost:8080", + BasePath: "/", + InfoInstanceName: "TestInstanceName1", + }, + want: "TestInstanceName1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + doc := Spec{ + Version: tt.fields.Version, + Host: tt.fields.Host, + BasePath: tt.fields.BasePath, + Schemes: tt.fields.Schemes, + Title: tt.fields.Title, + Description: tt.fields.Description, + InfoInstanceName: tt.fields.InfoInstanceName, + SwaggerTemplate: tt.fields.SwaggerTemplate, + } + + assert.Equal(t, tt.want, doc.InstanceName()) + }) + } +} + +func TestSpec_ReadDoc(t *testing.T) { + type fields struct { + Version string + Host string + BasePath string + Schemes []string + Title string + Description string + InfoInstanceName string + SwaggerTemplate string + LeftDelim string + RightDelim string + } + + tests := []struct { + name string + fields fields + want string + }{ + { + name: "TestReadDocCorrect", + fields: fields{ + Version: "1.0", + Host: "localhost:8080", + BasePath: "/", + InfoInstanceName: "TestInstanceName", + SwaggerTemplate: `{ + "swagger": "2.0", + "info": { + "description": "{{escape .Description}}", + "title": "{{.Title}}", + "version": "{{.Version}}" + }, + "host": "{{.Host}}", + "basePath": "{{.BasePath}}", + }`, + }, + want: "{" + + "\n\t\t\t\"swagger\": \"2.0\"," + + "\n\t\t\t\"info\": {" + + "\n\t\t\t\t\"description\": \"\",\n\t\t\t\t\"" + + "title\": \"\"," + + "\n\t\t\t\t\"version\": \"1.0\"" + + "\n\t\t\t}," + + "\n\t\t\t\"host\": \"localhost:8080\"," + + "\n\t\t\t\"basePath\": \"/\"," + + "\n\t\t}", + }, + { + name: "TestReadDocMarshalTrigger", + fields: fields{ + Version: "1.0", + Host: "localhost:8080", + BasePath: "/", + InfoInstanceName: "TestInstanceName", + SwaggerTemplate: "{{ marshal .Version }}", + }, + want: "\"1.0\"", + }, + { + name: "TestReadDocParseError", + fields: fields{ + Version: "1.0", + Host: "localhost:8080", + BasePath: "/", + InfoInstanceName: "TestInstanceName", + SwaggerTemplate: "{{ ..Version }}", + }, + want: "{{ ..Version }}", + }, + { + name: "TestReadDocExecuteError", + fields: fields{ + Version: "1.0", + Host: "localhost:8080", + BasePath: "/", + InfoInstanceName: "TestInstanceName", + SwaggerTemplate: "{{ .Schemesa }}", + }, + want: "{{ .Schemesa }}", + }, + { + name: "TestReadDocCustomDelims", + fields: fields{ + Version: "1.0", + Host: "localhost:8080", + BasePath: "/", + InfoInstanceName: "TestInstanceName", + SwaggerTemplate: `{ + "swagger": "2.0", + "info": { + "description": "{%escape .Description%}", + "title": "{%.Title%}", + "version": "{%.Version%}" + }, + "host": "{%.Host%}", + "basePath": "{%.BasePath%}", + }`, + LeftDelim: "{%", + RightDelim: "%}", + }, + want: "{" + + "\n\t\t\t\"swagger\": \"2.0\"," + + "\n\t\t\t\"info\": {" + + "\n\t\t\t\t\"description\": \"\",\n\t\t\t\t\"" + + "title\": \"\"," + + "\n\t\t\t\t\"version\": \"1.0\"" + + "\n\t\t\t}," + + "\n\t\t\t\"host\": \"localhost:8080\"," + + "\n\t\t\t\"basePath\": \"/\"," + + "\n\t\t}", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + doc := Spec{ + Version: tt.fields.Version, + Host: tt.fields.Host, + BasePath: tt.fields.BasePath, + Schemes: tt.fields.Schemes, + Title: tt.fields.Title, + Description: tt.fields.Description, + InfoInstanceName: tt.fields.InfoInstanceName, + SwaggerTemplate: tt.fields.SwaggerTemplate, + LeftDelim: tt.fields.LeftDelim, + RightDelim: tt.fields.RightDelim, + } + + assert.Equal(t, tt.want, doc.ReadDoc()) + }) + } +} diff --git a/pkg/swag/swagger.go b/pkg/swag/swagger.go new file mode 100644 index 0000000..74c162c --- /dev/null +++ b/pkg/swag/swagger.go @@ -0,0 +1,72 @@ +package swag + +import ( + "errors" + "fmt" + "sync" +) + +// Name is a unique name be used to register swag instance. +const Name = "swagger" + +var ( + swaggerMu sync.RWMutex + swags map[string]Swagger +) + +// Swagger is an interface to read swagger document. +type Swagger interface { + ReadDoc() string +} + +// Register registers swagger for given name. +func Register(name string, swagger Swagger) { + swaggerMu.Lock() + defer swaggerMu.Unlock() + + if swagger == nil { + panic("swagger is nil") + } + + if swags == nil { + swags = make(map[string]Swagger) + } + + if _, ok := swags[name]; ok { + panic("Register called twice for swag: " + name) + } + + swags[name] = swagger +} + +// GetSwagger returns the swagger instance for given name. +// If not found, returns nil. +func GetSwagger(name string) Swagger { + swaggerMu.RLock() + defer swaggerMu.RUnlock() + + return swags[name] +} + +// ReadDoc reads swagger document. An optional name parameter can be passed to read a specific document. +// The default name is "swagger". +func ReadDoc(optionalName ...string) (string, error) { + swaggerMu.RLock() + defer swaggerMu.RUnlock() + + if swags == nil { + return "", errors.New("no swag has yet been registered") + } + + name := Name + if len(optionalName) != 0 && optionalName[0] != "" { + name = optionalName[0] + } + + swag, ok := swags[name] + if !ok { + return "", fmt.Errorf("no swag named \"%s\" was registered", name) + } + + return swag.ReadDoc(), nil +} diff --git a/pkg/swag/swagger_test.go b/pkg/swag/swagger_test.go new file mode 100644 index 0000000..0431905 --- /dev/null +++ b/pkg/swag/swagger_test.go @@ -0,0 +1,232 @@ +package swag + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +var doc = `{ + "swagger": "2.0", + "info": { + "description": "This is a sample server Petstore server.", + "title": "Swagger Example API", + "termsOfService": "http://swagger.io/terms/", + "contact": { + "name": "API Support", + "url": "http://www.swagger.io/support", + "email": "support@swagger.io" + }, + "license": { + "name": "Apache 2.0", + "url": "http://www.apache.org/licenses/LICENSE-2.0.html" + }, + "version": "1.0" + }, + "host": "petstore.swagger.io", + "basePath": "/v2", + "paths": { + "/testapi/get-string-by-int/{some_id}": { + "get": { + "description": "get string by ID", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "summary": "Add a new pet to the store", + "parameters": [ + { + "description": "Some ID", + "name": "some_id", + "in": "path", + "required": true, + "schema": { + "type": "int" + } + } + ], + "responses": { + "200": { + "description": "ok", + "schema": { + "type": "string" + } + }, + "400": { + "description": "We need ID!!", + "schema": { + "type": "object", + "$ref": "#/definitions/web.APIError" + } + }, + "404": { + "description": "Can not find ID", + "schema": { + "type": "object", + "$ref": "#/definitions/web.APIError" + } + } + } + } + }, + "/testapi/get-struct-array-by-string/{some_id}": { + "get": { + "description": "get struct array by ID", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "parameters": [ + { + "description": "Some ID", + "name": "some_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + }, + { + "description": "Offset", + "name": "offset", + "in": "query", + "required": true, + "schema": { + "type": "int" + } + }, + { + "description": "Offset", + "name": "limit", + "in": "query", + "required": true, + "schema": { + "type": "int" + } + } + ], + "responses": { + "200": { + "description": "ok", + "schema": { + "type": "string" + } + }, + "400": { + "description": "We need ID!!", + "schema": { + "type": "object", + "$ref": "#/definitions/web.APIError" + } + }, + "404": { + "description": "Can not find ID", + "schema": { + "type": "object", + "$ref": "#/definitions/web.APIError" + } + } + } + } + } + }, + "definitions": { + "web.APIError": { + "type": "object", + "properties": { + "ErrorCode": { + "type": "int" + }, + "ErrorMessage": { + "type": "string" + } + } + } + }, + "securityDefinitions": { + "ApiKey": { + "description: "some", + "type": "apiKey", + "name": "X-API-KEY", + "in": "header" + } + } +}` + +type s struct{} + +func (s *s) ReadDoc() string { + return doc +} + +func TestRegister(t *testing.T) { + setup() + Register(Name, &s{}) + d, _ := ReadDoc() + assert.Equal(t, doc, d) +} + +func TestRegisterByName(t *testing.T) { + setup() + Register("another_name", &s{}) + d, _ := ReadDoc("another_name") + assert.Equal(t, doc, d) +} + +func TestRegisterMultiple(t *testing.T) { + setup() + Register(Name, &s{}) + Register("another_name", &s{}) + d1, _ := ReadDoc(Name) + d2, _ := ReadDoc("another_name") + assert.Equal(t, doc, d1) + assert.Equal(t, doc, d2) +} + +func TestReadDocBeforeRegistered(t *testing.T) { + setup() + _, err := ReadDoc() + assert.Error(t, err) +} + +func TestReadDocWithInvalidName(t *testing.T) { + setup() + Register(Name, &s{}) + _, err := ReadDoc("invalid") + assert.Error(t, err) +} + +func TestNilRegister(t *testing.T) { + setup() + var swagger Swagger + assert.Panics(t, func() { + Register(Name, swagger) + }) +} + +func TestCalledTwicelRegister(t *testing.T) { + setup() + assert.Panics(t, func() { + Register(Name, &s{}) + Register(Name, &s{}) + }) +} + +func setup() { + swags = nil +} + +func TestGetSwagger(t *testing.T) { + setup() + instance := &s{} + Register(Name, instance) + swagger := GetSwagger(Name) + assert.Equal(t, instance, swagger) + + swagger = GetSwagger("invalid") + assert.Nil(t, swagger) +} diff --git a/pkg/swag/types.go b/pkg/swag/types.go new file mode 100644 index 0000000..5f3031e --- /dev/null +++ b/pkg/swag/types.go @@ -0,0 +1,123 @@ +package swag + +import ( + "go/ast" + "go/token" + "regexp" + "strings" + + "github.com/go-openapi/spec" +) + +// Schema parsed schema. +type Schema struct { + *spec.Schema // + PkgPath string // package import path used to rename Name of a definition int case of conflict + Name string // Name in definitions +} + +// TypeSpecDef the whole information of a typeSpec. +type TypeSpecDef struct { + // ast file where TypeSpec is + File *ast.File + + // the TypeSpec of this type definition + TypeSpec *ast.TypeSpec + + Enums []EnumValue + + // path of package starting from under ${GOPATH}/src or from module path in go.mod + PkgPath string + ParentSpec ast.Decl + + SchemaName string + + NotUnique bool +} + +// Name the name of the typeSpec. +func (t *TypeSpecDef) Name() string { + if t.TypeSpec != nil && t.TypeSpec.Name != nil { + return t.TypeSpec.Name.Name + } + + return "" +} + +// TypeName the type name of the typeSpec. +func (t *TypeSpecDef) TypeName() string { + if ignoreNameOverride(t.TypeSpec.Name.Name) { + return t.TypeSpec.Name.Name[1:] + } + + var names []string + if t.NotUnique { + pkgPath := strings.Map(func(r rune) rune { + if r == '\\' || r == '/' || r == '.' { + return '_' + } + return r + }, t.PkgPath) + names = append(names, pkgPath) + } else if t.File != nil { + names = append(names, t.File.Name.Name) + } + if parentFun, ok := (t.ParentSpec).(*ast.FuncDecl); ok && parentFun != nil { + names = append(names, parentFun.Name.Name) + } + names = append(names, t.TypeSpec.Name.Name) + return fullTypeName(names...) +} + +// FullPath return the full path of the typeSpec. +func (t *TypeSpecDef) FullPath() string { + return t.PkgPath + "." + t.Name() +} + +const regexCaseInsensitive = "(?i)" + +var reTypeName = regexp.MustCompile(regexCaseInsensitive + `^@name\s+(\S+)`) + +func (t *TypeSpecDef) Alias() string { + if t.TypeSpec.Comment == nil { + return "" + } + + // get alias from comment '// @name ' + for _, comment := range t.TypeSpec.Comment.List { + trimmedComment := strings.TrimSpace(strings.TrimLeft(comment.Text, "/")) + texts := reTypeName.FindStringSubmatch(trimmedComment) + if len(texts) > 1 { + return texts[1] + } + } + + return "" +} + +func (t *TypeSpecDef) SetSchemaName() { + if alias := t.Alias(); alias != "" { + t.SchemaName = alias + return + } + + t.SchemaName = t.TypeName() +} + +// AstFileInfo information of an ast.File. +type AstFileInfo struct { + //FileSet the FileSet object which is used to parse this go source file + FileSet *token.FileSet + + // File ast.File + File *ast.File + + // Path the path of the ast.File + Path string + + // PackagePath package import path of the ast.File + PackagePath string + + // ParseFlag determine what to parse + ParseFlag ParseFlag +} diff --git a/pkg/swag/utils.go b/pkg/swag/utils.go new file mode 100644 index 0000000..df31ff2 --- /dev/null +++ b/pkg/swag/utils.go @@ -0,0 +1,55 @@ +package swag + +import "unicode" + +// FieldsFunc split a string s by a func splitter into max n parts +func FieldsFunc(s string, f func(rune2 rune) bool, n int) []string { + // A span is used to record a slice of s of the form s[start:end]. + // The start index is inclusive and the end index is exclusive. + type span struct { + start int + end int + } + spans := make([]span, 0, 32) + + // Find the field start and end indices. + // Doing this in a separate pass (rather than slicing the string s + // and collecting the result substrings right away) is significantly + // more efficient, possibly due to cache effects. + start := -1 // valid span start if >= 0 + for end, rune := range s { + if f(rune) { + if start >= 0 { + spans = append(spans, span{start, end}) + // Set start to a negative value. + // Note: using -1 here consistently and reproducibly + // slows down this code by a several percent on amd64. + start = ^start + } + } else { + if start < 0 { + start = end + if n > 0 && len(spans)+1 >= n { + break + } + } + } + } + + // Last field might end at EOF. + if start >= 0 { + spans = append(spans, span{start, len(s)}) + } + + // Create strings from recorded field indices. + a := make([]string, len(spans)) + for i, span := range spans { + a[i] = s[span.start:span.end] + } + return a +} + +// FieldsByAnySpace split a string s by any space character into max n parts +func FieldsByAnySpace(s string, n int) []string { + return FieldsFunc(s, unicode.IsSpace, n) +} diff --git a/pkg/swag/utils_test.go b/pkg/swag/utils_test.go new file mode 100644 index 0000000..1c4d995 --- /dev/null +++ b/pkg/swag/utils_test.go @@ -0,0 +1,38 @@ +package swag + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestFieldsByAnySpace(t *testing.T) { + type args struct { + s string + n int + } + tests := []struct { + name string + args args + want []string + }{ + {"test1", + args{ + " aa bb cc dd ff", + 2, + }, + []string{"aa", "bb\tcc dd \t\tff"}, + }, + {"test2", + args{ + ` aa "bb cc dd ff"`, + 2, + }, + []string{"aa", `"bb cc dd ff"`}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, FieldsByAnySpace(tt.args.s, tt.args.n), "FieldsByAnySpace(%v, %v)", tt.args.s, tt.args.n) + }) + } +} diff --git a/pkg/swag/version.go b/pkg/swag/version.go new file mode 100644 index 0000000..cf52320 --- /dev/null +++ b/pkg/swag/version.go @@ -0,0 +1,4 @@ +package swag + +// Version of swag. +const Version = "v1.16.4" diff --git a/templates/module/controller.go.tpl b/templates/module/controller.go.tpl index 39512d2..b00115e 100644 --- a/templates/module/controller.go.tpl +++ b/templates/module/controller.go.tpl @@ -27,4 +27,4 @@ func (c *Controller) Prepare() error { // @Param pageFilter query common.PageQueryFilter true "PageQueryFilter" // @Param sortFilter query common.SortQueryFilter true "SortQueryFilter" // @Success 200 {object} common.PageDataResponse{list=dto.AlarmItem} -// @Router /v1/test//{id} [get] \ No newline at end of file +// @Router /v1/test/:id [get] \ No newline at end of file diff --git a/templates/project/docs/ember.go.tpl b/templates/project/docs/ember.go.tpl index 8404904..101b3a8 100755 --- a/templates/project/docs/ember.go.tpl +++ b/templates/project/docs/ember.go.tpl @@ -3,7 +3,7 @@ package docs import ( _ "embed" - _ "github.com/swaggo/swag" + _ "git.ipao.vip/rogeecn/atomctl/pkg/swag" ) //go:embed swagger.json diff --git a/templates/project/providers/http/swagger/swagger.go.tpl b/templates/project/providers/http/swagger/swagger.go.tpl index 9372b03..2eaf9ed 100644 --- a/templates/project/providers/http/swagger/swagger.go.tpl +++ b/templates/project/providers/http/swagger/swagger.go.tpl @@ -7,11 +7,11 @@ import ( "strings" "sync" + "git.ipao.vip/rogeecn/atomctl/pkg/swag" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/static" "github.com/gofiber/utils/v2" swaggerFiles "github.com/swaggo/files/v2" - "github.com/swaggo/swag" ) const (