diff --git a/.vscode/launch.json b/.vscode/launch.json index bb9aa51..6d2140a 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -12,8 +12,8 @@ "program": "${workspaceFolder}", "args": [ "gen", - "provider", - "/projects/mp-qvyun/backend", + "route", + "/projects/atom_starter", ] } ] diff --git a/cmd/gen.go b/cmd/gen.go index 633a652..06184ce 100644 --- a/cmd/gen.go +++ b/cmd/gen.go @@ -11,6 +11,7 @@ func CommandGen(root *cobra.Command) { cmds := []func(*cobra.Command){ CommandGenProvider, + CommandGenRoute, CommandGenModel, CommandGenEnum, } diff --git a/cmd/gen_route.go b/cmd/gen_route.go index 03816fa..0f81a4f 100644 --- a/cmd/gen_route.go +++ b/cmd/gen_route.go @@ -1,21 +1,15 @@ package cmd import ( - "fmt" - "go/ast" - "go/parser" - "go/token" "io/fs" - "log" "os" "path/filepath" - "regexp" "strings" + "git.ipao.vip/rogeecn/atomctl/pkg/ast/route" "git.ipao.vip/rogeecn/atomctl/pkg/utils/gomod" - "github.com/iancoleman/strcase" - "github.com/pkg/errors" "github.com/samber/lo" + log "github.com/sirupsen/logrus" "github.com/spf13/cobra" ) @@ -29,6 +23,7 @@ func CommandGenRoute(root *cobra.Command) { root.AddCommand(cmd) } +// https://github.com/swaggo/swag?tab=readme-ov-file#api-operation func commandGenRouteE(cmd *cobra.Command, args []string) error { var err error var path string @@ -48,354 +43,37 @@ func commandGenRouteE(cmd *cobra.Command, args []string) error { return err } - routes := []RouteDefinition{} + routes := []route.RouteDefinition{} modulePath := filepath.Join(path, "modules") if _, err := os.Stat(modulePath); os.IsNotExist(err) { log.Fatal("modules dir not exist, ", modulePath) } - filepath.WalkDir(modulePath, func(path string, d fs.DirEntry, err error) error { + err = filepath.WalkDir(modulePath, func(path string, d fs.DirEntry, err error) error { if d.IsDir() { return nil } - if !strings.HasSuffix(path, "_controller.go") { - return nil - } - if !strings.HasSuffix(path, ".go") { + if !strings.HasSuffix(path, "controller.go") { return nil } - routes = append(routes, astParseRoutes(path)...) + routes = append(routes, route.ParseFile(path)...) return nil }) + if err != nil { + return err + } - routeGroups := lo.GroupBy(routes, func(item RouteDefinition) string { + routeGroups := lo.GroupBy(routes, func(item route.RouteDefinition) string { return filepath.Dir(item.Path) }) - for key, routes := range routeGroups { - routePath := filepath.Join(key, "routes.gen.go") - - imports := lo.FlatMap(routes, func(item RouteDefinition, index int) []string { - return lo.Map(item.Imports, func(item string, index int) string { - return fmt.Sprintf("\t%s", item) - }) - }) - - pkgName := filepath.Base(key) - - f, err := os.OpenFile(routePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644) - if err != nil { - fmt.Printf("ERR: %s %s\n", err, routePath) - continue - } - defer f.Close() - - _, _ = f.WriteString("// Code generated by the atomctl; DO NOT EDIT.\n\n") - _, _ = f.WriteString("package " + pkgName + "\n\n") - _, _ = f.WriteString("import (\n") - _, _ = f.WriteString(strings.Join(imports, "\n")) - _, _ = f.WriteString("\n") - _, _ = f.WriteString("\n") - - _, _ = f.WriteString("\t\"github.com/atom-providers/log\"\n") - _, _ = f.WriteString("\t\"git.ipao.vip/rogeecn/atom/contracts\"\n") - _, _ = f.WriteString("\t\"github.com/gofiber/fiber/v3\"\n") - _, _ = f.WriteString("\t. \"github.com/pkg/f\"\n") - _, _ = f.WriteString(")\n\n") - - _, _ = f.WriteString("// @provider contracts.HttpRoute atom.GroupRoutes\n") - _, _ = f.WriteString("type Routes struct {\n") - _, _ = f.WriteString("\tengine\t*fiber.App `inject:\"false\"`\n") - _, _ = f.WriteString("\tsvc\tcontracts.HttpService\n") - // inject controllers - for _, route := range routes { - _, _ = f.WriteString(fmt.Sprintf("\t%s *%s", strcase.ToLowerCamel(route.Name), route.Name)) - } - _, _ = f.WriteString("\n}\n\n") - - _, _ = f.WriteString("func (r *Routes) Prepare() error {\n") - _, _ = f.WriteString("\tlog.Infof(\"register route group: " + pkgName + "\")\n") - _, _ = f.WriteString("\tr.engine = r.svc.GetEngine().(*fiber.App)\n") - - for _, route := range routes { - funcName := fmt.Sprintf(`register%sRoutes()`, route.Name) - _, _ = f.WriteString(fmt.Sprintf("\tr.%s\n", funcName)) - } - _, _ = f.WriteString("\treturn nil\n") - _, _ = f.WriteString("}\n\n") - - for _, route := range routes { - funcName := fmt.Sprintf(`func (r *Routes)register%sRoutes()`, route.Name) - _, _ = f.WriteString(fmt.Sprintf(`%s{`, funcName)) - _, _ = f.WriteString("\n") - - for _, action := range route.Actions { - _, _ = f.WriteString("\t") - _, _ = f.WriteString(fmt.Sprintf("r.engine.%s(%q, ", strcase.ToCamel(strings.ToLower(action.Method)), formatRoute(action.Route))) - - if action.HasData { - _, _ = f.WriteString("Data") - } - _, _ = f.WriteString("Func") - - paramsSize := len(action.Params) - if paramsSize >= 1 { - _, _ = f.WriteString(fmt.Sprintf("%d", paramsSize)) - } - - paramsStrings := []string{fmt.Sprintf("r.%s.%s", strcase.ToLowerCamel(route.Name), action.Name)} - for _, p := range action.Params { - var paramString string - switch p.Position { - case PositionPath: - paramString = fmt.Sprintf(`Path[%s](%q)`, p.Type, p.Name) - case PositionURI: - paramString = fmt.Sprintf(`URI[%s]()`, p.Type) - case PositionQuery: - paramString = fmt.Sprintf(`Query[%s]()`, p.Type) - case PositionBody: - paramString = fmt.Sprintf(`Body[%s]()`, p.Type) - case PositionHeader: - paramString = fmt.Sprintf(`Header[%s]()`, p.Type) - case PositionCookie: - paramString = fmt.Sprintf(`Cookie[%s]()`, p.Type) - } - - paramsStrings = append(paramsStrings, paramString) - } - _, _ = f.WriteString("(" + strings.Join(paramsStrings, ", ") + ")") - - _, _ = f.WriteString(")\n") - } - - _, _ = f.WriteString("}\n") + for path, routes := range routeGroups { + if err := route.Render(path, routes); err != nil { + log.WithError(err).WithField("path", path).Error("render route failed") } } return nil } - -type RouteDefinition struct { - Path string - Name string - Imports []string - Actions []ActionDefinition -} - -type ParamDefinition struct { - Name string - Type string - Position Position -} - -type ActionDefinition struct { - Route string - Method string - Name string - HasData bool - Params []ParamDefinition -} -type Position string - -const ( - PositionPath Position = "Path" - PositionURI Position = "Uri" - PositionQuery Position = "Query" - PositionBody Position = "Body" - PositionHeader Position = "Header" - PositionCookie Position = "Cookie" -) - -func astParseRoutes(source string) []RouteDefinition { - if strings.HasSuffix(source, "_test.go") { - return []RouteDefinition{} - } - - if strings.HasSuffix(source, "/provider.go") { - return []RouteDefinition{} - } - - fset := token.NewFileSet() - node, err := parser.ParseFile(fset, source, nil, parser.ParseComments) - if err != nil { - log.Println("ERR: ", err) - return nil - } - - imports := make(map[string]string) - for _, imp := range node.Imports { - paths := strings.Split(strings.Trim(imp.Path.Value, "\""), "/") - name := paths[len(paths)-1] - pkg := imp.Path.Value - if imp.Name != nil { - name = imp.Name.Name - pkg = fmt.Sprintf("%s %q", name, imp.Path.Value) - } - imports[name] = pkg - } - - routes := make(map[string]RouteDefinition) - actions := make(map[string][]ActionDefinition) - usedImports := make(map[string][]string) - - // 再去遍历 struct 的方法去 - for _, decl := range node.Decls { - decl, ok := decl.(*ast.FuncDecl) - if !ok { - continue - } - - // 普通方法不要 - if decl.Recv == nil { - continue - } - - recvType := decl.Recv.List[0].Type.(*ast.StarExpr).X.(*ast.Ident).Name - if _, ok := routes[recvType]; !ok { - routes[recvType] = RouteDefinition{ - Name: recvType, - Path: source, - Actions: []ActionDefinition{}, - } - actions[recvType] = []ActionDefinition{} - } - - // Doc 中把 @Router 的定义拿出来, Route 格式为 /user/:id [get] 两部分,表示路径和请求方法 - var path, method string - var err error - if decl.Doc != nil { - for _, line := range decl.Doc.List { - lineText := strings.TrimSpace(line.Text) - lineText = strings.TrimLeft(lineText, "/ \t") - if !strings.HasPrefix(lineText, "@Route") { - continue - } - - path, method, err = parseRouteComment(lineText) - if err != nil { - log.Fatal(errors.Wrapf(err, "file: %s, action: %s", source, decl.Name.Name)) - } - break - } - } - if path == "" || method == "" { - log.Printf("[WARN] failed to get router ,file: %s, action: %s", source, decl.Name.Name) - continue - } - - // 拿参数列表去, 忽略 context.Context 参数 - params := []ParamDefinition{} - for _, param := range decl.Type.Params.List { - // paramsType, ok := param.Type.(*ast.SelectorExpr) - - var typ string - switch param.Type.(type) { - case *ast.Ident: - typ = param.Type.(*ast.Ident).Name - case *ast.StarExpr: - paramsType := param.Type.(*ast.StarExpr) - switch paramsType.X.(type) { - case *ast.SelectorExpr: - X := paramsType.X.(*ast.SelectorExpr) - typ = fmt.Sprintf("*%s.%s", X.X.(*ast.Ident).Name, X.Sel.Name) - default: - typ = fmt.Sprintf("*%s", paramsType.X.(*ast.Ident).Name) - } - case *ast.SelectorExpr: - typ = fmt.Sprintf("%s.%s", param.Type.(*ast.SelectorExpr).X.(*ast.Ident).Name, param.Type.(*ast.SelectorExpr).Sel.Name) - } - - if strings.HasSuffix(typ, "Context") || strings.HasSuffix(typ, "Ctx") { - continue - } - pkgName := strings.Split(strings.Trim(typ, "*"), ".") - if len(pkgName) == 2 { - usedImports[recvType] = append(usedImports[recvType], imports[pkgName[0]]) - } - - position := PositionPath - if strings.HasSuffix(typ, string(PositionQuery)) || strings.HasSuffix(typ, "QueryFilter") { - position = PositionQuery - } - if strings.HasSuffix(typ, string(PositionBody)) || strings.HasSuffix(typ, "Form") { - position = PositionBody - } - if strings.HasSuffix(typ, string(PositionHeader)) { - position = PositionHeader - } - typ = strings.TrimPrefix(typ, "*") - - for _, name := range param.Names { - params = append(params, ParamDefinition{ - Name: name.Name, - Type: typ, - Position: position, - }) - } - } - - actions[recvType] = append(actions[recvType], ActionDefinition{ - Route: path, - Method: strings.ToUpper(method), - Name: decl.Name.Name, - HasData: len(decl.Type.Results.List) > 1, - Params: params, - }) - } - - var items []RouteDefinition - for k, item := range routes { - a, ok := actions[k] - if !ok { - continue - } - item.Actions = a - item.Imports = []string{} - if im, ok := usedImports[k]; ok { - item.Imports = lo.Uniq(im) - } - items = append(items, item) - } - return items -} - -func parseRouteComment(line string) (string, string, error) { - pattern := regexp.MustCompile(`(?mi)@router\s+(.*?)\s+\[(.*?)\]`) - submatch := pattern.FindStringSubmatch(line) - - if len(submatch) != 3 { - return "", "", errors.New("invalid route definition") - } - - return submatch[1], submatch[2], nil -} - -func getPackageRoute(mod, path string) string { - paths := strings.SplitN(path, "modules", 2) - pkg := paths[1] - // path可能值为 - // /test/user_controller.go - // /test/modules/user_controller.go - - return strings.TrimLeft(filepath.Dir(pkg), "/") -} - -func formatRoute(route string) string { - pattern := regexp.MustCompile(`(?mi)\{(.*?)\}`) - if !pattern.MatchString(route) { - return route - } - - items := pattern.FindAllStringSubmatch(route, -1) - for _, item := range items { - param := strcase.ToLowerCamel(item[1]) - route = strings.ReplaceAll(route, item[0], fmt.Sprintf("{%s}", param)) - } - - route = pattern.ReplaceAllString(route, ":$1") - route = strings.ReplaceAll(route, "/:id", "/:id") - route = strings.ReplaceAll(route, "Id/", "Id/") - return route -} diff --git a/go.mod b/go.mod index 2652632..6da4fec 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,9 @@ go 1.23.2 require ( github.com/Masterminds/sprig/v3 v3.3.0 github.com/bradleyjkemp/cupaloy/v2 v2.8.0 + github.com/ettle/strcase v0.2.0 github.com/go-jet/jet/v2 v2.12.0 + github.com/iancoleman/strcase v0.3.0 github.com/lib/pq v1.10.9 github.com/pkg/errors v0.9.1 github.com/pressly/goose/v3 v3.23.1 diff --git a/go.sum b/go.sum index 2d3bc48..d5f3fa2 100644 --- a/go.sum +++ b/go.sum @@ -21,6 +21,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/ettle/strcase v0.2.0 h1:fGNiVF21fHXpX1niBgk0aROov1LagYsOwV/xqKDKR/Q= +github.com/ettle/strcase v0.2.0/go.mod h1:DajmHElDSaX76ITe3/VHVyMin4LWSJN5Z909Wp+ED1A= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= @@ -45,6 +47,8 @@ github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= +github.com/iancoleman/strcase v0.3.0 h1:nTXanmYxhfFAMjZL34Ov6gkzEsSJZ5DbhxWjvSASxEI= +github.com/iancoleman/strcase v0.3.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..cd98cb9 --- /dev/null +++ b/main_test.go @@ -0,0 +1,64 @@ +package main + +import ( + "strings" + "testing" +) + +type ParamDefinition struct { + Name string + Type string + Key string + Table string + Model string + Position string +} + +func parseBind(bind string) ParamDefinition { + var param ParamDefinition + parts := strings.FieldsFunc(bind, func(r rune) bool { + return r == ' ' || r == '(' || r == ')' + }) + + // 过滤掉空的元素 + var newParts []string + for _, part := range parts { + part = strings.TrimSpace(part) + if part != "" { + newParts = append(newParts, part) + } + } + + for i, part := range parts { + switch part { + case "@Bind": + param.Name = parts[i+1] + param.Position = parts[i+2] + case "key": + param.Key = parts[i+1] + case "table": + param.Table = parts[i+1] + case "model": + param.Model = parts[i+1] + } + } + return param +} + +func Test_T(t *testing.T) { + // @Bind [Name] [Type] [Key] [Table] [Model] + suites := []string{ + `@Bind name query key("a") table(b) model("c")`, + `@Bind id uri key(a)`, + `@Bind id uri table(b)`, + `@Bind id uri key(b) model(c)`, + `@Bind id uri key(b) table(c)`, + `@Bind id uri table(b) key(c)`, + `@Bind id uri`, + } + + for _, suite := range suites { + param := parseBind(suite) + t.Logf("Parsed Param: %+v", param) + } +} diff --git a/pkg/ast/route/render.go b/pkg/ast/route/render.go new file mode 100644 index 0000000..cdea9ae --- /dev/null +++ b/pkg/ast/route/render.go @@ -0,0 +1,110 @@ +package route + +import ( + "bytes" + _ "embed" + "fmt" + "os" + "path/filepath" + "text/template" + + "git.ipao.vip/rogeecn/atomctl/pkg/utils/gomod" + "github.com/Masterminds/sprig/v3" + "github.com/iancoleman/strcase" + "github.com/samber/lo" +) + +//go:embed router.go.tpl +var routeTpl string + +type RenderData struct { + PackageName string + ProjectPackage string + Imports []string + Controllers []string + Routes map[string][]Router +} + +type Router struct { + Method string + Route string + Controller string + Action string + Func string + Params []string +} + +func Render(path string, routes []RouteDefinition) error { + routePath := filepath.Join(path, "routes.gen.go") + + tmpl, err := template.New("route").Funcs(sprig.FuncMap()).Parse(routeTpl) + if err != nil { + return err + } + + renderData := RenderData{ + PackageName: filepath.Base(path), + ProjectPackage: gomod.GetModuleName(), + Routes: make(map[string][]Router), + } + + // collect imports + imports := []string{} + controllers := []string{} + for _, route := range routes { + imports = append(imports, route.Imports...) + controllers = append(controllers, fmt.Sprintf("%s *%s", strcase.ToLowerCamel(route.Name), route.Name)) + for _, action := range route.Actions { + funcName := fmt.Sprintf("Func%d", len(action.Params)) + if action.HasData { + funcName = "Data" + funcName + } + + renderData.Routes[route.Name] = append(renderData.Routes[route.Name], Router{ + Method: strcase.ToCamel(action.Method), + Route: action.Route, + Controller: strcase.ToLowerCamel(route.Name), + Action: action.Name, + Func: funcName, + Params: lo.FilterMap(action.Params, func(item ParamDefinition, _ int) (string, bool) { + switch item.Position { + case PositionURI: + return fmt.Sprintf(`URI[%s]("%s")`, item.Type, item.Name), true + case PositionQuery: + return fmt.Sprintf(`Query[%s]("%s")`, item.Type, item.Name), true + case PositionHeader: + return fmt.Sprintf(`Header[%s]("%s")`, item.Type, item.Name), true + case PositionCookie: + return fmt.Sprintf(`Cookie[%s]("%s")`, item.Type, item.Name), true + case PositionBody: + return fmt.Sprintf(`Body[%s]("%s")`, item.Type, item.Name), true + case PositionPath: + return fmt.Sprintf(`Path[%s]("%s")`, item.Type, item.Name), true + } + return "", false + }), + }) + } + } + + renderData.Imports = lo.Uniq(imports) + renderData.Controllers = lo.Uniq(controllers) + + var buf bytes.Buffer + err = tmpl.Execute(&buf, renderData) + if err != nil { + return err + } + + f, err := os.OpenFile(routePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644) + if err != nil { + return err + } + defer f.Close() + + _, err = f.Write(buf.Bytes()) + if err != nil { + return err + } + return nil +} diff --git a/pkg/ast/route/route.go b/pkg/ast/route/route.go new file mode 100644 index 0000000..10be40b --- /dev/null +++ b/pkg/ast/route/route.go @@ -0,0 +1,285 @@ +package route + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "path/filepath" + "regexp" + "strings" + + "git.ipao.vip/rogeecn/atomctl/pkg/utils/gomod" + "github.com/iancoleman/strcase" + "github.com/pkg/errors" + "github.com/samber/lo" + log "github.com/sirupsen/logrus" +) + +type RouteDefinition struct { + Path string + Name string + Imports []string + Actions []ActionDefinition +} + +type ActionDefinition struct { + Route string + Method string + Name string + HasData bool + Params []ParamDefinition +} + +type ParamDefinition struct { + Name string + Type string + Key string + Table string + Model string + Position Position +} + +type Position string + +func positionFromString(v string) Position { + switch v { + case "path": + return PositionPath + case "uri": + return PositionURI + case "query": + return PositionQuery + case "body": + return PositionBody + case "header": + return PositionHeader + case "cookie": + return PositionCookie + } + panic("invalid position: " + v) +} + +const ( + PositionPath Position = "path" + PositionURI Position = "uri" + PositionQuery Position = "query" + PositionBody Position = "body" + PositionHeader Position = "header" + PositionCookie Position = "cookie" +) + +func ParseFile(file string) []RouteDefinition { + fset := token.NewFileSet() + node, err := parser.ParseFile(fset, file, nil, parser.ParseComments) + if err != nil { + log.Println("ERR: ", err) + return nil + } + + imports := make(map[string]string) + for _, imp := range node.Imports { + pkg := strings.Trim(imp.Path.Value, "\"") + name := gomod.GetPackageModuleName(pkg) + if imp.Name != nil { + name = imp.Name.Name + pkg = fmt.Sprintf(`%s %q`, name, pkg) + imports[name] = pkg + continue + } + imports[name] = fmt.Sprintf("%q", pkg) + } + + routes := make(map[string]RouteDefinition) + actions := make(map[string][]ActionDefinition) + usedImports := make(map[string][]string) + + // 再去遍历 struct 的方法去 + for _, decl := range node.Decls { + decl, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + + // 普通方法不要 + if decl.Recv == nil { + continue + } + + // 没有Doc不要 + if decl.Doc == nil { + continue + } + + recvType := decl.Recv.List[0].Type.(*ast.StarExpr).X.(*ast.Ident).Name + if _, ok := routes[recvType]; !ok { + routes[recvType] = RouteDefinition{ + Name: recvType, + Path: file, + Actions: []ActionDefinition{}, + } + actions[recvType] = []ActionDefinition{} + } + + bindParams := []ParamDefinition{} + + // Doc 中把 @Router 的定义拿出来, Route 格式为 /user/:id [get] 两部分,表示路径和请求方法 + var path, method string + var err error + for _, l := range decl.Doc.List { + line := strings.TrimLeft(l.Text, "/ \t") + line = strings.TrimSpace(line) + + // 路由需要一些切换 + if strings.HasPrefix(line, "@Router") { + path, method, err = parseRouteComment(line) + if err != nil { + log.Fatal(errors.Wrapf(err, "file: %s, action: %s", file, decl.Name.Name)) + } + } + + if strings.HasPrefix(line, "@Bind") { + //@Bind name query key() table() model() + //@Bind name query + bindParams = append(bindParams, parseRouteBind(line)) + } + } + + if path == "" || method == "" { + continue + } + log.WithField("file", file).WithField("action", decl.Name.Name).WithField("path", path).WithField("method", method).Info("get router") + + // 拿参数列表去, 忽略 context.Context 参数 + for _, param := range decl.Type.Params.List { + // paramsType, ok := param.Type.(*ast.SelectorExpr) + + var typ string + switch param.Type.(type) { + case *ast.Ident: + typ = param.Type.(*ast.Ident).Name + case *ast.StarExpr: + paramsType := param.Type.(*ast.StarExpr) + switch paramsType.X.(type) { + case *ast.SelectorExpr: + X := paramsType.X.(*ast.SelectorExpr) + typ = fmt.Sprintf("*%s.%s", X.X.(*ast.Ident).Name, X.Sel.Name) + default: + typ = fmt.Sprintf("*%s", paramsType.X.(*ast.Ident).Name) + } + case *ast.SelectorExpr: + typ = fmt.Sprintf("%s.%s", param.Type.(*ast.SelectorExpr).X.(*ast.Ident).Name, param.Type.(*ast.SelectorExpr).Sel.Name) + } + + if strings.HasSuffix(typ, "Context") || strings.HasSuffix(typ, "Ctx") { + continue + } + pkgName := strings.Split(strings.Trim(typ, "*"), ".") + if len(pkgName) == 2 { + usedImports[recvType] = append(usedImports[recvType], imports[pkgName[0]]) + } + + typ = strings.TrimPrefix(typ, "*") + + for _, name := range param.Names { + for i, bindParam := range bindParams { + if bindParam.Name == name.Name { + bindParams[i].Type = typ + break + } + } + } + } + + actions[recvType] = append(actions[recvType], ActionDefinition{ + Route: path, + Method: strings.ToUpper(method), + Name: decl.Name.Name, + HasData: len(decl.Type.Results.List) > 1, + Params: bindParams, + }) + } + + var items []RouteDefinition + for k, item := range routes { + a, ok := actions[k] + if !ok { + continue + } + item.Actions = a + item.Imports = []string{} + if im, ok := usedImports[k]; ok { + item.Imports = lo.Uniq(im) + } + items = append(items, item) + } + return items +} + +func parseRouteComment(line string) (string, string, error) { + parts := strings.FieldsFunc(line, func(r rune) bool { + return r == ' ' || r == '\t' || r == '[' || r == ']' + }) + parts = lo.Filter(parts, func(item string, idx int) bool { + return item != "" + }) + + if len(parts) != 3 { + return "", "", errors.New("invalid route definition") + } + + return parts[1], parts[2], nil +} + +func getPackageRoute(mod, path string) string { + paths := strings.SplitN(path, "modules", 2) + pkg := paths[1] + // path可能值为 + // /test/user_controller.go + // /test/modules/user_controller.go + + return strings.TrimLeft(filepath.Dir(pkg), "/") +} + +func formatRoute(route string) string { + pattern := regexp.MustCompile(`(?mi)\{(.*?)\}`) + if !pattern.MatchString(route) { + return route + } + + items := pattern.FindAllStringSubmatch(route, -1) + for _, item := range items { + param := strcase.ToLowerCamel(item[1]) + route = strings.ReplaceAll(route, item[0], fmt.Sprintf("{%s}", param)) + } + + route = pattern.ReplaceAllString(route, ":$1") + route = strings.ReplaceAll(route, "/:id", "/:id") + route = strings.ReplaceAll(route, "Id/", "Id/") + return route +} + +func parseRouteBind(bind string) ParamDefinition { + var param ParamDefinition + parts := strings.FieldsFunc(bind, func(r rune) bool { + return r == ' ' || r == '(' || r == ')' || r == '\t' + }) + parts = lo.Filter(parts, func(item string, idx int) bool { + return item != "" + }) + + for i, part := range parts { + switch part { + case "@Bind": + param.Name = parts[i+1] + param.Position = positionFromString(parts[i+2]) + case "key": + param.Key = parts[i+1] + case "table": + param.Table = parts[i+1] + case "model": + param.Model = parts[i+1] + } + } + return param +} diff --git a/pkg/ast/route/router.go.tpl b/pkg/ast/route/router.go.tpl new file mode 100644 index 0000000..b77c913 --- /dev/null +++ b/pkg/ast/route/router.go.tpl @@ -0,0 +1,42 @@ +// Code generated by the atomctl ; DO NOT EDIT. + +package {{.PackageName}} + +import ( +{{- range .Imports }} + {{.}} +{{- end }} + . "{{.ProjectPackage}}/pkg/f" + + _ "git.ipao.vip/rogeecn/atom" + _ "git.ipao.vip/rogeecn/atom/contracts" + "github.com/gofiber/fiber/v3" + log "github.com/sirupsen/logrus" +) + +// @provider contracts.HttpRoute atom.GroupRoutes +type Routes struct { + log *log.Entry `inject:"false"` +{{- range .Controllers }} + {{.}} +{{- end }} +} + +func (r *Routes) Prepare() error { + r.log = log.WithField("module", "routes.{{.PackageName}}") + return nil +} + +func (r *Routes) Register(router fiber.Router) { +{{- range $key, $value := .Routes }} + // 注册路由组: {{$key}} + {{- range $value }} + router.{{.Method}}("{{.Route}}", {{.Func}}( + r.{{.Controller}}.{{.Action}}, + {{- range .Params}} + {{.}}, + {{- end }} + )) + {{ end }} +{{- end }} +} diff --git a/templates/project/pkg/f/bind.go.tpl b/templates/project/pkg/f/bind.go.tpl index 81b5a71..c4b2f8a 100644 --- a/templates/project/pkg/f/bind.go.tpl +++ b/templates/project/pkg/f/bind.go.tpl @@ -2,9 +2,9 @@ package f import ( "github.com/gofiber/fiber/v3" + "github.com/pkg/errors" ) - func Path[T fiber.GenericType](key string) func(fiber.Ctx) (T, error) { return func(ctx fiber.Ctx) (T, error) { v := fiber.Params[T](ctx, key) @@ -12,58 +12,48 @@ func Path[T fiber.GenericType](key string) func(fiber.Ctx) (T, error) { } } -func URI[T any]() func(fiber.Ctx) (*T, error) { +func URI[T any](name string) func(fiber.Ctx) (*T, error) { return func(ctx fiber.Ctx) (*T, error) { p := new(T) if err := ctx.Bind().URI(p); err != nil { - return nil, err + return nil, errors.Wrapf(err, "uri: %s", name) } return p, nil } } -func Cookie[T any]() func(fiber.Ctx) (*T, error) { - return func(ctx fiber.Ctx) (*T, error) { - p := new(T) - if err := ctx.Bind().Cookie(p); err != nil { - return nil, err - } - - return p, nil - } -} - -func Body[T any]() func(fiber.Ctx) (*T, error) { +func Body[T any](name string) func(fiber.Ctx) (*T, error) { return func(ctx fiber.Ctx) (*T, error) { p := new(T) if err := ctx.Bind().Body(p); err != nil { - return nil, err + return nil, errors.Wrapf(err, "body: %s", name) } return p, nil } } -func Query[T any]() func(fiber.Ctx) (*T, error) { +func Query[T any](name string) func(fiber.Ctx) (*T, error) { return func(ctx fiber.Ctx) (*T, error) { p := new(T) if err := ctx.Bind().Query(p); err != nil { - return nil, err + return nil, errors.Wrapf(err, "query: %s", name) } return p, nil } } -func Header[T any]() func(fiber.Ctx) (*T, error) { +func Header[T any](name string) func(fiber.Ctx) (*T, error) { return func(ctx fiber.Ctx) (*T, error) { p := new(T) err := ctx.Bind().Header(p) if err != nil { - return nil, err + return nil, errors.Wrapf(err, "header: %s", name) } return p, nil } } + diff --git a/templates/project/pkg/f/func.go.tpl b/templates/project/pkg/f/func.go.tpl index b6c1158..02a34b7 100644 --- a/templates/project/pkg/f/func.go.tpl +++ b/templates/project/pkg/f/func.go.tpl @@ -18,11 +18,7 @@ func Func1[P1 any]( return err } - err = f(ctx, p) - if err != nil { - return err - } - return nil + return f(ctx, p) } } @@ -42,11 +38,7 @@ func Func2[P1 any, P2 any]( return err } - err = f(ctx, p1, p2) - if err != nil { - return err - } - return nil + return f(ctx, p1, p2) } } @@ -70,11 +62,7 @@ func Func3[P1 any, P2 any, P3 any]( if err != nil { return nil } - err = f(ctx, p1, p2, p3) - if err != nil { - return nil - } - return nil + return f(ctx, p1, p2, p3) } } @@ -106,11 +94,7 @@ func Func4[P1 any, P2 any, P3 any, P4 any]( return nil } - err = f(ctx, p1, p2, p3, p4) - if err != nil { - return nil - } - return nil + return f(ctx, p1, p2, p3, p4) } } @@ -143,11 +127,7 @@ func Func5[P1 any, P2 any, P3 any, P4 any, P5 any]( if err != nil { return nil } - err = f(ctx, p1, p2, p3, p4, p5) - if err != nil { - return nil - } - return nil + return f(ctx, p1, p2, p3, p4, p5) } } @@ -185,10 +165,209 @@ func Func6[P1 any, P2 any, P3 any, P4 any, P5 any, P6 any]( if err != nil { return nil } - err = f(ctx, p1, p2, p3, p4, p5, p6) + return f(ctx, p1, p2, p3, p4, p5, p6) + } +} + +func Func7[P1 any, P2 any, P3 any, P4 any, P5 any, P6 any, P7 any]( + f func(fiber.Ctx, P1, P2, P3, P4, P5, P6, P7) error, + pf1 func(fiber.Ctx) (P1, error), + pf2 func(fiber.Ctx) (P2, error), + pf3 func(fiber.Ctx) (P3, error), + pf4 func(fiber.Ctx) (P4, error), + pf5 func(fiber.Ctx) (P5, error), + pf6 func(fiber.Ctx) (P6, error), + pf7 func(fiber.Ctx) (P7, error), +) fiber.Handler { + return func(ctx fiber.Ctx) error { + p1, err := pf1(ctx) if err != nil { return nil } - return nil + p2, err := pf2(ctx) + if err != nil { + return nil + } + p3, err := pf3(ctx) + if err != nil { + return nil + } + p4, err := pf4(ctx) + if err != nil { + return nil + } + p5, err := pf5(ctx) + if err != nil { + return nil + } + p6, err := pf6(ctx) + if err != nil { + return nil + } + p7, err := pf7(ctx) + if err != nil { + return nil + } + return f(ctx, p1, p2, p3, p4, p5, p6, p7) } } + +func Func8[P1 any, P2 any, P3 any, P4 any, P5 any, P6 any, P7 any, P8 any]( + f func(fiber.Ctx, P1, P2, P3, P4, P5, P6, P7, P8) error, + pf1 func(fiber.Ctx) (P1, error), + pf2 func(fiber.Ctx) (P2, error), + pf3 func(fiber.Ctx) (P3, error), + pf4 func(fiber.Ctx) (P4, error), + pf5 func(fiber.Ctx) (P5, error), + pf6 func(fiber.Ctx) (P6, error), + pf7 func(fiber.Ctx) (P7, error), + pf8 func(fiber.Ctx) (P8, error), +) fiber.Handler { + return func(ctx fiber.Ctx) error { + p1, err := pf1(ctx) + if err != nil { + return nil + } + p2, err := pf2(ctx) + if err != nil { + return nil + } + p3, err := pf3(ctx) + if err != nil { + return nil + } + p4, err := pf4(ctx) + if err != nil { + return nil + } + p5, err := pf5(ctx) + if err != nil { + return nil + } + p6, err := pf6(ctx) + if err != nil { + return nil + } + p7, err := pf7(ctx) + if err != nil { + return nil + } + p8, err := pf8(ctx) + if err != nil { + return nil + } + return f(ctx, p1, p2, p3, p4, p5, p6, p7, p8) + } +} + +func Func9[P1 any, P2 any, P3 any, P4 any, P5 any, P6 any, P7 any, P8 any, P9 any]( + f func(fiber.Ctx, P1, P2, P3, P4, P5, P6, P7, P8, P9) error, + pf1 func(fiber.Ctx) (P1, error), + pf2 func(fiber.Ctx) (P2, error), + pf3 func(fiber.Ctx) (P3, error), + pf4 func(fiber.Ctx) (P4, error), + pf5 func(fiber.Ctx) (P5, error), + pf6 func(fiber.Ctx) (P6, error), + pf7 func(fiber.Ctx) (P7, error), + pf8 func(fiber.Ctx) (P8, error), + pf9 func(fiber.Ctx) (P9, error), +) fiber.Handler { + return func(ctx fiber.Ctx) error { + p1, err := pf1(ctx) + if err != nil { + return nil + } + p2, err := pf2(ctx) + if err != nil { + return nil + } + p3, err := pf3(ctx) + if err != nil { + return nil + } + p4, err := pf4(ctx) + if err != nil { + return nil + } + p5, err := pf5(ctx) + if err != nil { + return nil + } + p6, err := pf6(ctx) + if err != nil { + return nil + } + p7, err := pf7(ctx) + if err != nil { + return nil + } + p8, err := pf8(ctx) + if err != nil { + return nil + } + p9, err := pf9(ctx) + if err != nil { + return nil + } + return f(ctx, p1, p2, p3, p4, p5, p6, p7, p8, p9) + } +} + +func Func10[P1 any, P2 any, P3 any, P4 any, P5 any, P6 any, P7 any, P8 any, P9 any, P10 any]( + f func(fiber.Ctx, P1, P2, P3, P4, P5, P6, P7, P8, P9, P10) error, + pf1 func(fiber.Ctx) (P1, error), + pf2 func(fiber.Ctx) (P2, error), + pf3 func(fiber.Ctx) (P3, error), + pf4 func(fiber.Ctx) (P4, error), + pf5 func(fiber.Ctx) (P5, error), + pf6 func(fiber.Ctx) (P6, error), + pf7 func(fiber.Ctx) (P7, error), + pf8 func(fiber.Ctx) (P8, error), + pf9 func(fiber.Ctx) (P9, error), + pf10 func(fiber.Ctx) (P10, error), +) fiber.Handler { + return func(ctx fiber.Ctx) error { + p1, err := pf1(ctx) + if err != nil { + return nil + } + p2, err := pf2(ctx) + if err != nil { + return nil + } + p3, err := pf3(ctx) + if err != nil { + return nil + } + p4, err := pf4(ctx) + if err != nil { + return nil + } + p5, err := pf5(ctx) + if err != nil { + return nil + } + p6, err := pf6(ctx) + if err != nil { + return nil + } + p7, err := pf7(ctx) + if err != nil { + return nil + } + p8, err := pf8(ctx) + if err != nil { + return nil + } + p9, err := pf9(ctx) + if err != nil { + return nil + } + p10, err := pf10(ctx) + if err != nil { + return nil + } + return f(ctx, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10) + } +} +