Files
atomctl/pkg/ast/provider/builder.go

522 lines
15 KiB
Go

package provider
import (
"fmt"
"go/ast"
"strings"
)
// ProviderBuilder handles the construction of Provider objects from parsed AST components
type ProviderBuilder struct {
config *BuilderConfig
commentParser *CommentParser
importResolver *ImportResolver
astWalker *ASTWalker
}
// BuilderConfig configures the provider builder behavior
type BuilderConfig struct {
EnableValidation bool
StrictMode bool
DefaultProviderMode ProviderMode
DefaultInjectionMode InjectionMode
AutoGenerateReturnTypes bool
ResolveImportDependencies bool
}
// BuilderContext maintains context during provider building
type BuilderContext struct {
FilePath string
PackageName string
ImportContext *ImportContext
ASTFile *ast.File
ProcessedTypes map[string]bool
Errors []error
Warnings []string
}
// NewProviderBuilder creates a new ProviderBuilder with default configuration
func NewProviderBuilder() *ProviderBuilder {
return &ProviderBuilder{
config: &BuilderConfig{
EnableValidation: true,
StrictMode: false,
DefaultProviderMode: ProviderModeBasic,
DefaultInjectionMode: InjectionModeAuto,
AutoGenerateReturnTypes: true,
ResolveImportDependencies: true,
},
commentParser: NewCommentParser(),
importResolver: NewImportResolver(),
astWalker: NewASTWalker(),
}
}
// NewProviderBuilderWithConfig creates a new ProviderBuilder with custom configuration
func NewProviderBuilderWithConfig(config *BuilderConfig) *ProviderBuilder {
if config == nil {
return NewProviderBuilder()
}
return &ProviderBuilder{
config: config,
commentParser: NewCommentParser(),
importResolver: NewImportResolver(),
astWalker: NewASTWalkerWithConfig(&WalkerConfig{
StrictMode: config.StrictMode,
}),
}
}
// BuildFromTypeSpec builds a Provider from a type specification and its declaration
func (pb *ProviderBuilder) BuildFromTypeSpec(typeSpec *ast.TypeSpec, decl *ast.GenDecl, context *BuilderContext) (Provider, error) {
if typeSpec == nil {
return Provider{}, fmt.Errorf("type specification cannot be nil")
}
// Initialize builder context if not provided
if context == nil {
context = &BuilderContext{
ProcessedTypes: make(map[string]bool),
Errors: make([]error, 0),
Warnings: make([]string, 0),
}
}
// Check if type has already been processed
if context.ProcessedTypes[typeSpec.Name.Name] {
return Provider{}, fmt.Errorf("type %s has already been processed", typeSpec.Name.Name)
}
// Parse provider comment
providerComment, err := pb.parseProviderComment(decl)
if err != nil {
return Provider{}, fmt.Errorf("failed to parse provider comment: %w", err)
}
// Create basic provider structure
provider := Provider{
StructName: typeSpec.Name.Name,
Mode: pb.determineProviderMode(providerComment),
ProviderGroup: pb.determineProviderGroup(providerComment),
InjectParams: make(map[string]InjectParam),
Imports: make(map[string]string),
PkgName: context.PackageName,
ProviderFile: context.FilePath,
Location: SourceLocation{
File: context.FilePath,
},
}
// Set return type
if err := pb.setReturnType(&provider, providerComment, typeSpec); err != nil {
return Provider{}, err
}
// Process struct fields if it's a struct type
if structType, ok := typeSpec.Type.(*ast.StructType); ok {
if err := pb.processStructFields(&provider, structType, context); err != nil {
return Provider{}, err
}
}
// Resolve import dependencies
if pb.config.ResolveImportDependencies {
if err := pb.resolveImportDependencies(&provider, context); err != nil {
return Provider{}, err
}
}
// Apply mode-specific configurations
if err := pb.applyModeSpecificConfig(&provider, providerComment); err != nil {
return Provider{}, err
}
// Validate the built provider
if pb.config.EnableValidation {
if err := pb.validateProvider(&provider); err != nil {
return Provider{}, err
}
}
// Mark type as processed
context.ProcessedTypes[typeSpec.Name.Name] = true
return provider, nil
}
// BuildFromComment builds a Provider from a provider comment string
func (pb *ProviderBuilder) BuildFromComment(comment string, context *BuilderContext) (Provider, error) {
// Parse the provider comment
providerComment, err := pb.commentParser.ParseProviderComment(comment)
if err != nil {
return Provider{}, fmt.Errorf("failed to parse provider comment: %w", err)
}
// Create basic provider structure
provider := Provider{
Mode: pb.determineProviderMode(providerComment),
ProviderGroup: pb.determineProviderGroup(providerComment),
InjectParams: make(map[string]InjectParam),
Imports: make(map[string]string),
}
// Set return type from comment
if providerComment.ReturnType != "" {
provider.ReturnType = providerComment.ReturnType
} else if pb.config.AutoGenerateReturnTypes {
// Generate a default return type based on mode
provider.ReturnType = pb.generateDefaultReturnType(provider.Mode)
}
// Apply mode-specific configurations
if err := pb.applyModeSpecificConfig(&provider, providerComment); err != nil {
return Provider{}, err
}
// Validate the built provider
if pb.config.EnableValidation {
if err := pb.validateProvider(&provider); err != nil {
return Provider{}, err
}
}
return provider, nil
}
// parseProviderComment parses the provider comment from a declaration
func (pb *ProviderBuilder) parseProviderComment(decl *ast.GenDecl) (*ProviderComment, error) {
if decl.Doc == nil || len(decl.Doc.List) == 0 {
return nil, fmt.Errorf("no documentation found for declaration")
}
// Extract comment lines
commentLines := make([]string, len(decl.Doc.List))
for i, comment := range decl.Doc.List {
commentLines[i] = comment.Text
}
// Parse provider annotation
return pb.commentParser.ParseCommentBlock(commentLines)
}
// determineProviderMode determines the provider mode from the comment
func (pb *ProviderBuilder) determineProviderMode(comment *ProviderComment) ProviderMode {
if comment != nil && comment.Mode != "" {
return comment.Mode
}
return pb.config.DefaultProviderMode
}
// determineProviderGroup determines the provider group from the comment
func (pb *ProviderBuilder) determineProviderGroup(comment *ProviderComment) string {
if comment != nil && comment.Group != "" {
return comment.Group
}
return ""
}
// setReturnType sets the return type for the provider
func (pb *ProviderBuilder) setReturnType(provider *Provider, comment *ProviderComment, typeSpec *ast.TypeSpec) error {
if comment != nil && comment.ReturnType != "" {
provider.ReturnType = comment.ReturnType
return nil
}
if pb.config.AutoGenerateReturnTypes {
provider.ReturnType = pb.generateDefaultReturnType(provider.Mode)
return nil
}
// Default to pointer type
provider.ReturnType = "*" + typeSpec.Name.Name
return nil
}
// generateDefaultReturnType generates a default return type based on provider mode
func (pb *ProviderBuilder) generateDefaultReturnType(mode ProviderMode) string {
switch mode {
case ProviderModeBasic:
return "interface{}"
case ProviderModeGrpc:
return "interface{}"
case ProviderModeEvent:
return "func() error"
case ProviderModeJob:
return "func(ctx context.Context) error"
case ProviderModeCronJob:
return "func(ctx context.Context) error"
case ProviderModeModel:
return "interface{}"
default:
return "interface{}"
}
}
// processStructFields processes struct fields to extract injection parameters
func (pb *ProviderBuilder) processStructFields(provider *Provider, structType *ast.StructType, context *BuilderContext) error {
if structType.Fields == nil {
return nil
}
for _, field := range structType.Fields.List {
if len(field.Names) == 0 {
// Skip anonymous fields
continue
}
for _, fieldName := range field.Names {
injectParam, err := pb.createInjectParam(fieldName.Name, field, context)
if err != nil {
return fmt.Errorf("failed to create inject param for field %s: %w", fieldName.Name, err)
}
provider.InjectParams[fieldName.Name] = *injectParam
}
}
return nil
}
// createInjectParam creates an InjectParam from a field
func (pb *ProviderBuilder) createInjectParam(fieldName string, field *ast.Field, context *BuilderContext) (*InjectParam, error) {
// Extract field type
typeStr := pb.extractFieldType(field)
if typeStr == "" {
return nil, fmt.Errorf("cannot determine type for field %s", fieldName)
}
// Check for import dependencies
packagePath, packageAlias := pb.extractImportInfo(typeStr, context)
return &InjectParam{
Type: typeStr,
Package: packagePath,
PackageAlias: packageAlias,
}, nil
}
// extractFieldType extracts the type string from a field
func (pb *ProviderBuilder) extractFieldType(field *ast.Field) string {
if field.Type == nil {
return ""
}
// Handle different type representations
switch t := field.Type.(type) {
case *ast.Ident:
return t.Name
case *ast.SelectorExpr:
// Handle qualified identifiers like "package.Type"
if x, ok := t.X.(*ast.Ident); ok {
return x.Name + "." + t.Sel.Name
}
case *ast.StarExpr:
// Handle pointer types
if x, ok := t.X.(*ast.Ident); ok {
return "*" + x.Name
}
case *ast.ArrayType:
// Handle array/slice types
if x, ok := t.Elt.(*ast.Ident); ok {
return "[]" + x.Name
}
case *ast.MapType:
// Handle map types
keyType := pb.extractTypeFromExpr(t.Key)
valueType := pb.extractTypeFromExpr(t.Value)
if keyType != "" && valueType != "" {
return fmt.Sprintf("map[%s]%s", keyType, valueType)
}
}
return ""
}
// extractTypeFromExpr extracts type string from an expression
func (pb *ProviderBuilder) extractTypeFromExpr(expr ast.Expr) string {
switch t := expr.(type) {
case *ast.Ident:
return t.Name
case *ast.SelectorExpr:
if x, ok := t.X.(*ast.Ident); ok {
return x.Name + "." + t.Sel.Name
}
}
return ""
}
// extractImportInfo extracts import information from a type string
func (pb *ProviderBuilder) extractImportInfo(typeStr string, context *BuilderContext) (packagePath, packageAlias string) {
if !strings.Contains(typeStr, ".") {
return "", ""
}
// Extract the package part
parts := strings.Split(typeStr, ".")
if len(parts) < 2 {
return "", ""
}
packageAlias = parts[0]
// Look up the import path from the import context
if context.ImportContext != nil {
if path, exists := context.ImportContext.ImportPaths[packageAlias]; exists {
return path, packageAlias
}
}
return "", packageAlias
}
// resolveImportDependencies resolves import dependencies for the provider
func (pb *ProviderBuilder) resolveImportDependencies(provider *Provider, context *BuilderContext) error {
if context.ImportContext == nil {
return nil
}
// Add imports from injection parameters
for _, param := range provider.InjectParams {
if param.Package != "" && param.PackageAlias != "" {
provider.Imports[param.PackageAlias] = param.Package
}
}
// For gRPC mode, extract and add imports from the original file's imports
if provider.Mode == ProviderModeGrpc && provider.GrpcRegisterFunc != "" {
// Extract package alias from gRPC register function name (e.g., "userv1" from "userv1.RegisterUserServiceServer")
if pkgAlias := getTypePkgName(provider.GrpcRegisterFunc); pkgAlias != "" {
// Look for this package in the original file's imports
if importResolution, exists := context.ImportContext.FileImports[pkgAlias]; exists {
// Add the import from the original file
provider.Imports[importResolution.Path] = pkgAlias
}
}
}
// Add mode-specific imports
modeImports := pb.getModeSpecificImports(provider.Mode)
for alias, path := range modeImports {
// Check for conflicts
if existingPath, exists := provider.Imports[alias]; exists && existingPath != path {
// Handle conflict by generating unique alias
uniqueAlias := pb.generateUniqueAlias(alias, provider.Imports)
provider.Imports[uniqueAlias] = path
} else {
provider.Imports[alias] = path
}
}
return nil
}
// getModeSpecificImports returns mode-specific import requirements
func (pb *ProviderBuilder) getModeSpecificImports(mode ProviderMode) map[string]string {
imports := make(map[string]string)
switch mode {
case ProviderModeGrpc:
imports["grpc"] = "google.golang.org/grpc"
case ProviderModeEvent:
imports["context"] = "context"
case ProviderModeJob:
imports["context"] = "context"
case ProviderModeCronJob:
imports["context"] = "context"
case ProviderModeModel:
imports["encoding/json"] = "encoding/json"
}
return imports
}
// generateUniqueAlias generates a unique alias to avoid conflicts
func (pb *ProviderBuilder) generateUniqueAlias(baseAlias string, existingImports map[string]string) string {
for i := 1; i < 1000; i++ {
candidate := fmt.Sprintf("%s%d", baseAlias, i)
if _, exists := existingImports[candidate]; !exists {
return candidate
}
}
return baseAlias
}
// applyModeSpecificConfig applies mode-specific configurations to the provider
func (pb *ProviderBuilder) applyModeSpecificConfig(provider *Provider, comment *ProviderComment) error {
switch provider.Mode {
case ProviderModeGrpc:
provider.GrpcRegisterFunc = pb.generateGrpcRegisterFuncName(provider.StructName)
provider.NeedPrepareFunc = true
case ProviderModeEvent:
provider.NeedPrepareFunc = true
case ProviderModeJob:
provider.NeedPrepareFunc = true
case ProviderModeCronJob:
provider.NeedPrepareFunc = true
case ProviderModeModel:
provider.NeedPrepareFunc = true
default:
// Basic mode - no special configuration
}
return nil
}
// generateGrpcRegisterFuncName generates a gRPC register function name
func (pb *ProviderBuilder) generateGrpcRegisterFuncName(structName string) string {
// Convert struct name to register function name
// Example: UserService -> RegisterUserServiceServer
return "Register" + structName + "Server"
}
// validateProvider validates the constructed provider
func (pb *ProviderBuilder) validateProvider(provider *Provider) error {
// Basic validation
if provider.StructName == "" {
return fmt.Errorf("provider struct name cannot be empty")
}
if provider.ReturnType == "" {
return fmt.Errorf("provider return type cannot be empty")
}
if !IsValidProviderMode(string(provider.Mode)) {
return fmt.Errorf("invalid provider mode: %s", provider.Mode)
}
// Validate injection parameters
for name, param := range provider.InjectParams {
if param.Type == "" {
return fmt.Errorf("injection parameter type cannot be empty for %s", name)
}
}
return nil
}
// GetConfig returns the builder configuration
func (pb *ProviderBuilder) GetConfig() *BuilderConfig {
return pb.config
}
// SetConfig updates the builder configuration
func (pb *ProviderBuilder) SetConfig(config *BuilderConfig) {
pb.config = config
}
// GetCommentParser returns the comment parser used by the builder
func (pb *ProviderBuilder) GetCommentParser() *CommentParser {
return pb.commentParser
}
// GetImportResolver returns the import resolver used by the builder
func (pb *ProviderBuilder) GetImportResolver() *ImportResolver {
return pb.importResolver
}
// GetASTWalker returns the AST walker used by the builder
func (pb *ProviderBuilder) GetASTWalker() *ASTWalker {
return pb.astWalker
}