diff --git a/pkg/ast/route/builder.go b/pkg/ast/route/builder.go index 7a01d32..acf0619 100644 --- a/pkg/ast/route/builder.go +++ b/pkg/ast/route/builder.go @@ -3,6 +3,7 @@ package route import ( "fmt" "sort" + "strings" "github.com/iancoleman/strcase" "github.com/samber/lo" @@ -85,41 +86,51 @@ func buildRenderData(opts RenderBuildOpts) (RenderData, error) { } func buildParamToken(item ParamDefinition) string { - key := item.Name - if item.Key != "" { - key = item.Key - } + key := item.Name + if item.Key != "" { + key = item.Key + } - switch item.Position { - case PositionQuery: - return fmt.Sprintf(`Query%s[%s]("%s")`, scalarSuffix(item.Type), item.Type, key) - case PositionHeader: - return fmt.Sprintf(`Header[%s]("%s")`, item.Type, key) - case PositionFile: - return fmt.Sprintf(`File[multipart.FileHeader]("%s")`, key) - case PositionCookie: - if item.Type == "string" { - return fmt.Sprintf(`CookieParam("%s")`, key) - } - return fmt.Sprintf(`Cookie[%s]("%s")`, item.Type, key) - case PositionBody: - return fmt.Sprintf(`Body[%s]("%s")`, item.Type, key) - case PositionPath: - // If a model field is specified, generate a model-lookup binder from path value. - if item.ModelField != "" || item.Model != "" { - field := item.ModelField - if field == "" { - field = "id" - } - // PathModel is expected to resolve the path param to the specified model by field. - // Example: PathModel[models.User]("id", "user_id") - return fmt.Sprintf(`PathModel[%s]("%s", "%s")`, item.Type, field, key) - } - return fmt.Sprintf(`Path%s[%s]("%s")`, scalarSuffix(item.Type), item.Type, key) - case PositionLocal: - return fmt.Sprintf(`Local[%s]("%s")`, item.Type, key) - } - return "" + switch item.Position { + case PositionQuery: + return fmt.Sprintf(`Query%s[%s]("%s")`, scalarSuffix(item.Type), item.Type, key) + case PositionHeader: + return fmt.Sprintf(`Header[%s]("%s")`, item.Type, key) + case PositionFile: + return fmt.Sprintf(`File[multipart.FileHeader]("%s")`, key) + case PositionCookie: + if item.Type == "string" { + return fmt.Sprintf(`CookieParam("%s")`, key) + } + return fmt.Sprintf(`Cookie[%s]("%s")`, item.Type, key) + case PositionBody: + return fmt.Sprintf(`Body[%s]("%s")`, item.Type, key) + case PositionPath: + // If a model field is specified, generate a model-lookup binder from path value. + if item.Model != "" { + field := "id" + fieldType := "int" + if strings.Contains(item.Model, ":") { + parts := strings.SplitN(item.Model, ":", 2) + if len(parts) == 2 { + field = parts[0] + fieldType = parts[1] + } + } else { + field = item.Model + } + + tpl := `func(ctx fiber.Ctx) (*%s, error) { + v := fiber.Params[%s](ctx, "%s") + return %sQuery.WithContext(ctx).Where(field.NewUnsafeFieldRaw("%s = ?", v)).First() + }` + return fmt.Sprintf(tpl, item.Type, fieldType, key, item.Type, field) + } + return fmt.Sprintf(`Path%s[%s]("%s")`, scalarSuffix(item.Type), item.Type, key) + case PositionLocal: + return fmt.Sprintf(`Local[%s]("%s")`, item.Type, key) + } + return "" } func scalarSuffix(t string) string { diff --git a/pkg/ast/route/route.go b/pkg/ast/route/route.go index 3d2927e..18e4925 100644 --- a/pkg/ast/route/route.go +++ b/pkg/ast/route/route.go @@ -29,14 +29,11 @@ type ActionDefinition struct { } type ParamDefinition struct { - Name string - Type string - Key string - Model string - // ModelField is the field/column name used to lookup the model when Model is set. - // Example: `@Bind user path key(id) model(database/models.User:id)` -> Model=database/models.User, ModelField=id - ModelField string - Position Position + Name string + Type string + Key string + Model string + Position Position } type Position string @@ -141,15 +138,19 @@ func ParseFile(file string) []RouteDefinition { } if strings.HasPrefix(line, "@Bind") { - //@Bind name [uri|query|path|body|header|cookie] [key()] [table()] [model(.[:])] - bindParams = append(bindParams, parseRouteBind(line)) - } + //@Bind name [uri|query|path|body|header|cookie] [key()] [table()] [model(.[:])] + bindParams = append(bindParams, parseRouteBind(line)) + } } if path == "" || method == "" { continue } - log.WithField("file", file).WithField("action", decl.Name.Name).WithField("path", path).WithField("method", method).Info("get router") + log.WithField("file", file). + WithField("action", decl.Name.Name). + WithField("path", path). + WithField("method", method). + Info("get router") // 拿参数列表去, 忽略 context.Context 参数 orderBindParams := []ParamDefinition{} @@ -239,44 +240,32 @@ func parseRouteComment(line string) (string, string, error) { } func parseRouteBind(bind string) ParamDefinition { - var param ParamDefinition - parts := strings.FieldsFunc(bind, func(r rune) bool { - return r == ' ' || r == '(' || r == ')' || r == '\t' - }) - parts = lo.Filter(parts, func(item string, idx int) bool { - return item != "" - }) + var param ParamDefinition + parts := strings.FieldsFunc(bind, func(r rune) bool { + return r == ' ' || r == '(' || r == ')' || r == '\t' + }) + parts = lo.Filter(parts, func(item string, idx int) bool { + return item != "" + }) - for i, part := range parts { - switch part { - case "@Bind": - param.Name = parts[i+1] - param.Position = positionFromString(parts[i+2]) - case "key": - param.Key = parts[i+1] - case "model": - // Supported formats: - // - model(field) -> only specify model field/column; model type inferred from parameter - // - model(pkg/path.Type) -> type hint (optional); default field will be used later - // - model(pkg/path.Type:id) or model(pkg/path.Type#id) -> type + field - mv := parts[i+1] - // if mv contains no dot, treat as field name directly - if !strings.Contains(mv, ".") && !strings.Contains(mv, "/") { - param.ModelField = mv - break - } - // otherwise try type[:field] - fieldSep := ":" - if strings.Contains(mv, "#") { - fieldSep = "#" - } - if idx := strings.LastIndex(mv, fieldSep); idx > 0 && idx < len(mv)-1 { - param.Model = mv[:idx] - param.ModelField = mv[idx+1:] - } else { - param.Model = mv - } - } - } - return param + for i, part := range parts { + switch part { + case "@Bind": + param.Name = parts[i+1] + param.Position = positionFromString(parts[i+2]) + case "key": + param.Key = parts[i+1] + case "model": + // Supported formats: + // - model(field:field_type) -> only specify model field/column; + mv := parts[i+1] + // if mv contains no dot, treat as field name directly + if mv == "" { + param.Model = "id" + break + } + param.Model = mv + } + } + return param } diff --git a/pkg/ast/route/route_model_bind_test.go b/pkg/ast/route/route_model_bind_test.go deleted file mode 100644 index ec41fac..0000000 --- a/pkg/ast/route/route_model_bind_test.go +++ /dev/null @@ -1,90 +0,0 @@ -package route - -import ( - "os" - "path/filepath" - "strings" - "testing" - - "go.ipao.vip/atomctl/v2/pkg/utils/gomod" -) - -// Test that @Bind with model(field) on a path parameter generates PathModel[T](field, key) -func Test_PathModelBind_FromRouteComments(t *testing.T) { - dir := t.TempDir() - src := `package v1 - -import ( - "context" -) - -type User struct{} - -type Demo struct{} - -// @Router /users/:id [get] -// @Bind user path key(id) model(id) -func (d *Demo) Show(ctx context.Context, user *User) (*User, error) { - return nil, nil -}` - - // minimal go.mod so gomod.GetPackageModuleName works without panic - gomodPath := filepath.Join(dir, "go.mod") - goModContent := "module example.com/test\n\ngo 1.23\n" - if err := os.WriteFile(gomodPath, []byte(goModContent), 0o644); err != nil { - t.Fatalf("write go.mod: %v", err) - } - - if err := gomod.Parse(gomodPath); err != nil { - t.Fatalf("gomod.Parse error: %v", err) - } - - file := filepath.Join(dir, "demo.go") - if err := os.WriteFile(file, []byte(src), 0o644); err != nil { - t.Fatalf("write file: %v", err) - } - - defs := ParseFile(file) - if len(defs) != 1 { - t.Fatalf("expected 1 route definition, got %d", len(defs)) - } - if len(defs[0].Actions) != 1 { - t.Fatalf("expected 1 action, got %d", len(defs[0].Actions)) - } - act := defs[0].Actions[0] - if len(act.Params) != 1 { - t.Fatalf("expected 1 param, got %d", len(act.Params)) - } - p := act.Params[0] - if p.Position != PositionPath { - t.Fatalf("expected path position, got %s", p.Position) - } - if p.Key != "id" { - t.Fatalf("expected key=id, got %s", p.Key) - } - if p.ModelField != "id" { - t.Fatalf("expected ModelField=id, got %s", p.ModelField) - } - if p.Type != "User" { // pointer should be trimmed for non-local - t.Fatalf("expected Type=User, got %s", p.Type) - } - - // Build render data and check binder token - rd, err := buildRenderData(RenderBuildOpts{ - PackageName: "v1", - ProjectPackage: "example.com/test", - Routes: defs, - }) - if err != nil { - t.Fatalf("buildRenderData error: %v", err) - } - // Render to text and assert PathModel usage - out, err := renderTemplate(rd) - if err != nil { - t.Fatalf("renderTemplate error: %v", err) - } - got := string(out) - if !strings.Contains(got, "PathModel[User](\"id\", \"id\")") { - t.Fatalf("expected generated code to contain PathModel[User](\"id\", \"id\"), got:\n%s", got) - } -}