From 0cfc5739602ac46ea47d52ab111d7216f7539359 Mon Sep 17 00:00:00 2001 From: Rogee Date: Mon, 22 Sep 2025 10:00:25 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=BC=BA=20gRPC=20=E6=94=AF?= =?UTF-8?q?=E6=8C=81=EF=BC=8C=E4=BC=98=E5=8C=96=E6=B3=A8=E5=86=8C=E5=87=BD?= =?UTF-8?q?=E6=95=B0=E5=92=8C=E5=AF=BC=E5=85=A5=E5=A4=84=E7=90=86=E9=80=BB?= =?UTF-8?q?=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/ast/provider/builder.go | 12 ++++++++++++ pkg/ast/provider/parser.go | 20 +++++++++++++++++--- pkg/ast/provider/parser_interface.go | 27 ++++++++++++++++++++++++--- pkg/ast/provider/provider.go | 18 +++++++++++++++++- 4 files changed, 70 insertions(+), 7 deletions(-) diff --git a/pkg/ast/provider/builder.go b/pkg/ast/provider/builder.go index 8e78bef..6cc3d75 100644 --- a/pkg/ast/provider/builder.go +++ b/pkg/ast/provider/builder.go @@ -380,6 +380,18 @@ func (pb *ProviderBuilder) resolveImportDependencies(provider *Provider, context } } + // For gRPC mode, extract and add imports from the original file's imports + if provider.Mode == ProviderModeGrpc && provider.GrpcRegisterFunc != "" { + // Extract package alias from gRPC register function name (e.g., "userv1" from "userv1.RegisterUserServiceServer") + if pkgAlias := getTypePkgName(provider.GrpcRegisterFunc); pkgAlias != "" { + // Look for this package in the original file's imports + if importResolution, exists := context.ImportContext.FileImports[pkgAlias]; exists { + // Add the import from the original file + provider.Imports[importResolution.Path] = pkgAlias + } + } + } + // Add mode-specific imports modeImports := pb.getModeSpecificImports(provider.Mode) for alias, path := range modeImports { diff --git a/pkg/ast/provider/parser.go b/pkg/ast/provider/parser.go index 318b06c..e5ff47c 100644 --- a/pkg/ast/provider/parser.go +++ b/pkg/ast/provider/parser.go @@ -368,11 +368,25 @@ func (p *MainParser) applyGrpcCompatibility(provider *Provider) { } // Set return type and register function - if provider.GrpcRegisterFunc == "" { - provider.GrpcRegisterFunc = provider.ReturnType - } + // Important: Save the original return type before setting it to contracts.Initial + originalReturnType := provider.ReturnType provider.ReturnType = "contracts.Initial" + // Set gRPC register function name if not already set + if provider.GrpcRegisterFunc == "" { + if originalReturnType != "" && strings.Contains(originalReturnType, ".") { + // User specified a complete register function name, like userv1.RegisterUserServiceServer + provider.GrpcRegisterFunc = originalReturnType + } else { + // Generate default gRPC register function name + // Example: UserService -> RegisterUserServiceServer + provider.GrpcRegisterFunc = "Register" + strings.TrimPrefix(originalReturnType, "*") + "Server" + } + } + + // Note: Package import handling for gRPC register functions is now done + // in resolveImportDependencies to ensure access to original file imports + // Add gRPC injection parameter provider.InjectParams["__grpc"] = InjectParam{ Star: "*", diff --git a/pkg/ast/provider/parser_interface.go b/pkg/ast/provider/parser_interface.go index f8cf656..5ad8f62 100644 --- a/pkg/ast/provider/parser_interface.go +++ b/pkg/ast/provider/parser_interface.go @@ -498,7 +498,7 @@ func (p *GoParser) parseProviderDecl(filePath string, fileNode *ast.File, decl a } // Handle special provider modes - p.handleProviderModes(provider, providerDoc.Mode) + p.handleProviderModes(provider, providerDoc.Mode, imports) // Add source location if enabled if p.config.SourceLocations { @@ -593,20 +593,41 @@ func (p *GoParser) parseFieldType(expr ast.Expr, imports map[string]string) (sta } // handleProviderModes applies special handling for different provider modes -func (p *GoParser) handleProviderModes(provider *Provider, mode string) { +func (p *GoParser) handleProviderModes(provider *Provider, mode string, imports map[string]string) { moduleName := gomod.GetModuleName() switch provider.Mode { case ProviderModeGrpc: modePkg := moduleName + "/providers/grpc" provider.ProviderGroup = "atom.GroupInitial" - provider.GrpcRegisterFunc = provider.ReturnType + // Save the original return type before changing it + originalReturnType := provider.ReturnType + provider.GrpcRegisterFunc = originalReturnType provider.ReturnType = "contracts.Initial" provider.Imports[atomPackage("")] = "" provider.Imports[atomPackage("contracts")] = "" provider.Imports[modePkg] = "" + // Extract and add gRPC service package import + if originalReturnType != "" && strings.Contains(originalReturnType, ".") { + // Extract package alias from gRPC register function name (e.g., "userv1" from "userv1.RegisterUserServiceServer") + if pkgAlias := getTypePkgName(originalReturnType); pkgAlias != "" { + // Look for this package in the original file's imports + if importPath, exists := imports[pkgAlias]; exists { + // Use the exact import path from the original file + provider.Imports[importPath] = pkgAlias + } else { + // Fallback: try to infer the common pattern + if moduleName != "" { + // Common pattern: {module}/pkg/proto/{service}/v1 + servicePkg := moduleName + "/pkg/proto/" + strings.ToLower(pkgAlias) + provider.Imports[servicePkg] = pkgAlias + } + } + } + } + provider.InjectParams["__grpc"] = InjectParam{ Star: "*", Type: "Grpc", diff --git a/pkg/ast/provider/provider.go b/pkg/ast/provider/provider.go index 696f510..470aa1a 100644 --- a/pkg/ast/provider/provider.go +++ b/pkg/ast/provider/provider.go @@ -311,7 +311,23 @@ func Parse(source string) []Provider { provider.Imports[modePkg] = "" provider.ProviderGroup = "atom.GroupInitial" - provider.GrpcRegisterFunc = provider.ReturnType + + // Handle gRPC register function correctly + if providerDoc.ReturnType != "" && strings.Contains(providerDoc.ReturnType, ".") { + // User specified a complete register function name, like userv1.RegisterUserServiceServer + provider.GrpcRegisterFunc = providerDoc.ReturnType + // Extract package information and add import + if pkgAlias := getTypePkgName(providerDoc.ReturnType); pkgAlias != "" { + if importPkg, ok := imports[pkgAlias]; ok { + provider.Imports[importPkg] = pkgAlias + } + } + } else { + // Generate default gRPC register function name + // Example: UserService -> RegisterUserServiceServer + provider.GrpcRegisterFunc = "Register" + strings.TrimPrefix(provider.ReturnType, "*") + "Server" + } + provider.ReturnType = "contracts.Initial" provider.InjectParams["__grpc"] = InjectParam{