402 lines
10 KiB
Go
402 lines
10 KiB
Go
package cmd
|
|
|
|
import (
|
|
"fmt"
|
|
"go/ast"
|
|
"go/parser"
|
|
"go/token"
|
|
"io/fs"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"git.ipao.vip/rogeecn/atomctl/pkg/utils/gomod"
|
|
"github.com/iancoleman/strcase"
|
|
"github.com/pkg/errors"
|
|
"github.com/samber/lo"
|
|
"github.com/spf13/cobra"
|
|
)
|
|
|
|
func CommandGenRoute(root *cobra.Command) {
|
|
cmd := &cobra.Command{
|
|
Use: "route",
|
|
Short: "generate routes",
|
|
RunE: commandGenRouteE,
|
|
}
|
|
|
|
root.AddCommand(cmd)
|
|
}
|
|
|
|
func commandGenRouteE(cmd *cobra.Command, args []string) error {
|
|
var err error
|
|
var path string
|
|
if len(args) > 0 {
|
|
path = args[0]
|
|
} else {
|
|
path, err = os.Getwd()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
path, _ = filepath.Abs(path)
|
|
|
|
err = gomod.Parse(filepath.Join(path, "go.mod"))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
routes := []RouteDefinition{}
|
|
|
|
modulePath := filepath.Join(path, "modules")
|
|
if _, err := os.Stat(modulePath); os.IsNotExist(err) {
|
|
log.Fatal("modules dir not exist, ", modulePath)
|
|
}
|
|
|
|
filepath.WalkDir(modulePath, func(path string, d fs.DirEntry, err error) error {
|
|
if d.IsDir() {
|
|
return nil
|
|
}
|
|
if !strings.HasSuffix(path, "_controller.go") {
|
|
return nil
|
|
}
|
|
if !strings.HasSuffix(path, ".go") {
|
|
return nil
|
|
}
|
|
|
|
routes = append(routes, astParseRoutes(path)...)
|
|
return nil
|
|
})
|
|
|
|
routeGroups := lo.GroupBy(routes, func(item RouteDefinition) string {
|
|
return filepath.Dir(item.Path)
|
|
})
|
|
|
|
for key, routes := range routeGroups {
|
|
routePath := filepath.Join(key, "routes.gen.go")
|
|
|
|
imports := lo.FlatMap(routes, func(item RouteDefinition, index int) []string {
|
|
return lo.Map(item.Imports, func(item string, index int) string {
|
|
return fmt.Sprintf("\t%s", item)
|
|
})
|
|
})
|
|
|
|
pkgName := filepath.Base(key)
|
|
|
|
f, err := os.OpenFile(routePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644)
|
|
if err != nil {
|
|
fmt.Printf("ERR: %s %s\n", err, routePath)
|
|
continue
|
|
}
|
|
defer f.Close()
|
|
|
|
_, _ = f.WriteString("// Code generated by the atomctl; DO NOT EDIT.\n\n")
|
|
_, _ = f.WriteString("package " + pkgName + "\n\n")
|
|
_, _ = f.WriteString("import (\n")
|
|
_, _ = f.WriteString(strings.Join(imports, "\n"))
|
|
_, _ = f.WriteString("\n")
|
|
_, _ = f.WriteString("\n")
|
|
|
|
_, _ = f.WriteString("\t\"github.com/atom-providers/log\"\n")
|
|
_, _ = f.WriteString("\t\"git.ipao.vip/rogeecn/atom/contracts\"\n")
|
|
_, _ = f.WriteString("\t\"github.com/gofiber/fiber/v3\"\n")
|
|
_, _ = f.WriteString("\t. \"github.com/pkg/f\"\n")
|
|
_, _ = f.WriteString(")\n\n")
|
|
|
|
_, _ = f.WriteString("// @provider contracts.HttpRoute atom.GroupRoutes\n")
|
|
_, _ = f.WriteString("type Routes struct {\n")
|
|
_, _ = f.WriteString("\tengine\t*fiber.App `inject:\"false\"`\n")
|
|
_, _ = f.WriteString("\tsvc\tcontracts.HttpService\n")
|
|
// inject controllers
|
|
for _, route := range routes {
|
|
_, _ = f.WriteString(fmt.Sprintf("\t%s *%s", strcase.ToLowerCamel(route.Name), route.Name))
|
|
}
|
|
_, _ = f.WriteString("\n}\n\n")
|
|
|
|
_, _ = f.WriteString("func (r *Routes) Prepare() error {\n")
|
|
_, _ = f.WriteString("\tlog.Infof(\"register route group: " + pkgName + "\")\n")
|
|
_, _ = f.WriteString("\tr.engine = r.svc.GetEngine().(*fiber.App)\n")
|
|
|
|
for _, route := range routes {
|
|
funcName := fmt.Sprintf(`register%sRoutes()`, route.Name)
|
|
_, _ = f.WriteString(fmt.Sprintf("\tr.%s\n", funcName))
|
|
}
|
|
_, _ = f.WriteString("\treturn nil\n")
|
|
_, _ = f.WriteString("}\n\n")
|
|
|
|
for _, route := range routes {
|
|
funcName := fmt.Sprintf(`func (r *Routes)register%sRoutes()`, route.Name)
|
|
_, _ = f.WriteString(fmt.Sprintf(`%s{`, funcName))
|
|
_, _ = f.WriteString("\n")
|
|
|
|
for _, action := range route.Actions {
|
|
_, _ = f.WriteString("\t")
|
|
_, _ = f.WriteString(fmt.Sprintf("r.engine.%s(%q, ", strcase.ToCamel(strings.ToLower(action.Method)), formatRoute(action.Route)))
|
|
|
|
if action.HasData {
|
|
_, _ = f.WriteString("Data")
|
|
}
|
|
_, _ = f.WriteString("Func")
|
|
|
|
paramsSize := len(action.Params)
|
|
if paramsSize >= 1 {
|
|
_, _ = f.WriteString(fmt.Sprintf("%d", paramsSize))
|
|
}
|
|
|
|
paramsStrings := []string{fmt.Sprintf("r.%s.%s", strcase.ToLowerCamel(route.Name), action.Name)}
|
|
for _, p := range action.Params {
|
|
var paramString string
|
|
switch p.Position {
|
|
case PositionPath:
|
|
paramString = fmt.Sprintf(`Path[%s](%q)`, p.Type, p.Name)
|
|
case PositionURI:
|
|
paramString = fmt.Sprintf(`URI[%s]()`, p.Type)
|
|
case PositionQuery:
|
|
paramString = fmt.Sprintf(`Query[%s]()`, p.Type)
|
|
case PositionBody:
|
|
paramString = fmt.Sprintf(`Body[%s]()`, p.Type)
|
|
case PositionHeader:
|
|
paramString = fmt.Sprintf(`Header[%s]()`, p.Type)
|
|
case PositionCookie:
|
|
paramString = fmt.Sprintf(`Cookie[%s]()`, p.Type)
|
|
}
|
|
|
|
paramsStrings = append(paramsStrings, paramString)
|
|
}
|
|
_, _ = f.WriteString("(" + strings.Join(paramsStrings, ", ") + ")")
|
|
|
|
_, _ = f.WriteString(")\n")
|
|
}
|
|
|
|
_, _ = f.WriteString("}\n")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type RouteDefinition struct {
|
|
Path string
|
|
Name string
|
|
Imports []string
|
|
Actions []ActionDefinition
|
|
}
|
|
|
|
type ParamDefinition struct {
|
|
Name string
|
|
Type string
|
|
Position Position
|
|
}
|
|
|
|
type ActionDefinition struct {
|
|
Route string
|
|
Method string
|
|
Name string
|
|
HasData bool
|
|
Params []ParamDefinition
|
|
}
|
|
type Position string
|
|
|
|
const (
|
|
PositionPath Position = "Path"
|
|
PositionURI Position = "Uri"
|
|
PositionQuery Position = "Query"
|
|
PositionBody Position = "Body"
|
|
PositionHeader Position = "Header"
|
|
PositionCookie Position = "Cookie"
|
|
)
|
|
|
|
func astParseRoutes(source string) []RouteDefinition {
|
|
if strings.HasSuffix(source, "_test.go") {
|
|
return []RouteDefinition{}
|
|
}
|
|
|
|
if strings.HasSuffix(source, "/provider.go") {
|
|
return []RouteDefinition{}
|
|
}
|
|
|
|
fset := token.NewFileSet()
|
|
node, err := parser.ParseFile(fset, source, nil, parser.ParseComments)
|
|
if err != nil {
|
|
log.Println("ERR: ", err)
|
|
return nil
|
|
}
|
|
|
|
imports := make(map[string]string)
|
|
for _, imp := range node.Imports {
|
|
paths := strings.Split(strings.Trim(imp.Path.Value, "\""), "/")
|
|
name := paths[len(paths)-1]
|
|
pkg := imp.Path.Value
|
|
if imp.Name != nil {
|
|
name = imp.Name.Name
|
|
pkg = fmt.Sprintf("%s %q", name, imp.Path.Value)
|
|
}
|
|
imports[name] = pkg
|
|
}
|
|
|
|
routes := make(map[string]RouteDefinition)
|
|
actions := make(map[string][]ActionDefinition)
|
|
usedImports := make(map[string][]string)
|
|
|
|
// 再去遍历 struct 的方法去
|
|
for _, decl := range node.Decls {
|
|
decl, ok := decl.(*ast.FuncDecl)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
// 普通方法不要
|
|
if decl.Recv == nil {
|
|
continue
|
|
}
|
|
|
|
recvType := decl.Recv.List[0].Type.(*ast.StarExpr).X.(*ast.Ident).Name
|
|
if _, ok := routes[recvType]; !ok {
|
|
routes[recvType] = RouteDefinition{
|
|
Name: recvType,
|
|
Path: source,
|
|
Actions: []ActionDefinition{},
|
|
}
|
|
actions[recvType] = []ActionDefinition{}
|
|
}
|
|
|
|
// Doc 中把 @Router 的定义拿出来, Route 格式为 /user/:id [get] 两部分,表示路径和请求方法
|
|
var path, method string
|
|
var err error
|
|
if decl.Doc != nil {
|
|
for _, line := range decl.Doc.List {
|
|
lineText := strings.TrimSpace(line.Text)
|
|
lineText = strings.TrimLeft(lineText, "/ \t")
|
|
if !strings.HasPrefix(lineText, "@Route") {
|
|
continue
|
|
}
|
|
|
|
path, method, err = parseRouteComment(lineText)
|
|
if err != nil {
|
|
log.Fatal(errors.Wrapf(err, "file: %s, action: %s", source, decl.Name.Name))
|
|
}
|
|
break
|
|
}
|
|
}
|
|
if path == "" || method == "" {
|
|
log.Printf("[WARN] failed to get router ,file: %s, action: %s", source, decl.Name.Name)
|
|
continue
|
|
}
|
|
|
|
// 拿参数列表去, 忽略 context.Context 参数
|
|
params := []ParamDefinition{}
|
|
for _, param := range decl.Type.Params.List {
|
|
// paramsType, ok := param.Type.(*ast.SelectorExpr)
|
|
|
|
var typ string
|
|
switch param.Type.(type) {
|
|
case *ast.Ident:
|
|
typ = param.Type.(*ast.Ident).Name
|
|
case *ast.StarExpr:
|
|
paramsType := param.Type.(*ast.StarExpr)
|
|
switch paramsType.X.(type) {
|
|
case *ast.SelectorExpr:
|
|
X := paramsType.X.(*ast.SelectorExpr)
|
|
typ = fmt.Sprintf("*%s.%s", X.X.(*ast.Ident).Name, X.Sel.Name)
|
|
default:
|
|
typ = fmt.Sprintf("*%s", paramsType.X.(*ast.Ident).Name)
|
|
}
|
|
case *ast.SelectorExpr:
|
|
typ = fmt.Sprintf("%s.%s", param.Type.(*ast.SelectorExpr).X.(*ast.Ident).Name, param.Type.(*ast.SelectorExpr).Sel.Name)
|
|
}
|
|
|
|
if strings.HasSuffix(typ, "Context") || strings.HasSuffix(typ, "Ctx") {
|
|
continue
|
|
}
|
|
pkgName := strings.Split(strings.Trim(typ, "*"), ".")
|
|
if len(pkgName) == 2 {
|
|
usedImports[recvType] = append(usedImports[recvType], imports[pkgName[0]])
|
|
}
|
|
|
|
position := PositionPath
|
|
if strings.HasSuffix(typ, string(PositionQuery)) || strings.HasSuffix(typ, "QueryFilter") {
|
|
position = PositionQuery
|
|
}
|
|
if strings.HasSuffix(typ, string(PositionBody)) || strings.HasSuffix(typ, "Form") {
|
|
position = PositionBody
|
|
}
|
|
if strings.HasSuffix(typ, string(PositionHeader)) {
|
|
position = PositionHeader
|
|
}
|
|
typ = strings.TrimPrefix(typ, "*")
|
|
|
|
for _, name := range param.Names {
|
|
params = append(params, ParamDefinition{
|
|
Name: name.Name,
|
|
Type: typ,
|
|
Position: position,
|
|
})
|
|
}
|
|
}
|
|
|
|
actions[recvType] = append(actions[recvType], ActionDefinition{
|
|
Route: path,
|
|
Method: strings.ToUpper(method),
|
|
Name: decl.Name.Name,
|
|
HasData: len(decl.Type.Results.List) > 1,
|
|
Params: params,
|
|
})
|
|
}
|
|
|
|
var items []RouteDefinition
|
|
for k, item := range routes {
|
|
a, ok := actions[k]
|
|
if !ok {
|
|
continue
|
|
}
|
|
item.Actions = a
|
|
item.Imports = []string{}
|
|
if im, ok := usedImports[k]; ok {
|
|
item.Imports = lo.Uniq(im)
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
return items
|
|
}
|
|
|
|
func parseRouteComment(line string) (string, string, error) {
|
|
pattern := regexp.MustCompile(`(?mi)@router\s+(.*?)\s+\[(.*?)\]`)
|
|
submatch := pattern.FindStringSubmatch(line)
|
|
|
|
if len(submatch) != 3 {
|
|
return "", "", errors.New("invalid route definition")
|
|
}
|
|
|
|
return submatch[1], submatch[2], nil
|
|
}
|
|
|
|
func getPackageRoute(mod, path string) string {
|
|
paths := strings.SplitN(path, "modules", 2)
|
|
pkg := paths[1]
|
|
// path可能值为
|
|
// /test/user_controller.go
|
|
// /test/modules/user_controller.go
|
|
|
|
return strings.TrimLeft(filepath.Dir(pkg), "/")
|
|
}
|
|
|
|
func formatRoute(route string) string {
|
|
pattern := regexp.MustCompile(`(?mi)\{(.*?)\}`)
|
|
if !pattern.MatchString(route) {
|
|
return route
|
|
}
|
|
|
|
items := pattern.FindAllStringSubmatch(route, -1)
|
|
for _, item := range items {
|
|
param := strcase.ToLowerCamel(item[1])
|
|
route = strings.ReplaceAll(route, item[0], fmt.Sprintf("{%s}", param))
|
|
}
|
|
|
|
route = pattern.ReplaceAllString(route, ":$1")
|
|
route = strings.ReplaceAll(route, "/:id", "/:id<int>")
|
|
route = strings.ReplaceAll(route, "Id/", "Id<int>/")
|
|
return route
|
|
}
|