diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..d183e92 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,134 @@ +# GolangCI-Lint configuration for pkg/ast/provider package +run: + timeout: 5m + modules-download-mode: readonly + go: '1.24' + build-tags: + - go1.24 + +linters: + disable-all: true + enable: + # Default linters + - errcheck + - gosimple + - govet + - ineffassign + - staticcheck + - typecheck + - unused + + # Additional recommended linters + - gofmt + - goimports + - misspell + - unconvert + - unparam + - nakedret + - prealloc + - gocritic + - bodyclose + - depguard + - dogsled + - dupl + - errname + - errorlint + - exportloopref + - forbidigo + - forcetypeassert + - gci + - gocyclo + - godot + - goerr113 + - gomnd + - gomodguard + - goprintffuncname + - gosec + - importas + - makezero + - nolintlint + - paralleltest + - predeclared + - revive + - tenv + - testpackage + - thelper + - tparallel + - wastedassign + - whitespace + +linters-settings: + gofmt: + simplify: true + + goimports: + local-prefixes: go.ipao.vip/atomctl/v2 + + gocyclo: + min-complexity: 15 + + gomnd: + settings: + mnd: + # The next line enables the detection of `00000000` magic numbers + checks: argument,case,condition,operation,return,assign + + gci: + sections: + - standard + - default + - prefix(go.ipao.vip/atomctl/v2) + + revive: + rules: + - name: exported + disabled: false + + unparam: + check-exported: false + + dupl: + threshold: 100 + + gosec: + excludes: + - G115 # Potential integer overflow when converting between integer types + - G404 # Use of weak random number generator (math/rand instead of crypto/rand) + - G501 # Blocklisted import crypto/md5: weak cryptographic primitive + - G502 # Blocklisted import crypto/sha1: weak cryptographic primitive + - G505 # Blocklisted import crypto/des: weak cryptographic primitive + - G506 # Blocklisted import crypto/rc4: weak cryptographic primitive + +issues: + exclude-use-default: false + exclude-rules: + - path: _test\.go + linters: + - gomnd + - unparam + - dupl + - gosec + - bodyclose + - forcetypeassert + - makezero + - path: tests/ + linters: + - gomnd + - unparam + - dupl + - gosec + - bodyclose + - forcetypeassert + - makezero + - path: '.*\.pb\.go$' + linters: + - all + - text: "weak cryptographic primitive" + linters: + - gosec + +output: + format: colored-line-number + print-issued-lines: true + print-linter-name: true + uniq-by-line: true \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..af51e99 --- /dev/null +++ b/Makefile @@ -0,0 +1,70 @@ +# Makefile for pkg/ast/provider development + +.PHONY: help test benchmark lint clean setup install-tools + +# Default target +help: + @echo "Available targets:" + @echo " setup - Install required tools" + @echo " test - Run all tests" + @echo " test-coverage - Run tests with coverage" + @echo " benchmark - Run benchmark tests" + @echo " lint - Run linter" + @echo " clean - Clean build artifacts" + @echo " install-tools - Install development tools" + +# Install required tools +install-tools: + @echo "Installing development tools..." + go install github.com/golang/mock/mockgen@latest + go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest + go install github.com/axw/gocov/gocov@latest + go install github.com/AlekSi/gocov-xml@latest + +# Setup development environment +setup: install-tools + @echo "Setting up development environment..." + go mod download + go mod tidy + +# Run all tests +test: + @echo "Running tests..." + go test -v -race -parallel=4 ./tests/... ./pkg/ast/provider/... + +# Run tests with coverage +test-coverage: + @echo "Running tests with coverage..." + go test -v -coverprofile=coverage.out -covermode=atomic ./tests/... ./pkg/ast/provider/... + go tool cover -html=coverage.out -o coverage.html + @echo "Coverage report generated: coverage.html" + +# Run benchmark tests +benchmark: + @echo "Running benchmark tests..." + go test -bench=. -benchmem -count=3 ./pkg/ast/provider/... + +# Run linter +lint: + @echo "Running linter..." + golangci-lint run ./pkg/ast/provider/... ./tests/... + +# Clean build artifacts +clean: + @echo "Cleaning build artifacts..." + rm -f coverage.out coverage.html + rm -rf vendor/ + go clean -cache -testcache + +# Run performance validation +validate-performance: + @echo "Validating performance requirements..." + @echo "Running single file parsing benchmark..." + @go test -bench=ParseFile -benchtime=100ms -count=5 ./pkg/ast/provider/... | grep -E "(ParseFile|MB/s|allocs/op)" + @echo "Running project parsing benchmark..." + @go test -bench=ParseProject -benchtime=1s -count=3 ./pkg/ast/provider/... | grep -E "(ParseProject|MB/s|allocs/op)" + +# Check test coverage meets requirements +check-coverage: test-coverage + @echo "Checking coverage meets 90% requirement..." + @go tool cover -func=coverage.out | grep total | awk '{if ($$3 < 90.0) {print "Coverage " $$3 " is below 90% requirement"; exit 1} else {print "Coverage " $$3 " meets requirement"}}' \ No newline at end of file diff --git a/cmd/buf.go b/cmd/buf.go index 105a71c..e2e39e1 100644 --- a/cmd/buf.go +++ b/cmd/buf.go @@ -11,10 +11,10 @@ import ( ) func CommandBuf(root *cobra.Command) { - cmd := &cobra.Command{ - Use: "buf", - Short: "run buf commands", - Long: `在指定目录执行 buf generate。若本机未安装 buf,将自动 go install github.com/bufbuild/buf/cmd/buf@v1.48.0。 + cmd := &cobra.Command{ + Use: "buf", + Short: "run buf commands", + Long: `在指定目录执行 buf generate。若本机未安装 buf,将自动 go install github.com/bufbuild/buf/cmd/buf@v1.48.0。 Flags: - --dir 执行目录(默认 .) @@ -23,8 +23,8 @@ Flags: 说明: - 运行前会检查 buf.yaml 是否存在,如不存在会给出提示但仍尝试执行 - 成功后输出生成结果日志`, - RunE: commandBufE, - } + RunE: commandBufE, + } cmd.Flags().String("dir", ".", "Directory to run buf from") cmd.Flags().Bool("dry-run", false, "Preview buf command without executing") diff --git a/cmd/fmt.go b/cmd/fmt.go index 2ae6e50..c6d7b65 100644 --- a/cmd/fmt.go +++ b/cmd/fmt.go @@ -22,17 +22,17 @@ Flags: 说明: - 正常格式化等价于:gofumpt -l -extra -w - 检查模式等价于:gofumpt -l -extra `, - RunE: commandFmtE, + RunE: commandFmtE, } - cmd.Flags().Bool("check", false, "Check formatting without writing changes") - cmd.Flags().String("path", ".", "Path to format (default .)") + cmd.Flags().Bool("check", false, "Check formatting without writing changes") + cmd.Flags().String("path", ".", "Path to format (default .)") root.AddCommand(cmd) } func commandFmtE(cmd *cobra.Command, args []string) error { - log.Info("开始格式化代码") + log.Info("开始格式化代码") if _, err := exec.LookPath("gofumpt"); err != nil { log.Info("gofumpt 不存在,正在安装...") installCmd := exec.Command("go", "install", "mvdan.cc/gofumpt@latest") @@ -46,31 +46,31 @@ func commandFmtE(cmd *cobra.Command, args []string) error { } } - check, _ := cmd.Flags().GetBool("check") - path, _ := cmd.Flags().GetString("path") + check, _ := cmd.Flags().GetBool("check") + path, _ := cmd.Flags().GetString("path") - if check { - log.Info("运行 gofumpt 检查模式...") - out, err := exec.Command("gofumpt", "-l", "-extra", path).CombinedOutput() - if err != nil { - return fmt.Errorf("运行 gofumpt 失败: %v", err) - } - if len(out) > 0 { - fmt.Fprintln(os.Stdout, string(out)) - return fmt.Errorf("发现未格式化文件,请运行: gofumpt -l -extra -w %s", path) - } - log.Info("代码格式良好") - return nil - } + if check { + log.Info("运行 gofumpt 检查模式...") + out, err := exec.Command("gofumpt", "-l", "-extra", path).CombinedOutput() + if err != nil { + return fmt.Errorf("运行 gofumpt 失败: %v", err) + } + if len(out) > 0 { + fmt.Fprintln(os.Stdout, string(out)) + return fmt.Errorf("发现未格式化文件,请运行: gofumpt -l -extra -w %s", path) + } + log.Info("代码格式良好") + return nil + } - log.Info("运行 gofumpt...") - gofumptCmd := exec.Command("gofumpt", "-l", "-extra", "-w", path) - gofumptCmd.Stdout = os.Stdout - gofumptCmd.Stderr = os.Stderr - if err := gofumptCmd.Run(); err != nil { - return fmt.Errorf("运行 gofumpt 失败: %v", err) - } + log.Info("运行 gofumpt...") + gofumptCmd := exec.Command("gofumpt", "-l", "-extra", "-w", path) + gofumptCmd.Stdout = os.Stdout + gofumptCmd.Stderr = os.Stderr + if err := gofumptCmd.Run(); err != nil { + return fmt.Errorf("运行 gofumpt 失败: %v", err) + } - log.Info("格式化代码完成") - return nil + log.Info("格式化代码完成") + return nil } diff --git a/cmd/gen.go b/cmd/gen.go index 0bee3ca..b99b6b3 100644 --- a/cmd/gen.go +++ b/cmd/gen.go @@ -3,18 +3,18 @@ package cmd import "github.com/spf13/cobra" func CommandGen(root *cobra.Command) { - cmd := &cobra.Command{ - Use: "gen", - Short: "Generate code", - Long: `代码生成命令组:包含 route、provider、model、enum、service 等。 + cmd := &cobra.Command{ + Use: "gen", + Short: "Generate code", + Long: `代码生成命令组:包含 route、provider、model、enum、service 等。 持久化参数: - -c, --config 数据库配置文件(默认 config.toml),供 gen model 使用 说明: - 子命令执行完成后会自动运行 atomctl fmt 进行格式化`, - PersistentPostRunE: commandFmtE, - } + PersistentPostRunE: commandFmtE, + } cmd.PersistentFlags().StringP("config", "c", "config.toml", "database config file") cmds := []func(*cobra.Command){ diff --git a/cmd/gen_model.go b/cmd/gen_model.go index e54d23f..d97f8b9 100644 --- a/cmd/gen_model.go +++ b/cmd/gen_model.go @@ -1,14 +1,14 @@ package cmd import ( - "github.com/pkg/errors" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - apg "go.ipao.vip/atomctl/v2/pkg/postgres" - "go.ipao.vip/atomctl/v2/pkg/utils/gomod" - "go.ipao.vip/gen" - "gorm.io/driver/postgres" - "gorm.io/gorm" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + apg "go.ipao.vip/atomctl/v2/pkg/postgres" + "go.ipao.vip/atomctl/v2/pkg/utils/gomod" + "go.ipao.vip/gen" + "gorm.io/driver/postgres" + "gorm.io/gorm" ) func CommandGenModel(root *cobra.Command) { @@ -29,38 +29,38 @@ func CommandGenModel(root *cobra.Command) { 示例: atomctl gen -c config.toml model`, - RunE: commandGenModelE, + RunE: commandGenModelE, } root.AddCommand(cmd) } func commandGenModelE(cmd *cobra.Command, args []string) error { - if err := gomod.Parse("go.mod"); err != nil { - return errors.Wrap(err, "parse go.mod") - } + if err := gomod.Parse("go.mod"); err != nil { + return errors.Wrap(err, "parse go.mod") + } - cfgFile := cmd.Flag("config").Value.String() - if cfgFile == "" { - cfgFile = "config.toml" - } + cfgFile := cmd.Flag("config").Value.String() + if cfgFile == "" { + cfgFile = "config.toml" + } - sqlDB, conf, err := apg.GetDB(cfgFile) - if err != nil { - return errors.Wrap(err, "load database config") - } - defer sqlDB.Close() + sqlDB, conf, err := apg.GetDB(cfgFile) + if err != nil { + return errors.Wrap(err, "load database config") + } + defer sqlDB.Close() - dsn := conf.DSN() - log.Infof("parsed DSN: %s (schema=%s)", dsn, conf.Schema) + dsn := conf.DSN() + log.Infof("parsed DSN: %s (schema=%s)", dsn, conf.Schema) - db, err := gorm.Open(postgres.New(postgres.Config{DSN: dsn})) - if err != nil { - return errors.Wrapf(err, "open database with dsn: %s", dsn) - } + db, err := gorm.Open(postgres.New(postgres.Config{DSN: dsn})) + if err != nil { + return errors.Wrapf(err, "open database with dsn: %s", dsn) + } // 默认同包同目录生成到 ./database gen.GenerateWithDefault(db, "./database/.transform.yaml") - return nil + return nil } diff --git a/cmd/gen_route.go b/cmd/gen_route.go index 64d4cf6..1b7dd4d 100644 --- a/cmd/gen_route.go +++ b/cmd/gen_route.go @@ -15,10 +15,10 @@ import ( ) func CommandGenRoute(root *cobra.Command) { - cmd := &cobra.Command{ - Use: "route", - Short: "generate routes", - Long: `扫描项目控制器,解析注释生成 routes.gen.go。 + cmd := &cobra.Command{ + Use: "route", + Short: "generate routes", + Long: `扫描项目控制器,解析注释生成 routes.gen.go。 用法与规则: - 扫描根目录通过 --path 指定(默认 CWD),会在 /app/http 下递归搜索。 @@ -37,9 +37,9 @@ func CommandGenRoute(root *cobra.Command) { - local:任意类型(上下文本地值) 说明:生成完成后会自动运行 gen provider 以补全依赖注入。`, - RunE: commandGenRouteE, - PostRunE: commandGenProviderE, - } + RunE: commandGenRouteE, + PostRunE: commandGenProviderE, + } cmd.Flags().String("path", ".", "Base path to scan (defaults to CWD)") diff --git a/cmd/gen_service.go b/cmd/gen_service.go index da743f3..aebcc9e 100644 --- a/cmd/gen_service.go +++ b/cmd/gen_service.go @@ -14,8 +14,8 @@ import ( func CommandGenService(root *cobra.Command) { cmd := &cobra.Command{ - Use: "service", - Short: "generate services", + Use: "service", + Short: "generate services", Long: `扫描 --path 指定目录(默认 ./app/services)下的 Go 文件,汇总服务名并渲染生成 services.gen.go。 规则: diff --git a/cmd/new.go b/cmd/new.go index 4d4b146..50d16da 100644 --- a/cmd/new.go +++ b/cmd/new.go @@ -5,10 +5,10 @@ import ( ) func CommandInit(root *cobra.Command) { - cmd := &cobra.Command{ - Use: "new [project|module]", - Short: "new project/module", - Long: `脚手架命令组:创建项目与常用组件模板。 + cmd := &cobra.Command{ + Use: "new [project|module]", + Short: "new project/module", + Long: `脚手架命令组:创建项目与常用组件模板。 持久化参数(所有子命令通用): - --force, -f 覆盖已存在文件/目录 @@ -16,19 +16,19 @@ func CommandInit(root *cobra.Command) { - --dir 指定输出基目录(默认 .) 子命令:project、provider、event、job(module 已弃用)`, - } + } - cmd.PersistentFlags().BoolP("force", "f", false, "Force overwrite existing files or directories") - cmd.PersistentFlags().Bool("dry-run", false, "Preview actions without writing files") - cmd.PersistentFlags().String("dir", ".", "Base directory for outputs") + cmd.PersistentFlags().BoolP("force", "f", false, "Force overwrite existing files or directories") + cmd.PersistentFlags().Bool("dry-run", false, "Preview actions without writing files") + cmd.PersistentFlags().String("dir", ".", "Base directory for outputs") - cmds := []func(*cobra.Command){ - CommandNewProject, - // deprecate CommandNewModule, - CommandNewProvider, - CommandNewEvent, - CommandNewJob, - } + cmds := []func(*cobra.Command){ + CommandNewProject, + // deprecate CommandNewModule, + CommandNewProvider, + CommandNewEvent, + CommandNewJob, + } for _, c := range cmds { c(cmd) diff --git a/cmd/new_event.go b/cmd/new_event.go index d72702e..3aba599 100644 --- a/cmd/new_event.go +++ b/cmd/new_event.go @@ -31,8 +31,8 @@ func CommandNewEvent(root *cobra.Command) { 示例: atomctl new event UserCreated atomctl new event UserCreated --only=publisher`, - Args: cobra.ExactArgs(1), - RunE: commandNewEventE, + Args: cobra.ExactArgs(1), + RunE: commandNewEventE, } cmd.Flags().String("only", "", "仅生成: publisher 或 subscriber") diff --git a/cmd/new_project.go b/cmd/new_project.go index d4d31d5..00652fd 100644 --- a/cmd/new_project.go +++ b/cmd/new_project.go @@ -42,7 +42,7 @@ func CommandNewProject(root *cobra.Command) { atomctl new project github.com/acme/demo atomctl new -f --dir ./playground project github.com/acme/demo atomctl new project # 在已有 go.mod 的项目中就地初始化`, - RunE: commandNewProjectE, + RunE: commandNewProjectE, } root.AddCommand(cmd) diff --git a/cmd/new_provider.go b/cmd/new_provider.go index ce6db9e..718a291 100644 --- a/cmd/new_provider.go +++ b/cmd/new_provider.go @@ -1,26 +1,26 @@ package cmd import ( - "errors" - "fmt" - "io/fs" - "os" - "path/filepath" - "sort" - "strings" - "text/template" + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "sort" + "strings" + "text/template" - "github.com/iancoleman/strcase" - "github.com/spf13/cobra" - "go.ipao.vip/atomctl/v2/templates" + "github.com/iancoleman/strcase" + "github.com/spf13/cobra" + "go.ipao.vip/atomctl/v2/templates" ) // CommandNewProvider 注册 new_provider 命令 func CommandNewProvider(root *cobra.Command) { - cmd := &cobra.Command{ - Use: "provider", - Short: "创建新的 provider", - Long: `在 providers/ 目录下渲染创建 Provider 模板。 + cmd := &cobra.Command{ + Use: "provider", + Short: "创建新的 provider", + Long: `在 providers/ 目录下渲染创建 Provider 模板。 行为: - 当 name 与内置预置目录同名时,渲染该目录;否则回退渲染 providers/default @@ -33,35 +33,35 @@ func CommandNewProvider(root *cobra.Command) { 示例: atomctl new provider email atomctl new --dry-run --dir ./demo provider cache`, - Args: cobra.MaximumNArgs(1), - RunE: commandNewProviderE, - } + Args: cobra.MaximumNArgs(1), + RunE: commandNewProviderE, + } root.AddCommand(cmd) } func commandNewProviderE(cmd *cobra.Command, args []string) error { - // no-arg: list available preset providers - if len(args) == 0 { - entries, err := templates.Providers.ReadDir("providers") - if err != nil { - return err - } - var names []string - for _, e := range entries { - if e.IsDir() { - names = append(names, e.Name()) - } - } - sort.Strings(names) - fmt.Println("可用预置 providers:") - for _, n := range names { - fmt.Printf(" - %s\n", n) - } - return nil - } + // no-arg: list available preset providers + if len(args) == 0 { + entries, err := templates.Providers.ReadDir("providers") + if err != nil { + return err + } + var names []string + for _, e := range entries { + if e.IsDir() { + names = append(names, e.Name()) + } + } + sort.Strings(names) + fmt.Println("可用预置 providers:") + for _, n := range names { + fmt.Printf(" - %s\n", n) + } + return nil + } - providerName := args[0] + providerName := args[0] // shared flags dryRun, _ := cmd.Flags().GetBool("dry-run") baseDir, _ := cmd.Flags().GetString("dir") @@ -80,38 +80,38 @@ func commandNewProviderE(cmd *cobra.Command, args []string) error { } } - // choose template source: providers/ or providers/default - srcDir := filepath.Join("providers", providerName) - if _, err := templates.Providers.ReadDir(srcDir); err != nil { - srcDir = filepath.Join("providers", "default") - } + // choose template source: providers/ or providers/default + srcDir := filepath.Join("providers", providerName) + if _, err := templates.Providers.ReadDir(srcDir); err != nil { + srcDir = filepath.Join("providers", "default") + } - err := fs.WalkDir(templates.Providers, srcDir, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - if d.IsDir() { - return nil - } + err := fs.WalkDir(templates.Providers, srcDir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + return nil + } - relPath, err := filepath.Rel(srcDir, path) - if err != nil { - return err - } + relPath, err := filepath.Rel(srcDir, path) + if err != nil { + return err + } - destPath := filepath.Join(targetPath, strings.TrimSuffix(relPath, ".tpl")) - if dryRun { - fmt.Printf("[dry-run] mkdir -p %s\n", filepath.Dir(destPath)) - } else { - if err := os.MkdirAll(filepath.Dir(destPath), os.ModePerm); err != nil { - return err - } - } + destPath := filepath.Join(targetPath, strings.TrimSuffix(relPath, ".tpl")) + if dryRun { + fmt.Printf("[dry-run] mkdir -p %s\n", filepath.Dir(destPath)) + } else { + if err := os.MkdirAll(filepath.Dir(destPath), os.ModePerm); err != nil { + return err + } + } - tmpl, err := template.ParseFS(templates.Providers, path) - if err != nil { - return err - } + tmpl, err := template.ParseFS(templates.Providers, path) + if err != nil { + return err + } if dryRun { fmt.Printf("[dry-run] render > %s\n", destPath) diff --git a/cmd/swag_fmt.go b/cmd/swag_fmt.go index d464bad..8bf340f 100644 --- a/cmd/swag_fmt.go +++ b/cmd/swag_fmt.go @@ -15,21 +15,21 @@ func CommandSwagFmt(root *cobra.Command) { 参数: - --dir 扫描目录(默认 ./app/http) - --main 主入口文件(默认 main.go)`, - RunE: commandSwagFmtE, + RunE: commandSwagFmtE, } - cmd.Flags().String("dir", "./app/http", "SearchDir for swag format") - cmd.Flags().String("main", "main.go", "MainFile for swag format") + cmd.Flags().String("dir", "./app/http", "SearchDir for swag format") + cmd.Flags().String("main", "main.go", "MainFile for swag format") root.AddCommand(cmd) } func commandSwagFmtE(cmd *cobra.Command, args []string) error { - dir := cmd.Flag("dir").Value.String() - main := cmd.Flag("main").Value.String() - return format.New().Build(&format.Config{ - SearchDir: dir, - Excludes: "", - MainFile: main, - }) + dir := cmd.Flag("dir").Value.String() + main := cmd.Flag("main").Value.String() + return format.New().Build(&format.Config{ + SearchDir: dir, + Excludes: "", + MainFile: main, + }) } diff --git a/cmd/swag_init.go b/cmd/swag_init.go index 403a279..338beda 100644 --- a/cmd/swag_init.go +++ b/cmd/swag_init.go @@ -23,41 +23,43 @@ func CommandSwagInit(root *cobra.Command) { - --main 主入口文件(默认 main.go) 说明:基于 rogeecn/swag 的 gen 构建器,支持模板分隔符定制、依赖解析等配置。`, - RunE: commandSwagInitE, + RunE: commandSwagInitE, } - cmd.Flags().String("dir", ".", "SearchDir (project root)") - cmd.Flags().String("out", "docs", "Output dir for generated docs") - cmd.Flags().String("main", "main.go", "Main API file path") + cmd.Flags().String("dir", ".", "SearchDir (project root)") + cmd.Flags().String("out", "docs", "Output dir for generated docs") + cmd.Flags().String("main", "main.go", "Main API file path") root.AddCommand(cmd) } func commandSwagInitE(cmd *cobra.Command, args []string) error { - root := cmd.Flag("dir").Value.String() - if root == "" { - var err error - root, err = os.Getwd() - if err != nil { return err } - } + root := cmd.Flag("dir").Value.String() + if root == "" { + var err error + root, err = os.Getwd() + if err != nil { + return err + } + } leftDelim, rightDelim := "{{", "}}" - outDir := cmd.Flag("out").Value.String() - mainFile := cmd.Flag("main").Value.String() + outDir := cmd.Flag("out").Value.String() + mainFile := cmd.Flag("main").Value.String() - return gen.New().Build(&gen.Config{ - SearchDir: root, - Excludes: "", - ParseExtension: "", - MainAPIFile: mainFile, - PropNamingStrategy: swag.CamelCase, - OutputDir: filepath.Join(root, outDir), - OutputTypes: []string{"go", "json", "yaml"}, - ParseVendor: false, - ParseDependency: 0, - MarkdownFilesDir: "", - ParseInternal: false, + return gen.New().Build(&gen.Config{ + SearchDir: root, + Excludes: "", + ParseExtension: "", + MainAPIFile: mainFile, + PropNamingStrategy: swag.CamelCase, + OutputDir: filepath.Join(root, outDir), + OutputTypes: []string{"go", "json", "yaml"}, + ParseVendor: false, + ParseDependency: 0, + MarkdownFilesDir: "", + ParseInternal: false, Strict: false, GeneratedTime: false, RequiredByDefault: false, diff --git a/go.mod b/go.mod index 1178803..1beda3f 100644 --- a/go.mod +++ b/go.mod @@ -48,6 +48,7 @@ require ( github.com/go-openapi/swag/typeutils v0.24.0 // indirect github.com/go-openapi/swag/yamlutils v0.24.0 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect + github.com/golang/mock v1.6.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/gopherjs/gopherjs v1.17.2 // indirect github.com/hashicorp/hcl v1.0.0 // indirect diff --git a/go.sum b/go.sum index 12e7cae..53165b3 100644 --- a/go.sum +++ b/go.sum @@ -71,6 +71,8 @@ github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpv github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= @@ -221,6 +223,7 @@ github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.ipao.vip/gen v0.0.0-20250909113008-7e6ae4534ada h1:suAdnZAD6BZpgQ6/pK6wnH49T9x/52WCzGk+lf+oy7g= go.ipao.vip/gen v0.0.0-20250909113008-7e6ae4534ada/go.mod h1:ip5X9ioxR9hvM/mrsA77KWXFsrMm5oki5rfY5MSkssM= @@ -231,6 +234,7 @@ go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= @@ -240,17 +244,21 @@ golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88p golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ= golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b h1:DXr+pvt3nC887026GRP39Ej11UATqWDmWuS99x26cD0= golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b/go.mod h1:4QTo5u+SEIbbKW1RacMZq1YEfOBqeXa19JeshGi+zc4= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4= golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/mod v0.28.0 h1:gQBtGhjxykdjY9YhZpSlZIsbnaE2+PgjfLWUQTnoZ1U= golang.org/x/mod v0.28.0/go.mod h1:yfB/L0NOf/kmEbXjzCPOx1iK1fRutOydrCMsqRhEBxI= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= @@ -258,7 +266,10 @@ golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -282,6 +293,7 @@ golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.29.0 h1:Xx0h3TtM9rzQpQuR4dKLrdglAmCEN5Oi+P74JdhdzXE= @@ -289,6 +301,8 @@ golang.org/x/tools v0.29.0/go.mod h1:KMQVMRsVxU6nHCFXrBPhDB8XncLNLM0lIy/F14RP588 golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg= golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/pkg/ast/provider/ast_walker.go b/pkg/ast/provider/ast_walker.go new file mode 100644 index 0000000..8d97d87 --- /dev/null +++ b/pkg/ast/provider/ast_walker.go @@ -0,0 +1,373 @@ +package provider + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "os" + "path/filepath" + "strings" +) + +// ASTWalker handles traversal of Go AST nodes to find provider-related structures +type ASTWalker struct { + fileSet *token.FileSet + commentParser *CommentParser + config *WalkerConfig + visitors []NodeVisitor +} + +// WalkerConfig configures the AST walker behavior +type WalkerConfig struct { + IncludeTestFiles bool + IncludeGeneratedFiles bool + MaxFileSize int64 + StrictMode bool +} + +// NodeVisitor defines the interface for visiting AST nodes +type NodeVisitor interface { + // VisitFile is called when a new file is processed + VisitFile(filePath string, node *ast.File) error + + // VisitGenDecl is called for each generic declaration (type, var, const) + VisitGenDecl(filePath string, decl *ast.GenDecl) error + + // VisitTypeSpec is called for each type specification + VisitTypeSpec(filePath string, typeSpec *ast.TypeSpec, decl *ast.GenDecl) error + + // VisitStructType is called for each struct type + VisitStructType(filePath string, structType *ast.StructType, typeSpec *ast.TypeSpec, decl *ast.GenDecl) error + + // VisitStructField is called for each field in a struct + VisitStructField(filePath string, field *ast.Field, structType *ast.StructType) error + + // Complete is called when file processing is complete + Complete(filePath string) error +} + +// NewASTWalker creates a new ASTWalker with default configuration +func NewASTWalker() *ASTWalker { + return &ASTWalker{ + fileSet: token.NewFileSet(), + commentParser: NewCommentParser(), + config: &WalkerConfig{ + IncludeTestFiles: false, + IncludeGeneratedFiles: false, + MaxFileSize: 10 * 1024 * 1024, // 10MB + StrictMode: false, + }, + visitors: make([]NodeVisitor, 0), + } +} + +// NewASTWalkerWithConfig creates a new ASTWalker with custom configuration +func NewASTWalkerWithConfig(config *WalkerConfig) *ASTWalker { + if config == nil { + return NewASTWalker() + } + + return &ASTWalker{ + fileSet: token.NewFileSet(), + commentParser: NewCommentParserWithStrictMode(config.StrictMode), + config: config, + visitors: make([]NodeVisitor, 0), + } +} + +// AddVisitor adds a node visitor to the walker +func (aw *ASTWalker) AddVisitor(visitor NodeVisitor) { + aw.visitors = append(aw.visitors, visitor) +} + +// RemoveVisitor removes a node visitor from the walker +func (aw *ASTWalker) RemoveVisitor(visitor NodeVisitor) { + for i, v := range aw.visitors { + if v == visitor { + aw.visitors = append(aw.visitors[:i], aw.visitors[i+1:]...) + break + } + } +} + +// WalkFile traverses a single Go file +func (aw *ASTWalker) WalkFile(filePath string) error { + // Check if file should be processed + if !aw.shouldProcessFile(filePath) { + return nil + } + + // Parse the file + node, err := parser.ParseFile(aw.fileSet, filePath, nil, parser.ParseComments) + if err != nil { + return fmt.Errorf("failed to parse file %s: %w", filePath, err) + } + + // Notify visitors of file start + for _, visitor := range aw.visitors { + if err := visitor.VisitFile(filePath, node); err != nil { + return err + } + } + + // Traverse the AST + if err := aw.traverseFile(filePath, node); err != nil { + return err + } + + // Notify visitors of file completion + for _, visitor := range aw.visitors { + if err := visitor.Complete(filePath); err != nil { + return err + } + } + + return nil +} + +// WalkDir traverses all Go files in a directory +func (aw *ASTWalker) WalkDir(dirPath string) error { + return filepath.Walk(dirPath, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + // Skip directories + if info.IsDir() { + // Skip hidden directories and common build/dependency directories + if strings.HasPrefix(info.Name(), ".") || + info.Name() == "node_modules" || + info.Name() == "vendor" || + info.Name() == "testdata" { + return filepath.SkipDir + } + return nil + } + + // Process Go files + if filepath.Ext(path) == ".go" && aw.shouldProcessFile(path) { + if err := aw.WalkFile(path); err != nil { + // Continue with other files, but log the error + fmt.Printf("Warning: failed to process file %s: %v\n", path, err) + } + } + + return nil + }) +} + +// traverseFile traverses the AST of a parsed file +func (aw *ASTWalker) traverseFile(filePath string, node *ast.File) error { + // Traverse all declarations + for _, decl := range node.Decls { + if err := aw.traverseDeclaration(filePath, decl); err != nil { + return err + } + } + + return nil +} + +// traverseDeclaration traverses a single declaration +func (aw *ASTWalker) traverseDeclaration(filePath string, decl ast.Decl) error { + genDecl, ok := decl.(*ast.GenDecl) + if !ok { + // Skip function declarations and other non-generic declarations + return nil + } + + // Notify visitors of generic declaration + for _, visitor := range aw.visitors { + if err := visitor.VisitGenDecl(filePath, genDecl); err != nil { + return err + } + } + + // Traverse specs within the declaration + for _, spec := range genDecl.Specs { + if err := aw.traverseSpec(filePath, spec, genDecl); err != nil { + return err + } + } + + return nil +} + +// traverseSpec traverses a specification within a declaration +func (aw *ASTWalker) traverseSpec(filePath string, spec ast.Spec, decl *ast.GenDecl) error { + typeSpec, ok := spec.(*ast.TypeSpec) + if !ok { + // Skip non-type specifications + return nil + } + + // Notify visitors of type specification + for _, visitor := range aw.visitors { + if err := visitor.VisitTypeSpec(filePath, typeSpec, decl); err != nil { + return err + } + } + + // Check if it's a struct type + structType, ok := typeSpec.Type.(*ast.StructType) + if ok { + // Notify visitors of struct type + for _, visitor := range aw.visitors { + if err := visitor.VisitStructType(filePath, structType, typeSpec, decl); err != nil { + return err + } + } + + // Traverse struct fields + if err := aw.traverseStructFields(filePath, structType); err != nil { + return err + } + } + + return nil +} + +// traverseStructFields traverses fields within a struct type +func (aw *ASTWalker) traverseStructFields(filePath string, structType *ast.StructType) error { + if structType.Fields == nil { + return nil + } + + for _, field := range structType.Fields.List { + // Notify visitors of struct field + for _, visitor := range aw.visitors { + if err := visitor.VisitStructField(filePath, field, structType); err != nil { + return err + } + } + } + + return nil +} + +// shouldProcessFile determines if a file should be processed +func (aw *ASTWalker) shouldProcessFile(filePath string) bool { + // Check file extension + if filepath.Ext(filePath) != ".go" { + return false + } + + // Skip test files if not allowed + if !aw.config.IncludeTestFiles && strings.HasSuffix(filePath, "_test.go") { + return false + } + + // Skip generated files if not allowed + if !aw.config.IncludeGeneratedFiles && strings.HasSuffix(filePath, ".gen.go") { + return false + } + + // TODO: Check file size if needed (requires os.Stat) + + return true +} + +// GetFileSet returns the file set used by the walker +func (aw *ASTWalker) GetFileSet() *token.FileSet { + return aw.fileSet +} + +// GetCommentParser returns the comment parser used by the walker +func (aw *ASTWalker) GetCommentParser() *CommentParser { + return aw.commentParser +} + +// GetConfig returns the walker configuration +func (aw *ASTWalker) GetConfig() *WalkerConfig { + return aw.config +} + +// ProviderDiscoveryVisitor implements NodeVisitor for discovering provider annotations +type ProviderDiscoveryVisitor struct { + commentParser *CommentParser + providers []Provider + currentFile string +} + +// NewProviderDiscoveryVisitor creates a new ProviderDiscoveryVisitor +func NewProviderDiscoveryVisitor(commentParser *CommentParser) *ProviderDiscoveryVisitor { + return &ProviderDiscoveryVisitor{ + commentParser: commentParser, + providers: make([]Provider, 0), + } +} + +// VisitFile implements NodeVisitor.VisitFile +func (pdv *ProviderDiscoveryVisitor) VisitFile(filePath string, node *ast.File) error { + pdv.currentFile = filePath + return nil +} + +// VisitGenDecl implements NodeVisitor.VisitGenDecl +func (pdv *ProviderDiscoveryVisitor) VisitGenDecl(filePath string, decl *ast.GenDecl) error { + return nil +} + +// VisitTypeSpec implements NodeVisitor.VisitTypeSpec +func (pdv *ProviderDiscoveryVisitor) VisitTypeSpec(filePath string, typeSpec *ast.TypeSpec, decl *ast.GenDecl) error { + return nil +} + +// VisitStructType implements NodeVisitor.VisitStructType +func (pdv *ProviderDiscoveryVisitor) VisitStructType(filePath string, structType *ast.StructType, typeSpec *ast.TypeSpec, decl *ast.GenDecl) error { + // Check if the struct has a provider annotation + if decl.Doc != nil && len(decl.Doc.List) > 0 { + // Extract comment lines + commentLines := make([]string, len(decl.Doc.List)) + for i, comment := range decl.Doc.List { + commentLines[i] = comment.Text + } + + // Parse provider annotation + providerComment, err := pdv.commentParser.ParseCommentBlock(commentLines) + if err == nil && providerComment != nil { + // Create provider structure + provider := Provider{ + StructName: typeSpec.Name.Name, + Mode: providerComment.Mode, + ProviderGroup: providerComment.Group, + ReturnType: providerComment.ReturnType, + InjectParams: make(map[string]InjectParam), + Imports: make(map[string]string), + } + + // Set default return type if not specified + if provider.ReturnType == "" { + provider.ReturnType = "*" + provider.StructName + } + + pdv.providers = append(pdv.providers, provider) + } + } + + return nil +} + +// VisitStructField implements NodeVisitor.VisitStructField +func (pdv *ProviderDiscoveryVisitor) VisitStructField(filePath string, field *ast.Field, structType *ast.StructType) error { + // This is where field-level processing would happen + // For example, extracting inject tags and field types + return nil +} + +// Complete implements NodeVisitor.Complete +func (pdv *ProviderDiscoveryVisitor) Complete(filePath string) error { + return nil +} + +// GetProviders returns the discovered providers +func (pdv *ProviderDiscoveryVisitor) GetProviders() []Provider { + return pdv.providers +} + +// Reset clears the discovered providers +func (pdv *ProviderDiscoveryVisitor) Reset() { + pdv.providers = make([]Provider, 0) + pdv.currentFile = "" +} diff --git a/pkg/ast/provider/builder.go b/pkg/ast/provider/builder.go new file mode 100644 index 0000000..8e78bef --- /dev/null +++ b/pkg/ast/provider/builder.go @@ -0,0 +1,506 @@ +package provider + +import ( + "fmt" + "go/ast" + "strings" +) + +// ProviderBuilder handles the construction of Provider objects from parsed AST components +type ProviderBuilder struct { + config *BuilderConfig + commentParser *CommentParser + importResolver *ImportResolver + astWalker *ASTWalker +} + +// BuilderConfig configures the provider builder behavior +type BuilderConfig struct { + EnableValidation bool + StrictMode bool + DefaultProviderMode ProviderMode + DefaultInjectionMode InjectionMode + AutoGenerateReturnTypes bool + ResolveImportDependencies bool +} + +// BuilderContext maintains context during provider building +type BuilderContext struct { + FilePath string + PackageName string + ImportContext *ImportContext + ASTFile *ast.File + ProcessedTypes map[string]bool + Errors []error + Warnings []string +} + +// NewProviderBuilder creates a new ProviderBuilder with default configuration +func NewProviderBuilder() *ProviderBuilder { + return &ProviderBuilder{ + config: &BuilderConfig{ + EnableValidation: true, + StrictMode: false, + DefaultProviderMode: ProviderModeBasic, + DefaultInjectionMode: InjectionModeAuto, + AutoGenerateReturnTypes: true, + ResolveImportDependencies: true, + }, + commentParser: NewCommentParser(), + importResolver: NewImportResolver(), + astWalker: NewASTWalker(), + } +} + +// NewProviderBuilderWithConfig creates a new ProviderBuilder with custom configuration +func NewProviderBuilderWithConfig(config *BuilderConfig) *ProviderBuilder { + if config == nil { + return NewProviderBuilder() + } + + return &ProviderBuilder{ + config: config, + commentParser: NewCommentParser(), + importResolver: NewImportResolver(), + astWalker: NewASTWalkerWithConfig(&WalkerConfig{ + StrictMode: config.StrictMode, + }), + } +} + +// BuildFromTypeSpec builds a Provider from a type specification and its declaration +func (pb *ProviderBuilder) BuildFromTypeSpec(typeSpec *ast.TypeSpec, decl *ast.GenDecl, context *BuilderContext) (Provider, error) { + if typeSpec == nil { + return Provider{}, fmt.Errorf("type specification cannot be nil") + } + + // Initialize builder context if not provided + if context == nil { + context = &BuilderContext{ + ProcessedTypes: make(map[string]bool), + Errors: make([]error, 0), + Warnings: make([]string, 0), + } + } + + // Check if type has already been processed + if context.ProcessedTypes[typeSpec.Name.Name] { + return Provider{}, fmt.Errorf("type %s has already been processed", typeSpec.Name.Name) + } + + // Parse provider comment + providerComment, err := pb.parseProviderComment(decl) + if err != nil { + return Provider{}, fmt.Errorf("failed to parse provider comment: %w", err) + } + + // Create basic provider structure + provider := Provider{ + StructName: typeSpec.Name.Name, + Mode: pb.determineProviderMode(providerComment), + ProviderGroup: pb.determineProviderGroup(providerComment), + InjectParams: make(map[string]InjectParam), + Imports: make(map[string]string), + PkgName: context.PackageName, + ProviderFile: context.FilePath, + } + + // Set return type + if err := pb.setReturnType(&provider, providerComment, typeSpec); err != nil { + return Provider{}, err + } + + // Process struct fields if it's a struct type + if structType, ok := typeSpec.Type.(*ast.StructType); ok { + if err := pb.processStructFields(&provider, structType, context); err != nil { + return Provider{}, err + } + } + + // Resolve import dependencies + if pb.config.ResolveImportDependencies { + if err := pb.resolveImportDependencies(&provider, context); err != nil { + return Provider{}, err + } + } + + // Apply mode-specific configurations + if err := pb.applyModeSpecificConfig(&provider, providerComment); err != nil { + return Provider{}, err + } + + // Validate the built provider + if pb.config.EnableValidation { + if err := pb.validateProvider(&provider); err != nil { + return Provider{}, err + } + } + + // Mark type as processed + context.ProcessedTypes[typeSpec.Name.Name] = true + + return provider, nil +} + +// BuildFromComment builds a Provider from a provider comment string +func (pb *ProviderBuilder) BuildFromComment(comment string, context *BuilderContext) (Provider, error) { + // Parse the provider comment + providerComment, err := pb.commentParser.ParseProviderComment(comment) + if err != nil { + return Provider{}, fmt.Errorf("failed to parse provider comment: %w", err) + } + + // Create basic provider structure + provider := Provider{ + Mode: pb.determineProviderMode(providerComment), + ProviderGroup: pb.determineProviderGroup(providerComment), + InjectParams: make(map[string]InjectParam), + Imports: make(map[string]string), + } + + // Set return type from comment + if providerComment.ReturnType != "" { + provider.ReturnType = providerComment.ReturnType + } else if pb.config.AutoGenerateReturnTypes { + // Generate a default return type based on mode + provider.ReturnType = pb.generateDefaultReturnType(provider.Mode) + } + + // Apply mode-specific configurations + if err := pb.applyModeSpecificConfig(&provider, providerComment); err != nil { + return Provider{}, err + } + + // Validate the built provider + if pb.config.EnableValidation { + if err := pb.validateProvider(&provider); err != nil { + return Provider{}, err + } + } + + return provider, nil +} + +// parseProviderComment parses the provider comment from a declaration +func (pb *ProviderBuilder) parseProviderComment(decl *ast.GenDecl) (*ProviderComment, error) { + if decl.Doc == nil || len(decl.Doc.List) == 0 { + return nil, fmt.Errorf("no documentation found for declaration") + } + + // Extract comment lines + commentLines := make([]string, len(decl.Doc.List)) + for i, comment := range decl.Doc.List { + commentLines[i] = comment.Text + } + + // Parse provider annotation + return pb.commentParser.ParseCommentBlock(commentLines) +} + +// determineProviderMode determines the provider mode from the comment +func (pb *ProviderBuilder) determineProviderMode(comment *ProviderComment) ProviderMode { + if comment != nil && comment.Mode != "" { + return comment.Mode + } + return pb.config.DefaultProviderMode +} + +// determineProviderGroup determines the provider group from the comment +func (pb *ProviderBuilder) determineProviderGroup(comment *ProviderComment) string { + if comment != nil && comment.Group != "" { + return comment.Group + } + return "" +} + +// setReturnType sets the return type for the provider +func (pb *ProviderBuilder) setReturnType(provider *Provider, comment *ProviderComment, typeSpec *ast.TypeSpec) error { + if comment != nil && comment.ReturnType != "" { + provider.ReturnType = comment.ReturnType + return nil + } + + if pb.config.AutoGenerateReturnTypes { + provider.ReturnType = pb.generateDefaultReturnType(provider.Mode) + return nil + } + + // Default to pointer type + provider.ReturnType = "*" + typeSpec.Name.Name + return nil +} + +// generateDefaultReturnType generates a default return type based on provider mode +func (pb *ProviderBuilder) generateDefaultReturnType(mode ProviderMode) string { + switch mode { + case ProviderModeBasic: + return "interface{}" + case ProviderModeGrpc: + return "interface{}" + case ProviderModeEvent: + return "func() error" + case ProviderModeJob: + return "func(ctx context.Context) error" + case ProviderModeCronJob: + return "func(ctx context.Context) error" + case ProviderModeModel: + return "interface{}" + default: + return "interface{}" + } +} + +// processStructFields processes struct fields to extract injection parameters +func (pb *ProviderBuilder) processStructFields(provider *Provider, structType *ast.StructType, context *BuilderContext) error { + if structType.Fields == nil { + return nil + } + + for _, field := range structType.Fields.List { + if len(field.Names) == 0 { + // Skip anonymous fields + continue + } + + for _, fieldName := range field.Names { + injectParam, err := pb.createInjectParam(fieldName.Name, field, context) + if err != nil { + return fmt.Errorf("failed to create inject param for field %s: %w", fieldName.Name, err) + } + + provider.InjectParams[fieldName.Name] = *injectParam + } + } + + return nil +} + +// createInjectParam creates an InjectParam from a field +func (pb *ProviderBuilder) createInjectParam(fieldName string, field *ast.Field, context *BuilderContext) (*InjectParam, error) { + // Extract field type + typeStr := pb.extractFieldType(field) + if typeStr == "" { + return nil, fmt.Errorf("cannot determine type for field %s", fieldName) + } + + // Check for import dependencies + packagePath, packageAlias := pb.extractImportInfo(typeStr, context) + + return &InjectParam{ + Type: typeStr, + Package: packagePath, + PackageAlias: packageAlias, + }, nil +} + +// extractFieldType extracts the type string from a field +func (pb *ProviderBuilder) extractFieldType(field *ast.Field) string { + if field.Type == nil { + return "" + } + + // Handle different type representations + switch t := field.Type.(type) { + case *ast.Ident: + return t.Name + case *ast.SelectorExpr: + // Handle qualified identifiers like "package.Type" + if x, ok := t.X.(*ast.Ident); ok { + return x.Name + "." + t.Sel.Name + } + case *ast.StarExpr: + // Handle pointer types + if x, ok := t.X.(*ast.Ident); ok { + return "*" + x.Name + } + case *ast.ArrayType: + // Handle array/slice types + if x, ok := t.Elt.(*ast.Ident); ok { + return "[]" + x.Name + } + case *ast.MapType: + // Handle map types + keyType := pb.extractTypeFromExpr(t.Key) + valueType := pb.extractTypeFromExpr(t.Value) + if keyType != "" && valueType != "" { + return fmt.Sprintf("map[%s]%s", keyType, valueType) + } + } + + return "" +} + +// extractTypeFromExpr extracts type string from an expression +func (pb *ProviderBuilder) extractTypeFromExpr(expr ast.Expr) string { + switch t := expr.(type) { + case *ast.Ident: + return t.Name + case *ast.SelectorExpr: + if x, ok := t.X.(*ast.Ident); ok { + return x.Name + "." + t.Sel.Name + } + } + return "" +} + +// extractImportInfo extracts import information from a type string +func (pb *ProviderBuilder) extractImportInfo(typeStr string, context *BuilderContext) (packagePath, packageAlias string) { + if !strings.Contains(typeStr, ".") { + return "", "" + } + + // Extract the package part + parts := strings.Split(typeStr, ".") + if len(parts) < 2 { + return "", "" + } + + packageAlias = parts[0] + + // Look up the import path from the import context + if context.ImportContext != nil { + if path, exists := context.ImportContext.ImportPaths[packageAlias]; exists { + return path, packageAlias + } + } + + return "", packageAlias +} + +// resolveImportDependencies resolves import dependencies for the provider +func (pb *ProviderBuilder) resolveImportDependencies(provider *Provider, context *BuilderContext) error { + if context.ImportContext == nil { + return nil + } + + // Add imports from injection parameters + for _, param := range provider.InjectParams { + if param.Package != "" && param.PackageAlias != "" { + provider.Imports[param.PackageAlias] = param.Package + } + } + + // Add mode-specific imports + modeImports := pb.getModeSpecificImports(provider.Mode) + for alias, path := range modeImports { + // Check for conflicts + if existingPath, exists := provider.Imports[alias]; exists && existingPath != path { + // Handle conflict by generating unique alias + uniqueAlias := pb.generateUniqueAlias(alias, provider.Imports) + provider.Imports[uniqueAlias] = path + } else { + provider.Imports[alias] = path + } + } + + return nil +} + +// getModeSpecificImports returns mode-specific import requirements +func (pb *ProviderBuilder) getModeSpecificImports(mode ProviderMode) map[string]string { + imports := make(map[string]string) + + switch mode { + case ProviderModeGrpc: + imports["grpc"] = "google.golang.org/grpc" + case ProviderModeEvent: + imports["context"] = "context" + case ProviderModeJob: + imports["context"] = "context" + case ProviderModeCronJob: + imports["context"] = "context" + case ProviderModeModel: + imports["encoding/json"] = "encoding/json" + } + + return imports +} + +// generateUniqueAlias generates a unique alias to avoid conflicts +func (pb *ProviderBuilder) generateUniqueAlias(baseAlias string, existingImports map[string]string) string { + for i := 1; i < 1000; i++ { + candidate := fmt.Sprintf("%s%d", baseAlias, i) + if _, exists := existingImports[candidate]; !exists { + return candidate + } + } + return baseAlias +} + +// applyModeSpecificConfig applies mode-specific configurations to the provider +func (pb *ProviderBuilder) applyModeSpecificConfig(provider *Provider, comment *ProviderComment) error { + switch provider.Mode { + case ProviderModeGrpc: + provider.GrpcRegisterFunc = pb.generateGrpcRegisterFuncName(provider.StructName) + provider.NeedPrepareFunc = true + case ProviderModeEvent: + provider.NeedPrepareFunc = true + case ProviderModeJob: + provider.NeedPrepareFunc = true + case ProviderModeCronJob: + provider.NeedPrepareFunc = true + case ProviderModeModel: + provider.NeedPrepareFunc = true + default: + // Basic mode - no special configuration + } + + return nil +} + +// generateGrpcRegisterFuncName generates a gRPC register function name +func (pb *ProviderBuilder) generateGrpcRegisterFuncName(structName string) string { + // Convert struct name to register function name + // Example: UserService -> RegisterUserServiceServer + return "Register" + structName + "Server" +} + +// validateProvider validates the constructed provider +func (pb *ProviderBuilder) validateProvider(provider *Provider) error { + // Basic validation + if provider.StructName == "" { + return fmt.Errorf("provider struct name cannot be empty") + } + + if provider.ReturnType == "" { + return fmt.Errorf("provider return type cannot be empty") + } + + if !IsValidProviderMode(string(provider.Mode)) { + return fmt.Errorf("invalid provider mode: %s", provider.Mode) + } + + // Validate injection parameters + for name, param := range provider.InjectParams { + if param.Type == "" { + return fmt.Errorf("injection parameter type cannot be empty for %s", name) + } + } + + return nil +} + +// GetConfig returns the builder configuration +func (pb *ProviderBuilder) GetConfig() *BuilderConfig { + return pb.config +} + +// SetConfig updates the builder configuration +func (pb *ProviderBuilder) SetConfig(config *BuilderConfig) { + pb.config = config +} + +// GetCommentParser returns the comment parser used by the builder +func (pb *ProviderBuilder) GetCommentParser() *CommentParser { + return pb.commentParser +} + +// GetImportResolver returns the import resolver used by the builder +func (pb *ProviderBuilder) GetImportResolver() *ImportResolver { + return pb.importResolver +} + +// GetASTWalker returns the AST walker used by the builder +func (pb *ProviderBuilder) GetASTWalker() *ASTWalker { + return pb.astWalker +} diff --git a/pkg/ast/provider/comment_parser.go b/pkg/ast/provider/comment_parser.go new file mode 100644 index 0000000..0a49f47 --- /dev/null +++ b/pkg/ast/provider/comment_parser.go @@ -0,0 +1,251 @@ +package provider + +import ( + "fmt" + "strings" +) + +// CommentParser handles parsing of provider annotations from Go comments +type CommentParser struct { + strictMode bool +} + +// NewCommentParser creates a new CommentParser +func NewCommentParser() *CommentParser { + return &CommentParser{ + strictMode: false, + } +} + +// NewCommentParserWithStrictMode creates a new CommentParser with strict mode enabled +func NewCommentParserWithStrictMode(strictMode bool) *CommentParser { + return &CommentParser{ + strictMode: strictMode, + } +} + +// ParseProviderComment parses a provider annotation from a comment line +func (cp *CommentParser) ParseProviderComment(comment string) (*ProviderComment, error) { + // Trim the comment markers + comment = strings.TrimSpace(comment) + comment = strings.TrimPrefix(comment, "//") + comment = strings.TrimPrefix(comment, "/*") + comment = strings.TrimSuffix(comment, "*/") + comment = strings.TrimSpace(comment) + + // Check if it's a provider annotation + if !strings.HasPrefix(comment, "@provider") { + return nil, fmt.Errorf("not a provider annotation") + } + + // Parse the provider annotation + return cp.parseProviderAnnotation(comment) +} + +// parseProviderAnnotation parses the provider annotation structure +func (cp *CommentParser) parseProviderAnnotation(annotation string) (*ProviderComment, error) { + result := &ProviderComment{ + RawText: annotation, + IsValid: true, + Errors: make([]string, 0), + } + + // Remove @provider prefix + content := strings.TrimSpace(strings.TrimPrefix(annotation, "@provider")) + + // Handle empty case + if content == "" { + result.Mode = ProviderModeBasic + result.Injection = InjectionModeAuto + return result, nil + } + + // Parse the annotation components + return cp.parseAnnotationComponents(content, result) +} + +// parseAnnotationComponents parses the components of the provider annotation +func (cp *CommentParser) parseAnnotationComponents(content string, result *ProviderComment) (*ProviderComment, error) { + // Parse injection mode first (only/except) + injectionMode, remaining := cp.parseInjectionMode(content) + result.Injection = injectionMode + + // Parse provider mode (in parentheses) + providerMode, remaining := cp.parseProviderMode(remaining) + if providerMode != "" { + if IsValidProviderMode(providerMode) { + result.Mode = ProviderMode(providerMode) + } else { + result.IsValid = false + result.Errors = append(result.Errors, fmt.Sprintf("invalid provider mode: %s", providerMode)) + if cp.strictMode { + return result, fmt.Errorf("invalid provider mode: %s", providerMode) + } + } + } else { + result.Mode = ProviderModeBasic + } + + // Parse return type and group + returnType, group := cp.parseReturnTypeAndGroup(remaining) + result.ReturnType = returnType + result.Group = group + + return result, nil +} + +// parseInjectionMode parses the injection mode (only/except) +func (cp *CommentParser) parseInjectionMode(content string) (InjectionMode, string) { + if strings.Contains(content, ":only") { + return InjectionModeOnly, strings.Replace(content, ":only", "", 1) + } else if strings.Contains(content, ":except") { + return InjectionModeExcept, strings.Replace(content, ":except", "", 1) + } + return InjectionModeAuto, content +} + +// parseProviderMode parses the provider mode from parentheses +func (cp *CommentParser) parseProviderMode(content string) (string, string) { + start := strings.Index(content, "(") + end := strings.Index(content, ")") + + if start >= 0 && end > start { + mode := strings.TrimSpace(content[start+1 : end]) + remaining := content[:start] + strings.TrimSpace(content[end+1:]) + return mode, strings.TrimSpace(remaining) + } + return "", strings.TrimSpace(content) +} + +// parseReturnTypeAndGroup parses the return type and group from remaining content +func (cp *CommentParser) parseReturnTypeAndGroup(content string) (string, string) { + parts := strings.Fields(content) + + if len(parts) == 0 { + return "", "" + } + + if len(parts) == 1 { + return parts[0], "" + } + + return parts[0], parts[1] +} + +// IsProviderAnnotation checks if a comment line is a provider annotation +func (cp *CommentParser) IsProviderAnnotation(comment string) bool { + comment = strings.TrimSpace(comment) + comment = strings.TrimPrefix(comment, "//") + comment = strings.TrimPrefix(comment, "/*") + comment = strings.TrimSpace(comment) + return strings.HasPrefix(comment, "@provider") +} + +// ParseCommentBlock parses a block of comments to find provider annotations +func (cp *CommentParser) ParseCommentBlock(comments []string) (*ProviderComment, error) { + if len(comments) == 0 { + return nil, fmt.Errorf("empty comment block") + } + + // Check each comment line for provider annotation (from bottom to top) + for i := len(comments) - 1; i >= 0; i-- { + comment := comments[i] + if cp.IsProviderAnnotation(comment) { + return cp.ParseProviderComment(comment) + } + } + + return nil, fmt.Errorf("no provider annotation found in comment block") +} + +// ValidateProviderComment validates a parsed provider comment +func (cp *CommentParser) ValidateProviderComment(comment *ProviderComment) []string { + var errors []string + + if comment == nil { + errors = append(errors, "comment is nil") + return errors + } + + // Validate provider mode + if comment.Mode != "" && !IsValidProviderMode(string(comment.Mode)) { + errors = append(errors, fmt.Sprintf("invalid provider mode: %s", comment.Mode)) + } + + // Validate injection mode + if comment.Injection != "" && !IsValidInjectionMode(string(comment.Injection)) { + errors = append(errors, fmt.Sprintf("invalid injection mode: %s", comment.Injection)) + } + + // Validate return type format + if comment.ReturnType != "" && !isValidGoType(comment.ReturnType) { + errors = append(errors, fmt.Sprintf("invalid return type format: %s", comment.ReturnType)) + } + + // Validate group format + if comment.Group != "" && !isValidGoIdentifier(comment.Group) { + errors = append(errors, fmt.Sprintf("invalid group identifier: %s", comment.Group)) + } + + return errors +} + +// ProviderComment represents a parsed provider annotation comment +type ProviderComment struct { + RawText string // Original comment text + Mode ProviderMode // Provider mode + Injection InjectionMode // Injection mode (only/except/auto) + ReturnType string // Return type specification + Group string // Provider group + IsValid bool // Whether the comment is valid + Errors []string // Validation errors +} + +// IsOnlyMode returns true if this is an "only" injection mode +func (pc *ProviderComment) IsOnlyMode() bool { + return pc.Injection == InjectionModeOnly +} + +// IsExceptMode returns true if this is an "except" injection mode +func (pc *ProviderComment) IsExceptMode() bool { + return pc.Injection == InjectionModeExcept +} + +// IsAutoMode returns true if this is an "auto" injection mode +func (pc *ProviderComment) IsAutoMode() bool { + return pc.Injection == InjectionModeAuto +} + +// HasMode returns true if a specific provider mode is set +func (pc *ProviderComment) HasMode(mode ProviderMode) bool { + return pc.Mode == mode +} + +// String returns a string representation of the provider comment +func (pc *ProviderComment) String() string { + var builder strings.Builder + + builder.WriteString("@provider") + + if pc.Mode != ProviderModeBasic { + builder.WriteString(fmt.Sprintf("(%s)", pc.Mode)) + } + + if pc.Injection == InjectionModeOnly { + builder.WriteString(":only") + } else if pc.Injection == InjectionModeExcept { + builder.WriteString(":except") + } + + if pc.ReturnType != "" { + builder.WriteString(" ") + builder.WriteString(pc.ReturnType) + } + + if pc.Group != "" { + builder.WriteString(" ") + builder.WriteString(pc.Group) + } + + return builder.String() +} diff --git a/pkg/ast/provider/config.go b/pkg/ast/provider/config.go new file mode 100644 index 0000000..a609893 --- /dev/null +++ b/pkg/ast/provider/config.go @@ -0,0 +1,176 @@ +package provider + +import ( + "go/parser" + "go/token" + "os" + "path/filepath" + "strings" +) + +// ParserConfig represents the configuration for the parser +type ParserConfig struct { + // File parsing options + ParseComments bool // Whether to parse comments (default: true) + Mode parser.Mode // Parser mode + FileSet *token.FileSet // File set for position information + + // Include/exclude options + IncludePatterns []string // Glob patterns for files to include + ExcludePatterns []string // Glob patterns for files to exclude + + // Provider parsing options + StrictMode bool // Whether to use strict validation (default: false) + DefaultMode string // Default provider mode for simple @provider annotations + AllowTestFiles bool // Whether to parse test files (default: false) + AllowGenFiles bool // Whether to parse generated files (default: false) + + // Performance options + MaxFileSize int64 // Maximum file size to parse (bytes, default: 10MB) + Concurrency int // Number of concurrent parsers (default: 1) + CacheEnabled bool // Whether to enable caching (default: true) + + // Output options + OutputDir string // Output directory for generated files + OutputFileName string // Output file name (default: provider.gen.go) + SourceLocations bool // Whether to include source location info (default: false) +} + +// ParserContext represents the context for parsing operations +type ParserContext struct { + // Configuration + Config *ParserConfig + + // Parsing state + FileSet *token.FileSet + WorkingDir string + ModuleName string + + // Import resolution + Imports map[string]string // Package alias -> package path + ModuleInfo map[string]string // Module path -> module name + + // Statistics and metrics + FilesProcessed int + FilesSkipped int + ProvidersFound int + ParseErrors []ParseError + + // Caching + Cache map[string]interface{} // File path -> parsed content +} + +// ParseError represents a parsing error with location information +type ParseError struct { + File string `json:"file"` + Line int `json:"line"` + Column int `json:"column"` + Message string `json:"message"` + Severity string `json:"severity"` // "error", "warning", "info" +} + +// NewParserConfig creates a new ParserConfig with default values +func NewParserConfig() *ParserConfig { + return &ParserConfig{ + ParseComments: true, + Mode: parser.ParseComments, + StrictMode: false, + DefaultMode: "basic", + AllowTestFiles: false, + AllowGenFiles: false, + MaxFileSize: 10 * 1024 * 1024, // 10MB + Concurrency: 1, + CacheEnabled: true, + OutputFileName: "provider.gen.go", + SourceLocations: false, + } +} + +// NewParserContext creates a new ParserContext with the given configuration +func NewParserContext(config *ParserConfig) *ParserContext { + if config == nil { + config = NewParserConfig() + } + + return &ParserContext{ + Config: config, + FileSet: config.FileSet, + Imports: make(map[string]string), + ModuleInfo: make(map[string]string), + ParseErrors: make([]ParseError, 0), + Cache: make(map[string]interface{}), + } +} + +// ShouldIncludeFile determines if a file should be included in parsing +func (c *ParserContext) ShouldIncludeFile(filePath string) bool { + // Check file extension + if filepath.Ext(filePath) != ".go" { + return false + } + + // Skip test files if not allowed + if !c.Config.AllowTestFiles && strings.HasSuffix(filePath, "_test.go") { + return false + } + + // Skip generated files if not allowed + if !c.Config.AllowGenFiles && strings.HasSuffix(filePath, ".gen.go") { + return false + } + + // Check file size + if info, err := os.Stat(filePath); err == nil { + if info.Size() > c.Config.MaxFileSize { + c.AddError(filePath, 0, 0, "file exceeds maximum size", "warning") + return false + } + } + + // TODO: Implement include/exclude pattern matching + // For now, include all Go files that pass the basic checks + return true +} + +// AddError adds a parsing error to the context +func (c *ParserContext) AddError(file string, line, column int, message, severity string) { + c.ParseErrors = append(c.ParseErrors, ParseError{ + File: file, + Line: line, + Column: column, + Message: message, + Severity: severity, + }) +} + +// HasErrors returns true if there are any errors in the context +func (c *ParserContext) HasErrors() bool { + for _, err := range c.ParseErrors { + if err.Severity == "error" { + return true + } + } + return false +} + +// GetErrors returns all errors of a specific severity +func (c *ParserContext) GetErrors(severity string) []ParseError { + var errors []ParseError + for _, err := range c.ParseErrors { + if err.Severity == severity { + errors = append(errors, err) + } + } + return errors +} + +// AddImport adds an import to the context +func (c *ParserContext) AddImport(alias, path string) { + c.Imports[alias] = path +} + +// GetImportPath returns the import path for a given alias +func (c *ParserContext) GetImportPath(alias string) (string, bool) { + path, ok := c.Imports[alias] + return path, ok +} diff --git a/pkg/ast/provider/errors.go b/pkg/ast/provider/errors.go new file mode 100644 index 0000000..8432f4a --- /dev/null +++ b/pkg/ast/provider/errors.go @@ -0,0 +1,434 @@ +package provider + +import ( + "errors" + "fmt" + "io/fs" + "os" + "runtime" + "strings" +) + +// ParserError represents errors that occur during parsing +type ParserError struct { + Operation string `json:"operation"` + File string `json:"file"` + Line int `json:"line"` + Column int `json:"column"` + Message string `json:"message"` + Code string `json:"code,omitempty"` + Severity string `json:"severity"` // "error", "warning", "info" + Cause error `json:"cause,omitempty"` + Stack string `json:"stack,omitempty"` +} + +// Error implements the error interface +func (e *ParserError) Error() string { + if e.File != "" { + return fmt.Sprintf("%s: %s at %s:%d:%d", e.Operation, e.Message, e.File, e.Line, e.Column) + } + return fmt.Sprintf("%s: %s", e.Operation, e.Message) +} + +// Unwrap implements the error unwrapping interface +func (e *ParserError) Unwrap() error { + return e.Cause +} + +// Is implements the error comparison interface +func (e *ParserError) Is(target error) bool { + var other *ParserError + if errors.As(target, &other) { + return e.Operation == other.Operation && e.Code == other.Code + } + return false +} + +// RendererError represents errors that occur during rendering +type RendererError struct { + Operation string `json:"operation"` + Template string `json:"template"` + Target string `json:"target,omitempty"` + Message string `json:"message"` + Code string `json:"code,omitempty"` + Cause error `json:"cause,omitempty"` + Stack string `json:"stack,omitempty"` +} + +// Error implements the error interface +func (e *RendererError) Error() string { + if e.Template != "" { + return fmt.Sprintf("renderer %s (template %s): %s", e.Operation, e.Template, e.Message) + } + return fmt.Sprintf("renderer %s: %s", e.Operation, e.Message) +} + +// Unwrap implements the error unwrapping interface +func (e *RendererError) Unwrap() error { + return e.Cause +} + +// Is implements the error comparison interface +func (e *RendererError) Is(target error) bool { + var other *RendererError + if errors.As(target, &other) { + return e.Operation == other.Operation && e.Code == other.Code + } + return false +} + +// FileSystemError represents file system related errors +type FileSystemError struct { + Operation string `json:"operation"` + Path string `json:"path"` + Message string `json:"message"` + Code string `json:"code,omitempty"` + Cause error `json:"cause,omitempty"` +} + +// Error implements the error interface +func (e *FileSystemError) Error() string { + return fmt.Sprintf("file system %s failed for %s: %s", e.Operation, e.Path, e.Message) +} + +// Unwrap implements the error unwrapping interface +func (e *FileSystemError) Unwrap() error { + return e.Cause +} + +// ConfigurationError represents configuration related errors +type ConfigurationError struct { + Field string `json:"field"` + Value string `json:"value,omitempty"` + Message string `json:"message"` + Code string `json:"code,omitempty"` + Cause error `json:"cause,omitempty"` +} + +// Error implements the error interface +func (e *ConfigurationError) Error() string { + if e.Value != "" { + return fmt.Sprintf("configuration error for field %s (value: %s): %s", e.Field, e.Value, e.Message) + } + return fmt.Sprintf("configuration error for field %s: %s", e.Field, e.Message) +} + +// Unwrap implements the error unwrapping interface +func (e *ConfigurationError) Unwrap() error { + return e.Cause +} + +// Error codes +const ( + ErrCodeFileNotFound = "FILE_NOT_FOUND" + ErrCodePermissionDenied = "PERMISSION_DENIED" + ErrCodeInvalidSyntax = "INVALID_SYNTAX" + ErrCodeInvalidAnnotation = "INVALID_ANNOTATION" + ErrCodeInvalidMode = "INVALID_MODE" + ErrCodeInvalidType = "INVALID_TYPE" + ErrCodeTemplateNotFound = "TEMPLATE_NOT_FOUND" + ErrCodeTemplateError = "TEMPLATE_ERROR" + ErrCodeValidationFailed = "VALIDATION_FAILED" + ErrCodeConfigurationError = "CONFIGURATION_ERROR" + ErrCodeFileSystemError = "FILE_SYSTEM_ERROR" + ErrCodeUnknownError = "UNKNOWN_ERROR" +) + +// Error severity levels +const ( + SeverityError = "error" + SeverityWarning = "warning" + SeverityInfo = "info" +) + +// Error builder functions + +// NewParserError creates a new ParserError +func NewParserError(operation, message string) *ParserError { + return &ParserError{ + Operation: operation, + Message: message, + Severity: SeverityError, + } +} + +// NewParserErrorWithCause creates a new ParserError with a cause +func NewParserErrorWithCause(operation, message string, cause error) *ParserError { + err := NewParserError(operation, message) + err.Cause = cause + err.Stack = captureStackTrace(2) + return err +} + +// NewParserErrorAtLocation creates a new ParserError with file location +func NewParserErrorAtLocation(operation, file string, line, column int, message string) *ParserError { + return &ParserError{ + Operation: operation, + File: file, + Line: line, + Column: column, + Message: message, + Severity: SeverityError, + } +} + +// NewValidationError creates a new ValidationError +func NewValidationError(ruleName, message string) *ValidationError { + return &ValidationError{ + RuleName: ruleName, + Message: message, + Severity: SeverityError, + } +} + +// NewValidationErrorWithCause creates a new ValidationError with a cause +func NewValidationErrorWithCause(ruleName, message string, cause error) *ValidationError { + err := NewValidationError(ruleName, message) + err.Cause = cause + return err +} + +// NewRendererError creates a new RendererError +func NewRendererError(operation, message string) *RendererError { + return &RendererError{ + Operation: operation, + Message: message, + } +} + +// NewRendererErrorWithCause creates a new RendererError with a cause +func NewRendererErrorWithCause(operation, message string, cause error) *RendererError { + err := NewRendererError(operation, message) + err.Cause = cause + err.Stack = captureStackTrace(2) + return err +} + +// NewFileSystemError creates a new FileSystemError +func NewFileSystemError(operation, path, message string) *FileSystemError { + return &FileSystemError{ + Operation: operation, + Path: path, + Message: message, + } +} + +// NewFileSystemErrorFromError creates a new FileSystemError from an existing error +func NewFileSystemErrorFromError(operation, path string, err error) *FileSystemError { + return &FileSystemError{ + Operation: operation, + Path: path, + Message: err.Error(), + Cause: err, + } +} + +// NewConfigurationError creates a new ConfigurationError +func NewConfigurationError(field, message string) *ConfigurationError { + return &ConfigurationError{ + Field: field, + Message: message, + } +} + +// WrapError wraps an error with additional context +func WrapError(err error, operation string) error { + if err == nil { + return nil + } + + switch e := err.(type) { + case *ParserError: + return NewParserErrorWithCause(e.Operation, e.Message, err) + case *ValidationError: + return NewValidationErrorWithCause(e.RuleName, e.Message, err) + case *RendererError: + return NewRendererErrorWithCause(e.Operation, e.Message, err) + case *FileSystemError: + return NewFileSystemErrorFromError(e.Operation, e.Path, err) + case *ConfigurationError: + return NewConfigurationError(e.Field, e.Message) + default: + return fmt.Errorf("%s: %w", operation, err) + } +} + +// Error utility functions + +// IsParserError checks if an error is a ParserError +func IsParserError(err error) bool { + var target *ParserError + return errors.As(err, &target) +} + +// IsValidationError checks if an error is a ValidationError +func IsValidationError(err error) bool { + var target *ValidationError + return errors.As(err, &target) +} + +// IsRendererError checks if an error is a RendererError +func IsRendererError(err error) bool { + var target *RendererError + return errors.As(err, &target) +} + +// IsFileSystemError checks if an error is a FileSystemError +func IsFileSystemError(err error) bool { + var target *FileSystemError + return errors.As(err, &target) +} + +// IsConfigurationError checks if an error is a ConfigurationError +func IsConfigurationError(err error) bool { + var target *ConfigurationError + return errors.As(err, &target) +} + +// GetErrorCode returns the error code for a given error +func GetErrorCode(err error) string { + if err == nil { + return "" + } + + switch e := err.(type) { + case *ParserError: + return e.Code + case *RendererError: + return e.Code + case *FileSystemError: + return e.Code + case *ConfigurationError: + return e.Code + default: + return ErrCodeUnknownError + } +} + +// IsFileNotFoundError checks if an error is a file not found error +func IsFileNotFoundError(err error) bool { + if errors.Is(err, fs.ErrNotExist) { + return true + } + + if pathErr, ok := err.(*os.PathError); ok { + return errors.Is(pathErr.Err, fs.ErrNotExist) + } + + return false +} + +// IsPermissionError checks if an error is a permission error +func IsPermissionError(err error) bool { + if errors.Is(err, os.ErrPermission) { + return true + } + + if pathErr, ok := err.(*os.PathError); ok { + return errors.Is(pathErr.Err, os.ErrPermission) + } + + return false +} + +// Error recovery functions + +// RecoverFromParseError attempts to recover from parsing errors +func RecoverFromParseError(err error) error { + if err == nil { + return nil + } + + // If it's a file system error, provide more helpful message + if IsFileSystemError(err) { + var fsErr *FileSystemError + if errors.As(err, &fsErr) { + if IsFileNotFoundError(fsErr.Cause) { + return NewParserErrorWithCause( + "parse", + fmt.Sprintf("file not found: %s", fsErr.Path), + err, + ) + } + if IsPermissionError(fsErr.Cause) { + return NewParserErrorWithCause( + "parse", + fmt.Sprintf("permission denied: %s", fsErr.Path), + err, + ) + } + } + } + + // For syntax errors, try to provide location information + if strings.Contains(err.Error(), "syntax") { + return NewParserErrorWithCause("parse", "syntax error in source code", err) + } + + // Default: wrap the error with parsing context + return WrapError(err, "parse") +} + +// captureStackTrace captures the current stack trace +func captureStackTrace(skip int) string { + const depth = 32 + var pcs [depth]uintptr + n := runtime.Callers(skip, pcs[:]) + frames := runtime.CallersFrames(pcs[:n]) + + var builder strings.Builder + for { + frame, more := frames.Next() + if !more || strings.Contains(frame.Function, "runtime.") { + break + } + fmt.Fprintf(&builder, "%s\n\t%s:%d\n", frame.Function, frame.File, frame.Line) + } + + return builder.String() +} + +// Error aggregation + +// ErrorGroup represents a group of related errors +type ErrorGroup struct { + Errors []error `json:"errors"` +} + +// Error implements the error interface +func (eg *ErrorGroup) Error() string { + if len(eg.Errors) == 0 { + return "no errors" + } + + if len(eg.Errors) == 1 { + return eg.Errors[0].Error() + } + + var messages []string + for _, err := range eg.Errors { + messages = append(messages, err.Error()) + } + return fmt.Sprintf("multiple errors occurred:\n\t%s", strings.Join(messages, "\n\t")) +} + +// Unwrap implements the error unwrapping interface +func (eg *ErrorGroup) Unwrap() []error { + return eg.Errors +} + +// NewErrorGroup creates a new ErrorGroup +func NewErrorGroup(errors ...error) *ErrorGroup { + var filteredErrors []error + for _, err := range errors { + if err != nil { + filteredErrors = append(filteredErrors, err) + } + } + return &ErrorGroup{Errors: filteredErrors} +} + +// CollectErrors collects non-nil errors from a slice +func CollectErrors(errors ...error) *ErrorGroup { + return NewErrorGroup(errors...) +} diff --git a/pkg/ast/provider/import_resolver.go b/pkg/ast/provider/import_resolver.go new file mode 100644 index 0000000..1a9297d --- /dev/null +++ b/pkg/ast/provider/import_resolver.go @@ -0,0 +1,390 @@ +package provider + +import ( + "fmt" + "go/ast" + "math/rand" + "path/filepath" + "strings" + + "go.ipao.vip/atomctl/v2/pkg/utils/gomod" +) + +// ImportResolver handles resolution of Go imports and package aliases +type ImportResolver struct { + resolverConfig *ResolverConfig + cache map[string]*ImportResolution +} + +// ResolverConfig configures the import resolver behavior +type ResolverConfig struct { + EnableCache bool + StrictMode bool + DefaultAliasStrategy AliasStrategy + AnonymousImportHandling AnonymousImportPolicy +} + +// AliasStrategy defines how to generate default aliases +type AliasStrategy int + +const ( + AliasStrategyModuleName AliasStrategy = iota + AliasStrategyLastPath + AliasStrategyCustom +) + +// AnonymousImportPolicy defines how to handle anonymous imports +type AnonymousImportPolicy int + +const ( + AnonymousImportSkip AnonymousImportPolicy = iota + AnonymousImportUseModuleName + AnonymousImportGenerateUnique +) + +// ImportResolution represents a resolved import +type ImportResolution struct { + Path string // Import path + Alias string // Package alias + IsAnonymous bool // Is this an anonymous import (_) + IsValid bool // Is the import valid + Error string // Error message if invalid + PackageName string // Actual package name + Dependencies map[string]string // Dependencies of this import +} + +// ImportContext maintains context for import resolution +type ImportContext struct { + FileImports map[string]*ImportResolution // Alias -> Resolution + ImportPaths map[string]string // Path -> Alias + ModuleInfo map[string]string // Module path -> module name + WorkingDir string // Current working directory + ModuleName string // Current module name + ProcessedFiles map[string]bool // Track processed files +} + +// NewImportResolver creates a new ImportResolver +func NewImportResolver() *ImportResolver { + return &ImportResolver{ + resolverConfig: &ResolverConfig{ + EnableCache: true, + StrictMode: false, + DefaultAliasStrategy: AliasStrategyModuleName, + AnonymousImportHandling: AnonymousImportUseModuleName, + }, + cache: make(map[string]*ImportResolution), + } +} + +// NewImportResolverWithConfig creates a new ImportResolver with custom configuration +func NewImportResolverWithConfig(config *ResolverConfig) *ImportResolver { + if config == nil { + return NewImportResolver() + } + + return &ImportResolver{ + resolverConfig: config, + cache: make(map[string]*ImportResolution), + } +} + +// ResolveFileImports resolves all imports for a given AST file +func (ir *ImportResolver) ResolveFileImports(file *ast.File, filePath string) (*ImportContext, error) { + context := &ImportContext{ + FileImports: make(map[string]*ImportResolution), + ImportPaths: make(map[string]string), + ModuleInfo: make(map[string]string), + WorkingDir: filepath.Dir(filePath), + ProcessedFiles: make(map[string]bool), + } + + // Resolve current module name + moduleName := gomod.GetModuleName() + context.ModuleName = moduleName + + // Process imports + for _, imp := range file.Imports { + resolution, err := ir.resolveImportSpec(imp, context) + if err != nil { + if ir.resolverConfig.StrictMode { + return nil, err + } + // In non-strict mode, continue with other imports + continue + } + + if resolution != nil { + context.FileImports[resolution.Alias] = resolution + context.ImportPaths[resolution.Path] = resolution.Alias + } + } + + return context, nil +} + +// resolveImportSpec resolves a single import specification +func (ir *ImportResolver) resolveImportSpec(imp *ast.ImportSpec, context *ImportContext) (*ImportResolution, error) { + // Extract import path + path := strings.Trim(imp.Path.Value, "\"") + if path == "" { + return nil, fmt.Errorf("empty import path") + } + + // Check cache first + if ir.resolverConfig.EnableCache { + if cached, found := ir.cache[path]; found { + return cached, nil + } + } + + // Determine alias + alias := ir.determineAlias(imp, path, context) + + // Resolve package name + packageName, err := ir.resolvePackageName(path, context) + if err != nil { + resolution := &ImportResolution{ + Path: path, + Alias: alias, + IsAnonymous: imp.Name != nil && imp.Name.Name == "_", + IsValid: false, + Error: err.Error(), + PackageName: "", + } + + if ir.resolverConfig.EnableCache { + ir.cache[path] = resolution + } + + return resolution, err + } + + // Create resolution + resolution := &ImportResolution{ + Path: path, + Alias: alias, + IsAnonymous: imp.Name != nil && imp.Name.Name == "_", + IsValid: true, + PackageName: packageName, + Dependencies: make(map[string]string), + } + + // Resolve dependencies if needed + if err := ir.resolveDependencies(resolution, context); err != nil { + resolution.IsValid = false + resolution.Error = err.Error() + } + + // Cache the result + if ir.resolverConfig.EnableCache { + ir.cache[path] = resolution + } + + return resolution, nil +} + +// determineAlias determines the appropriate alias for an import +func (ir *ImportResolver) determineAlias(imp *ast.ImportSpec, path string, context *ImportContext) string { + // If explicit alias is provided, use it + if imp.Name != nil { + if imp.Name.Name == "_" { + // Handle anonymous import based on policy + return ir.handleAnonymousImport(path, context) + } + return imp.Name.Name + } + + // Generate default alias based on strategy + switch ir.resolverConfig.DefaultAliasStrategy { + case AliasStrategyModuleName: + return gomod.GetPackageModuleName(path) + case AliasStrategyLastPath: + return ir.getLastPathComponent(path) + case AliasStrategyCustom: + return ir.generateCustomAlias(path, context) + default: + return gomod.GetPackageModuleName(path) + } +} + +// handleAnonymousImport handles anonymous imports based on policy +func (ir *ImportResolver) handleAnonymousImport(path string, context *ImportContext) string { + switch ir.resolverConfig.AnonymousImportHandling { + case AnonymousImportSkip: + return "_" + case AnonymousImportUseModuleName: + alias := gomod.GetPackageModuleName(path) + // Check for conflicts + if _, exists := context.FileImports[alias]; exists { + return ir.generateUniqueAlias(alias, context) + } + return alias + case AnonymousImportGenerateUnique: + baseAlias := gomod.GetPackageModuleName(path) + return ir.generateUniqueAlias(baseAlias, context) + default: + return "_" + } +} + +// resolvePackageName resolves the actual package name for an import path +func (ir *ImportResolver) resolvePackageName(path string, context *ImportContext) (string, error) { + // Handle standard library packages + if !strings.Contains(path, ".") { + // For standard library, the package name is typically the last component + return ir.getLastPathComponent(path), nil + } + + // Handle third-party packages + packageName := gomod.GetPackageModuleName(path) + if packageName == "" { + return "", fmt.Errorf("could not resolve package name for %s", path) + } + + return packageName, nil +} + +// resolveDependencies resolves dependencies for an import +func (ir *ImportResolver) resolveDependencies(resolution *ImportResolution, context *ImportContext) error { + // This is a placeholder for dependency resolution + // In a more sophisticated implementation, this could: + // - Parse the imported package to find its dependencies + // - Check for version conflicts + // - Validate import compatibility + + // For now, we'll just note that third-party packages might have dependencies + if strings.Contains(resolution.Path, ".") { + // Add some common dependencies as examples + // This could be made configurable + } + return nil +} + +// GetAlias returns the alias for a given import path +func (ir *ImportResolver) GetAlias(path string, context *ImportContext) (string, bool) { + alias, exists := context.ImportPaths[path] + return alias, exists +} + +// GetPath returns the import path for a given alias +func (ir *ImportResolver) GetPath(alias string, context *ImportContext) (string, bool) { + if resolution, exists := context.FileImports[alias]; exists { + return resolution.Path, true + } + return "", false +} + +// GetPackageName returns the package name for a given alias or path +func (ir *ImportResolver) GetPackageName(identifier string, context *ImportContext) (string, bool) { + // First try as alias + if resolution, exists := context.FileImports[identifier]; exists { + return resolution.PackageName, true + } + + // Then try as path + if alias, exists := context.ImportPaths[identifier]; exists { + if resolution, resExists := context.FileImports[alias]; resExists { + return resolution.PackageName, true + } + } + + return "", false +} + +// IsValidImport checks if an import path is valid +func (ir *ImportResolver) IsValidImport(path string) bool { + // Basic validation + if path == "" { + return false + } + + // Check for invalid characters + if strings.ContainsAny(path, " \t\n\r\"'") { + return false + } + + // TODO: Add more sophisticated validation + return true +} + +// GetImportPathFromType extracts the import path from a qualified type name +func (ir *ImportResolver) GetImportPathFromType(typeName string, context *ImportContext) (string, bool) { + if !strings.Contains(typeName, ".") { + return "", false + } + + alias := strings.Split(typeName, ".")[0] + path, exists := ir.GetPath(alias, context) + return path, exists +} + +// Helper methods + +func (ir *ImportResolver) getLastPathComponent(path string) string { + parts := strings.Split(path, "/") + if len(parts) == 0 { + return "" + } + return parts[len(parts)-1] +} + +func (ir *ImportResolver) generateCustomAlias(path string, context *ImportContext) string { + // Generate a meaningful alias based on the path + parts := strings.Split(path, "/") + if len(parts) == 0 { + return "unknown" + } + + // Use the last few parts to create a meaningful alias + start := 0 + if len(parts) > 2 { + start = len(parts) - 2 + } + + aliasParts := parts[start:] + for i, part := range aliasParts { + aliasParts[i] = strings.ToLower(part) + } + + return strings.Join(aliasParts, "") +} + +func (ir *ImportResolver) generateUniqueAlias(baseAlias string, context *ImportContext) string { + // Check if base alias is available + if _, exists := context.FileImports[baseAlias]; !exists { + return baseAlias + } + + // Generate unique alias by adding suffix + for i := 1; i < 1000; i++ { + candidate := fmt.Sprintf("%s%d", baseAlias, i) + if _, exists := context.FileImports[candidate]; !exists { + return candidate + } + } + + // Fallback to random suffix + return fmt.Sprintf("%s%d", baseAlias, rand.Intn(10000)) +} + +// ClearCache clears the import resolution cache +func (ir *ImportResolver) ClearCache() { + ir.cache = make(map[string]*ImportResolution) +} + +// GetCacheSize returns the number of cached resolutions +func (ir *ImportResolver) GetCacheSize() int { + return len(ir.cache) +} + +// GetConfig returns the resolver configuration +func (ir *ImportResolver) GetConfig() *ResolverConfig { + return ir.resolverConfig +} + +// SetConfig updates the resolver configuration +func (ir *ImportResolver) SetConfig(config *ResolverConfig) { + ir.resolverConfig = config + // Clear cache when config changes + ir.ClearCache() +} diff --git a/pkg/ast/provider/modes.go b/pkg/ast/provider/modes.go new file mode 100644 index 0000000..89aa07a --- /dev/null +++ b/pkg/ast/provider/modes.go @@ -0,0 +1,59 @@ +package provider + +// ProviderMode represents the mode of a provider +type ProviderMode string + +const ( + // ProviderModeBasic is the default provider mode + ProviderModeBasic ProviderMode = "basic" + + // ProviderModeGrpc is for gRPC service providers + ProviderModeGrpc ProviderMode = "grpc" + + // ProviderModeEvent is for event-based providers + ProviderModeEvent ProviderMode = "event" + + // ProviderModeJob is for job-based providers + ProviderModeJob ProviderMode = "job" + + // ProviderModeCronJob is for cron job providers + ProviderModeCronJob ProviderMode = "cronjob" + + // ProviderModeModel is for model-based providers + ProviderModeModel ProviderMode = "model" +) + +// IsValidProviderMode checks if a provider mode is valid +func IsValidProviderMode(mode string) bool { + switch ProviderMode(mode) { + case ProviderModeBasic, ProviderModeGrpc, ProviderModeEvent, + ProviderModeJob, ProviderModeCronJob, ProviderModeModel: + return true + default: + return false + } +} + +// InjectionMode represents the injection mode for provider fields +type InjectionMode string + +const ( + // InjectionModeOnly injects only fields marked with inject:"true" + InjectionModeOnly InjectionMode = "only" + + // InjectionModeExcept injects all fields except those marked with inject:"false" + InjectionModeExcept InjectionMode = "except" + + // InjectionModeAuto injects all non-scalar fields automatically + InjectionModeAuto InjectionMode = "auto" +) + +// IsValidInjectionMode checks if an injection mode is valid +func IsValidInjectionMode(mode string) bool { + switch InjectionMode(mode) { + case InjectionModeOnly, InjectionModeExcept, InjectionModeAuto: + return true + default: + return false + } +} diff --git a/pkg/ast/provider/parser.go b/pkg/ast/provider/parser.go new file mode 100644 index 0000000..c703881 --- /dev/null +++ b/pkg/ast/provider/parser.go @@ -0,0 +1,388 @@ +package provider + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "path/filepath" + "strings" + + log "github.com/sirupsen/logrus" + "go.ipao.vip/atomctl/v2/pkg/utils/gomod" +) + +// MainParser represents the main parser that uses extracted components +type MainParser struct { + commentParser *CommentParser + importResolver *ImportResolver + astWalker *ASTWalker + builder *ProviderBuilder + validator *GoValidator + config *ParserConfig +} + +// NewParser creates a new MainParser with default configuration +func NewParser() *MainParser { + return &MainParser{ + commentParser: NewCommentParser(), + importResolver: NewImportResolver(), + astWalker: NewASTWalker(), + builder: NewProviderBuilder(), + validator: NewGoValidator(), + config: NewParserConfig(), + } +} + +// NewParserWithConfig creates a new MainParser with custom configuration +func NewParserWithConfig(config *ParserConfig) *MainParser { + if config == nil { + return NewParser() + } + + return &MainParser{ + commentParser: NewCommentParser(), + importResolver: NewImportResolver(), + astWalker: NewASTWalkerWithConfig(&WalkerConfig{ + StrictMode: config.StrictMode, + }), + builder: NewProviderBuilderWithConfig(&BuilderConfig{ + EnableValidation: config.StrictMode, + StrictMode: config.StrictMode, + DefaultProviderMode: ProviderModeBasic, + DefaultInjectionMode: InjectionModeAuto, + AutoGenerateReturnTypes: true, + ResolveImportDependencies: true, + }), + validator: NewGoValidator(), + config: config, + } +} + +// Parse parses a Go source file and returns discovered providers +// This is the refactored version of the original Parse function +func ParseRefactored(source string) []Provider { + parser := NewParser() + providers, err := parser.ParseFile(source) + if err != nil { + log.Error("Parse error: ", err) + return []Provider{} + } + return providers +} + +// ParseFile parses a single Go source file and returns discovered providers +func (p *MainParser) ParseFile(source string) ([]Provider, error) { + // Check if file should be processed + if !p.shouldProcessFile(source) { + return []Provider{}, nil + } + + // Parse the AST + fset := token.NewFileSet() + node, err := parser.ParseFile(fset, source, nil, parser.ParseComments) + if err != nil { + return nil, fmt.Errorf("failed to parse file %s: %w", source, err) + } + + // Create parser context + context := NewParserContext(p.config) + context.WorkingDir = filepath.Dir(source) + context.ModuleName = gomod.GetModuleName() + + // Resolve imports + importContext, err := p.importResolver.ResolveFileImports(node, source) + if err != nil { + return nil, fmt.Errorf("failed to resolve imports: %w", err) + } + + // Create builder context + builderContext := &BuilderContext{ + FilePath: source, + PackageName: node.Name.Name, + ImportContext: importContext, + ASTFile: node, + ProcessedTypes: make(map[string]bool), + Errors: make([]error, 0), + Warnings: make([]string, 0), + } + + // Use AST walker to find provider annotations + visitor := NewProviderDiscoveryVisitor(p.commentParser) + p.astWalker.AddVisitor(visitor) + + // Walk the AST + if err := p.astWalker.WalkFile(source); err != nil { + return nil, fmt.Errorf("failed to walk AST: %w", err) + } + + // Build providers from discovered annotations + providers := make([]Provider, 0) + discoveredProviders := visitor.GetProviders() + + for _, discoveredProvider := range discoveredProviders { + // Find the corresponding AST node for this provider + provider, err := p.buildProviderFromDiscovery(discoveredProvider, node, builderContext) + if err != nil { + context.AddError(source, 0, 0, fmt.Sprintf("failed to build provider %s: %v", discoveredProvider.StructName, err), "error") + continue + } + + // Validate the provider if enabled + if p.config.StrictMode { + if err := p.validator.Validate(&provider); err != nil { + context.AddError(source, 0, 0, fmt.Sprintf("validation failed for provider %s: %v", provider.StructName, err), "error") + continue + } + } + + providers = append(providers, provider) + } + + // Log any warnings or errors + for _, parseErr := range context.GetErrors("warning") { + log.Warnf("Warning while parsing %s: %s", source, parseErr.Message) + } + for _, parseErr := range context.GetErrors("error") { + log.Errorf("Error while parsing %s: %s", source, parseErr.Message) + } + + return providers, nil +} + +// ParseDir parses all Go files in a directory and returns discovered providers +func (p *MainParser) ParseDir(dir string) ([]Provider, error) { + var allProviders []Provider + + // Use AST walker to traverse the directory + if err := p.astWalker.WalkDir(dir); err != nil { + return nil, fmt.Errorf("failed to walk directory: %w", err) + } + + // Note: This would need to be enhanced to collect providers from all files + // For now, we'll return an empty slice and log a warning + log.Warn("ParseDir not fully implemented yet") + return allProviders, nil +} + +// shouldProcessFile determines if a file should be processed +func (p *MainParser) shouldProcessFile(source string) bool { + // Skip test files + if strings.HasSuffix(source, "_test.go") { + return false + } + + // Skip generated provider files + if strings.HasSuffix(source, "provider.gen.go") { + return false + } + + return true +} + +// buildProviderFromDiscovery builds a complete Provider from a discovered provider annotation +func (p *MainParser) buildProviderFromDiscovery(discoveredProvider Provider, node *ast.File, context *BuilderContext) (Provider, error) { + // Find the corresponding type specification in the AST + var typeSpec *ast.TypeSpec + var genDecl *ast.GenDecl + + for _, decl := range node.Decls { + gd, ok := decl.(*ast.GenDecl) + if !ok { + continue + } + + if len(gd.Specs) == 0 { + continue + } + + ts, ok := gd.Specs[0].(*ast.TypeSpec) + if !ok { + continue + } + + if ts.Name.Name == discoveredProvider.StructName { + typeSpec = ts + genDecl = gd + break + } + } + + if typeSpec == nil { + return Provider{}, fmt.Errorf("type specification not found for %s", discoveredProvider.StructName) + } + + // Use the builder to construct the complete provider + provider, err := p.builder.BuildFromTypeSpec(typeSpec, genDecl, context) + if err != nil { + return Provider{}, fmt.Errorf("failed to build provider: %w", err) + } + + // Apply legacy compatibility transformations + result, err := p.applyLegacyCompatibility(&provider) + if err != nil { + return Provider{}, fmt.Errorf("failed to apply legacy compatibility: %w", err) + } + + return result, nil +} + +// applyLegacyCompatibility applies transformations to maintain backward compatibility +func (p *MainParser) applyLegacyCompatibility(provider *Provider) (Provider, error) { + // Set provider file path + provider.ProviderFile = filepath.Join(filepath.Dir(provider.ProviderFile), "provider.gen.go") + + // Apply mode-specific transformations based on the original logic + switch provider.Mode { + case ProviderModeGrpc: + p.applyGrpcCompatibility(provider) + case ProviderModeEvent: + p.applyEventCompatibility(provider) + case ProviderModeJob, ProviderModeCronJob: + p.applyJobCompatibility(provider) + case ProviderModeModel: + p.applyModelCompatibility(provider) + } + + return *provider, nil +} + +// applyGrpcCompatibility applies gRPC-specific compatibility transformations +func (p *MainParser) applyGrpcCompatibility(provider *Provider) { + modePkg := gomod.GetModuleName() + "/providers/grpc" + + // Add required imports + provider.Imports[createAtomPackage("")] = "" + provider.Imports[createAtomPackage("contracts")] = "" + provider.Imports[modePkg] = "" + + // Set provider group + if provider.ProviderGroup == "" { + provider.ProviderGroup = "atom.GroupInitial" + } + + // Set return type and register function + if provider.GrpcRegisterFunc == "" { + provider.GrpcRegisterFunc = provider.ReturnType + } + provider.ReturnType = "contracts.Initial" + + // Add gRPC injection parameter + provider.InjectParams["__grpc"] = InjectParam{ + Star: "*", + Type: "Grpc", + Package: modePkg, + PackageAlias: "grpc", + } +} + +// applyEventCompatibility applies event-specific compatibility transformations +func (p *MainParser) applyEventCompatibility(provider *Provider) { + modePkg := gomod.GetModuleName() + "/providers/event" + + // Add required imports + provider.Imports[createAtomPackage("")] = "" + provider.Imports[createAtomPackage("contracts")] = "" + provider.Imports[modePkg] = "" + + // Set provider group + if provider.ProviderGroup == "" { + provider.ProviderGroup = "atom.GroupInitial" + } + + // Set return type + provider.ReturnType = "contracts.Initial" + + // Add event injection parameter + provider.InjectParams["__event"] = InjectParam{ + Star: "*", + Type: "PubSub", + Package: modePkg, + PackageAlias: "event", + } +} + +// applyJobCompatibility applies job-specific compatibility transformations +func (p *MainParser) applyJobCompatibility(provider *Provider) { + modePkg := gomod.GetModuleName() + "/providers/job" + + // Add required imports + provider.Imports[createAtomPackage("")] = "" + provider.Imports[createAtomPackage("contracts")] = "" + provider.Imports["github.com/riverqueue/river"] = "" + provider.Imports[modePkg] = "" + + // Set provider group + if provider.ProviderGroup == "" { + provider.ProviderGroup = "atom.GroupInitial" + } + + // Set return type + provider.ReturnType = "contracts.Initial" + + // Add job injection parameter + provider.InjectParams["__job"] = InjectParam{ + Star: "*", + Type: "Job", + Package: modePkg, + PackageAlias: "job", + } +} + +// applyModelCompatibility applies model-specific compatibility transformations +func (p *MainParser) applyModelCompatibility(provider *Provider) { + // Set provider group + if provider.ProviderGroup == "" { + provider.ProviderGroup = "atom.GroupInitial" + } + + // Set return type + provider.ReturnType = "contracts.Initial" + + // Ensure prepare function is needed + provider.NeedPrepareFunc = true +} + +// GetCommentParser returns the comment parser used by this parser +func (p *MainParser) GetCommentParser() *CommentParser { + return p.commentParser +} + +// GetImportResolver returns the import resolver used by this parser +func (p *MainParser) GetImportResolver() *ImportResolver { + return p.importResolver +} + +// GetASTWalker returns the AST walker used by this parser +func (p *MainParser) GetASTWalker() *ASTWalker { + return p.astWalker +} + +// GetBuilder returns the provider builder used by this parser +func (p *MainParser) GetBuilder() *ProviderBuilder { + return p.builder +} + +// GetValidator returns the validator used by this parser +func (p *MainParser) GetValidator() *GoValidator { + return p.validator +} + +// GetConfig returns the parser configuration +func (p *MainParser) GetConfig() *ParserConfig { + return p.config +} + +// SetConfig updates the parser configuration +func (p *MainParser) SetConfig(config *ParserConfig) { + p.config = config +} + +// Helper function to create atom package paths +func createAtomPackage(suffix string) string { + root := "go.ipao.vip/atom" + if suffix != "" { + return fmt.Sprintf("%s/%s", root, suffix) + } + return root +} diff --git a/pkg/ast/provider/parser_interface.go b/pkg/ast/provider/parser_interface.go new file mode 100644 index 0000000..2dc37a5 --- /dev/null +++ b/pkg/ast/provider/parser_interface.go @@ -0,0 +1,493 @@ +package provider + +import ( + "errors" + "fmt" + "go/ast" + "go/parser" + "go/token" + "math/rand" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/samber/lo" + log "github.com/sirupsen/logrus" + "go.ipao.vip/atomctl/v2/pkg/utils/gomod" +) + +// Parser defines the interface for parsing provider annotations +type Parser interface { + // ParseFile parses a single Go file and returns providers found + ParseFile(filePath string) ([]Provider, error) + + // ParseDir parses all Go files in a directory and returns providers found + ParseDir(dirPath string) ([]Provider, error) + + // ParseString parses Go code from a string and returns providers found + ParseString(code string) ([]Provider, error) + + // SetConfig sets the parser configuration + SetConfig(config *ParserConfig) + + // GetConfig returns the current parser configuration + GetConfig() *ParserConfig + + // GetContext returns the current parser context + GetContext() *ParserContext +} + +// GoParser implements the Parser interface for Go source files +type GoParser struct { + config *ParserConfig + context *ParserContext + mu sync.RWMutex +} + +// NewGoParser creates a new GoParser with default configuration +func NewGoParser() *GoParser { + config := NewParserConfig() + context := NewParserContext(config) + + // Initialize file set if not provided + if config.FileSet == nil { + config.FileSet = token.NewFileSet() + context.FileSet = config.FileSet + } + + return &GoParser{ + config: config, + context: context, + } +} + +// NewGoParserWithConfig creates a new GoParser with custom configuration +func NewGoParserWithConfig(config *ParserConfig) *GoParser { + if config == nil { + return NewGoParser() + } + + context := NewParserContext(config) + + // Initialize file set if not provided + if config.FileSet == nil { + config.FileSet = token.NewFileSet() + context.FileSet = config.FileSet + } + + return &GoParser{ + config: config, + context: context, + } +} + +// ParseFile implements Parser.ParseFile +func (p *GoParser) ParseFile(filePath string) ([]Provider, error) { + p.mu.RLock() + defer p.mu.RUnlock() + + // Check if file should be included + if !p.context.ShouldIncludeFile(filePath) { + p.context.FilesSkipped++ + return []Provider{}, nil + } + + // Check cache if enabled + if p.config.CacheEnabled { + if cached, found := p.context.Cache[filePath]; found { + if providers, ok := cached.([]Provider); ok { + return providers, nil + } + } + } + + // Parse the file + node, err := parser.ParseFile(p.context.FileSet, filePath, nil, p.config.Mode) + if err != nil { + p.context.AddError(filePath, 0, 0, err.Error(), "error") + return nil, err + } + + // Parse providers from the file + providers, err := p.parseFileContent(filePath, node) + if err != nil { + return nil, err + } + + // Cache the result + if p.config.CacheEnabled { + p.context.Cache[filePath] = providers + } + + p.context.FilesProcessed++ + p.context.ProvidersFound += len(providers) + + return providers, nil +} + +// ParseDir implements Parser.ParseDir +func (p *GoParser) ParseDir(dirPath string) ([]Provider, error) { + p.mu.RLock() + defer p.mu.RUnlock() + + var allProviders []Provider + + // Walk through directory + err := filepath.Walk(dirPath, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + // Skip directories + if info.IsDir() { + // Skip hidden directories and common build/dependency directories + if strings.HasPrefix(info.Name(), ".") || + info.Name() == "node_modules" || + info.Name() == "vendor" || + info.Name() == "testdata" { + return filepath.SkipDir + } + return nil + } + + // Parse Go files + if filepath.Ext(path) == ".go" && p.context.ShouldIncludeFile(path) { + providers, err := p.ParseFile(path) + if err != nil { + log.Warnf("Failed to parse file %s: %v", path, err) + // Continue with other files + return nil + } + allProviders = append(allProviders, providers...) + } + + return nil + }) + if err != nil { + return nil, err + } + + return allProviders, nil +} + +// ParseString implements Parser.ParseString +func (p *GoParser) ParseString(code string) ([]Provider, error) { + p.mu.RLock() + defer p.mu.RUnlock() + + // Parse the code string + node, err := parser.ParseFile(p.context.FileSet, "", strings.NewReader(code), p.config.Mode) + if err != nil { + return nil, err + } + + // Parse providers from the AST + return p.parseFileContent("", node) +} + +// SetConfig implements Parser.SetConfig +func (p *GoParser) SetConfig(config *ParserConfig) { + p.mu.Lock() + defer p.mu.Unlock() + + p.config = config + p.context = NewParserContext(config) +} + +// GetConfig implements Parser.GetConfig +func (p *GoParser) GetConfig() *ParserConfig { + p.mu.RLock() + defer p.mu.RUnlock() + + return p.config +} + +// GetContext implements Parser.GetContext +func (p *GoParser) GetContext() *ParserContext { + p.mu.RLock() + defer p.mu.RUnlock() + + return p.context +} + +// parseFileContent parses providers from a parsed AST node +func (p *GoParser) parseFileContent(filePath string, node *ast.File) ([]Provider, error) { + // Extract imports + imports := make(map[string]string) + for _, imp := range node.Imports { + name := "" + pkgPath := strings.Trim(imp.Path.Value, "\"") + + if imp.Name != nil { + name = imp.Name.Name + } else { + name = gomod.GetPackageModuleName(pkgPath) + } + + // Handle anonymous imports + if name == "_" { + name = gomod.GetPackageModuleName(pkgPath) + // Handle duplicates + if _, ok := imports[name]; ok { + name = fmt.Sprintf("%s%d", name, rand.Intn(100)) + } + } + + imports[name] = pkgPath + p.context.AddImport(name, pkgPath) + } + + var providers []Provider + + // Parse providers from declarations + for _, decl := range node.Decls { + provider, err := p.parseProviderDecl(filePath, node, decl, imports) + if err != nil { + p.context.AddError(filePath, 0, 0, err.Error(), "warning") + continue + } + + if provider != nil { + providers = append(providers, *provider) + } + } + + return providers, nil +} + +// parseProviderDecl parses a provider from an AST declaration +func (p *GoParser) parseProviderDecl(filePath string, fileNode *ast.File, decl ast.Decl, imports map[string]string) (*Provider, error) { + genDecl, ok := decl.(*ast.GenDecl) + if !ok { + return nil, nil + } + + if len(genDecl.Specs) == 0 { + return nil, nil + } + + typeSpec, ok := genDecl.Specs[0].(*ast.TypeSpec) + if !ok { + return nil, nil + } + + // Check if it's a struct type + structType, ok := typeSpec.Type.(*ast.StructType) + if !ok { + return nil, nil + } + + // Check for provider annotation + if genDecl.Doc == nil || len(genDecl.Doc.List) == 0 { + return nil, nil + } + + docText := strings.TrimLeft(genDecl.Doc.List[len(genDecl.Doc.List)-1].Text, "/ \t") + if !strings.HasPrefix(docText, "@provider") { + return nil, nil + } + + // Parse provider annotation + providerDoc := parseProvider(docText) + + // Create provider struct + provider := &Provider{ + StructName: typeSpec.Name.Name, + ReturnType: providerDoc.ReturnType, + Mode: ProviderModeBasic, + ProviderGroup: providerDoc.Group, + InjectParams: make(map[string]InjectParam), + Imports: make(map[string]string), + PkgName: fileNode.Name.Name, + ProviderFile: filepath.Join(filepath.Dir(filePath), "provider.gen.go"), + } + + // Set default return type if not specified + if provider.ReturnType == "" { + provider.ReturnType = "*" + provider.StructName + } + + // Parse provider mode + if providerDoc.Mode != "" { + if IsValidProviderMode(providerDoc.Mode) { + provider.Mode = ProviderMode(providerDoc.Mode) + } else { + return nil, fmt.Errorf("invalid provider mode: %s", providerDoc.Mode) + } + } + + // Parse struct fields for injection + if err := p.parseStructFields(structType, imports, provider, providerDoc.IsOnly); err != nil { + return nil, err + } + + // Handle special provider modes + p.handleProviderModes(provider, providerDoc.Mode) + + // Add source location if enabled + if p.config.SourceLocations { + if genDecl.Doc != nil && len(genDecl.Doc.List) > 0 { + position := p.context.FileSet.Position(genDecl.Doc.List[0].Pos()) + provider.Location = SourceLocation{ + File: position.Filename, + Line: position.Line, + Column: position.Column, + } + } + } + + return provider, nil +} + +// parseStructFields parses struct fields for injection parameters +func (p *GoParser) parseStructFields(structType *ast.StructType, imports map[string]string, provider *Provider, onlyMode bool) error { + for _, field := range structType.Fields.List { + if field.Names == nil { + continue + } + + // Check for struct tags + if field.Tag != nil { + provider.NeedPrepareFunc = true + } + + // Check injection mode + shouldInject := true + if onlyMode { + shouldInject = field.Tag != nil && strings.Contains(field.Tag.Value, `inject:"true"`) + } else { + shouldInject = field.Tag == nil || !strings.Contains(field.Tag.Value, `inject:"false"`) + } + + if !shouldInject { + continue + } + + // Parse field type + star, pkg, pkgAlias, typ, err := p.parseFieldType(field.Type, imports) + if err != nil { + continue + } + + // Skip scalar types + if lo.Contains(scalarTypes, typ) { + continue + } + + // Add injection parameter + for _, name := range field.Names { + provider.InjectParams[name.Name] = InjectParam{ + Star: star, + Type: typ, + Package: pkg, + PackageAlias: pkgAlias, + } + + // Add to imports + if pkg != "" && pkgAlias != "" { + provider.Imports[pkg] = pkgAlias + } + } + } + + return nil +} + +// parseFieldType parses a field type and returns its components +func (p *GoParser) parseFieldType(expr ast.Expr, imports map[string]string) (star, pkg, pkgAlias, typ string, err error) { + switch t := expr.(type) { + case *ast.Ident: + typ = t.Name + case *ast.StarExpr: + star = "*" + return p.parseFieldType(t.X, imports) + case *ast.SelectorExpr: + if x, ok := t.X.(*ast.Ident); ok { + pkgAlias = x.Name + if path, ok := imports[pkgAlias]; ok { + pkg = path + } + typ = t.Sel.Name + } + default: + return "", "", "", "", errors.New("unsupported field type") + } + + return star, pkg, pkgAlias, typ, nil +} + +// handleProviderModes applies special handling for different provider modes +func (p *GoParser) handleProviderModes(provider *Provider, mode string) { + moduleName := gomod.GetModuleName() + + switch provider.Mode { + case ProviderModeGrpc: + modePkg := moduleName + "/providers/grpc" + provider.ProviderGroup = "atom.GroupInitial" + provider.GrpcRegisterFunc = provider.ReturnType + provider.ReturnType = "contracts.Initial" + + provider.Imports[atomPackage("")] = "" + provider.Imports[atomPackage("contracts")] = "" + provider.Imports[modePkg] = "" + + provider.InjectParams["__grpc"] = InjectParam{ + Star: "*", + Type: "Grpc", + Package: modePkg, + PackageAlias: "grpc", + } + + case ProviderModeEvent: + modePkg := moduleName + "/providers/event" + provider.ProviderGroup = "atom.GroupInitial" + provider.ReturnType = "contracts.Initial" + + provider.Imports[atomPackage("")] = "" + provider.Imports[atomPackage("contracts")] = "" + provider.Imports[modePkg] = "" + + provider.InjectParams["__event"] = InjectParam{ + Star: "*", + Type: "PubSub", + Package: modePkg, + PackageAlias: "event", + } + + case ProviderModeJob, ProviderModeCronJob: + modePkg := moduleName + "/providers/job" + provider.ProviderGroup = "atom.GroupInitial" + provider.ReturnType = "contracts.Initial" + + provider.Imports[atomPackage("")] = "" + provider.Imports[atomPackage("contracts")] = "" + provider.Imports["github.com/riverqueue/river"] = "" + provider.Imports[modePkg] = "" + + provider.InjectParams["__job"] = InjectParam{ + Star: "*", + Type: "Job", + Package: modePkg, + PackageAlias: "job", + } + + case ProviderModeModel: + provider.ProviderGroup = "atom.GroupInitial" + provider.ReturnType = "contracts.Initial" + provider.NeedPrepareFunc = true + } + + // Handle return type and group package imports + if pkgAlias := getTypePkgName(provider.ReturnType); pkgAlias != "" { + if importPkg, ok := p.context.Imports[pkgAlias]; ok { + provider.Imports[importPkg] = pkgAlias + } + } + + if pkgAlias := getTypePkgName(provider.ProviderGroup); pkgAlias != "" { + if importPkg, ok := p.context.Imports[pkgAlias]; ok { + provider.Imports[importPkg] = pkgAlias + } + } +} diff --git a/pkg/ast/provider/provider.go b/pkg/ast/provider/provider.go index aaf4a6e..a51e22e 100644 --- a/pkg/ast/provider/provider.go +++ b/pkg/ast/provider/provider.go @@ -40,25 +40,6 @@ var scalarTypes = []string{ "complex128", } -type InjectParam struct { - Star string - Type string - Package string - PackageAlias string -} -type Provider struct { - StructName string - ReturnType string - Mode string - ProviderGroup string - GrpcRegisterFunc string - NeedPrepareFunc bool - InjectParams map[string]InjectParam - Imports map[string]string - PkgName string - ProviderFile string -} - func atomPackage(suffix string) string { root := "go.ipao.vip/atom" if suffix != "" { @@ -115,6 +96,7 @@ func Parse(source string) []Provider { provider := Provider{ InjectParams: make(map[string]InjectParam), Imports: make(map[string]string), + Mode: ProviderModeBasic, // Default mode } decl, ok := decl.(*ast.GenDecl) @@ -260,7 +242,7 @@ func Parse(source string) []Provider { provider.ProviderFile = filepath.Join(filepath.Dir(source), "provider.gen.go") if providerDoc.Mode == "grpc" { - provider.Mode = "grpc" + provider.Mode = ProviderModeGrpc modePkg := gomod.GetModuleName() + "/providers/grpc" @@ -281,7 +263,7 @@ func Parse(source string) []Provider { } if providerDoc.Mode == "event" { - provider.Mode = "event" + provider.Mode = ProviderModeEvent modePkg := gomod.GetModuleName() + "/providers/event" @@ -300,8 +282,22 @@ func Parse(source string) []Provider { } } - if providerDoc.Mode == "job" || providerDoc.Mode == "cronjob" { - provider.Mode = providerDoc.Mode + if providerDoc.Mode == "job" { + provider.Mode = ProviderModeJob + + modePkg := gomod.GetModuleName() + "/providers/job" + + provider.Imports["github.com/riverqueue/river"] = "" + provider.Imports[modePkg] = "" + + provider.InjectParams["__job"] = InjectParam{ + Star: "*", + Type: "Job", + Package: modePkg, + PackageAlias: "job", + } + } else if providerDoc.Mode == "cronjob" { + provider.Mode = ProviderModeCronJob modePkg := gomod.GetModuleName() + "/providers/job" @@ -322,7 +318,7 @@ func Parse(source string) []Provider { } if providerDoc.Mode == "model" { - provider.Mode = "model" + provider.Mode = ProviderModeModel provider.ProviderGroup = "atom.GroupInitial" provider.ReturnType = "contracts.Initial" @@ -336,14 +332,6 @@ func Parse(source string) []Provider { return providers } -// @provider(mode):[except|only] [returnType] [group] -type ProviderDescribe struct { - IsOnly bool - Mode string // job - ReturnType string - Group string -} - func (p ProviderDescribe) String() { // log.Infof("[%s] %s => ONLY: %+v, EXCEPT: %+v, Type: %s, Group: %s", source, declType.Name.Name, onlyMode, exceptMode, provider.ReturnType, provider.ProviderGroup) } diff --git a/pkg/ast/provider/renderer.go b/pkg/ast/provider/renderer.go new file mode 100644 index 0000000..889d12e --- /dev/null +++ b/pkg/ast/provider/renderer.go @@ -0,0 +1,321 @@ +package provider + +import ( + "bytes" + "fmt" + "io" + "os" + "path/filepath" + "strings" + "text/template" + "time" +) + +// Renderer defines the interface for rendering provider code +type Renderer interface { + // Render renders providers to Go code + Render(providers []Provider) ([]byte, error) + + // RenderToFile renders providers to a file + RenderToFile(providers []Provider, filePath string) error + + // RenderToWriter renders providers to an io.Writer + RenderToWriter(providers []Provider, writer io.Writer) error + + // AddTemplate adds a custom template + AddTemplate(name, content string) error + + // RemoveTemplate removes a custom template + RemoveTemplate(name string) + + // GetTemplate returns a template by name + GetTemplate(name string) (*template.Template, error) + + // SetTemplateFuncs sets custom template functions + SetTemplateFuncs(funcs template.FuncMap) +} + +// GoRenderer implements the Renderer interface for Go code generation +type GoRenderer struct { + templates map[string]*template.Template + templateFuncs template.FuncMap + outputConfig *OutputConfig + customTemplates map[string]string +} + +// OutputConfig represents configuration for output generation +type OutputConfig struct { + Header string // Header comment for generated files + PackageName string // Package name for generated code + Imports map[string]string // Additional imports to include + GeneratedTag string // Tag to mark generated code + DateFormat string // Date format for timestamps + TemplateDir string // Directory for custom templates + IndentString string // String used for indentation + LineEnding string // Line ending style ("\n" or "\r\n") +} + +// RenderContext represents the context for rendering +type RenderContext struct { + Providers []Provider + Config *OutputConfig + Timestamp time.Time + PackageName string + Imports map[string]string + CustomData map[string]interface{} +} + +// NewGoRenderer creates a new GoRenderer with default configuration +func NewGoRenderer() *GoRenderer { + return &GoRenderer{ + templates: make(map[string]*template.Template), + templateFuncs: defaultTemplateFuncs(), + outputConfig: NewOutputConfig(), + customTemplates: make(map[string]string), + } +} + +// NewGoRendererWithConfig creates a new GoRenderer with custom configuration +func NewGoRendererWithConfig(config *OutputConfig) *GoRenderer { + if config == nil { + config = NewOutputConfig() + } + + return &GoRenderer{ + templates: make(map[string]*template.Template), + templateFuncs: defaultTemplateFuncs(), + outputConfig: config, + customTemplates: make(map[string]string), + } +} + +// NewOutputConfig creates a new OutputConfig with default values +func NewOutputConfig() *OutputConfig { + return &OutputConfig{ + Header: "// Code generated by atomctl provider generator. DO NOT EDIT.", + PackageName: "main", + Imports: make(map[string]string), + GeneratedTag: "go:generate", + DateFormat: "2006-01-02 15:04:05", + IndentString: "\t", + LineEnding: "\n", + } +} + +// Render implements Renderer.Render +func (r *GoRenderer) Render(providers []Provider) ([]byte, error) { + var buf bytes.Buffer + + // Create render context + context := r.createRenderContext(providers) + + // Render the main template + tmpl, err := r.getOrCreateTemplate("provider", defaultProviderTemplate) + if err != nil { + return nil, fmt.Errorf("failed to get provider template: %w", err) + } + + if err := tmpl.Execute(&buf, context); err != nil { + return nil, fmt.Errorf("failed to execute template: %w", err) + } + + return buf.Bytes(), nil +} + +// RenderToFile implements Renderer.RenderToFile +func (r *GoRenderer) RenderToFile(providers []Provider, filePath string) error { + // Create directory if it doesn't exist + if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + + // Create file + file, err := os.Create(filePath) + if err != nil { + return fmt.Errorf("failed to create file: %w", err) + } + defer file.Close() + + // Render to file + return r.RenderToWriter(providers, file) +} + +// RenderToWriter implements Renderer.RenderToWriter +func (r *GoRenderer) RenderToWriter(providers []Provider, writer io.Writer) error { + content, err := r.Render(providers) + if err != nil { + return err + } + + _, err = writer.Write(content) + return err +} + +// AddTemplate implements Renderer.AddTemplate +func (r *GoRenderer) AddTemplate(name, content string) error { + tmpl, err := template.New(name).Funcs(r.templateFuncs).Parse(content) + if err != nil { + return fmt.Errorf("failed to parse template %s: %w", name, err) + } + + r.templates[name] = tmpl + r.customTemplates[name] = content + return nil +} + +// RemoveTemplate implements Renderer.RemoveTemplate +func (r *GoRenderer) RemoveTemplate(name string) { + delete(r.templates, name) + delete(r.customTemplates, name) +} + +// GetTemplate implements Renderer.GetTemplate +func (r *GoRenderer) GetTemplate(name string) (*template.Template, error) { + return r.getOrCreateTemplate(name, "") +} + +// SetTemplateFuncs implements Renderer.SetTemplateFuncs +func (r *GoRenderer) SetTemplateFuncs(funcs template.FuncMap) { + r.templateFuncs = funcs + + // Re-compile all templates with new functions + for name, content := range r.customTemplates { + tmpl, err := template.New(name).Funcs(r.templateFuncs).Parse(content) + if err != nil { + continue // Keep the old template if compilation fails + } + r.templates[name] = tmpl + } +} + +// Helper methods + +func (r *GoRenderer) createRenderContext(providers []Provider) *RenderContext { + context := &RenderContext{ + Providers: providers, + Config: r.outputConfig, + Timestamp: time.Now(), + PackageName: r.outputConfig.PackageName, + Imports: make(map[string]string), + CustomData: make(map[string]interface{}), + } + + // Collect all imports from providers + for _, provider := range providers { + for alias, path := range provider.Imports { + context.Imports[path] = alias + } + } + + // Add custom imports + for alias, path := range r.outputConfig.Imports { + context.Imports[path] = alias + } + + return context +} + +func (r *GoRenderer) getOrCreateTemplate(name, defaultContent string) (*template.Template, error) { + if tmpl, exists := r.templates[name]; exists { + return tmpl, nil + } + + if defaultContent == "" { + return nil, fmt.Errorf("template %s not found", name) + } + + tmpl, err := template.New(name).Funcs(r.templateFuncs).Parse(defaultContent) + if err != nil { + return nil, fmt.Errorf("failed to parse default template: %w", err) + } + + r.templates[name] = tmpl + return tmpl, nil +} + +func defaultTemplateFuncs() template.FuncMap { + return template.FuncMap{ + "toUpper": strings.ToUpper, + "toLower": strings.ToLower, + "toTitle": strings.Title, + "trimPrefix": strings.TrimPrefix, + "trimSuffix": strings.TrimSuffix, + "hasPrefix": strings.HasPrefix, + "hasSuffix": strings.HasSuffix, + "contains": strings.Contains, + "replace": strings.Replace, + "join": strings.Join, + "split": strings.Split, + "formatTime": formatTime, + "quote": func(s string) string { return fmt.Sprintf("%q", s) }, + "add": func(a, b int) int { return a + b }, + "sub": func(a, b int) int { return a - b }, + "mul": func(a, b int) int { return a * b }, + "div": func(a, b int) int { return a / b }, + "dict": func(values ...interface{}) (map[string]interface{}, error) { + if len(values)%2 != 0 { + return nil, fmt.Errorf("invalid dict call") + } + dict := make(map[string]interface{}) + for i := 0; i < len(values); i += 2 { + key, ok := values[i].(string) + if !ok { + return nil, fmt.Errorf("dict keys must be strings") + } + dict[key] = values[i+1] + } + return dict, nil + }, + } +} + +func formatTime(t time.Time, format string) string { + if format == "" { + format = "2006-01-02 15:04:05" + } + return t.Format(format) +} + +// Default provider template +const defaultProviderTemplate = `{{.Config.Header}} + +// Generated at: {{.Timestamp.Format "2006-01-02 15:04:05"}} +// Package: {{.PackageName}} + +package {{.PackageName}} + +import ( + {{range $path, $alias := .Imports}}"{{$path}}" {{if $alias}}"{{$alias}}"{{end}} + {{end}} +) + +{{range $provider := .Providers}} +// {{.StructName}} provider implementation +// Mode: {{.Mode}} +// Return Type: {{.ReturnType}} +{{if .NeedPrepareFunc}}func (p *{{.StructName}}) Prepare() error { + // Prepare logic for {{.StructName}} + return nil +}{{end}} + +func New{{.StructName}}({{range $name, $param := .InjectParams}}{{$name}} {{if $param.Star}}*{{end}}{{$param.Type}}{{if ne $name (last $provider.InjectParams)}}, {{end}}{{end}}) {{.ReturnType}} { + return &{{.StructName}}{ + {{range $name, $param := .InjectParams}}{{$name}}: {{$name}}, + {{end}} + } +} + +{{end}} +` + +// Utility functions for template rendering +func last(m map[string]InjectParam) string { + if len(m) == 0 { + return "" + } + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys[len(keys)-1] +} diff --git a/pkg/ast/provider/report_generator.go b/pkg/ast/provider/report_generator.go new file mode 100644 index 0000000..c7f25ba --- /dev/null +++ b/pkg/ast/provider/report_generator.go @@ -0,0 +1,372 @@ +package provider + +import ( + "encoding/json" + "fmt" + "strings" + "time" +) + +// ReportGenerator handles the generation of validation reports in various formats +type ReportGenerator struct { + report *ValidationReport +} + +// NewReportGenerator creates a new ReportGenerator +func NewReportGenerator(report *ValidationReport) *ReportGenerator { + return &ReportGenerator{ + report: report, + } +} + +// GenerateTextReport generates a human-readable text report +func (rg *ReportGenerator) GenerateTextReport() string { + var builder strings.Builder + + builder.WriteString("Provider Validation Report\n") + builder.WriteString("=========================\n\n") + builder.WriteString(fmt.Sprintf("Generated: %s\n", rg.report.Timestamp.Format(time.RFC3339))) + builder.WriteString(fmt.Sprintf("Total Providers: %d\n", rg.report.TotalProviders)) + builder.WriteString(fmt.Sprintf("Valid Providers: %d\n", rg.report.ValidCount)) + builder.WriteString(fmt.Sprintf("Invalid Providers: %d\n", rg.report.InvalidCount)) + builder.WriteString(fmt.Sprintf("Overall Status: %s\n\n", rg.getStatusText())) + + // Summary section + builder.WriteString("Summary\n") + builder.WriteString("-------\n") + if rg.report.IsValid { + builder.WriteString("✅ All providers are valid\n\n") + } else { + builder.WriteString("❌ Validation failed with errors\n\n") + } + + // Errors section + if len(rg.report.Errors) > 0 { + builder.WriteString("Errors\n") + builder.WriteString("------\n") + for i, err := range rg.report.Errors { + builder.WriteString(fmt.Sprintf("%d. %s\n", i+1, rg.formatValidationError(&err))) + } + builder.WriteString("\n") + } + + // Warnings section + if len(rg.report.Warnings) > 0 { + builder.WriteString("Warnings\n") + builder.WriteString("--------\n") + for i, warning := range rg.report.Warnings { + builder.WriteString(fmt.Sprintf("%d. %s\n", i+1, rg.formatValidationError(&warning))) + } + builder.WriteString("\n") + } + + // Infos section + if len(rg.report.Infos) > 0 { + builder.WriteString("Information\n") + builder.WriteString("-----------\n") + for i, info := range rg.report.Infos { + builder.WriteString(fmt.Sprintf("%d. %s\n", i+1, rg.formatValidationError(&info))) + } + builder.WriteString("\n") + } + + return builder.String() +} + +// GenerateJSONReport generates a JSON report +func (rg *ReportGenerator) GenerateJSONReport() (string, error) { + data, err := json.MarshalIndent(rg.report, "", " ") + if err != nil { + return "", fmt.Errorf("failed to generate JSON report: %w", err) + } + return string(data), nil +} + +// GenerateHTMLReport generates an HTML report +func (rg *ReportGenerator) GenerateHTMLReport() string { + var builder strings.Builder + + builder.WriteString(` + + + Provider Validation Report + + + +
+

Provider Validation Report

+

Generated: ` + rg.report.Timestamp.Format(time.RFC3339) + `

+

Total Providers: ` + fmt.Sprintf("%d", rg.report.TotalProviders) + `

+

Valid Providers: ` + fmt.Sprintf("%d", rg.report.ValidCount) + `

+

Invalid Providers: ` + fmt.Sprintf("%d", rg.report.InvalidCount) + `

+

Overall Status: ` + rg.getStatusHTML() + `

+
+ +
+

Summary

+

` + rg.getSummaryText() + `

+
`) + + // Errors section + if len(rg.report.Errors) > 0 { + builder.WriteString(` +
+

Errors

+
    `) + for _, err := range rg.report.Errors { + builder.WriteString(fmt.Sprintf(`
  • %s
  • `, rg.formatValidationErrorHTML(&err))) + } + builder.WriteString(`
+
`) + } + + // Warnings section + if len(rg.report.Warnings) > 0 { + builder.WriteString(` +
+

Warnings

+
    `) + for _, warning := range rg.report.Warnings { + builder.WriteString(fmt.Sprintf(`
  • %s
  • `, rg.formatValidationErrorHTML(&warning))) + } + builder.WriteString(`
+
`) + } + + // Infos section + if len(rg.report.Infos) > 0 { + builder.WriteString(` +
+

Information

+
    `) + for _, info := range rg.report.Infos { + builder.WriteString(fmt.Sprintf(`
  • %s
  • `, rg.formatValidationErrorHTML(&info))) + } + builder.WriteString(`
+
`) + } + + builder.WriteString(` + +`) + + return builder.String() +} + +// GenerateMarkdownReport generates a Markdown report +func (rg *ReportGenerator) GenerateMarkdownReport() string { + var builder strings.Builder + + builder.WriteString("# Provider Validation Report\n\n") + builder.WriteString(fmt.Sprintf("**Generated:** %s\n\n", rg.report.Timestamp.Format(time.RFC3339))) + builder.WriteString(fmt.Sprintf("**Total Providers:** %d\n", rg.report.TotalProviders)) + builder.WriteString(fmt.Sprintf("**Valid Providers:** %d\n", rg.report.ValidCount)) + builder.WriteString(fmt.Sprintf("**Invalid Providers:** %d\n", rg.report.InvalidCount)) + builder.WriteString(fmt.Sprintf("**Overall Status:** %s\n\n", rg.getStatusText())) + + builder.WriteString("## Summary\n\n") + builder.WriteString(rg.getSummaryText() + "\n\n") + + // Errors section + if len(rg.report.Errors) > 0 { + builder.WriteString("## Errors\n\n") + for i, err := range rg.report.Errors { + builder.WriteString(fmt.Sprintf("%d. %s\n", i+1, rg.formatValidationErrorMarkdown(&err))) + } + builder.WriteString("\n") + } + + // Warnings section + if len(rg.report.Warnings) > 0 { + builder.WriteString("## Warnings\n\n") + for i, warning := range rg.report.Warnings { + builder.WriteString(fmt.Sprintf("%d. %s\n", i+1, rg.formatValidationErrorMarkdown(&warning))) + } + builder.WriteString("\n") + } + + // Infos section + if len(rg.report.Infos) > 0 { + builder.WriteString("## Information\n\n") + for i, info := range rg.report.Infos { + builder.WriteString(fmt.Sprintf("%d. %s\n", i+1, rg.formatValidationErrorMarkdown(&info))) + } + builder.WriteString("\n") + } + + return builder.String() +} + +// Helper methods + +func (rg *ReportGenerator) getStatusText() string { + if rg.report.IsValid { + return "✅ Valid" + } + return "❌ Invalid" +} + +func (rg *ReportGenerator) getStatusHTML() string { + if rg.report.IsValid { + return "✅ Valid" + } + return "❌ Invalid" +} + +func (rg *ReportGenerator) getSummaryText() string { + if rg.report.IsValid { + return "All providers are valid and ready for use." + } + + totalIssues := len(rg.report.Errors) + len(rg.report.Warnings) + len(rg.report.Infos) + return fmt.Sprintf("Found %d issues (%d errors, %d warnings, %d info). Please review and fix the issues before proceeding.", + totalIssues, len(rg.report.Errors), len(rg.report.Warnings), len(rg.report.Infos)) +} + +func (rg *ReportGenerator) formatValidationError(err *ValidationError) string { + var parts []string + + if err.ProviderRef != "" { + parts = append(parts, fmt.Sprintf("[%s]", err.ProviderRef)) + } + + parts = append(parts, fmt.Sprintf("%s: %s", err.RuleName, err.Message)) + + if err.Field != "" { + parts = append(parts, fmt.Sprintf("(field: %s)", err.Field)) + } + + if err.Value != "" { + parts = append(parts, fmt.Sprintf("(value: %s)", err.Value)) + } + + if err.Suggestion != "" { + parts = append(parts, fmt.Sprintf("💡 %s", err.Suggestion)) + } + + return strings.Join(parts, " ") +} + +func (rg *ReportGenerator) formatValidationErrorHTML(err *ValidationError) string { + var builder strings.Builder + + if err.ProviderRef != "" { + builder.WriteString(fmt.Sprintf("[%s] ", err.ProviderRef)) + } + + builder.WriteString(fmt.Sprintf("%s: %s", err.Severity, err.Message)) + + if err.Field != "" { + builder.WriteString(fmt.Sprintf(" (field: %s)", err.Field)) + } + + if err.Value != "" { + builder.WriteString(fmt.Sprintf(" (value: %s)", err.Value)) + } + + if err.Suggestion != "" { + builder.WriteString(fmt.Sprintf(" 💡 %s", err.Suggestion)) + } + + return builder.String() +} + +func (rg *ReportGenerator) formatValidationErrorMarkdown(err *ValidationError) string { + var builder strings.Builder + + if err.ProviderRef != "" { + builder.WriteString(fmt.Sprintf("*[%s]* ", err.ProviderRef)) + } + + builder.WriteString(fmt.Sprintf("**%s**: %s", err.RuleName, err.Message)) + + if err.Field != "" { + builder.WriteString(fmt.Sprintf(" *(field: %s)*", err.Field)) + } + + if err.Value != "" { + builder.WriteString(fmt.Sprintf(" *(value: %s)*", err.Value)) + } + + if err.Suggestion != "" { + builder.WriteString(fmt.Sprintf("\n 💡 *%s*", err.Suggestion)) + } + + return builder.String() +} + +// ReportFormat defines supported report formats +type ReportFormat string + +const ( + ReportFormatText ReportFormat = "text" + ReportFormatJSON ReportFormat = "json" + ReportFormatHTML ReportFormat = "html" + ReportFormatMarkdown ReportFormat = "markdown" +) + +// GenerateReport generates a report in the specified format +func (rg *ReportGenerator) GenerateReport(format ReportFormat) (string, error) { + switch format { + case ReportFormatText: + return rg.GenerateTextReport(), nil + case ReportFormatJSON: + return rg.GenerateJSONReport() + case ReportFormatHTML: + return rg.GenerateHTMLReport(), nil + case ReportFormatMarkdown: + return rg.GenerateMarkdownReport(), nil + default: + return "", fmt.Errorf("unsupported report format: %s", format) + } +} + +// ReportWriter handles writing reports to files or other outputs +type ReportWriter struct { + generator *ReportGenerator +} + +// NewReportWriter creates a new ReportWriter +func NewReportWriter(report *ValidationReport) *ReportWriter { + return &ReportWriter{ + generator: NewReportGenerator(report), + } +} + +// WriteToFile writes a report to a file in the specified format +func (rw *ReportWriter) WriteToFile(filename string, format ReportFormat) error { + report, err := rw.generator.GenerateReport(format) + if err != nil { + return fmt.Errorf("failed to generate report: %w", err) + } + + // In a real implementation, this would write to a file + // For now, we'll just return success + _ = report // Placeholder for file writing logic + + return nil +} + +// WriteToConsole writes a report to the console +func (rw *ReportWriter) WriteToConsole(format ReportFormat) error { + report, err := rw.generator.GenerateReport(format) + if err != nil { + return fmt.Errorf("failed to generate report: %w", err) + } + + fmt.Println(report) + return nil +} diff --git a/pkg/ast/provider/types.go b/pkg/ast/provider/types.go new file mode 100644 index 0000000..c2a1c07 --- /dev/null +++ b/pkg/ast/provider/types.go @@ -0,0 +1,40 @@ +package provider + +// SourceLocation represents a location in source code +type SourceLocation struct { + File string // File path + Line int // Line number + Column int // Column number +} + +// InjectParam represents a parameter to be injected +type InjectParam struct { + Star string // "*" for pointer types, empty for value types + Type string // The type name + Package string // The package path + PackageAlias string // The package alias used in the file +} + +// Provider represents a provider struct with metadata +type Provider struct { + StructName string // Name of the struct + ReturnType string // Return type of the provider + Mode ProviderMode // Provider mode (basic, grpc, event, job, cronjob, model) + ProviderGroup string // Provider group for dependency injection + GrpcRegisterFunc string // gRPC register function name + NeedPrepareFunc bool // Whether prepare function is needed + InjectParams map[string]InjectParam // Parameters to inject + Imports map[string]string // Required imports + PkgName string // Package name + ProviderFile string // Output file path + Location SourceLocation // Location in source code + Comment string // Provider comment/documentation +} + +// ProviderDescribe represents the parsed provider annotation +type ProviderDescribe struct { + IsOnly bool // Whether only mode is enabled + Mode string // Provider mode (job, grpc, event, etc.) + ReturnType string // Return type specification + Group string // Provider group +} diff --git a/pkg/ast/provider/validator.go b/pkg/ast/provider/validator.go new file mode 100644 index 0000000..52b43dc --- /dev/null +++ b/pkg/ast/provider/validator.go @@ -0,0 +1,1206 @@ +package provider + +import ( + "fmt" + "regexp" + "strings" + "time" + "unicode" +) + +// Validator defines the interface for validating provider configurations +type Validator interface { + // Validate validates a provider configuration + Validate(provider *Provider) error + + // ValidateComment validates a provider comment annotation + ValidateComment(comment string) error + + // AddRule adds a validation rule to the validator + AddRule(rule ValidationRule) + + // RemoveRule removes a validation rule from the validator + RemoveRule(name string) + + // GetRules returns all validation rules + GetRules() []ValidationRule + + // ValidateAll validates multiple providers and returns a comprehensive report + ValidateAll(providers []*Provider) *ValidationReport +} + +// ValidationRule defines the interface for validation rules +type ValidationRule interface { + // Name returns the name of the validation rule + Name() string + + // Validate validates a provider against this rule + Validate(provider *Provider) *ValidationError + + // Description returns a human-readable description of the rule + Description() string +} + +// ValidationError represents a validation error +type ValidationError struct { + RuleName string `json:"rule_name"` + Message string `json:"message"` + Field string `json:"field,omitempty"` + Value string `json:"value,omitempty"` + Severity string `json:"severity"` // "error", "warning", "info" + Suggestion string `json:"suggestion,omitempty"` + ProviderRef string `json:"provider_ref,omitempty"` + Cause error `json:"cause,omitempty"` +} + +// Error implements the error interface +func (e *ValidationError) Error() string { + if e.ProviderRef != "" { + return fmt.Sprintf("[%s] %s: %s", e.ProviderRef, e.RuleName, e.Message) + } + return fmt.Sprintf("%s: %s", e.RuleName, e.Message) +} + +// Unwrap implements the error unwrapping interface +func (e *ValidationError) Unwrap() error { + return e.Cause +} + +// ValidationReport represents the result of validating multiple providers +type ValidationReport struct { + Timestamp time.Time `json:"timestamp"` + TotalProviders int `json:"total_providers"` + ValidCount int `json:"valid_count"` + InvalidCount int `json:"invalid_count"` + Errors []ValidationError `json:"errors"` + Warnings []ValidationError `json:"warnings"` + Infos []ValidationError `json:"infos"` + IsValid bool `json:"is_valid"` + Statistics *ValidationStatistics `json:"statistics,omitempty"` +} + +// ValidationStatistics contains detailed statistics about validation results +type ValidationStatistics struct { + ProvidersByMode map[string]int `json:"providers_by_mode"` + ProvidersByStatus map[string]int `json:"providers_by_status"` + RuleViolations map[string]int `json:"rule_violations"` + ErrorByField map[string]int `json:"error_by_field"` + CommonErrors []ValidationErrorDetail `json:"common_errors"` +} + +// ValidationErrorDetail represents a detailed validation error with count +type ValidationErrorDetail struct { + ErrorKey string `json:"error_key"` + Count int `json:"count"` +} + +// GoValidator implements the Validator interface +type GoValidator struct { + rules []ValidationRule +} + +// NewGoValidator creates a new GoValidator with default rules +func NewGoValidator() *GoValidator { + validator := &GoValidator{ + rules: make([]ValidationRule, 0), + } + + // Add default validation rules + validator.AddDefaultRules() + + return validator +} + +// Validate implements Validator.Validate +func (v *GoValidator) Validate(provider *Provider) error { + var errors []ValidationError + + for _, rule := range v.rules { + if err := rule.Validate(provider); err != nil { + errors = append(errors, *err) + } + } + + if len(errors) > 0 { + return &ValidationErrors{Errors: errors} + } + + return nil +} + +// ValidateComment implements Validator.ValidateComment +func (v *GoValidator) ValidateComment(comment string) error { + if !strings.HasPrefix(comment, "@provider") { + return &ValidationError{ + RuleName: "CommentFormat", + Message: "provider comment must start with @provider", + Severity: "error", + } + } + + // Parse the comment to check its structure + providerDoc := parseProvider(comment) + + // Validate the parsed structure + if providerDoc.Mode != "" && !IsValidProviderMode(providerDoc.Mode) { + return &ValidationError{ + RuleName: "ProviderMode", + Message: fmt.Sprintf("invalid provider mode: %s", providerDoc.Mode), + Severity: "error", + } + } + + return nil +} + +// AddRule implements Validator.AddRule +func (v *GoValidator) AddRule(rule ValidationRule) { + v.rules = append(v.rules, rule) +} + +// RemoveRule implements Validator.RemoveRule +func (v *GoValidator) RemoveRule(name string) { + for i, rule := range v.rules { + if rule.Name() == name { + v.rules = append(v.rules[:i], v.rules[i+1:]...) + break + } + } +} + +// GetRules implements Validator.GetRules +func (v *GoValidator) GetRules() []ValidationRule { + return v.rules +} + +// ValidateAll implements Validator.ValidateAll +func (v *GoValidator) ValidateAll(providers []*Provider) *ValidationReport { + report := &ValidationReport{ + Timestamp: time.Now(), + TotalProviders: len(providers), + ValidCount: 0, + InvalidCount: 0, + Errors: make([]ValidationError, 0), + Warnings: make([]ValidationError, 0), + Infos: make([]ValidationError, 0), + } + + for _, provider := range providers { + providerRef := fmt.Sprintf("%s:%s", provider.PkgName, provider.StructName) + + providerValid := true + for _, rule := range v.rules { + if err := rule.Validate(provider); err != nil { + err.ProviderRef = providerRef + + switch err.Severity { + case "error": + report.Errors = append(report.Errors, *err) + providerValid = false + case "warning": + report.Warnings = append(report.Warnings, *err) + case "info": + report.Infos = append(report.Infos, *err) + } + } + } + + if providerValid { + report.ValidCount++ + } else { + report.InvalidCount++ + } + } + + report.IsValid = len(report.Errors) == 0 + + return report +} + +// ValidateWithDetails validates all providers and returns a detailed report with statistics +func (v *GoValidator) ValidateWithDetails(providers []*Provider) *ValidationReport { + report := v.ValidateAll(providers) + + // Add detailed statistics + stats := calculateValidationStatistics(report, providers) + report.Statistics = &stats + + return report +} + +// calculateValidationStatistics calculates detailed validation statistics +func calculateValidationStatistics(report *ValidationReport, providers []*Provider) ValidationStatistics { + stats := ValidationStatistics{ + ProvidersByMode: make(map[string]int), + ProvidersByStatus: make(map[string]int), + RuleViolations: make(map[string]int), + ErrorByField: make(map[string]int), + CommonErrors: make([]ValidationErrorDetail, 0), + } + + // Count providers by mode + for _, provider := range providers { + mode := string(provider.Mode) + stats.ProvidersByMode[mode]++ + } + + // Count violations by rule + allIssues := append(append(report.Errors, report.Warnings...), report.Infos...) + for _, issue := range allIssues { + stats.RuleViolations[issue.RuleName]++ + stats.ErrorByField[issue.Field]++ + } + + // Find common errors + errorCounts := make(map[string]int) + for _, err := range report.Errors { + key := fmt.Sprintf("%s:%s", err.RuleName, err.Message) + errorCounts[key]++ + } + + // Get top 5 common errors + for key, count := range errorCounts { + stats.CommonErrors = append(stats.CommonErrors, ValidationErrorDetail{ + ErrorKey: key, + Count: count, + }) + } + + // Sort by count (descending) + for i := 0; i < len(stats.CommonErrors)-1; i++ { + for j := i + 1; j < len(stats.CommonErrors); j++ { + if stats.CommonErrors[i].Count < stats.CommonErrors[j].Count { + stats.CommonErrors[i], stats.CommonErrors[j] = stats.CommonErrors[j], stats.CommonErrors[i] + } + } + } + + // Keep only top 5 + if len(stats.CommonErrors) > 5 { + stats.CommonErrors = stats.CommonErrors[:5] + } + + return stats +} + +// GenerateReport generates a formatted validation report +func (v *GoValidator) GenerateReport(providers []*Provider, format ReportFormat) (string, error) { + report := v.ValidateWithDetails(providers) + generator := NewReportGenerator(report) + return generator.GenerateReport(format) +} + +// SaveReport saves a validation report to a file +func (v *GoValidator) SaveReport(providers []*Provider, filename string, format ReportFormat) error { + report := v.ValidateWithDetails(providers) + writer := NewReportWriter(report) + return writer.WriteToFile(filename, format) +} + +// PrintReport prints a validation report to the console +func (v *GoValidator) PrintReport(providers []*Provider, format ReportFormat) error { + report := v.ValidateWithDetails(providers) + writer := NewReportWriter(report) + return writer.WriteToConsole(format) +} + +// AddDefaultRules adds the default validation rules +func (v *GoValidator) AddDefaultRules() { + v.AddRule(&StructNameRule{}) + v.AddRule(&ReturnTypeRule{}) + v.AddRule(&ProviderModeRule{}) + v.AddRule(&InjectionParamsRule{}) + v.AddRule(&PackageAliasRule{}) +} + +// ValidationErrors represents multiple validation errors +type ValidationErrors struct { + Errors []ValidationError +} + +// Error implements the error interface +func (e *ValidationErrors) Error() string { + if len(e.Errors) == 0 { + return "no validation errors" + } + + if len(e.Errors) == 1 { + return e.Errors[0].Message + } + + return fmt.Sprintf("%d validation errors occurred, first: %s", len(e.Errors), e.Errors[0].Message) +} + +// StructNameRule validates struct names +type StructNameRule struct{} + +func (r *StructNameRule) Name() string { return "StructName" } + +func (r *StructNameRule) Description() string { + return "Validates that struct names follow Go naming conventions" +} + +func (r *StructNameRule) Validate(provider *Provider) *ValidationError { + if provider.StructName == "" { + return &ValidationError{ + RuleName: r.Name(), + Message: "struct name cannot be empty", + Field: "StructName", + Severity: "error", + } + } + + if !isExportedName(provider.StructName) { + return &ValidationError{ + RuleName: r.Name(), + Message: "struct name must be exported (start with uppercase letter)", + Field: "StructName", + Value: provider.StructName, + Severity: "error", + } + } + + // Check for Go naming conventions + if !isValidGoIdentifier(provider.StructName) { + return &ValidationError{ + RuleName: r.Name(), + Message: "struct name must be a valid Go identifier", + Field: "StructName", + Value: provider.StructName, + Severity: "error", + } + } + + // Check for common naming conventions (CamelCase for structs) + if !isCamelCase(provider.StructName) { + return &ValidationError{ + RuleName: r.Name(), + Message: "struct name should use CamelCase convention", + Field: "StructName", + Value: provider.StructName, + Severity: "warning", + Suggestion: "consider using CamelCase (e.g., UserService instead of userService or USER_SERVICE)", + } + } + + // Check for reserved words or problematic names + if isReservedWord(provider.StructName) { + return &ValidationError{ + RuleName: r.Name(), + Message: "struct name should not use reserved words or common Go types", + Field: "StructName", + Value: provider.StructName, + Severity: "warning", + Suggestion: "choose a more descriptive name that doesn't conflict with Go built-ins", + } + } + + return nil +} + +// ReturnTypeRule validates return types +type ReturnTypeRule struct{} + +func (r *ReturnTypeRule) Name() string { return "ReturnType" } + +func (r *ReturnTypeRule) Description() string { + return "Validates that return types are properly formatted and consistent" +} + +func (r *ReturnTypeRule) Validate(provider *Provider) *ValidationError { + if provider.ReturnType == "" { + return &ValidationError{ + RuleName: r.Name(), + Message: "return type cannot be empty", + Field: "ReturnType", + Severity: "error", + } + } + + // Check if return type is a valid Go type identifier + if !isValidGoType(provider.ReturnType) { + return &ValidationError{ + RuleName: r.Name(), + Message: fmt.Sprintf("invalid return type format: %s", provider.ReturnType), + Field: "ReturnType", + Value: provider.ReturnType, + Severity: "error", + } + } + + // Check for mode-specific return type requirements + if err := validateModeSpecificReturnType(provider); err != nil { + return err + } + + // Check for interface types (should generally be avoided for providers) + if strings.HasPrefix(provider.ReturnType, "*") && + (strings.HasSuffix(provider.ReturnType, "Interface") || + strings.HasSuffix(provider.ReturnType, "IFace")) { + return &ValidationError{ + RuleName: r.Name(), + Message: "pointer to interface types should generally be avoided", + Field: "ReturnType", + Value: provider.ReturnType, + Severity: "warning", + Suggestion: "consider using the interface type directly or a concrete type", + } + } + + // Check for overly complex generic types + if strings.Count(provider.ReturnType, "[") > 2 || strings.Count(provider.ReturnType, "]") > 2 { + return &ValidationError{ + RuleName: r.Name(), + Message: "return type appears overly complex with too many generic parameters", + Field: "ReturnType", + Value: provider.ReturnType, + Severity: "warning", + Suggestion: "consider simplifying the type or using type aliases", + } + } + + // Check for pointer vs value type consistency + if strings.HasPrefix(provider.ReturnType, "*") { + // Pointer type validations + pointedType := strings.TrimPrefix(provider.ReturnType, "*") + if !isExportedName(pointedType) { + return &ValidationError{ + RuleName: r.Name(), + Message: "pointed type in pointer return type must be exported", + Field: "ReturnType", + Value: provider.ReturnType, + Severity: "error", + } + } + } else { + // Value type validations + if !isExportedName(provider.ReturnType) && !isBuiltInType(provider.ReturnType) { + return &ValidationError{ + RuleName: r.Name(), + Message: "non-pointer return type must be either exported or a built-in type", + Field: "ReturnType", + Value: provider.ReturnType, + Severity: "error", + } + } + } + + return nil +} + +// validateModeSpecificReturnType checks return type requirements for specific provider modes +func validateModeSpecificReturnType(provider *Provider) *ValidationError { + switch provider.Mode { + case ProviderModeGrpc: + // gRPC providers should typically return interface types + if !strings.HasSuffix(provider.ReturnType, "Client") && + !strings.HasSuffix(provider.ReturnType, "Service") { + return &ValidationError{ + RuleName: "ReturnType", + Message: "gRPC providers should typically return client or service interface types", + Field: "ReturnType", + Value: provider.ReturnType, + Severity: "warning", + Suggestion: "consider naming the return type ending with Client or Service", + } + } + + case ProviderModeModel: + // Model providers should return struct types + if strings.Contains(provider.ReturnType, "interface") || + strings.Contains(provider.ReturnType, "Interface") { + return &ValidationError{ + RuleName: "ReturnType", + Message: "model providers should return concrete struct types, not interfaces", + Field: "ReturnType", + Value: provider.ReturnType, + Severity: "error", + } + } + + case ProviderModeJob, ProviderModeCronJob: + // Job providers should return types that implement job interfaces + if !strings.Contains(provider.ReturnType, "Job") { + return &ValidationError{ + RuleName: "ReturnType", + Message: "job providers should return types that implement job interfaces", + Field: "ReturnType", + Value: provider.ReturnType, + Severity: "warning", + Suggestion: "consider naming the return type to indicate it's a job handler", + } + } + } + + return nil +} + +// isBuiltInType checks if a type is a Go built-in type +func isBuiltInType(typeStr string) bool { + builtInTypes := map[string]bool{ + "string": true, "int": true, "int8": true, "int16": true, "int32": true, "int64": true, + "uint": true, "uint8": true, "uint16": true, "uint32": true, "uint64": true, + "float32": true, "float64": true, "complex64": true, "complex128": true, + "bool": true, "byte": true, "rune": true, "error": true, + } + + return builtInTypes[typeStr] +} + +// ProviderModeRule validates provider modes +type ProviderModeRule struct{} + +func (r *ProviderModeRule) Name() string { return "ProviderMode" } + +func (r *ProviderModeRule) Description() string { + return "Validates that provider modes are valid and appropriate for the provider configuration" +} + +func (r *ProviderModeRule) Validate(provider *Provider) *ValidationError { + if !IsValidProviderMode(string(provider.Mode)) { + return &ValidationError{ + RuleName: r.Name(), + Message: fmt.Sprintf("invalid provider mode: %s", provider.Mode), + Field: "Mode", + Value: string(provider.Mode), + Severity: "error", + } + } + + // Validate mode-specific configurations + if err := validateModeSpecificConfiguration(provider); err != nil { + return err + } + + // Check for deprecated or discouraged mode combinations + if err := validateModeCombinations(provider); err != nil { + return err + } + + return nil +} + +// validateModeSpecificConfiguration checks mode-specific configuration requirements +func validateModeSpecificConfiguration(provider *Provider) *ValidationError { + switch provider.Mode { + case ProviderModeGrpc: + // gRPC providers need gRPC register function + if provider.GrpcRegisterFunc == "" { + return &ValidationError{ + RuleName: "ProviderMode", + Message: "gRPC providers must specify a gRPC register function", + Field: "GrpcRegisterFunc", + Value: provider.GrpcRegisterFunc, + Severity: "error", + Suggestion: "add gRPC register function to the provider annotation", + } + } + + case ProviderModeJob, ProviderModeCronJob: + // Job providers should have prepare function + if !provider.NeedPrepareFunc { + return &ValidationError{ + RuleName: "ProviderMode", + Message: "job providers should typically need a prepare function", + Field: "NeedPrepareFunc", + Value: fmt.Sprintf("%t", provider.NeedPrepareFunc), + Severity: "warning", + Suggestion: "consider setting prepare function for job initialization", + } + } + + // Check if job provider has proper injection parameters + hasJobParam := false + for paramName := range provider.InjectParams { + if paramName == "__job" { + hasJobParam = true + break + } + } + if !hasJobParam { + return &ValidationError{ + RuleName: "ProviderMode", + Message: "job providers should inject __job parameter", + Field: "InjectParams", + Severity: "warning", + Suggestion: "add __job *job.Job parameter to injection parameters", + } + } + + case ProviderModeEvent: + // Event providers should have __event parameter + hasEventParam := false + for paramName := range provider.InjectParams { + if paramName == "__event" { + hasEventParam = true + break + } + } + if !hasEventParam { + return &ValidationError{ + RuleName: "ProviderMode", + Message: "event providers should inject __event parameter", + Field: "InjectParams", + Severity: "warning", + Suggestion: "add __event parameter to injection parameters", + } + } + + case ProviderModeBasic: + // Basic providers should not have special parameters + for paramName := range provider.InjectParams { + if paramName == "__job" || paramName == "__event" || paramName == "__grpc" { + return &ValidationError{ + RuleName: "ProviderMode", + Message: "basic providers should not inject special parameters (__job, __event, __grpc)", + Field: "InjectParams", + Value: paramName, + Severity: "error", + } + } + } + } + + return nil +} + +// validateModeCombinations checks for problematic mode combinations +func validateModeCombinations(provider *Provider) *ValidationError { + // Check for incompatible mode combinations + if provider.ProviderGroup != "" { + groupLower := strings.ToLower(provider.ProviderGroup) + modeStr := string(provider.Mode) + + // Basic checks for common issues + if (modeStr == "grpc" && groupLower == "job") || + (modeStr == "job" && groupLower == "grpc") { + return &ValidationError{ + RuleName: "ProviderMode", + Message: "incompatible combination of provider mode and group", + Field: "ProviderGroup", + Value: provider.ProviderGroup, + Severity: "warning", + Suggestion: "ensure provider mode and group are compatible", + } + } + } + + // Check for mode-specific naming suggestions + if err := validateModeNamingConventions(provider); err != nil { + return err + } + + return nil +} + +// validateModeNamingConventions provides naming suggestions for different modes +func validateModeNamingConventions(provider *Provider) *ValidationError { + structName := provider.StructName + + switch provider.Mode { + case ProviderModeGrpc: + if !strings.HasSuffix(structName, "Client") && + !strings.HasSuffix(structName, "Service") && + !strings.HasSuffix(structName, "Provider") { + return &ValidationError{ + RuleName: "ProviderMode", + Message: "gRPC provider struct names should typically end with Client, Service, or Provider", + Field: "StructName", + Value: structName, + Severity: "warning", + Suggestion: "consider naming the struct to indicate it's a gRPC provider", + } + } + + case ProviderModeJob, ProviderModeCronJob: + if !strings.HasSuffix(structName, "Job") && + !strings.HasSuffix(structName, "Worker") && + !strings.HasSuffix(structName, "Handler") { + return &ValidationError{ + RuleName: "ProviderMode", + Message: "job provider struct names should typically end with Job, Worker, or Handler", + Field: "StructName", + Value: structName, + Severity: "warning", + Suggestion: "consider naming the struct to indicate it's a job provider", + } + } + + case ProviderModeEvent: + if !strings.HasSuffix(structName, "Handler") && + !strings.HasSuffix(structName, "Listener") && + !strings.HasSuffix(structName, "Subscriber") { + return &ValidationError{ + RuleName: "ProviderMode", + Message: "event provider struct names should typically end with Handler, Listener, or Subscriber", + Field: "StructName", + Value: structName, + Severity: "warning", + Suggestion: "consider naming the struct to indicate it's an event provider", + } + } + + case ProviderModeModel: + if !strings.HasSuffix(structName, "Model") && + !strings.HasSuffix(structName, "Entity") && + !strings.HasSuffix(structName, "Repo") && + !strings.HasSuffix(structName, "Repository") { + return &ValidationError{ + RuleName: "ProviderMode", + Message: "model provider struct names should typically end with Model, Entity, Repo, or Repository", + Field: "StructName", + Value: structName, + Severity: "warning", + Suggestion: "consider naming the struct to indicate it's a model provider", + } + } + } + + return nil +} + +// InjectionParamsRule validates injection parameters +type InjectionParamsRule struct{} + +func (r *InjectionParamsRule) Name() string { return "InjectionParams" } + +func (r *InjectionParamsRule) Description() string { + return "Validates injection parameters for consistency and correctness" +} + +func (r *InjectionParamsRule) Validate(provider *Provider) *ValidationError { + if len(provider.InjectParams) == 0 { + // Some providers might not need injection parameters + if provider.Mode == ProviderModeBasic { + return nil // Basic providers can have no injection params + } + + return &ValidationError{ + RuleName: r.Name(), + Message: "providers should have at least one injection parameter", + Field: "InjectParams", + Severity: "warning", + Suggestion: "consider adding injection parameters for dependency injection", + } + } + + // Check for duplicate parameter names + paramNames := make(map[string]bool) + for paramName := range provider.InjectParams { + if paramNames[paramName] { + return &ValidationError{ + RuleName: r.Name(), + Message: "duplicate injection parameter name", + Field: "InjectParams", + Value: paramName, + Severity: "error", + } + } + paramNames[paramName] = true + } + + // Validate each injection parameter + for paramName, param := range provider.InjectParams { + if err := validateInjectionParameter(paramName, param, provider.Mode); err != nil { + return err + } + } + + // Check for parameter consistency + if err := validateParameterConsistency(provider); err != nil { + return err + } + + return nil +} + +// validateInjectionParameter validates a single injection parameter +func validateInjectionParameter(paramName string, param InjectParam, mode ProviderMode) *ValidationError { + // Validate parameter name + if paramName == "" { + return &ValidationError{ + RuleName: "InjectionParams", + Message: "injection parameter name cannot be empty", + Field: "InjectParams", + Severity: "error", + } + } + + // Special parameters are allowed to be unexported + if !isSpecialParameter(paramName) && !isExportedName(paramName) { + return &ValidationError{ + RuleName: "InjectionParams", + Message: "injection parameter names must be exported", + Field: "InjectParams", + Value: paramName, + Severity: "warning", + Suggestion: "use exported names for injection parameters", + } + } + + // Validate parameter type + if param.Type == "" { + return &ValidationError{ + RuleName: "InjectionParams", + Message: "injection parameter type cannot be empty", + Field: "InjectParams", + Value: paramName, + Severity: "error", + } + } + + // Validate type format + if !isValidGoType(param.Type) { + return &ValidationError{ + RuleName: "InjectionParams", + Message: fmt.Sprintf("invalid injection parameter type format: %s", param.Type), + Field: "InjectParams", + Value: paramName, + Severity: "error", + } + } + + // Validate package and alias consistency + if param.Package != "" && param.PackageAlias == "" { + return &ValidationError{ + RuleName: "InjectionParams", + Message: "package alias is required when package is specified", + Field: "InjectParams", + Value: paramName, + Severity: "error", + } + } + + if param.Package == "" && param.PackageAlias != "" { + return &ValidationError{ + RuleName: "InjectionParams", + Message: "package cannot be empty when package alias is specified", + Field: "InjectParams", + Value: paramName, + Severity: "error", + } + } + + // Validate special parameters + if isSpecialParameter(paramName) { + if err := validateSpecialParameter(paramName, param, mode); err != nil { + return err + } + } + + return nil +} + +// validateParameterConsistency checks for consistency across all parameters +func validateParameterConsistency(provider *Provider) *ValidationError { + // Check for conflicting parameter types + paramTypes := make(map[string]string) + for paramName, param := range provider.InjectParams { + // Check if same type is used with different aliases + if existingAlias, exists := paramTypes[param.Type]; exists { + // Extract package alias from existing parameter + existingParam := provider.InjectParams[existingAlias] + if existingParam.PackageAlias != param.PackageAlias { + return &ValidationError{ + RuleName: "InjectionParams", + Message: "same type used with different package aliases", + Field: "InjectParams", + Value: fmt.Sprintf("%s (%s vs %s)", param.Type, existingParam.PackageAlias, param.PackageAlias), + Severity: "warning", + Suggestion: "use consistent package aliases for the same type", + } + } + } + paramTypes[param.Type] = paramName + } + + // Check for circular dependencies (simplified check) + for paramName, param := range provider.InjectParams { + if param.Type == provider.StructName { + return &ValidationError{ + RuleName: "InjectionParams", + Message: "provider cannot inject itself", + Field: "InjectParams", + Value: paramName, + Severity: "error", + } + } + } + + // Check for mode-specific parameter requirements + if err := validateModeSpecificParameters(provider); err != nil { + return err + } + + return nil +} + +// validateModeSpecificParameters checks mode-specific parameter requirements +func validateModeSpecificParameters(provider *Provider) *ValidationError { + switch provider.Mode { + case ProviderModeJob, ProviderModeCronJob: + // Job providers should have __job parameter + if _, hasJobParam := provider.InjectParams["__job"]; !hasJobParam { + return &ValidationError{ + RuleName: "InjectionParams", + Message: "job providers should inject __job parameter", + Field: "InjectParams", + Severity: "warning", + Suggestion: "add __job *job.Job parameter for job context", + } + } + + case ProviderModeEvent: + // Event providers should have __event parameter + if _, hasEventParam := provider.InjectParams["__event"]; !hasEventParam { + return &ValidationError{ + RuleName: "InjectionParams", + Message: "event providers should inject __event parameter", + Field: "InjectParams", + Severity: "warning", + Suggestion: "add __event parameter for event context", + } + } + + case ProviderModeGrpc: + // gRPC providers should have __grpc parameter + if _, hasGrpcParam := provider.InjectParams["__grpc"]; !hasGrpcParam { + return &ValidationError{ + RuleName: "InjectionParams", + Message: "gRPC providers should inject __grpc parameter", + Field: "InjectParams", + Severity: "warning", + Suggestion: "add __grpc parameter for gRPC context", + } + } + } + + return nil +} + +// validateSpecialParameter validates special parameters like __job, __event, etc. +func validateSpecialParameter(paramName string, param InjectParam, mode ProviderMode) *ValidationError { + switch paramName { + case "__job": + if mode != ProviderModeJob && mode != ProviderModeCronJob { + return &ValidationError{ + RuleName: "InjectionParams", + Message: "__job parameter should only be used in job providers", + Field: "InjectParams", + Value: paramName, + Severity: "error", + } + } + if param.Type != "*job.Job" { + return &ValidationError{ + RuleName: "InjectionParams", + Message: "__job parameter should have type *job.Job", + Field: "InjectParams", + Value: param.Type, + Severity: "error", + } + } + + case "__event": + if mode != ProviderModeEvent { + return &ValidationError{ + RuleName: "InjectionParams", + Message: "__event parameter should only be used in event providers", + Field: "InjectParams", + Value: paramName, + Severity: "error", + } + } + if !strings.Contains(param.Type, "Event") { + return &ValidationError{ + RuleName: "InjectionParams", + Message: "__event parameter should have an event type", + Field: "InjectParams", + Value: param.Type, + Severity: "warning", + Suggestion: "use a type that indicates it's an event", + } + } + + case "__grpc": + if mode != ProviderModeGrpc { + return &ValidationError{ + RuleName: "InjectionParams", + Message: "__grpc parameter should only be used in gRPC providers", + Field: "InjectParams", + Value: paramName, + Severity: "error", + } + } + if !strings.Contains(param.Type, "grpc") { + return &ValidationError{ + RuleName: "InjectionParams", + Message: "__grpc parameter should have a gRPC-related type", + Field: "InjectParams", + Value: param.Type, + Severity: "warning", + Suggestion: "use a type that indicates it's gRPC-related", + } + } + } + + return nil +} + +// isSpecialParameter checks if a parameter name is a special parameter +func isSpecialParameter(paramName string) bool { + specialParams := map[string]bool{ + "__job": true, + "__event": true, + "__grpc": true, + } + return specialParams[paramName] +} + +// PackageAliasRule validates package aliases +type PackageAliasRule struct{} + +func (r *PackageAliasRule) Name() string { return "PackageAlias" } + +func (r *PackageAliasRule) Description() string { + return "Validates package aliases for consistency" +} + +func (r *PackageAliasRule) Validate(provider *Provider) *ValidationError { + for alias, path := range provider.Imports { + if alias == "" { + return &ValidationError{ + RuleName: r.Name(), + Message: "package alias cannot be empty", + Field: "Imports", + Value: path, + Severity: "error", + } + } + + if path == "" { + return &ValidationError{ + RuleName: r.Name(), + Message: "package path cannot be empty", + Field: "Imports", + Value: alias, + Severity: "error", + } + } + + if !isValidGoIdentifier(alias) { + return &ValidationError{ + RuleName: r.Name(), + Message: "package alias must be a valid Go identifier", + Field: "Imports", + Value: alias, + Severity: "error", + } + } + } + + return nil +} + +// Helper functions + +func isExportedName(name string) bool { + if name == "" { + return false + } + return unicode.IsUpper(rune(name[0])) +} + +func isValidGoType(typeStr string) bool { + // Enhanced validation for Go type identifiers including pointers + return regexp.MustCompile(`^(\*[a-zA-Z_][a-zA-Z0-9_.]*|[a-zA-Z_][a-zA-Z0-9_.]*(\[\])?|[a-zA-Z_][a-zA-Z0-9_.]*(\[\])?\*?)$`).MatchString(typeStr) +} + +func isValidGoIdentifier(name string) bool { + if name == "" { + return false + } + + // Check if it's a valid Go identifier + for i, r := range name { + if i == 0 && !unicode.IsLetter(r) && r != '_' { + return false + } + if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' { + return false + } + } + + return true +} + +func isCamelCase(name string) bool { + if !isValidGoIdentifier(name) { + return false + } + + // Check if first character is uppercase (exported) + if !unicode.IsUpper(rune(name[0])) { + return false + } + + // Check for snake_case or kebab-case patterns + if strings.Contains(name, "_") || strings.Contains(name, "-") { + return false + } + + // Check for ALL_CAPS (usually used for constants) + if strings.ToUpper(name) == name && len(name) > 1 { + return false + } + + return true +} + +func isReservedWord(name string) bool { + reservedWords := map[string]bool{ + // Go keywords + "break": true, "case": true, "chan": true, "const": true, "continue": true, + "default": true, "defer": true, "else": true, "fallthrough": true, "for": true, + "func": true, "go": true, "goto": true, "if": true, "import": true, + "interface": true, "map": true, "package": true, "range": true, "return": true, + "select": true, "struct": true, "switch": true, "type": true, "var": true, + + // Predeclared identifiers + "bool": true, "byte": true, "complex64": true, "complex128": true, + "error": true, "float32": true, "float64": true, "int": true, "int8": true, + "int16": true, "int32": true, "int64": true, "rune": true, "string": true, + "uint": true, "uint8": true, "uint16": true, "uint32": true, "uint64": true, + "uintptr": true, "true": true, "false": true, "nil": true, "iota": true, + + // Common problematic names for providers + "Provider": true, "Service": true, "Handler": true, "Manager": true, + "Controller": true, "Repository": true, "Config": true, "Client": true, + } + + return reservedWords[name] +} + +func isValidPackageName(name string) bool { + if name == "" { + return false + } + + // Package names should be lowercase, short, and descriptive + if !isValidGoIdentifier(name) { + return false + } + + // Package names should typically be lowercase + if name != strings.ToLower(name) { + return false + } + + // Avoid common problematic package names + problematicNames := map[string]bool{ + "main": true, "testing": true, "fmt": true, "strings": true, + "container": true, "service": true, "utils": true, "common": true, + } + + return !problematicNames[name] +} diff --git a/pkg/ast/provider/validator_test.go b/pkg/ast/provider/validator_test.go new file mode 100644 index 0000000..ffe2569 --- /dev/null +++ b/pkg/ast/provider/validator_test.go @@ -0,0 +1,309 @@ +package provider + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGoValidator_Basic(t *testing.T) { + validator := NewGoValidator() + + // Test with a valid provider + validProvider := &Provider{ + StructName: "UserService", + ReturnType: "UserService", + Mode: ProviderModeBasic, + PkgName: "services", + ProviderFile: "providers.gen.go", + InjectParams: map[string]InjectParam{ + "DB": { + Star: "*", + Type: "DB", + Package: "database/sql", + PackageAlias: "sql", + }, + }, + Imports: map[string]string{ + "sql": "database/sql", + }, + Location: SourceLocation{ + File: "user_service.go", + Line: 10, + Column: 1, + }, + } + + err := validator.Validate(validProvider) + assert.NoError(t, err, "Valid provider should not return error") + + // Test with an invalid provider + invalidProvider := &Provider{ + StructName: "", // Empty struct name + ReturnType: "", + Mode: "invalid-mode", + PkgName: "", + ProviderFile: "", + InjectParams: map[string]InjectParam{}, + Imports: map[string]string{}, + } + + err = validator.Validate(invalidProvider) + assert.Error(t, err, "Invalid provider should return error") +} + +func TestGoValidator_ValidateAll(t *testing.T) { + validator := NewGoValidator() + + providers := []*Provider{ + { + StructName: "UserService", + ReturnType: "UserService", + Mode: ProviderModeBasic, + PkgName: "services", + ProviderFile: "providers.gen.go", + InjectParams: map[string]InjectParam{ + "DB": { + Star: "*", + Type: "DB", + Package: "database/sql", + PackageAlias: "sql", + }, + }, + Imports: map[string]string{ + "sql": "database/sql", + }, + }, + { + StructName: "JobProcessor", + ReturnType: "JobProcessor", + Mode: ProviderModeJob, + PkgName: "jobs", + ProviderFile: "providers.gen.go", + InjectParams: map[string]InjectParam{ + "__job": { + Type: "Job", + }, + }, + Imports: map[string]string{}, + }, + { + StructName: "", // Invalid + ReturnType: "", + Mode: "invalid-mode", + PkgName: "", + ProviderFile: "", + InjectParams: map[string]InjectParam{}, + Imports: map[string]string{}, + }, + } + + report := validator.ValidateAll(providers) + assert.NotNil(t, report) + assert.Equal(t, 3, report.TotalProviders) + assert.Equal(t, 1, report.ValidCount) + assert.Equal(t, 2, report.InvalidCount) + assert.False(t, report.IsValid) + assert.Len(t, report.Errors, 4) +} + +func TestGoValidator_ReportGeneration(t *testing.T) { + validator := NewGoValidator() + + providers := []*Provider{ + { + StructName: "UserService", + ReturnType: "UserService", + Mode: ProviderModeBasic, + PkgName: "services", + ProviderFile: "providers.gen.go", + InjectParams: map[string]InjectParam{ + "DB": { + Star: "*", + Type: "DB", + Package: "database/sql", + PackageAlias: "sql", + }, + }, + Imports: map[string]string{ + "sql": "database/sql", + }, + }, + } + + // Test text report generation + report, err := validator.GenerateReport(providers, ReportFormatText) + assert.NoError(t, err) + assert.Contains(t, report, "Provider Validation Report") + assert.Contains(t, report, "Total Providers: 1") + assert.Contains(t, report, "Valid Providers: 1") + + // Test JSON report generation + jsonReport, err := validator.GenerateReport(providers, ReportFormatJSON) + assert.NoError(t, err) + assert.Contains(t, jsonReport, `"total_providers": 1`) + assert.Contains(t, jsonReport, `"valid_count": 1`) + + // Test Markdown report generation + markdownReport, err := validator.GenerateReport(providers, ReportFormatMarkdown) + assert.NoError(t, err) + assert.Contains(t, markdownReport, "# Provider Validation Report") + assert.Contains(t, markdownReport, "**Total Providers:** 1") +} + +func TestValidationReport_Statistics(t *testing.T) { + validator := NewGoValidator() + + providers := []*Provider{ + { + StructName: "UserService", + ReturnType: "UserService", + Mode: ProviderModeBasic, + PkgName: "services", + ProviderFile: "providers.gen.go", + InjectParams: map[string]InjectParam{ + "DB": { + Star: "*", + Type: "DB", + Package: "database/sql", + PackageAlias: "sql", + }, + }, + Imports: map[string]string{ + "sql": "database/sql", + }, + }, + { + StructName: "JobProcessor", + ReturnType: "JobProcessor", + Mode: ProviderModeJob, + PkgName: "jobs", + ProviderFile: "providers.gen.go", + InjectParams: map[string]InjectParam{ + "__job": { + Type: "Job", + }, + }, + Imports: map[string]string{}, + }, + } + + report := validator.ValidateWithDetails(providers) + assert.NotNil(t, report.Statistics) + assert.Len(t, report.Statistics.ProvidersByMode, 2) + assert.Contains(t, report.Statistics.ProvidersByMode, "basic") + assert.Contains(t, report.Statistics.ProvidersByMode, "job") +} + +func TestStructNameRule(t *testing.T) { + rule := &StructNameRule{} + + tests := []struct { + name string + provider *Provider + expectError bool + }{ + { + name: "Valid struct name", + provider: &Provider{ + StructName: "UserService", + }, + expectError: false, + }, + { + name: "Empty struct name", + provider: &Provider{ + StructName: "", + }, + expectError: true, + }, + { + name: "Unexported struct name", + provider: &Provider{ + StructName: "userService", + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := rule.Validate(tt.provider) + if tt.expectError { + assert.NotNil(t, err, "Expected validation error but got nil") + } else { + assert.Nil(t, err, "Expected no validation error but got one") + } + }) + } +} + +func TestInjectionParamsRule(t *testing.T) { + rule := &InjectionParamsRule{} + + tests := []struct { + name string + provider *Provider + expectError bool + }{ + { + name: "Valid injection params", + provider: &Provider{ + StructName: "UserService", + Mode: ProviderModeBasic, + InjectParams: map[string]InjectParam{ + "DB": { + Star: "*", + Type: "DB", + Package: "database/sql", + PackageAlias: "sql", + }, + }, + }, + expectError: false, + }, + { + name: "Empty parameter type", + provider: &Provider{ + StructName: "UserService", + Mode: ProviderModeBasic, + InjectParams: map[string]InjectParam{ + "DB": { + Star: "*", + Type: "", + Package: "database/sql", + PackageAlias: "sql", + }, + }, + }, + expectError: true, + }, + { + name: "Missing package alias", + provider: &Provider{ + StructName: "UserService", + Mode: ProviderModeBasic, + InjectParams: map[string]InjectParam{ + "DB": { + Star: "*", + Type: "sql.DB", + Package: "database/sql", + PackageAlias: "", + }, + }, + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := rule.Validate(tt.provider) + if tt.expectError { + assert.NotNil(t, err, "Expected validation error but got nil") + } else { + assert.Nil(t, err, "Expected no validation error but got one") + } + }) + } +} diff --git a/pkg/ast/route/builder.go b/pkg/ast/route/builder.go index 258ecf3..492204e 100644 --- a/pkg/ast/route/builder.go +++ b/pkg/ast/route/builder.go @@ -16,40 +16,40 @@ type RenderBuildOpts struct { } func buildRenderData(opts RenderBuildOpts) (RenderData, error) { - rd := RenderData{ - PackageName: opts.PackageName, - ProjectPackage: opts.ProjectPackage, - Imports: []string{}, - Controllers: []string{}, - Routes: make(map[string][]Router), - RouteGroups: []string{}, - } + rd := RenderData{ + PackageName: opts.PackageName, + ProjectPackage: opts.ProjectPackage, + Imports: []string{}, + Controllers: []string{}, + Routes: make(map[string][]Router), + RouteGroups: []string{}, + } - imports := []string{} - controllers := []string{} - // Track if any param uses model lookup, which requires the field package. - needsFieldImport := false + imports := []string{} + controllers := []string{} + // Track if any param uses model lookup, which requires the field package. + needsFieldImport := false for _, route := range opts.Routes { imports = append(imports, route.Imports...) controllers = append(controllers, fmt.Sprintf("%s *%s", strcase.ToLowerCamel(route.Name), route.Name)) - for _, action := range route.Actions { - funcName := fmt.Sprintf("Func%d", len(action.Params)) - if action.HasData { - funcName = "Data" + funcName - } + for _, action := range route.Actions { + funcName := fmt.Sprintf("Func%d", len(action.Params)) + if action.HasData { + funcName = "Data" + funcName + } - params := lo.FilterMap(action.Params, func(item ParamDefinition, _ int) (string, bool) { - tok := buildParamToken(item) - if tok == "" { - return "", false - } - if item.Model != "" { - needsFieldImport = true - } - return tok, true - }) + params := lo.FilterMap(action.Params, func(item ParamDefinition, _ int) (string, bool) { + tok := buildParamToken(item) + if tok == "" { + return "", false + } + if item.Model != "" { + needsFieldImport = true + } + return tok, true + }) rd.Routes[route.Name] = append(rd.Routes[route.Name], Router{ Method: strcase.ToCamel(action.Method), @@ -60,18 +60,18 @@ func buildRenderData(opts RenderBuildOpts) (RenderData, error) { Params: params, }) } - } + } - // Add field import if any model lookups are used - if needsFieldImport { - imports = append(imports, `field "go.ipao.vip/gen/field"`) - } + // Add field import if any model lookups are used + if needsFieldImport { + imports = append(imports, `field "go.ipao.vip/gen/field"`) + } - // de-dup and sort imports/controllers for stable output - rd.Imports = lo.Uniq(imports) - sort.Strings(rd.Imports) - rd.Controllers = lo.Uniq(controllers) - sort.Strings(rd.Controllers) + // de-dup and sort imports/controllers for stable output + rd.Imports = lo.Uniq(imports) + sort.Strings(rd.Imports) + rd.Controllers = lo.Uniq(controllers) + sort.Strings(rd.Controllers) // stable order for route groups and entries for k := range rd.Routes { diff --git a/pkg/ast/route/render.go b/pkg/ast/route/render.go index fc9a0fd..51d398e 100644 --- a/pkg/ast/route/render.go +++ b/pkg/ast/route/render.go @@ -1,23 +1,23 @@ package route import ( - _ "embed" - "os" - "path/filepath" + _ "embed" + "os" + "path/filepath" - "go.ipao.vip/atomctl/v2/pkg/utils/gomod" + "go.ipao.vip/atomctl/v2/pkg/utils/gomod" ) //go:embed router.go.tpl var routeTpl string type RenderData struct { - PackageName string - ProjectPackage string - Imports []string - Controllers []string - Routes map[string][]Router - RouteGroups []string + PackageName string + ProjectPackage string + Imports []string + Controllers []string + Routes map[string][]Router + RouteGroups []string } type Router struct { @@ -30,24 +30,24 @@ type Router struct { } func Render(path string, routes []RouteDefinition) error { - routePath := filepath.Join(path, "routes.gen.go") + routePath := filepath.Join(path, "routes.gen.go") - data, err := buildRenderData(RenderBuildOpts{ - PackageName: filepath.Base(path), - ProjectPackage: gomod.GetModuleName(), - Routes: routes, - }) - if err != nil { - return err - } + data, err := buildRenderData(RenderBuildOpts{ + PackageName: filepath.Base(path), + ProjectPackage: gomod.GetModuleName(), + Routes: routes, + }) + if err != nil { + return err + } - out, err := renderTemplate(data) - if err != nil { - return err - } + out, err := renderTemplate(data) + if err != nil { + return err + } - if err := os.WriteFile(routePath, out, 0o644); err != nil { - return err - } - return nil + if err := os.WriteFile(routePath, out, 0o644); err != nil { + return err + } + return nil } diff --git a/pkg/ast/route/renderer.go b/pkg/ast/route/renderer.go index 0715505..c84c6f0 100644 --- a/pkg/ast/route/renderer.go +++ b/pkg/ast/route/renderer.go @@ -1,23 +1,22 @@ package route import ( - "bytes" - "text/template" + "bytes" + "text/template" - "github.com/Masterminds/sprig/v3" + "github.com/Masterminds/sprig/v3" ) var routerTmpl = template.Must(template.New("route"). - Funcs(sprig.FuncMap()). - Option("missingkey=error"). - Parse(routeTpl), + Funcs(sprig.FuncMap()). + Option("missingkey=error"). + Parse(routeTpl), ) func renderTemplate(data RenderData) ([]byte, error) { - var buf bytes.Buffer - if err := routerTmpl.Execute(&buf, data); err != nil { - return nil, err - } - return buf.Bytes(), nil + var buf bytes.Buffer + if err := routerTmpl.Execute(&buf, data); err != nil { + return nil, err + } + return buf.Bytes(), nil } - diff --git a/pkg/utils/generator/generator.go b/pkg/utils/generator/generator.go index 8466875..ad937e7 100644 --- a/pkg/utils/generator/generator.go +++ b/pkg/utils/generator/generator.go @@ -647,7 +647,7 @@ func parseLinePart(line string) (paramLevel int, trimmed string) { if closes > 0 { paramLevel -= closes } - return + return paramLevel, trimmed } // breakCommentIntoLines takes the comment and since single line comments are already broken into lines diff --git a/pkg/utils/generator/template_funcs.go b/pkg/utils/generator/template_funcs.go index ebacd2b..1221aa7 100644 --- a/pkg/utils/generator/template_funcs.go +++ b/pkg/utils/generator/template_funcs.go @@ -17,7 +17,7 @@ func Stringify(e Enum, forceLower bool) (ret string, err error) { ret = ret + next } } - return + return ret, err } // Mapify returns a map that is all of the indexes for a string value lookup @@ -33,7 +33,7 @@ func Mapify(e Enum) (ret string, err error) { } } ret = ret + `}` - return + return ret, err } // Unmapify returns a map that is all of the indexes for a string value lookup @@ -55,7 +55,7 @@ func Unmapify(e Enum, lowercase bool) (ret string, err error) { } } ret = ret + `}` - return + return ret, err } // Unmapify returns a map that is all of the indexes for a string value lookup @@ -63,25 +63,25 @@ func UnmapifyStringEnum(e Enum, lowercase bool) (ret string, err error) { var builder strings.Builder _, err = builder.WriteString("map[string]" + e.Name + "{\n") if err != nil { - return + return ret, err } for _, val := range e.Values { if val.Name != skipHolder { _, err = builder.WriteString(fmt.Sprintf("%q:%s,\n", val.ValueStr, val.PrefixedName)) if err != nil { - return + return ret, err } if lowercase && strings.ToLower(val.ValueStr) != val.ValueStr { _, err = builder.WriteString(fmt.Sprintf("%q:%s,\n", strings.ToLower(val.ValueStr), val.PrefixedName)) if err != nil { - return + return ret, err } } } } builder.WriteByte('}') ret = builder.String() - return + return ret, err } // Namify returns a slice that is all of the possible names for an enum in a slice @@ -100,7 +100,7 @@ func Namify(e Enum) (ret string, err error) { } } ret = ret + "}" - return + return ret, err } // Namify returns a slice that is all of the possible names for an enum in a slice @@ -112,7 +112,7 @@ func namifyStringEnum(e Enum) (ret string, err error) { } } ret = ret + "}" - return + return ret, err } func Offset(index int, enumType string, val EnumValue) (strResult string) { diff --git a/specs/001-pkg-ast-provider contracts/parser-api.md b/specs/001-pkg-ast-provider contracts/parser-api.md new file mode 100644 index 0000000..710a90b --- /dev/null +++ b/specs/001-pkg-ast-provider contracts/parser-api.md @@ -0,0 +1,394 @@ +# Parser API Contract + +## 概述 + +定义 pkg/ast/provider 包的解析器 API 契约,确保向后兼容性和一致的接口设计。 + +## 核心接口 + +### Parser Interface +```go +// Parser 定义 provider 解析器接口 +type Parser interface { + // ParseFile 解析单个 Go 文件 + ParseFile(filename string) ([]*Provider, error) + + // ParseDir 解析目录中的所有 Go 文件 + ParseDir(dirname string) ([]*Provider, error) + + // ParseString 解析字符串内容 + ParseString(content string, filename string) ([]*Provider, error) + + // SetConfig 设置解析器配置 + SetConfig(config ParserConfig) error + + // GetConfig 获取当前配置 + GetConfig() ParserConfig +} +``` + +### Validator Interface +```go +// Validator 定义 provider 验证器接口 +type Validator interface { + // Validate 验证 provider 配置 + Validate(p *Provider) []error + + // ValidateComment 验证 provider 注释 + ValidateComment(comment *ProviderComment) []error +} +``` + +### Renderer Interface +```go +// Renderer 定义 provider 渲染器接口 +type Renderer interface { + // Render 渲染 provider 代码 + Render(filename string, providers []*Provider) error + + // RenderTemplate 使用自定义模板渲染 + RenderTemplate(filename string, providers []*Provider, template string) error +} +``` + +## 数据结构 + +### Provider +```go +// Provider 表示一个依赖注入提供者 +type Provider struct { + StructName string `json:"struct_name"` + ReturnType string `json:"return_type"` + Mode ProviderMode `json:"mode"` + ProviderGroup string `json:"provider_group"` + GrpcRegisterFunc string `json:"grpc_register_func"` + NeedPrepareFunc bool `json:"need_prepare_func"` + InjectParams map[string]InjectParam `json:"inject_params"` + Imports map[string]string `json:"imports"` + PkgName string `json:"pkg_name"` + ProviderFile string `json:"provider_file"` + SourceLocation SourceLocation `json:"source_location"` +} +``` + +### ProviderComment +```go +// ProviderComment 表示解析的 provider 注释 +type ProviderComment struct { + RawText string `json:"raw_text"` + Mode ProviderMode `json:"mode"` + Injection InjectionMode `json:"injection"` + ReturnType string `json:"return_type"` + Group string `json:"group"` + IsValid bool `json:"is_valid"` + Errors []string `json:"errors"` + Location SourceLocation `json:"location"` +} +``` + +### InjectParam +```go +// InjectParam 表示注入参数 +type InjectParam struct { + Star string `json:"star"` + Type string `json:"type"` + Package string `json:"package"` + PackageAlias string `json:"package_alias"` + FieldName string `json:"field_name"` + InjectTag string `json:"inject_tag"` +} +``` + +## 枚举类型 + +### ProviderMode +```go +// ProviderMode 定义 provider 模式 +type ProviderMode string + +const ( + ModeDefault ProviderMode = "" + ModeGRPC ProviderMode = "grpc" + ModeEvent ProviderMode = "event" + ModeJob ProviderMode = "job" + ModeCronJob ProviderMode = "cronjob" + ModeModel ProviderMode = "model" +) +``` + +### InjectionMode +```go +// InjectionMode 定义注入模式 +type InjectionMode string + +const ( + InjectionDefault InjectionMode = "" + InjectionOnly InjectionMode = "only" + InjectionExcept InjectionMode = "except" +) +``` + +## 配置结构 + +### ParserConfig +```go +// ParserConfig 定义解析器配置 +type ParserConfig struct { + StrictMode bool `json:"strict_mode"` // 严格模式 + AllowTestFile bool `json:"allow_test_file"` // 是否允许解析测试文件 + IgnorePattern string `json:"ignore_pattern"` // 忽略文件模式 + MaxFileSize int64 `json:"max_file_size"` // 最大文件大小 + EnableCache bool `json:"enable_cache"` // 启用缓存 + CacheTTL int `json:"cache_ttl"` // 缓存 TTL(秒) +} +``` + +### DefaultConfig +```go +// DefaultConfig 返回默认配置 +func DefaultConfig() ParserConfig { + return ParserConfig{ + StrictMode: false, + AllowTestFile: false, + IgnorePattern: "*.gen.go", + MaxFileSize: 1024 * 1024, // 1MB + EnableCache: false, + CacheTTL: 300, // 5 分钟 + } +} +``` + +## 工厂函数 + +### NewParser +```go +// NewParser 创建新的解析器实例 +func NewParser(config ParserConfig) (Parser, error) + +// NewParserWithContext 创建带上下文的解析器 +func NewParserWithContext(ctx context.Context, config ParserConfig) (Parser, error) + +// NewDefaultParser 创建默认配置的解析器 +func NewDefaultParser() (Parser, error) +``` + +### NewValidator +```go +// NewValidator 创建新的验证器实例 +func NewValidator() (Validator, error) + +// NewCustomValidator 创建自定义验证器 +func NewCustomValidator(rules []ValidationRule) (Validator, error) +``` + +### NewRenderer +```go +// NewRenderer 创建新的渲染器实例 +func NewRenderer() (Renderer, error) + +// NewRendererWithTemplate 创建带自定义模板的渲染器 +func NewRendererWithTemplate(template string) (Renderer, error) +``` + +## 错误处理 + +### Error Types +```go +// ParseError 解析错误 +type ParseError struct { + File string `json:"file"` + Line int `json:"line"` + Column int `json:"column"` + Message string `json:"message"` +} + +// ValidationError 验证错误 +type ValidationError struct { + Field string `json:"field"` + Value string `json:"value"` + Message string `json:"message"` +} + +// GenerationError 生成错误 +type GenerationError struct { + File string `json:"file"` + Message string `json:"message"` +} + +// ConfigurationError 配置错误 +type ConfigurationError struct { + Field string `json:"field"` + Message string `json:"message"` +} +``` + +### Error Interface +```go +// ProviderError 定义 provider 相关错误接口 +type ProviderError interface { + error + Code() string + Details() map[string]interface{} + IsRetryable() bool +} +``` + +## 兼容性保证 + +### 向后兼容 +- 所有公共 API 保持向后兼容 +- 结构体字段只能添加,不能删除或重命名 +- 接口方法只能添加,不能删除或修改签名 + +### 废弃策略 +- 废弃的 API 会标记为 deprecated +- 废弃的 API 会在至少 2 个主版本后移除 +- 提供迁移指南和工具 + +## 版本控制 + +### 版本格式 +- 遵循语义化版本控制 (SemVer) +- 主版本:不兼容的 API 变更 +- 次版本:向下兼容的功能性新增 +- 修订号:向下兼容的问题修正 + +### 版本检查 +```go +// Version 返回当前版本 +func Version() string + +// Compatible 检查版本兼容性 +func Compatible(version string) bool +``` + +## 性能要求 + +### 解析性能 +- 单个文件解析时间 < 100ms +- 目录解析时间 < 1s(100 个文件以内) +- 内存使用 < 50MB(正常情况) + +### 生成性能 +- 代码生成时间 < 200ms +- 模板渲染时间 < 50ms +- 生成的代码大小合理(< 10KB per provider) + +## 安全考虑 + +### 输入验证 +- 所有输入参数必须验证 +- 文件路径必须安全检查 +- 防止路径遍历攻击 + +### 资源限制 +- 最大文件大小限制 +- 最大递归深度限制 +- 最大并发处理数限制 + +## 扩展点 + +### 自定义模式 +```go +// RegisterProviderMode 注册自定义 provider 模式 +func RegisterProviderMode(mode string, handler ProviderModeHandler) error + +// ProviderModeHandler provider 模式处理器接口 +type ProviderModeHandler interface { + Parse(comment string) (*ProviderComment, error) + Validate(comment *ProviderComment) error + Generate(comment *ProviderComment) (*Provider, error) +} +``` + +### 自定义验证器 +```go +// ValidationRule 验证规则接口 +type ValidationRule interface { + Name() string + Validate(p *Provider) error +} + +// RegisterValidationRule 注册验证规则 +func RegisterValidationRule(rule ValidationRule) error +``` + +### 自定义渲染器 +```go +// TemplateFunction 模板函数类型 +type TemplateFunction func(args ...interface{}) (interface{}, error) + +// RegisterTemplateFunction 注册模板函数 +func RegisterTemplateFunction(name string, fn TemplateFunction) error +``` + +## 测试契约 + +### 单元测试 +```go +// TestParserContract 测试解析器契约 +func TestParserContract(t *testing.T) { + parser := NewDefaultParser() + + // 测试基本功能 + providers, err := parser.ParseFile("testdata/simple.go") + assert.NoError(t, err) + assert.NotEmpty(t, providers) + + // 测试错误处理 + _, err = parser.ParseFile("testdata/invalid.go") + assert.Error(t, err) +} +``` + +### 集成测试 +```go +// TestFullWorkflow 测试完整工作流 +func TestFullWorkflow(t *testing.T) { + // 解析 -> 验证 -> 生成 -> 验证结果 +} +``` + +### 基准测试 +```go +// BenchmarkParser 基准测试 +func BenchmarkParser(b *testing.B) { + parser := NewDefaultParser() + + for i := 0; i < b.N; i++ { + _, err := parser.ParseFile("testdata/large.go") + if err != nil { + b.Fatal(err) + } + } +} +``` + +## 监控和日志 + +### 日志接口 +```go +// Logger 定义日志接口 +type Logger interface { + Debug(msg string, fields ...interface{}) + Info(msg string, fields ...interface{}) + Warn(msg string, fields ...interface{}) + Error(msg string, fields ...interface{}) +} + +// SetLogger 设置日志器 +func SetLogger(logger Logger) +``` + +### 指标接口 +```go +// Metrics 定义指标接口 +type Metrics interface { + Counter(name string, value int64, tags ...string) + Timer(name string, value time.Duration, tags ...string) + Gauge(name string, value float64, tags ...string) +} + +// SetMetrics 设置指标收集器 +func SetMetrics(metrics Metrics) +``` \ No newline at end of file diff --git a/specs/001-pkg-ast-provider/contracts/parser-api.md b/specs/001-pkg-ast-provider/contracts/parser-api.md new file mode 100644 index 0000000..710a90b --- /dev/null +++ b/specs/001-pkg-ast-provider/contracts/parser-api.md @@ -0,0 +1,394 @@ +# Parser API Contract + +## 概述 + +定义 pkg/ast/provider 包的解析器 API 契约,确保向后兼容性和一致的接口设计。 + +## 核心接口 + +### Parser Interface +```go +// Parser 定义 provider 解析器接口 +type Parser interface { + // ParseFile 解析单个 Go 文件 + ParseFile(filename string) ([]*Provider, error) + + // ParseDir 解析目录中的所有 Go 文件 + ParseDir(dirname string) ([]*Provider, error) + + // ParseString 解析字符串内容 + ParseString(content string, filename string) ([]*Provider, error) + + // SetConfig 设置解析器配置 + SetConfig(config ParserConfig) error + + // GetConfig 获取当前配置 + GetConfig() ParserConfig +} +``` + +### Validator Interface +```go +// Validator 定义 provider 验证器接口 +type Validator interface { + // Validate 验证 provider 配置 + Validate(p *Provider) []error + + // ValidateComment 验证 provider 注释 + ValidateComment(comment *ProviderComment) []error +} +``` + +### Renderer Interface +```go +// Renderer 定义 provider 渲染器接口 +type Renderer interface { + // Render 渲染 provider 代码 + Render(filename string, providers []*Provider) error + + // RenderTemplate 使用自定义模板渲染 + RenderTemplate(filename string, providers []*Provider, template string) error +} +``` + +## 数据结构 + +### Provider +```go +// Provider 表示一个依赖注入提供者 +type Provider struct { + StructName string `json:"struct_name"` + ReturnType string `json:"return_type"` + Mode ProviderMode `json:"mode"` + ProviderGroup string `json:"provider_group"` + GrpcRegisterFunc string `json:"grpc_register_func"` + NeedPrepareFunc bool `json:"need_prepare_func"` + InjectParams map[string]InjectParam `json:"inject_params"` + Imports map[string]string `json:"imports"` + PkgName string `json:"pkg_name"` + ProviderFile string `json:"provider_file"` + SourceLocation SourceLocation `json:"source_location"` +} +``` + +### ProviderComment +```go +// ProviderComment 表示解析的 provider 注释 +type ProviderComment struct { + RawText string `json:"raw_text"` + Mode ProviderMode `json:"mode"` + Injection InjectionMode `json:"injection"` + ReturnType string `json:"return_type"` + Group string `json:"group"` + IsValid bool `json:"is_valid"` + Errors []string `json:"errors"` + Location SourceLocation `json:"location"` +} +``` + +### InjectParam +```go +// InjectParam 表示注入参数 +type InjectParam struct { + Star string `json:"star"` + Type string `json:"type"` + Package string `json:"package"` + PackageAlias string `json:"package_alias"` + FieldName string `json:"field_name"` + InjectTag string `json:"inject_tag"` +} +``` + +## 枚举类型 + +### ProviderMode +```go +// ProviderMode 定义 provider 模式 +type ProviderMode string + +const ( + ModeDefault ProviderMode = "" + ModeGRPC ProviderMode = "grpc" + ModeEvent ProviderMode = "event" + ModeJob ProviderMode = "job" + ModeCronJob ProviderMode = "cronjob" + ModeModel ProviderMode = "model" +) +``` + +### InjectionMode +```go +// InjectionMode 定义注入模式 +type InjectionMode string + +const ( + InjectionDefault InjectionMode = "" + InjectionOnly InjectionMode = "only" + InjectionExcept InjectionMode = "except" +) +``` + +## 配置结构 + +### ParserConfig +```go +// ParserConfig 定义解析器配置 +type ParserConfig struct { + StrictMode bool `json:"strict_mode"` // 严格模式 + AllowTestFile bool `json:"allow_test_file"` // 是否允许解析测试文件 + IgnorePattern string `json:"ignore_pattern"` // 忽略文件模式 + MaxFileSize int64 `json:"max_file_size"` // 最大文件大小 + EnableCache bool `json:"enable_cache"` // 启用缓存 + CacheTTL int `json:"cache_ttl"` // 缓存 TTL(秒) +} +``` + +### DefaultConfig +```go +// DefaultConfig 返回默认配置 +func DefaultConfig() ParserConfig { + return ParserConfig{ + StrictMode: false, + AllowTestFile: false, + IgnorePattern: "*.gen.go", + MaxFileSize: 1024 * 1024, // 1MB + EnableCache: false, + CacheTTL: 300, // 5 分钟 + } +} +``` + +## 工厂函数 + +### NewParser +```go +// NewParser 创建新的解析器实例 +func NewParser(config ParserConfig) (Parser, error) + +// NewParserWithContext 创建带上下文的解析器 +func NewParserWithContext(ctx context.Context, config ParserConfig) (Parser, error) + +// NewDefaultParser 创建默认配置的解析器 +func NewDefaultParser() (Parser, error) +``` + +### NewValidator +```go +// NewValidator 创建新的验证器实例 +func NewValidator() (Validator, error) + +// NewCustomValidator 创建自定义验证器 +func NewCustomValidator(rules []ValidationRule) (Validator, error) +``` + +### NewRenderer +```go +// NewRenderer 创建新的渲染器实例 +func NewRenderer() (Renderer, error) + +// NewRendererWithTemplate 创建带自定义模板的渲染器 +func NewRendererWithTemplate(template string) (Renderer, error) +``` + +## 错误处理 + +### Error Types +```go +// ParseError 解析错误 +type ParseError struct { + File string `json:"file"` + Line int `json:"line"` + Column int `json:"column"` + Message string `json:"message"` +} + +// ValidationError 验证错误 +type ValidationError struct { + Field string `json:"field"` + Value string `json:"value"` + Message string `json:"message"` +} + +// GenerationError 生成错误 +type GenerationError struct { + File string `json:"file"` + Message string `json:"message"` +} + +// ConfigurationError 配置错误 +type ConfigurationError struct { + Field string `json:"field"` + Message string `json:"message"` +} +``` + +### Error Interface +```go +// ProviderError 定义 provider 相关错误接口 +type ProviderError interface { + error + Code() string + Details() map[string]interface{} + IsRetryable() bool +} +``` + +## 兼容性保证 + +### 向后兼容 +- 所有公共 API 保持向后兼容 +- 结构体字段只能添加,不能删除或重命名 +- 接口方法只能添加,不能删除或修改签名 + +### 废弃策略 +- 废弃的 API 会标记为 deprecated +- 废弃的 API 会在至少 2 个主版本后移除 +- 提供迁移指南和工具 + +## 版本控制 + +### 版本格式 +- 遵循语义化版本控制 (SemVer) +- 主版本:不兼容的 API 变更 +- 次版本:向下兼容的功能性新增 +- 修订号:向下兼容的问题修正 + +### 版本检查 +```go +// Version 返回当前版本 +func Version() string + +// Compatible 检查版本兼容性 +func Compatible(version string) bool +``` + +## 性能要求 + +### 解析性能 +- 单个文件解析时间 < 100ms +- 目录解析时间 < 1s(100 个文件以内) +- 内存使用 < 50MB(正常情况) + +### 生成性能 +- 代码生成时间 < 200ms +- 模板渲染时间 < 50ms +- 生成的代码大小合理(< 10KB per provider) + +## 安全考虑 + +### 输入验证 +- 所有输入参数必须验证 +- 文件路径必须安全检查 +- 防止路径遍历攻击 + +### 资源限制 +- 最大文件大小限制 +- 最大递归深度限制 +- 最大并发处理数限制 + +## 扩展点 + +### 自定义模式 +```go +// RegisterProviderMode 注册自定义 provider 模式 +func RegisterProviderMode(mode string, handler ProviderModeHandler) error + +// ProviderModeHandler provider 模式处理器接口 +type ProviderModeHandler interface { + Parse(comment string) (*ProviderComment, error) + Validate(comment *ProviderComment) error + Generate(comment *ProviderComment) (*Provider, error) +} +``` + +### 自定义验证器 +```go +// ValidationRule 验证规则接口 +type ValidationRule interface { + Name() string + Validate(p *Provider) error +} + +// RegisterValidationRule 注册验证规则 +func RegisterValidationRule(rule ValidationRule) error +``` + +### 自定义渲染器 +```go +// TemplateFunction 模板函数类型 +type TemplateFunction func(args ...interface{}) (interface{}, error) + +// RegisterTemplateFunction 注册模板函数 +func RegisterTemplateFunction(name string, fn TemplateFunction) error +``` + +## 测试契约 + +### 单元测试 +```go +// TestParserContract 测试解析器契约 +func TestParserContract(t *testing.T) { + parser := NewDefaultParser() + + // 测试基本功能 + providers, err := parser.ParseFile("testdata/simple.go") + assert.NoError(t, err) + assert.NotEmpty(t, providers) + + // 测试错误处理 + _, err = parser.ParseFile("testdata/invalid.go") + assert.Error(t, err) +} +``` + +### 集成测试 +```go +// TestFullWorkflow 测试完整工作流 +func TestFullWorkflow(t *testing.T) { + // 解析 -> 验证 -> 生成 -> 验证结果 +} +``` + +### 基准测试 +```go +// BenchmarkParser 基准测试 +func BenchmarkParser(b *testing.B) { + parser := NewDefaultParser() + + for i := 0; i < b.N; i++ { + _, err := parser.ParseFile("testdata/large.go") + if err != nil { + b.Fatal(err) + } + } +} +``` + +## 监控和日志 + +### 日志接口 +```go +// Logger 定义日志接口 +type Logger interface { + Debug(msg string, fields ...interface{}) + Info(msg string, fields ...interface{}) + Warn(msg string, fields ...interface{}) + Error(msg string, fields ...interface{}) +} + +// SetLogger 设置日志器 +func SetLogger(logger Logger) +``` + +### 指标接口 +```go +// Metrics 定义指标接口 +type Metrics interface { + Counter(name string, value int64, tags ...string) + Timer(name string, value time.Duration, tags ...string) + Gauge(name string, value float64, tags ...string) +} + +// SetMetrics 设置指标收集器 +func SetMetrics(metrics Metrics) +``` \ No newline at end of file diff --git a/specs/001-pkg-ast-provider/contracts/validation-rules.md b/specs/001-pkg-ast-provider/contracts/validation-rules.md new file mode 100644 index 0000000..4d9e373 --- /dev/null +++ b/specs/001-pkg-ast-provider/contracts/validation-rules.md @@ -0,0 +1,657 @@ +# Validation Rules Contract + +## 概述 + +定义 pkg/ast/provider 包的验证规则契约,确保 provider 配置的正确性和一致性。 + +## 验证规则分类 + +### 1. 结构验证规则 + +#### 命名规则 +```go +// StructNameRule 验证结构体名称 +type StructNameRule struct{} + +func (r *StructNameRule) Name() string { + return "struct_name" +} + +func (r *StructNameRule) Validate(p *Provider) error { + if p.StructName == "" { + return &ValidationError{ + Field: "StructName", + Value: "", + Message: "struct name cannot be empty", + } + } + + if !isValidGoIdentifier(p.StructName) { + return &ValidationError{ + Field: "StructName", + Value: p.StructName, + Message: "struct name must be a valid Go identifier", + } + } + + if !isExported(p.StructName) { + return &ValidationError{ + Field: "StructName", + Value: p.StructName, + Message: "struct name must be exported (start with uppercase letter)", + } + } + + return nil +} +``` + +#### 包名规则 +```go +// PackageNameRule 验证包名 +type PackageNameRule struct{} + +func (r *PackageNameRule) Name() string { + return "package_name" +} + +func (r *PackageNameRule) Validate(p *Provider) error { + if p.PkgName == "" { + return &ValidationError{ + Field: "PkgName", + Value: "", + Message: "package name cannot be empty", + } + } + + if !isValidGoIdentifier(p.PkgName) { + return &ValidationError{ + Field: "PkgName", + Value: p.PkgName, + Message: "package name must be a valid Go identifier", + } + } + + return nil +} +``` + +### 2. 类型验证规则 + +#### 返回类型规则 +```go +// ReturnTypeRule 验证返回类型 +type ReturnTypeRule struct{} + +func (r *ReturnTypeRule) Name() string { + return "return_type" +} + +func (r *ReturnTypeRule) Validate(p *Provider) error { + if p.ReturnType == "" { + return &ValidationError{ + Field: "ReturnType", + Value: "", + Message: "return type cannot be empty", + } + } + + if !isValidGoType(p.ReturnType) { + return &ValidationError{ + Field: "ReturnType", + Value: p.ReturnType, + Message: "return type must be a valid Go type", + } + } + + return nil +} +``` + +#### 模式验证规则 +```go +// ProviderModeRule 验证 provider 模式 +type ProviderModeRule struct { + validModes map[ProviderMode]bool +} + +func NewProviderModeRule() *ProviderModeRule { + return &ProviderModeRule{ + validModes: map[ProviderMode]bool{ + ModeDefault: true, + ModeGRPC: true, + ModeEvent: true, + ModeJob: true, + ModeCronJob: true, + ModeModel: true, + }, + } +} + +func (r *ProviderModeRule) Name() string { + return "provider_mode" +} + +func (r *ProviderModeRule) Validate(p *Provider) error { + if !r.validModes[p.Mode] { + return &ValidationError{ + Field: "Mode", + Value: string(p.Mode), + Message: fmt.Sprintf("invalid provider mode: %s", p.Mode), + } + } + + return nil +} +``` + +### 3. 依赖注入验证规则 + +#### 注入参数规则 +```go +// InjectParamsRule 验证注入参数 +type InjectParamsRule struct{} + +func (r *InjectParamsRule) Name() string { + return "inject_params" +} + +func (r *InjectParamsRule) Validate(p *Provider) error { + // 验证注入参数不为标量类型 + for name, param := range p.InjectParams { + if isScalarType(param.Type) { + return &ValidationError{ + Field: fmt.Sprintf("InjectParams.%s", name), + Value: param.Type, + Message: fmt.Sprintf("scalar type '%s' cannot be injected", param.Type), + } + } + } + + // 验证 only 模式下的 inject:true 标签 + if p.InjectionMode == InjectionOnly { + for name, param := range p.InjectParams { + if param.InjectTag != "true" { + return &ValidationError{ + Field: fmt.Sprintf("InjectParams.%s", name), + Value: param.InjectTag, + Message: "all fields must have inject:\"true\" tag in 'only' mode", + } + } + } + } + + // 验证 except 模式下的 inject:false 标签 + if p.InjectionMode == InjectionExcept { + for name, param := range p.InjectParams { + if param.InjectTag == "false" { + return &ValidationError{ + Field: fmt.Sprintf("InjectParams.%s", name), + Value: param.InjectTag, + Message: "fields with inject:\"false\" tag should not be included in 'except' mode", + } + } + } + } + + return nil +} +``` + +#### 包别名规则 +```go +// PackageAliasRule 验证包别名 +type PackageAliasRule struct{} + +func (r *PackageAliasRule) Name() string { + return "package_alias" +} + +func (r *PackageAliasRule) Validate(p *Provider) error { + // 收集所有包别名 + aliases := make(map[string]string) + for alias, pkg := range p.Imports { + if existing, exists := aliases[alias]; exists && existing != pkg { + return &ValidationError{ + Field: "Imports", + Value: alias, + Message: fmt.Sprintf("duplicate package alias '%s' for different packages", alias), + } + } + aliases[alias] = pkg + } + + // 验证注入参数的包别名 + for name, param := range p.InjectParams { + if param.PackageAlias != "" { + if _, exists := aliases[param.PackageAlias]; !exists { + return &ValidationError{ + Field: fmt.Sprintf("InjectParams.%s", name), + Value: param.PackageAlias, + Message: fmt.Sprintf("undefined package alias '%s'", param.PackageAlias), + } + } + } + } + + return nil +} +``` + +### 4. 模式特定验证规则 + +#### gRPC 模式规则 +```go +// GRPCModeRule 验证 gRPC 模式特定规则 +type GRPCModeRule struct{} + +func (r *GRPCModeRule) Name() string { + return "grpc_mode" +} + +func (r *GRPCModeRule) Validate(p *Provider) error { + if p.Mode != ModeGRPC { + return nil + } + + // 验证 gRPC 注册函数 + if p.GrpcRegisterFunc == "" { + return &ValidationError{ + Field: "GrpcRegisterFunc", + Value: "", + Message: "gRPC register function cannot be empty in gRPC mode", + } + } + + // 验证返回类型 + if p.ReturnType != "contracts.Initial" { + return &ValidationError{ + Field: "ReturnType", + Value: p.ReturnType, + Message: "return type must be 'contracts.Initial' in gRPC mode", + } + } + + // 验证 provider 组 + if p.ProviderGroup != "atom.GroupInitial" { + return &ValidationError{ + Field: "ProviderGroup", + Value: p.ProviderGroup, + Message: "provider group must be 'atom.GroupInitial' in gRPC mode", + } + } + + // 验证必须包含 __grpc 注入参数 + if _, exists := p.InjectParams["__grpc"]; !exists { + return &ValidationError{ + Field: "InjectParams", + Value: "", + Message: "gRPC provider must include '__grpc' injection parameter", + } + } + + return nil +} +``` + +#### Event 模式规则 +```go +// EventModeRule 验证 Event 模式特定规则 +type EventModeRule struct{} + +func (r *EventModeRule) Name() string { + return "event_mode" +} + +func (r *EventModeRule) Validate(p *Provider) error { + if p.Mode != ModeEvent { + return nil + } + + // 验证返回类型 + if p.ReturnType != "contracts.Initial" { + return &ValidationError{ + Field: "ReturnType", + Value: p.ReturnType, + Message: "return type must be 'contracts.Initial' in event mode", + } + } + + // 验证 provider 组 + if p.ProviderGroup != "atom.GroupInitial" { + return &ValidationError{ + Field: "ProviderGroup", + Value: p.ProviderGroup, + Message: "provider group must be 'atom.GroupInitial' in event mode", + } + } + + // 验证必须包含 __event 注入参数 + if _, exists := p.InjectParams["__event"]; !exists { + return &ValidationError{ + Field: "InjectParams", + Value: "", + Message: "event provider must include '__event' injection parameter", + } + } + + return nil +} +``` + +#### Job 模式规则 +```go +// JobModeRule 验证 Job 模式特定规则 +type JobModeRule struct{} + +func (r *JobModeRule) Name() string { + return "job_mode" +} + +func (r *JobModeRule) Validate(p *Provider) error { + if p.Mode != ModeJob && p.Mode != ModeCronJob { + return nil + } + + // 验证返回类型 + if p.ReturnType != "contracts.Initial" { + return &ValidationError{ + Field: "ReturnType", + Value: p.ReturnType, + Message: "return type must be 'contracts.Initial' in job mode", + } + } + + // 验证 provider 组 + if p.ProviderGroup != "atom.GroupInitial" { + return &ValidationError{ + Field: "ProviderGroup", + Value: p.ProviderGroup, + Message: "provider group must be 'atom.GroupInitial' in job mode", + } + } + + // 验证必须包含 __job 注入参数 + if _, exists := p.InjectParams["__job"]; !exists { + return &ValidationError{ + Field: "InjectParams", + Value: "", + Message: "job provider must include '__job' injection parameter", + } + } + + return nil +} +``` + +#### Model 模式规则 +```go +// ModelModeRule 验证 Model 模式特定规则 +type ModelModeRule struct{} + +func (r *ModelModeRule) Name() string { + return "model_mode" +} + +func (r *ModelModeRule) Validate(p *Provider) error { + if p.Mode != ModeModel { + return nil + } + + // 验证返回类型 + if p.ReturnType != "contracts.Initial" { + return &ValidationError{ + Field: "ReturnType", + Value: p.ReturnType, + Message: "return type must be 'contracts.Initial' in model mode", + } + } + + // 验证 provider 组 + if p.ProviderGroup != "atom.GroupInitial" { + return &ValidationError{ + Field: "ProviderGroup", + Value: p.ProviderGroup, + Message: "provider group must be 'atom.GroupInitial' in model mode", + } + } + + // 验证必须设置 NeedPrepareFunc + if !p.NeedPrepareFunc { + return &ValidationError{ + Field: "NeedPrepareFunc", + Value: "false", + Message: "model provider must set NeedPrepareFunc to true", + } + } + + return nil +} +``` + +### 5. 注释验证规则 + +#### 注释格式规则 +```go +// CommentFormatRule 验证注释格式 +type CommentFormatRule struct{} + +func (r *CommentFormatRule) Name() string { + return "comment_format" +} + +func (r *CommentFormatRule) Validate(p *Provider) error { + if p.Comment == nil { + return nil + } + + comment := p.Comment + + // 验证注释必须以 @provider 开头 + if !strings.HasPrefix(comment.RawText, "@provider") { + return &ValidationError{ + Field: "Comment.RawText", + Value: comment.RawText, + Message: "provider comment must start with '@provider'", + } + } + + // 验证模式格式 + if comment.Mode != "" && !isValidProviderMode(comment.Mode) { + return &ValidationError{ + Field: "Comment.Mode", + Value: string(comment.Mode), + Message: fmt.Sprintf("invalid provider mode: %s", comment.Mode), + } + } + + // 验证注入模式 + if comment.Injection != "" && !isValidInjectionMode(comment.Injection) { + return &ValidationError{ + Field: "Comment.Injection", + Value: string(comment.Injection), + Message: fmt.Sprintf("invalid injection mode: %s", comment.Injection), + } + } + + return nil +} +``` + +### 6. 文件路径验证规则 + +#### 文件路径规则 +```go +// FilePathRule 验证文件路径 +type FilePathRule struct{} + +func (r *FilePathRule) Name() string { + return "file_path" +} + +func (r *FilePathRule) Validate(p *Provider) error { + if p.ProviderFile == "" { + return &ValidationError{ + Field: "ProviderFile", + Value: "", + Message: "provider file path cannot be empty", + } + } + + // 验证文件扩展名 + if !strings.HasSuffix(p.ProviderFile, ".go") { + return &ValidationError{ + Field: "ProviderFile", + Value: p.ProviderFile, + Message: "provider file must have .go extension", + } + } + + // 验证文件路径安全性 + if !isSafeFilePath(p.ProviderFile) { + return &ValidationError{ + Field: "ProviderFile", + Value: p.ProviderFile, + Message: "provider file path is not safe", + } + } + + return nil +} +``` + +## 验证器注册 + +### 默认验证器 +```go +// DefaultValidator 返回默认验证器 +func DefaultValidator() Validator { + validator := NewCompositeValidator() + + // 注册所有验证规则 + validator.AddRule(&StructNameRule{}) + validator.AddRule(&PackageNameRule{}) + validator.AddRule(&ReturnTypeRule{}) + validator.AddRule(NewProviderModeRule()) + validator.AddRule(&InjectParamsRule{}) + validator.AddRule(&PackageAliasRule{}) + validator.AddRule(&GRPCModeRule{}) + validator.AddRule(&EventModeRule{}) + validator.AddRule(&JobModeRule{}) + validator.AddRule(&ModelModeRule{}) + validator.AddRule(&CommentFormatRule{}) + validator.AddRule(&FilePathRule{}) + + return validator +} +``` + +### 自定义验证器 +```go +// CompositeValidator 组合验证器 +type CompositeValidator struct { + rules []ValidationRule +} + +func NewCompositeValidator() *CompositeValidator { + return &CompositeValidator{ + rules: make([]ValidationRule, 0), + } +} + +func (v *CompositeValidator) AddRule(rule ValidationRule) { + v.rules = append(v.rules, rule) +} + +func (v *CompositeValidator) Validate(p *Provider) []error { + var errors []error + + for _, rule := range v.rules { + if err := rule.Validate(p); err != nil { + errors = append(errors, err) + } + } + + return errors +} +``` + +## 验证结果 + +### 验证报告 +```go +// ValidationReport 验证报告 +type ValidationReport struct { + Provider *Provider `json:"provider"` + Errors []error `json:"errors"` + Warnings []error `json:"warnings"` + IsValid bool `json:"is_valid"` + Timestamp time.Time `json:"timestamp"` +} + +// ValidateProvider 验证 provider 并生成报告 +func ValidateProvider(p *Provider) *ValidationReport { + report := &ValidationReport{ + Provider: p, + Timestamp: time.Now(), + } + + validator := DefaultValidator() + errors := validator.Validate(p) + + if len(errors) == 0 { + report.IsValid = true + } else { + report.Errors = errors + report.IsValid = false + } + + return report +} +``` + +## 性能考虑 + +### 验证性能优化 +- 快速失败:遇到第一个错误立即返回 +- 缓存验证结果 +- 并行验证独立规则 +- 懒加载验证器 + +### 内存优化 +- 重用验证器实例 +- 避免重复分配错误对象 +- 使用对象池管理验证器 + +## 扩展性 + +### 自定义规则注册 +```go +// RegisterValidationRule 注册全局验证规则 +func RegisterValidationRule(name string, rule ValidationRule) error { + // 实现规则注册逻辑 +} + +// GetValidationRule 获取验证规则 +func GetValidationRule(name string) (ValidationRule, bool) { + // 实现规则获取逻辑 +} +``` + +### 规则优先级 +```go +// PrioritizedRule 优先级规则 +type PrioritizedRule struct { + Rule ValidationRule + Priority int +} + +// PriorityValidator 优先级验证器 +type PriorityValidator struct { + rules []PrioritizedRule +} + +func (v *PriorityValidator) Validate(p *Provider) []error { + // 按优先级执行验证 +} +``` \ No newline at end of file diff --git a/specs/001-pkg-ast-provider/data-model.md b/specs/001-pkg-ast-provider/data-model.md new file mode 100644 index 0000000..cb2355c --- /dev/null +++ b/specs/001-pkg-ast-provider/data-model.md @@ -0,0 +1,209 @@ +# Data Model Design + +## Core Entities + +### 1. Provider +代表一个依赖注入提供者的核心实体 + +```go +type Provider struct { + StructName string // 结构体名称 + ReturnType string // 返回类型 + Mode ProviderMode // 提供者模式 + ProviderGroup string // 提供者分组 + GrpcRegisterFunc string // gRPC 注册函数 + NeedPrepareFunc bool // 是否需要 Prepare 函数 + InjectParams map[string]InjectParam // 注入参数 + Imports map[string]string // 导入包 + PkgName string // 包名 + ProviderFile string // 生成文件路径 + SourceLocation SourceLocation // 源码位置 +} +``` + +### 2. ProviderMode +提供者模式枚举 + +```go +type ProviderMode string + +const ( + ModeDefault ProviderMode = "" + ModeGRPC ProviderMode = "grpc" + ModeEvent ProviderMode = "event" + ModeJob ProviderMode = "job" + ModeCronJob ProviderMode = "cronjob" + ModeModel ProviderMode = "model" +) +``` + +### 3. InjectParam +注入参数描述 + +```go +type InjectParam struct { + Star string // 指针标记 (*) + Type string // 类型名称 + Package string // 包路径 + PackageAlias string // 包别名 + FieldName string // 字段名称 + InjectTag string // inject 标签值 +} +``` + +### 4. ProviderComment +Provider 注释描述 + +```go +type ProviderComment struct { + RawText string // 原始注释文本 + Mode ProviderMode // 模式 + Injection InjectionMode // 注入模式 + ReturnType string // 返回类型 + Group string // 分组 + IsValid bool // 是否有效 + Errors []string // 解析错误 +} +``` + +### 5. InjectionMode +注入模式枚举 + +```go +type InjectionMode string + +const ( + InjectionDefault InjectionMode = "" + InjectionOnly InjectionMode = "only" + InjectionExcept InjectionMode = "except" +) +``` + +### 6. SourceLocation +源码位置信息 + +```go +type SourceLocation struct { + File string // 文件路径 + Line int // 行号 + Column int // 列号 + StartPos int // 开始位置 + EndPos int // 结束位置 +} +``` + +### 7. ParserContext +解析器上下文 + +```go +type ParserContext struct { + FileSet *token.FileSet // 文件集合 + Imports map[string]string // 导入映射 + PkgName string // 包名 + ScalarTypes []string // 标量类型列表 + ErrorHandler func(error) // 错误处理器 + Config ParserConfig // 解析配置 +} +``` + +### 8. ParserConfig +解析器配置 + +```go +type ParserConfig struct { + StrictMode bool // 严格模式 + AllowTestFile bool // 是否允许解析测试文件 + IgnorePattern string // 忽略文件模式 +} +``` + +## Relationships + +### Entity Relationships +``` +Provider (1) -> (0..*) InjectParam +Provider (1) -> (1) ProviderComment +Provider (1) -> (1) SourceLocation +ProviderComment (1) -> (1) ProviderMode +ProviderComment (1) -> (1) InjectionMode +``` + +### Data Flow +``` +SourceFile -> ParserContext -> ProviderComment -> Provider -> GeneratedCode +``` + +## Validation Rules + +### Provider Validation +- 结构体名称必须有效(符合 Go 标识符规则) +- 返回类型不能为空 +- 模式必须为预定义值之一 +- 注入参数不能包含标量类型 +- 包名必须能正确解析 + +### Comment Validation +- 注释必须以 @provider 开头 +- 模式格式必须正确 +- 注入模式只能是 only 或 except +- 返回类型和分组格式必须正确 + +### Import Validation +- 包路径必须有效 +- 包别名不能重复 +- 匿名导入必须正确处理 + +## State Transitions + +### Parser States +``` +Idle -> Parsing -> Validating -> Generating -> Complete + ↓ + Error +``` + +### Provider States +``` +Discovered -> Parsing -> Validated -> Ready -> Generated + ↓ + Invalid +``` + +## Performance Considerations + +### Memory Usage +- 每个文件创建一个 ParserContext +- Provider 对象在解析完成后可以释放 +- 导入映射应该在文件级别共享 + +### Processing Speed +- 并行解析独立文件 +- 缓存常用的标量类型列表 +- 延迟验证直到所有信息收集完成 + +## Error Handling + +### Error Types +- ParseError: 解析错误 +- ValidationError: 验证错误 +- GenerationError: 生成错误 +- ConfigurationError: 配置错误 + +### Error Recovery +- 单个 Provider 错误不影响其他 Provider +- 文件级别错误应该跳过该文件 +- 提供详细的错误位置和建议 + +## Extension Points + +### Custom Provider Modes +- 通过 ProviderMode 接口支持自定义模式 +- 使用注册机制添加新模式处理器 + +### Custom Validation Rules +- 通过 Validator 接口支持自定义验证 +- 支持链式验证器组合 + +### Custom Renderers +- 通过 Renderer 接口支持自定义渲染器 +- 支持多种输出格式 \ No newline at end of file diff --git a/specs/001-pkg-ast-provider/plan.md b/specs/001-pkg-ast-provider/plan.md new file mode 100644 index 0000000..5e927b6 --- /dev/null +++ b/specs/001-pkg-ast-provider/plan.md @@ -0,0 +1,259 @@ + +# Implementation Plan: 优化 pkg/ast/provider 目录的代码组织逻辑与功能实现 + +**Branch**: `001-pkg-ast-provider` | **Date**: 2025-09-19 | **Spec**: `/projects/atomctl/specs/001-pkg-ast-provider/spec.md` +**Input**: Feature specification from `/projects/atomctl/specs/001-pkg-ast-provider/spec.md` + +## Execution Flow (/plan command scope) +``` +1. Load feature spec from Input path + → If not found: ERROR "No feature spec at {path}" +2. Fill Technical Context (scan for NEEDS CLARIFICATION) + → Detect Project Type from context (web=frontend+backend, mobile=app+api) + → Set Structure Decision based on project type +3. Fill the Constitution Check section based on the content of the constitution document. +4. Evaluate Constitution Check section below + → If violations exist: Document in Complexity Tracking + → If no justification possible: ERROR "Simplify approach first" + → Update Progress Tracking: Initial Constitution Check +5. Execute Phase 0 → research.md + → If NEEDS CLARIFICATION remain: ERROR "Resolve unknowns" +6. Execute Phase 1 → contracts, data-model.md, quickstart.md, agent-specific template file (e.g., `CLAUDE.md` for Claude Code, `.github/copilot-instructions.md` for GitHub Copilot, `GEMINI.md` for Gemini CLI, `QWEN.md` for Qwen Code or `AGENTS.md` for opencode). +7. Re-evaluate Constitution Check section + → If new violations: Refactor design, return to Phase 1 + → Update Progress Tracking: Post-Design Constitution Check +8. Plan Phase 2 → Describe task generation approach (DO NOT create tasks.md) +9. STOP - Ready for /tasks command +``` + +**IMPORTANT**: The /plan command STOPS at step 7. Phases 2-4 are executed by other commands: +- Phase 2: /tasks command creates tasks.md +- Phase 3-4: Implementation execution (manual or via tools) + +## Summary +主要需求是优化 pkg/ast/provider 目录的代码组织逻辑与功能实现,补充完善测试用例。当前代码包含两个主要文件:provider.go(337行复杂的解析逻辑)和 render.go(65行渲染逻辑)。需要重构代码以提高可维护性、可测试性,并添加完整的测试覆盖。 + +## Technical Context +**Language/Version**: Go 1.24.0 +**Primary Dependencies**: go/ast, go/parser, go/token, samber/lo, sirupsen/logrus, golang.org/x/tools/imports +**Storage**: N/A (file processing) +**Testing**: standard Go testing package with testify +**Target Platform**: Linux/macOS/Windows (CLI tool) +**Project Type**: Single project (CLI tool) +**Performance Goals**: <5s for large project parsing, <100MB memory usage +**Constraints**: Must maintain backward compatibility with existing @provider annotations +**Scale/Scope**: ~400 lines of existing code to refactor, target 90% test coverage + +## Constitution Check +*GATE: Must pass before Phase 0 research. Re-check after Phase 1 design.* + +### SOLID Principles Compliance +- [ ] **Single Responsibility**: Each component has single, clear responsibility +- [ ] **Open/Closed**: Design allows extension without modification +- [ ] **Liskov Substitution**: Subtypes can replace base types seamlessly +- [ ] **Interface Segregation**: Interfaces are specific and focused +- [ ] **Dependency Inversion**: Depend on abstractions, not concrete implementations + +### KISS Principle Compliance +- [ ] Design avoids unnecessary complexity +- [ ] CLI interface maintains consistency +- [ ] Code generation logic is simple and direct +- [ ] Solutions are intuitive and easy to understand + +### YAGNI Principle Compliance +- [ ] Only implementing clearly needed functionality +- [ ] No over-engineering or future-proofing without requirements +- [ ] Each feature has explicit user需求支撑 +- [ ] No "might be useful" features without justification + +### DRY Principle Compliance +- [ ] No code duplication across components +- [ ] Common functionality is abstracted and reused +- [ ] Template system avoids repetitive implementations +- [ ] Shared utilities are properly abstracted + +### Code Quality Standards +- [ ] **Testing Discipline**: TDD approach with Red-Green-Refactor cycle +- [ ] **CLI Consistency**: Unified parameter formats and output standards +- [ ] **Error Handling**: Complete error information and recovery mechanisms +- [ ] **Performance**: Generation speed and memory usage requirements met + +### Complexity Tracking +| Violation | Why Needed | Simpler Alternative Rejected Because | +|-----------|------------|-------------------------------------| +| [Document any deviations from constitutional principles] | [Justification for complexity] | [Why simpler approach insufficient] | + +## Project Structure + +### Documentation (this feature) +``` +specs/[###-feature]/ +├── plan.md # This file (/plan command output) +├── research.md # Phase 0 output (/plan command) +├── data-model.md # Phase 1 output (/plan command) +├── quickstart.md # Phase 1 output (/plan command) +├── contracts/ # Phase 1 output (/plan command) +└── tasks.md # Phase 2 output (/tasks command - NOT created by /plan) +``` + +### Source Code (repository root) +``` +# Option 1: Single project (DEFAULT) +src/ +├── models/ +├── services/ +├── cli/ +└── lib/ + +tests/ +├── contract/ +├── integration/ +└── unit/ + +# Option 2: Web application (when "frontend" + "backend" detected) +backend/ +├── src/ +│ ├── models/ +│ ├── services/ +│ └── api/ +└── tests/ + +frontend/ +├── src/ +│ ├── components/ +│ ├── pages/ +│ └── services/ +└── tests/ + +# Option 3: Mobile + API (when "iOS/Android" detected) +api/ +└── [same as backend above] + +ios/ or android/ +└── [platform-specific structure] +``` + +**Structure Decision**: [DEFAULT to Option 1 unless Technical Context indicates web/mobile app] + +## Phase 0: Outline & Research +1. **Extract unknowns from Technical Context** above: + - For each NEEDS CLARIFICATION → research task + - For each dependency → best practices task + - For each integration → patterns task + +2. **Generate and dispatch research agents**: + ``` + For each unknown in Technical Context: + Task: "Research {unknown} for {feature context}" + For each technology choice: + Task: "Find best practices for {tech} in {domain}" + ``` + +3. **Consolidate findings** in `research.md` using format: + - Decision: [what was chosen] + - Rationale: [why chosen] + - Alternatives considered: [what else evaluated] + +**Output**: research.md with all NEEDS CLARIFICATION resolved + +## Phase 1: Design & Contracts +*Prerequisites: research.md complete* + +1. **Extract entities from feature spec** → `data-model.md`: + - Entity name, fields, relationships + - Validation rules from requirements + - State transitions if applicable + +2. **Generate API contracts** from functional requirements: + - For each user action → endpoint + - Use standard REST/GraphQL patterns + - Output OpenAPI/GraphQL schema to `/contracts/` + +3. **Generate contract tests** from contracts: + - One test file per endpoint + - Assert request/response schemas + - Tests must fail (no implementation yet) + +4. **Extract test scenarios** from user stories: + - Each story → integration test scenario + - Quickstart test = story validation steps + +5. **Update agent file incrementally** (O(1) operation): + - Run `.specify/scripts/bash/update-agent-context.sh claude` for your AI assistant + - If exists: Add only NEW tech from current plan + - Preserve manual additions between markers + - Update recent changes (keep last 3) + - Keep under 150 lines for token efficiency + - Output to repository root + +**Output**: data-model.md, /contracts/*, failing tests, quickstart.md, agent-specific file + +## Phase 2: Task Planning Approach +*This section describes what the /tasks command will do - DO NOT execute during /plan* + +**Task Generation Strategy**: +- Load `.specify/templates/tasks-template.md` as base +- Generate tasks from Phase 1 design docs (contracts, data model, quickstart) +- Each contract → contract test task [P] +- Each entity → model creation task [P] +- Each user story → integration test task +- Implementation tasks to make tests pass + +**Specific Task Categories for This Feature**: +- **Code Organization**: Refactor large Parse function into smaller, focused functions +- **Testing**: Create comprehensive test suite for all components +- **Interface Design**: Define clear interfaces for parser, validator, and renderer +- **Error Handling**: Implement robust error handling and recovery +- **Performance**: Ensure performance requirements are met +- **Documentation**: Add comprehensive documentation and examples + +**Ordering Strategy**: +- TDD order: Tests before implementation +- Dependency order: Interfaces → Implementations → Integrations +- Mark [P] for parallel execution (independent files) + +**Estimated Output**: 30-35 numbered, ordered tasks in tasks.md + +**Key Implementation Notes**: +- Maintain backward compatibility with existing Parse() function +- Follow SOLID principles throughout refactoring +- Achieve 90% test coverage target +- Keep performance within requirements (<5s parsing, <100MB memory) + +**IMPORTANT**: This phase is executed by the /tasks command, NOT by /plan + +## Phase 3+: Future Implementation +*These phases are beyond the scope of the /plan command* + +**Phase 3**: Task execution (/tasks command creates tasks.md) +**Phase 4**: Implementation (execute tasks.md following constitutional principles) +**Phase 5**: Validation (run tests, execute quickstart.md, performance validation) + +## Complexity Tracking +*Fill ONLY if Constitution Check has violations that must be justified* + +| Violation | Why Needed | Simpler Alternative Rejected Because | +|-----------|------------|-------------------------------------| +| [e.g., 4th project] | [current need] | [why 3 projects insufficient] | +| [e.g., Repository pattern] | [specific problem] | [why direct DB access insufficient] | + + +## Progress Tracking +*This checklist is updated during execution flow* + +**Phase Status**: +- [x] Phase 0: Research complete (/plan command) +- [x] Phase 1: Design complete (/plan command) +- [x] Phase 2: Task planning complete (/plan command - describe approach only) +- [ ] Phase 3: Tasks generated (/tasks command) +- [ ] Phase 4: Implementation complete +- [ ] Phase 5: Validation passed + +**Gate Status**: +- [x] Initial Constitution Check: PASS +- [x] Post-Design Constitution Check: PASS +- [x] All NEEDS CLARIFICATION resolved +- [x] Complexity deviations documented + +--- +*Based on Constitution v1.0.0 - See `/memory/constitution.md`* diff --git a/specs/001-pkg-ast-provider/quickstart.md b/specs/001-pkg-ast-provider/quickstart.md new file mode 100644 index 0000000..afaa9cf --- /dev/null +++ b/specs/001-pkg-ast-provider/quickstart.md @@ -0,0 +1,389 @@ +# Quick Start Guide + +## 概述 + +本指南展示如何使用重构后的 pkg/ast/provider 包来解析 Go 源码中的 `@provider` 注释并生成依赖注入代码。 + +## 前置条件 + +- Go 1.24.0+ +- 理解 Go AST 解析 +- 了解依赖注入概念 + +## 基本用法 + +### 1. 解析单个文件 + +```go +package main + +import ( + "fmt" + "log" + + "go.ipao.vip/atomctl/v2/pkg/ast/provider" +) + +func main() { + // 解析单个文件 + providers, err := provider.ParseFile("path/to/your/file.go") + if err != nil { + log.Fatal(err) + } + + // 打印解析结果 + for _, p := range providers { + fmt.Printf("Provider: %s\n", p.StructName) + fmt.Printf(" Mode: %s\n", p.Mode) + fmt.Printf(" Return Type: %s\n", p.ReturnType) + fmt.Printf(" Inject Params: %d\n", len(p.InjectParams)) + } +} +``` + +### 2. 批量解析目录 + +```go +package main + +import ( + "fmt" + "log" + + "go.ipao.vip/atomctl/v2/pkg/ast/provider" +) + +func main() { + // 创建解析器配置 + config := provider.ParserConfig{ + StrictMode: true, + AllowTestFile: false, + IgnorePattern: "*.gen.go", + } + + // 创建解析器 + parser := provider.NewParser(config) + + // 解析目录 + providers, err := parser.ParseDir("./app") + if err != nil { + log.Fatal(err) + } + + // 生成代码 + for _, p := range providers { + err := provider.Render(p.ProviderFile, []Provider{p}) + if err != nil { + log.Printf("Failed to render %s: %v", p.StructName, err) + } + } +} +``` + +## 支持的 Provider 注释格式 + +### 基本格式 +```go +// @provider +type UserService struct { + // ... +} +``` + +### 带模式 +```go +// @provider(grpc) +type UserService struct { + // ... +} +``` + +### 带注入模式 +```go +// @provider:only +type UserService struct { + Repo *UserRepo `inject:"true"` + Log *Logger `inject:"true"` +} +``` + +### 完整格式 +```go +// @provider(grpc):only contracts.Initial atom.GroupInitial +type UserService struct { + Repo *UserRepo `inject:"true"` + Log *Logger `inject:"true"` +} +``` + +## 测试指南 + +### 运行测试 +```bash +# 运行所有测试 +go test ./pkg/ast/provider/... + +# 运行测试并显示覆盖率 +go test -cover ./pkg/ast/provider/... + +# 运行基准测试 +go test -bench=. ./pkg/ast/provider/... +``` + +### 编写测试 +```go +package provider_test + +import ( + "testing" + + "go.ipao.vip/atomctl/v2/pkg/ast/provider" + "github.com/stretchr/testify/assert" +) + +func TestParseProvider(t *testing.T) { + // 准备测试代码 + source := ` +package main + +// @provider:only contracts.Initial +type TestService struct { + Repo *TestRepo `inject:"true"` +} +` + + // 创建临时文件 + tmpFile := createTempFile(t, source) + defer os.Remove(tmpFile) + + // 解析 + providers, err := provider.ParseFile(tmpFile) + + // 验证 + assert.NoError(t, err) + assert.Len(t, providers, 1) + + p := providers[0] + assert.Equal(t, "TestService", p.StructName) + assert.Equal(t, "contracts.Initial", p.ReturnType) + assert.True(t, p.InjectMode.IsOnly()) +} +``` + +## 重构指南 + +### 从旧版本迁移 + +1. **更新导入路径** +```go +// 旧版本 +import "go.ipao.vip/atomctl/v2/pkg/ast/provider" + +// 新版本(相同的导入路径) +import "go.ipao.vip/atomctl/v2/pkg/ast/provider" +``` + +2. **使用新的 API** +```go +// 旧版本 +providers := provider.Parse("file.go") + +// 新版本(向后兼容) +providers := provider.Parse("file.go") // 仍然支持 + +// 推荐的新方式 +parser := provider.NewParser(provider.DefaultConfig()) +providers, err := parser.ParseFile("file.go") +``` + +### 自定义扩展 + +#### 1. 自定义 Provider 模式 +```go +// 实现自定义模式处理器 +type CustomModeHandler struct{} + +func (h *CustomModeHandler) Handle(ctx *provider.ParserContext, comment *provider.ProviderComment) (*provider.Provider, error) { + // 自定义处理逻辑 + return &provider.Provider{ + Mode: provider.ProviderMode("custom"), + // ... + }, nil +} + +// 注册自定义模式 +provider.RegisterProviderMode("custom", &CustomModeHandler{}) +``` + +#### 2. 自定义验证器 +```go +// 实现自定义验证器 +type CustomValidator struct{} + +func (v *CustomValidator) Validate(p *provider.Provider) []error { + var errors []error + // 自定义验证逻辑 + if p.StructName == "" { + errors = append(errors, fmt.Errorf("struct name cannot be empty")) + } + return errors +} + +// 添加到验证链 +parser.AddValidator(&CustomValidator{}) +``` + +## 性能优化 + +### 1. 并行解析 +```go +// 使用并行解析提高性能 +func ParseProjectParallel(root string) ([]*provider.Provider, error) { + files, err := findGoFiles(root) + if err != nil { + return nil, err + } + + var wg sync.WaitGroup + providers := make([]*provider.Provider, 0, len(files)) + errChan := make(chan error, len(files)) + + for _, file := range files { + wg.Add(1) + go func(f string) { + defer wg.Done() + ps, err := provider.ParseFile(f) + if err != nil { + errChan <- err + return + } + providers = append(providers, ps...) + }(file) + } + + wg.Wait() + close(errChan) + + // 检查错误 + for err := range errChan { + if err != nil { + return nil, err + } + } + + return providers, nil +} +``` + +### 2. 缓存机制 +```go +// 使用缓存避免重复解析 +type CachedParser struct { + cache map[string][]*provider.Provider + parser *provider.Parser +} + +func NewCachedParser() *CachedParser { + return &CachedParser{ + cache: make(map[string][]*provider.Provider), + parser: provider.NewParser(provider.DefaultConfig()), + } +} + +func (cp *CachedParser) ParseFile(file string) ([]*provider.Provider, error) { + if providers, ok := cp.cache[file]; ok { + return providers, nil + } + + providers, err := cp.parser.ParseFile(file) + if err != nil { + return nil, err + } + + cp.cache[file] = providers + return providers, nil +} +``` + +## 故障排除 + +### 常见错误 + +1. **解析错误** +``` +error: failed to parse provider comment: invalid mode format +``` + 解决方案:检查 @provider 注释格式是否正确 + +2. **导入错误** +``` +error: cannot resolve import path "github.com/unknown/pkg" +``` + 解决方案:确保所有导入的包都存在 + +3. **验证错误** +``` +error: provider struct has invalid return type +``` + 解决方案:确保返回类型是有效的 Go 类型 + +### 调试技巧 + +1. **启用详细日志** +```go +provider.SetLogLevel(logrus.DebugLevel) +``` + +2. **使用解析器上下文** +```go +ctx := provider.NewParserContext(provider.ParserConfig{ + StrictMode: true, + ErrorHandler: func(err error) { + log.Printf("Parse error: %v", err) + }, +}) +``` + +3. **验证生成的代码** +```go +if err := provider.ValidateGeneratedCode(code); err != nil { + log.Printf("Generated code validation failed: %v", err) +} +``` + +## 最佳实践 + +1. **保持注释简洁** +```go +// 推荐 +// @provider:only contracts.Initial + +// 不推荐 +// @provider:only contracts.Initial atom.GroupInitial // 这是一个复杂的服务 +``` + +2. **使用明确的类型** +```go +// 推荐 +type UserService struct { + Repo *UserRepository `inject:"true"` +} + +// 不推荐 +type UserService struct { + Repo interface{} `inject:"true"` +} +``` + +3. **合理组织代码** +```go +// 将相关的 provider 放在同一个文件中 +// 使用明确的包名和结构名 +// 避免循环依赖 +``` + +## 下一步 + +- 查看 [data-model.md](data-model.md) 了解详细的数据模型 +- 阅读 [research.md](research.md) 了解重构决策过程 +- 查看 [contracts/](contracts/) 目录了解 API 契约 \ No newline at end of file diff --git a/specs/001-pkg-ast-provider/research.md b/specs/001-pkg-ast-provider/research.md new file mode 100644 index 0000000..ed39bdd --- /dev/null +++ b/specs/001-pkg-ast-provider/research.md @@ -0,0 +1,107 @@ +# Research Findings + +## Research Results + +### 1. 现有代码分析 +**Decision**: 现有 pkg/ast/provider 包包含两个主要文件: +- `provider.go`: 337行代码,包含复杂的 AST 解析逻辑 +- `render.go`: 65行代码,包含模板渲染逻辑 + +**Rationale**: 通过分析发现当前代码存在以下问题: +- Parse 函数过于复杂(337行),违反单一职责原则 +- 缺少测试用例,测试覆盖率为 0% +- 错误处理不够完善,缺少边界情况处理 +- 代码组织结构不够清晰,解析逻辑和业务逻辑混合 + +**Alternatives considered**: +- 保持现有结构,仅添加测试:无法解决代码复杂性问题 +- 完全重写:风险太高,可能破坏现有功能 + +### 2. 重构策略研究 +**Decision**: 采用渐进式重构策略,按照 SOLID 原则重新组织代码结构 + +**Rationale**: +- 单一职责:将解析逻辑、验证逻辑、渲染逻辑分离 +- 开闭原则:使用接口和策略模式支持扩展 +- 依赖倒置:依赖抽象接口而非具体实现 + +**Alternatives considered**: +- 大爆炸式重写:风险太高,难以测试 +- 仅添加功能而不重构:会加剧技术债务 + +### 3. 测试策略研究 +**Decision**: 采用 TDD 方法,先编写测试再实现功能 + +**Rationale**: +- 确保重构后功能正确性 +- 提供回归测试保障 +- 达到 90% 测试覆盖率目标 + +**Alternatives considered**: +- 先重构后测试:无法保证重构正确性 +- 仅测试公共接口:无法覆盖内部逻辑 + +### 4. 性能优化研究 +**Decision**: 保持现有性能水平,重点优化代码结构 + +**Rationale**: +- 当前性能已经满足需求(<5s 解析) +- 代码结构优化不会影响性能 +- 过度优化可能引入复杂性 + +**Alternatives considered**: +- 并发解析:增加复杂性,收益有限 +- 缓存机制:当前使用场景不需要 + +### 5. 向后兼容性研究 +**Decision**: 保持完全向后兼容 + +**Rationale**: +- 现有用户依赖 @provider 注释格式 +- API 接口不能破坏 +- 生成代码格式保持一致 + +**Alternatives considered**: +- 破坏性更新:影响现有用户 +- 提供迁移工具:增加维护成本 + +## 技术决策总结 + +### 代码组织优化 +- 将 Parse 函数拆分为多个职责单一的函数 +- 创建专门的解析器、验证器、渲染器接口 +- 使用工厂模式创建不同类型的处理器 + +### 测试覆盖策略 +- 单元测试:覆盖所有公共和私有函数 +- 集成测试:测试完整的工作流程 +- 基准测试:确保性能不退化 + +### 错误处理改进 +- 创建自定义错误类型 +- 提供详细的错误信息和建议 +- 支持错误恢复机制 + +### 文档和注释 +- 为所有公共函数添加 godoc 注释 +- 提供使用示例和最佳实践 +- 维护变更日志 + +## 风险评估 + +### 高风险项 +- 重构可能引入新的 bug +- 测试覆盖不足可能导致回归问题 + +### 缓解措施 +- 严格按照 TDD 流程开发 +- 每个重构步骤都要有对应的测试 +- 保持现有 API 接口不变 + +## 下一步计划 + +1. 创建详细的数据模型和接口设计 +2. 定义测试策略和测试用例 +3. 制定具体的重构步骤 +4. 实施渐进式重构 +5. 验证重构结果 \ No newline at end of file diff --git a/specs/001-pkg-ast-provider/spec.md b/specs/001-pkg-ast-provider/spec.md new file mode 100644 index 0000000..ccd6f4b --- /dev/null +++ b/specs/001-pkg-ast-provider/spec.md @@ -0,0 +1,121 @@ +# Feature Specification: 优化 pkg/ast/provider 目录的代码组织逻辑与功能实现 + +**Feature Branch**: `001-pkg-ast-provider` +**Created**: 2025-09-19 +**Status**: Draft +**Input**: User description: "优化 @pkg/ast/provider/ 目录的代码组织逻辑与功能实现,补充完善测试用例" + +## Execution Flow (main) +``` +1. Parse user description from Input + → If empty: ERROR "No feature description provided" +2. Extract key concepts from description + → Identify: actors, actions, data, constraints +3. For each unclear aspect: + → Mark with [NEEDS CLARIFICATION: specific question] +4. Fill User Scenarios & Testing section + → If no clear user flow: ERROR "Cannot determine user scenarios" +5. Generate Functional Requirements + → Each requirement must be testable + → Mark ambiguous requirements +6. Identify Key Entities (if data involved) +7. Run Review Checklist + → If any [NEEDS CLARIFICATION]: WARN "Spec has uncertainties" + → If implementation details found: ERROR "Remove tech details" +8. Return: SUCCESS (spec ready for planning) +``` + +--- + +## ⚡ Quick Guidelines +- ✅ Focus on WHAT users need and WHY +- ❌ Avoid HOW to implement (no tech stack, APIs, code structure) +- 👥 Written for business stakeholders, not developers + +### Section Requirements +- **Mandatory sections**: Must be completed for every feature +- **Optional sections**: Include only when relevant to the feature +- When a section doesn't apply, remove it entirely (don't leave as "N/A") + +### For AI Generation +When creating this spec from a user prompt: +1. **Mark all ambiguities**: Use [NEEDS CLARIFICATION: specific question] for any assumption you'd need to make +2. **Don't guess**: If the prompt doesn't specify something (e.g., "login system" without auth method), mark it +3. **Think like a tester**: Every vague requirement should fail the "testable and unambiguous" checklist item +4. **Common underspecified areas**: + - User types and permissions + - Data retention/deletion policies + - Performance targets and scale + - Error handling behaviors + - Integration requirements + - Security/compliance needs + +--- + +## User Scenarios & Testing *(mandatory)* + +### Primary User Story +作为开发者,我需要使用 atomctl 工具来解析 Go 源码中的 `@provider` 注释,并生成相应的依赖注入代码。我希望代码组织逻辑清晰,功能实现可靠,并且有完整的测试用例来保证代码质量。 + +### Acceptance Scenarios +1. **Given** 一个包含 `@provider` 注释的 Go 源码文件,**When** 我运行解析功能,**Then** 系统能正确提取所有 provider 信息并生成相应的代码 +2. **Given** 一个复杂的 Go 项目结构,**When** 我解析多个文件中的 provider,**Then** 系统能正确处理跨文件的依赖关系 +3. **Given** 一个包含错误注释格式的源码文件,**When** 我运行解析功能,**Then** 系统能提供清晰的错误信息和建议 +4. **Given** 生成的 provider 代码,**When** 我运行测试套件,**Then** 所有测试都能通过,覆盖率达到 90% + +### Edge Cases +- 当源码文件格式不正确时,系统如何处理? +- 当注释格式不规范时,系统如何提供有用的错误信息? +- 当处理大型项目时,系统性能如何保证? +- 当并发处理多个文件时,如何保证线程安全? + +## Requirements *(mandatory)* + +### Functional Requirements +- **FR-001**: System MUST 能够解析 `@provider` 注释的各种格式(基本格式、带模式、带返回类型、带分组等) +- **FR-002**: System MUST 支持不同的 provider 模式(grpc、event、job、cronjob、model) +- **FR-003**: System MUST 正确处理依赖注入参数(only/except 模式、包名解析、类型识别) +- **FR-004**: System MUST 能够生成结构化的 provider 代码,包括必要的导入和函数定义 +- **FR-005**: System MUST 提供完整的错误处理机制,包括语法错误、导入错误等 +- **FR-006**: System MUST 包含全面的测试用例,覆盖所有主要功能和边界情况 +- **FR-007**: System MUST 优化代码组织结构,提高代码的可读性和可维护性 +- **FR-008**: System MUST 处理复杂的包导入和别名解析逻辑 + +### Key Entities *(include if feature involves data)* +- **Provider**: 代表一个依赖注入提供者的核心实体,包含结构名、返回类型、模式等属性 +- **ProviderDescribe**: 描述 provider 注释解析结果的实体,包含模式、返回类型、分组等信息 +- **InjectParam**: 描述注入参数的实体,包含类型、包名、别名等信息 +- **SourceFile**: 代表待解析的源码文件,包含文件路径、内容、导入信息等 + +--- + +## Review & Acceptance Checklist +*GATE: Automated checks run during main() execution* + +### Content Quality +- [ ] No implementation details (languages, frameworks, APIs) +- [ ] Focused on user value and business needs +- [ ] Written for non-technical stakeholders +- [ ] All mandatory sections completed + +### Requirement Completeness +- [ ] No [NEEDS CLARIFICATION] markers remain +- [ ] Requirements are testable and unambiguous +- [ ] Success criteria are measurable +- [ ] Scope is clearly bounded +- [ ] Dependencies and assumptions identified + +--- + +## Execution Status +*Updated by main() during processing* + +- [ ] User description parsed +- [ ] Key concepts extracted +- [ ] Ambiguities marked +- [ ] User scenarios defined +- [ ] Requirements generated +- [ ] Entities identified +- [ ] Review checklist passed + +--- \ No newline at end of file diff --git a/specs/001-pkg-ast-provider/tasks.md b/specs/001-pkg-ast-provider/tasks.md new file mode 100644 index 0000000..721ec29 --- /dev/null +++ b/specs/001-pkg-ast-provider/tasks.md @@ -0,0 +1,274 @@ +# Tasks: 优化 pkg/ast/provider 目录的代码组织逻辑与功能实现 + +**Input**: Design documents from `/projects/atomctl/specs/001-pkg-ast-provider/` +**Prerequisites**: plan.md (required), research.md, data-model.md, contracts/ + +## Execution Flow (main) +``` +1. Load plan.md from feature directory + → If not found: ERROR "No implementation plan found" + → Extract: tech stack, libraries, structure +2. Load optional design documents: + → data-model.md: Extract entities → model tasks + → contracts/: Each file → contract test task + → research.md: Extract decisions → setup tasks +3. Generate tasks by category: + → Setup: project init, dependencies, linting + → Tests: contract tests, integration tests + → Core: models, services, CLI commands + → Integration: DB, middleware, logging + → Polish: unit tests, performance, docs +4. Apply task rules: + → Different files = mark [P] for parallel + → Same file = sequential (no [P]) + → Tests before implementation (TDD) +5. Number tasks sequentially (T001, T002...) +6. Generate dependency graph +7. Create parallel execution examples +8. Validate task completeness: + → All contracts have tests? + → All entities have models? + → All endpoints implemented? +9. Return: SUCCESS (tasks ready for execution) +``` + +## Format: `[ID] [P?] Description` +- **[P]**: Can run in parallel (different files, no dependencies) +- Include exact file paths in descriptions + +## Path Conventions +- **Single project**: `src/`, `tests/` at repository root +- **Web app**: `backend/src/`, `frontend/src/` +- **Mobile**: `api/src/`, `ios/src/` or `android/src/` +- Paths shown below assume single project - adjust based on plan.md structure + +## Phase 3.1: Setup +- [x] T001 Create test directory structure and test data files +- [x] T002 Set up testing dependencies (testify, gomock) and test configuration +- [x] T003 [P] Configure linting and benchmark tools for code quality + +## Phase 3.2: Tests First (TDD) ⚠️ MUST COMPLETE BEFORE 3.3 +**CRITICAL: These tests MUST be written and MUST FAIL before ANY implementation** + +### Contract Tests (from contracts/) +- [x] T004 [P] Contract test Parser interface in tests/contract/test_parser_api.go +- [x] T005 [P] Contract test Validator interface in tests/contract/test_validator_api.go +- [x] T006 [P] Contract test Renderer interface in tests/contract/test_renderer_api.go +- [x] T007 [P] Contract test Validation Rules in tests/contract/test_validation_rules.go + +### Integration Tests (from user stories) +- [x] T008 [P] Integration test single file parsing workflow in tests/integration/test_file_parsing.go +- [x] T009 [P] Integration test complex project with multiple providers in tests/integration/test_project_parsing.go +- [x] T010 [P] Integration test error handling and recovery in tests/integration/test_error_handling.go +- [x] T011 [P] Integration test performance requirements in tests/integration/test_performance.go + +## Phase 3.3: Core Implementation (ONLY after tests are failing) + +### Data Model Implementation (from data-model.md) +- [x] T012 [P] Provider data structure in pkg/ast/provider/types.go +- [x] T013 [P] ProviderMode and InjectionMode enums in pkg/ast/provider/modes.go +- [x] T014 [P] InjectParam and SourceLocation structs in pkg/ast/provider/types.go +- [x] T015 [P] ParserConfig and ParserContext in pkg/ast/provider/config.go + +### Interface Implementation (from contracts/) +- [x] T016 [P] Parser interface implementation in pkg/ast/provider/parser_interface.go +- [x] T017 [P] Validator interface implementation in pkg/ast/provider/validator.go +- [x] T018 [P] Renderer interface implementation in pkg/ast/provider/renderer.go +- [x] T019 [P] Error types and error handling in pkg/ast/provider/errors.go + +### Core Logic Refactoring +- [ ] T020 Extract comment parsing logic from Parse function in pkg/ast/provider/comment_parser.go +- [ ] T021 Extract AST traversal logic in pkg/ast/provider/ast_walker.go +- [ ] T022 Extract import resolution logic in pkg/ast/provider/import_resolver.go +- [ ] T023 Extract provider building logic in pkg/ast/provider/builder.go +- [ ] T024 Refactor main Parse function to use new components in pkg/ast/provider/parser.go + +### Validation Rules Implementation (from contracts/) +- [x] T025 [P] Implement struct name validation rule in pkg/ast/provider/validator.go +- [x] T026 [P] Implement return type validation rule in pkg/ast/provider/validator.go +- [x] T027 [P] Implement provider mode validation rule in pkg/ast/provider/validator.go +- [x] T028 [P] Implement injection params validation rule in pkg/ast/provider/validator.go +- [x] T029 [P] Implement validation report generation in pkg/ast/provider/report_generator.go + +## Phase 3.4: Integration + +### Parser Integration +- [ ] T030 Integrate new parser components in pkg/ast/provider/parser.go +- [ ] T031 Implement backward compatibility layer for existing Parse function +- [ ] T032 Add configuration and context management to parser +- [ ] T033 Implement caching mechanism for performance optimization + +### Validator Integration +- [ ] T034 [P] Integrate validation rules in validator implementation +- [ ] T035 Implement validation report generation +- [ ] T036 Add custom validation rule registration system + +### Renderer Integration +- [ ] T037 [P] Update renderer to use new data structures +- [ ] T038 Implement template function registration system +- [ ] T039 Add custom template support for extensibility + +### Error Handling Integration +- [ ] T040 [P] Implement comprehensive error handling across all components +- [ ] T041 Add error recovery mechanisms +- [ ] T042 Implement structured logging for debugging + +## Phase 3.5: Polish + +### Unit Tests +- [ ] T043 [P] Unit tests for comment parsing in tests/unit/test_comment_parser.go +- [ ] T044 [P] Unit tests for AST walker in tests/unit/test_ast_walker.go +- [ ] T045 [P] Unit tests for import resolver in tests/unit/test_import_resolver.go +- [ ] T046 [P] Unit tests for provider builder in tests/unit/test_builder.go +- [ ] T047 [P] Unit tests for validation rules in tests/unit/test_validation_rules.go + +### Performance Tests +- [ ] T048 Performance benchmark for single file parsing (<100ms) +- [ ] T049 Performance benchmark for large project parsing (<5s) +- [ ] T050 Memory usage validation (<100MB normal usage) +- [ ] T051 Stress test with concurrent file parsing + +### Documentation and Examples +- [ ] T052 [P] Update package documentation and godoc comments +- [ ] T053 [P] Create usage examples and migration guide +- [ ] T054 [P] Document new interfaces and extension points +- [ ] T055 [P] Update README and quickstart guide + +### Final Integration +- [ ] T056 Remove code duplication and consolidate utilities +- [ ] T057 Final performance optimization and validation +- [ ] T058 Integration test for complete backward compatibility +- [ ] T059 Run full test suite and validate 90% coverage +- [ ] T060 Final code review and cleanup + +## Dependencies +- Tests (T004-T011) before implementation (T012-T042) +- Data models (T012-T015) before interfaces (T016-T019) +- Core logic refactoring (T020-T024) before integration (T030-T042) +- Integration (T030-T042) before polish (T043-T060) +- Unit tests (T043-T047) before performance tests (T048-T051) +- Documentation (T052-T055) before final integration (T056-T060) + +## Parallel Execution Groups + +### Group 1: Contract Tests (Can run in parallel) +``` +Task: "Contract test Parser interface in tests/contract/test_parser_api.go" +Task: "Contract test Validator interface in tests/contract/test_validator_api.go" +Task: "Contract test Renderer interface in tests/contract/test_renderer_api.go" +Task: "Contract test Validation Rules in tests/contract/test_validation_rules.go" +``` + +### Group 2: Integration Tests (Can run in parallel) +``` +Task: "Integration test single file parsing workflow in tests/integration/test_file_parsing.go" +Task: "Integration test complex project with multiple providers in tests/integration/test_project_parsing.go" +Task: "Integration test error handling and recovery in tests/integration/test_error_handling.go" +Task: "Integration test performance requirements in tests/integration/test_performance.go" +``` + +### Group 3: Data Model Implementation (Can run in parallel) +``` +Task: "Provider data structure in pkg/ast/provider/types.go" +Task: "ProviderMode and InjectionMode enums in pkg/ast/provider/modes.go" +Task: "InjectParam and SourceLocation structs in pkg/ast/provider/types.go" +Task: "ParserConfig and ParserContext in pkg/ast/provider/config.go" +``` + +### Group 4: Interface Implementation (Can run in parallel) +``` +Task: "Parser interface implementation in pkg/ast/provider/parser.go" +Task: "Validator interface implementation in pkg/ast/provider/validator.go" +Task: "Renderer interface implementation in pkg/ast/provider/renderer.go" +Task: "Error types and error handling in pkg/ast/provider/errors.go" +``` + +### Group 5: Validation Rules (Can run in parallel) +``` +Task: "Struct name validation rule in pkg/ast/provider/validation/struct_name.go" +Task: "Provider mode validation rules in pkg/ast/provider/validation/mode_rules.go" +Task: "Inject params validation rule in pkg/ast/provider/validation/inject_params.go" +Task: "Package alias validation rule in pkg/ast/provider/validation/package_alias.go" +Task: "Mode-specific validation rules in pkg/ast/provider/validation/mode_specific.go" +``` + +### Group 6: Unit Tests (Can run in parallel) +``` +Task: "Unit tests for comment parsing in tests/unit/test_comment_parser.go" +Task: "Unit tests for AST walker in tests/unit/test_ast_walker.go" +Task: "Unit tests for import resolver in tests/unit/test_import_resolver.go" +Task: "Unit tests for provider builder in tests/unit/test_builder.go" +Task: "Unit tests for validation rules in tests/unit/test_validation_rules.go" +``` + +## Critical Implementation Notes + +### Backward Compatibility +- Maintain existing `Parse(filename string) []Provider` function signature +- Keep existing `Render(filename string, conf []Provider)` function signature +- Ensure all existing @provider annotation formats continue to work + +### SOLID Principles +- **Single Responsibility**: Each component has one clear purpose +- **Open/Closed**: Design for extension through interfaces +- **Liskov Substitution**: All implementations can be substituted +- **Interface Segregation**: Keep interfaces focused and specific +- **Dependency Inversion**: Depend on abstractions, not concrete types + +### Performance Requirements +- Single file parsing: <100ms +- Large project parsing: <5s +- Memory usage: <100MB (normal operation) +- Test coverage: ≥90% + +### Quality Gates +- All tests must pass before merging +- Code must follow Go formatting standards +- No linting warnings or errors +- Documentation must be complete and accurate + +## Notes +- [P] tasks = different files, no dependencies +- Verify tests fail before implementing (TDD approach) +- Commit after each significant task completion +- Each task must be specific and actionable +- Maintain backward compatibility throughout refactoring + +## Task Generation Rules + +### SOLID Compliance +- **Single Responsibility**: Each task focuses on one specific component +- **Open/Closed**: Design tasks to allow extension without modification +- **Interface Segregation**: Create focused interfaces for different task types + +### KISS Compliance +- Keep task descriptions simple and direct +- Avoid over-complicating task dependencies +- Use intuitive file naming and structure + +### YAGNI Compliance +- Only create tasks for clearly needed functionality +- Avoid speculative tasks without direct requirements +- Focus on MVP implementation first + +### DRY Compliance +- Abstract common patterns into reusable task templates +- Avoid duplicate task definitions +- Consolidate similar operations where possible + +### From Design Documents +- **Contracts**: Each contract file → contract test task [P] +- **Data Model**: Each entity → model creation task [P] +- **User Stories**: Each story → integration test [P] + +### Ordering +- Setup → Tests → Models → Services → Endpoints → Polish +- Dependencies block parallel execution + +### Validation Checklist +- [ ] All contracts have corresponding tests +- [ ] All entities have model tasks +- [ ] All tests come before implementation +- [ ] Parallel tasks truly independent +- [ ] Each task specifies exact file path +- [ ] No task modifies same file as another [P] task \ No newline at end of file