package cmd import ( "fmt" "go/ast" "go/parser" "go/token" "io/fs" "log" "os" "path/filepath" "regexp" "strings" "git.ipao.vip/rogeecn/atomctl/pkg/utils/gomod" "github.com/iancoleman/strcase" "github.com/pkg/errors" "github.com/samber/lo" "github.com/spf13/cobra" ) func CommandGenRoute(root *cobra.Command) { cmd := &cobra.Command{ Use: "route", Short: "generate routes", RunE: commandGenRouteE, } root.AddCommand(cmd) } func commandGenRouteE(cmd *cobra.Command, args []string) error { var err error var path string if len(args) > 0 { path = args[0] } else { path, err = os.Getwd() if err != nil { return err } } path, _ = filepath.Abs(path) err = gomod.Parse(filepath.Join(path, "go.mod")) if err != nil { return err } routes := []RouteDefinition{} modulePath := filepath.Join(path, "modules") if _, err := os.Stat(modulePath); os.IsNotExist(err) { log.Fatal("modules dir not exist, ", modulePath) } filepath.WalkDir(modulePath, func(path string, d fs.DirEntry, err error) error { if d.IsDir() { return nil } if !strings.HasSuffix(path, "_controller.go") { return nil } if !strings.HasSuffix(path, ".go") { return nil } routes = append(routes, astParseRoutes(path)...) return nil }) routeGroups := lo.GroupBy(routes, func(item RouteDefinition) string { return filepath.Dir(item.Path) }) for key, routes := range routeGroups { routePath := filepath.Join(key, "routes.gen.go") imports := lo.FlatMap(routes, func(item RouteDefinition, index int) []string { return lo.Map(item.Imports, func(item string, index int) string { return fmt.Sprintf("\t%s", item) }) }) pkgName := filepath.Base(key) f, err := os.OpenFile(routePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644) if err != nil { fmt.Printf("ERR: %s %s\n", err, routePath) continue } defer f.Close() _, _ = f.WriteString("// Code generated by the atomctl; DO NOT EDIT.\n\n") _, _ = f.WriteString("package " + pkgName + "\n\n") _, _ = f.WriteString("import (\n") _, _ = f.WriteString(strings.Join(imports, "\n")) _, _ = f.WriteString("\n") _, _ = f.WriteString("\n") _, _ = f.WriteString("\t\"github.com/atom-providers/log\"\n") _, _ = f.WriteString("\t\"git.ipao.vip/rogeecn/atom/contracts\"\n") _, _ = f.WriteString("\t\"github.com/gofiber/fiber/v3\"\n") _, _ = f.WriteString("\t. \"github.com/pkg/f\"\n") _, _ = f.WriteString(")\n\n") _, _ = f.WriteString("// @provider contracts.HttpRoute atom.GroupRoutes\n") _, _ = f.WriteString("type Routes struct {\n") _, _ = f.WriteString("\tengine\t*fiber.App `inject:\"false\"`\n") _, _ = f.WriteString("\tsvc\tcontracts.HttpService\n") // inject controllers for _, route := range routes { _, _ = f.WriteString(fmt.Sprintf("\t%s *%s", strcase.ToLowerCamel(route.Name), route.Name)) } _, _ = f.WriteString("\n}\n\n") _, _ = f.WriteString("func (r *Routes) Prepare() error {\n") _, _ = f.WriteString("\tlog.Infof(\"register route group: " + pkgName + "\")\n") _, _ = f.WriteString("\tr.engine = r.svc.GetEngine().(*fiber.App)\n") for _, route := range routes { funcName := fmt.Sprintf(`register%sRoutes()`, route.Name) _, _ = f.WriteString(fmt.Sprintf("\tr.%s\n", funcName)) } _, _ = f.WriteString("\treturn nil\n") _, _ = f.WriteString("}\n\n") for _, route := range routes { funcName := fmt.Sprintf(`func (r *Routes)register%sRoutes()`, route.Name) _, _ = f.WriteString(fmt.Sprintf(`%s{`, funcName)) _, _ = f.WriteString("\n") for _, action := range route.Actions { _, _ = f.WriteString("\t") _, _ = f.WriteString(fmt.Sprintf("r.engine.%s(%q, ", strcase.ToCamel(strings.ToLower(action.Method)), formatRoute(action.Route))) if action.HasData { _, _ = f.WriteString("Data") } _, _ = f.WriteString("Func") paramsSize := len(action.Params) if paramsSize >= 1 { _, _ = f.WriteString(fmt.Sprintf("%d", paramsSize)) } paramsStrings := []string{fmt.Sprintf("r.%s.%s", strcase.ToLowerCamel(route.Name), action.Name)} for _, p := range action.Params { var paramString string switch p.Position { case PositionPath: paramString = fmt.Sprintf(`Path[%s](%q)`, p.Type, p.Name) case PositionURI: paramString = fmt.Sprintf(`URI[%s]()`, p.Type) case PositionQuery: paramString = fmt.Sprintf(`Query[%s]()`, p.Type) case PositionBody: paramString = fmt.Sprintf(`Body[%s]()`, p.Type) case PositionHeader: paramString = fmt.Sprintf(`Header[%s]()`, p.Type) case PositionCookie: paramString = fmt.Sprintf(`Cookie[%s]()`, p.Type) } paramsStrings = append(paramsStrings, paramString) } _, _ = f.WriteString("(" + strings.Join(paramsStrings, ", ") + ")") _, _ = f.WriteString(")\n") } _, _ = f.WriteString("}\n") } } return nil } type RouteDefinition struct { Path string Name string Imports []string Actions []ActionDefinition } type ParamDefinition struct { Name string Type string Position Position } type ActionDefinition struct { Route string Method string Name string HasData bool Params []ParamDefinition } type Position string const ( PositionPath Position = "Path" PositionURI Position = "Uri" PositionQuery Position = "Query" PositionBody Position = "Body" PositionHeader Position = "Header" PositionCookie Position = "Cookie" ) func astParseRoutes(source string) []RouteDefinition { if strings.HasSuffix(source, "_test.go") { return []RouteDefinition{} } if strings.HasSuffix(source, "/provider.go") { return []RouteDefinition{} } fset := token.NewFileSet() node, err := parser.ParseFile(fset, source, nil, parser.ParseComments) if err != nil { log.Println("ERR: ", err) return nil } imports := make(map[string]string) for _, imp := range node.Imports { paths := strings.Split(strings.Trim(imp.Path.Value, "\""), "/") name := paths[len(paths)-1] pkg := imp.Path.Value if imp.Name != nil { name = imp.Name.Name pkg = fmt.Sprintf("%s %q", name, imp.Path.Value) } imports[name] = pkg } routes := make(map[string]RouteDefinition) actions := make(map[string][]ActionDefinition) usedImports := make(map[string][]string) // 再去遍历 struct 的方法去 for _, decl := range node.Decls { decl, ok := decl.(*ast.FuncDecl) if !ok { continue } // 普通方法不要 if decl.Recv == nil { continue } recvType := decl.Recv.List[0].Type.(*ast.StarExpr).X.(*ast.Ident).Name if _, ok := routes[recvType]; !ok { routes[recvType] = RouteDefinition{ Name: recvType, Path: source, Actions: []ActionDefinition{}, } actions[recvType] = []ActionDefinition{} } // Doc 中把 @Router 的定义拿出来, Route 格式为 /user/:id [get] 两部分,表示路径和请求方法 var path, method string var err error if decl.Doc != nil { for _, line := range decl.Doc.List { lineText := strings.TrimSpace(line.Text) lineText = strings.TrimLeft(lineText, "/ \t") if !strings.HasPrefix(lineText, "@Route") { continue } path, method, err = parseRouteComment(lineText) if err != nil { log.Fatal(errors.Wrapf(err, "file: %s, action: %s", source, decl.Name.Name)) } break } } if path == "" || method == "" { log.Printf("[WARN] failed to get router ,file: %s, action: %s", source, decl.Name.Name) continue } // 拿参数列表去, 忽略 context.Context 参数 params := []ParamDefinition{} for _, param := range decl.Type.Params.List { // paramsType, ok := param.Type.(*ast.SelectorExpr) var typ string switch param.Type.(type) { case *ast.Ident: typ = param.Type.(*ast.Ident).Name case *ast.StarExpr: paramsType := param.Type.(*ast.StarExpr) switch paramsType.X.(type) { case *ast.SelectorExpr: X := paramsType.X.(*ast.SelectorExpr) typ = fmt.Sprintf("*%s.%s", X.X.(*ast.Ident).Name, X.Sel.Name) default: typ = fmt.Sprintf("*%s", paramsType.X.(*ast.Ident).Name) } case *ast.SelectorExpr: typ = fmt.Sprintf("%s.%s", param.Type.(*ast.SelectorExpr).X.(*ast.Ident).Name, param.Type.(*ast.SelectorExpr).Sel.Name) } if strings.HasSuffix(typ, "Context") || strings.HasSuffix(typ, "Ctx") { continue } pkgName := strings.Split(strings.Trim(typ, "*"), ".") if len(pkgName) == 2 { usedImports[recvType] = append(usedImports[recvType], imports[pkgName[0]]) } position := PositionPath if strings.HasSuffix(typ, string(PositionQuery)) || strings.HasSuffix(typ, "QueryFilter") { position = PositionQuery } if strings.HasSuffix(typ, string(PositionBody)) || strings.HasSuffix(typ, "Form") { position = PositionBody } if strings.HasSuffix(typ, string(PositionHeader)) { position = PositionHeader } typ = strings.TrimPrefix(typ, "*") for _, name := range param.Names { params = append(params, ParamDefinition{ Name: name.Name, Type: typ, Position: position, }) } } actions[recvType] = append(actions[recvType], ActionDefinition{ Route: path, Method: strings.ToUpper(method), Name: decl.Name.Name, HasData: len(decl.Type.Results.List) > 1, Params: params, }) } var items []RouteDefinition for k, item := range routes { a, ok := actions[k] if !ok { continue } item.Actions = a item.Imports = []string{} if im, ok := usedImports[k]; ok { item.Imports = lo.Uniq(im) } items = append(items, item) } return items } func parseRouteComment(line string) (string, string, error) { pattern := regexp.MustCompile(`(?mi)@router\s+(.*?)\s+\[(.*?)\]`) submatch := pattern.FindStringSubmatch(line) if len(submatch) != 3 { return "", "", errors.New("invalid route definition") } return submatch[1], submatch[2], nil } func getPackageRoute(mod, path string) string { paths := strings.SplitN(path, "modules", 2) pkg := paths[1] // path可能值为 // /test/user_controller.go // /test/modules/user_controller.go return strings.TrimLeft(filepath.Dir(pkg), "/") } func formatRoute(route string) string { pattern := regexp.MustCompile(`(?mi)\{(.*?)\}`) if !pattern.MatchString(route) { return route } items := pattern.FindAllStringSubmatch(route, -1) for _, item := range items { param := strcase.ToLowerCamel(item[1]) route = strings.ReplaceAll(route, item[0], fmt.Sprintf("{%s}", param)) } route = pattern.ReplaceAllString(route, ":$1") route = strings.ReplaceAll(route, "/:id", "/:id") route = strings.ReplaceAll(route, "Id/", "Id/") return route }