feat: 增强 gRPC 支持,优化注册函数和导入处理逻辑

This commit is contained in:
Rogee
2025-09-22 10:00:25 +08:00
parent 1e98d0eaff
commit 0cfc573960
4 changed files with 70 additions and 7 deletions

View File

@@ -380,6 +380,18 @@ func (pb *ProviderBuilder) resolveImportDependencies(provider *Provider, context
}
}
// For gRPC mode, extract and add imports from the original file's imports
if provider.Mode == ProviderModeGrpc && provider.GrpcRegisterFunc != "" {
// Extract package alias from gRPC register function name (e.g., "userv1" from "userv1.RegisterUserServiceServer")
if pkgAlias := getTypePkgName(provider.GrpcRegisterFunc); pkgAlias != "" {
// Look for this package in the original file's imports
if importResolution, exists := context.ImportContext.FileImports[pkgAlias]; exists {
// Add the import from the original file
provider.Imports[importResolution.Path] = pkgAlias
}
}
}
// Add mode-specific imports
modeImports := pb.getModeSpecificImports(provider.Mode)
for alias, path := range modeImports {

View File

@@ -368,11 +368,25 @@ func (p *MainParser) applyGrpcCompatibility(provider *Provider) {
}
// Set return type and register function
if provider.GrpcRegisterFunc == "" {
provider.GrpcRegisterFunc = provider.ReturnType
}
// Important: Save the original return type before setting it to contracts.Initial
originalReturnType := provider.ReturnType
provider.ReturnType = "contracts.Initial"
// Set gRPC register function name if not already set
if provider.GrpcRegisterFunc == "" {
if originalReturnType != "" && strings.Contains(originalReturnType, ".") {
// User specified a complete register function name, like userv1.RegisterUserServiceServer
provider.GrpcRegisterFunc = originalReturnType
} else {
// Generate default gRPC register function name
// Example: UserService -> RegisterUserServiceServer
provider.GrpcRegisterFunc = "Register" + strings.TrimPrefix(originalReturnType, "*") + "Server"
}
}
// Note: Package import handling for gRPC register functions is now done
// in resolveImportDependencies to ensure access to original file imports
// Add gRPC injection parameter
provider.InjectParams["__grpc"] = InjectParam{
Star: "*",

View File

@@ -498,7 +498,7 @@ func (p *GoParser) parseProviderDecl(filePath string, fileNode *ast.File, decl a
}
// Handle special provider modes
p.handleProviderModes(provider, providerDoc.Mode)
p.handleProviderModes(provider, providerDoc.Mode, imports)
// Add source location if enabled
if p.config.SourceLocations {
@@ -593,20 +593,41 @@ func (p *GoParser) parseFieldType(expr ast.Expr, imports map[string]string) (sta
}
// handleProviderModes applies special handling for different provider modes
func (p *GoParser) handleProviderModes(provider *Provider, mode string) {
func (p *GoParser) handleProviderModes(provider *Provider, mode string, imports map[string]string) {
moduleName := gomod.GetModuleName()
switch provider.Mode {
case ProviderModeGrpc:
modePkg := moduleName + "/providers/grpc"
provider.ProviderGroup = "atom.GroupInitial"
provider.GrpcRegisterFunc = provider.ReturnType
// Save the original return type before changing it
originalReturnType := provider.ReturnType
provider.GrpcRegisterFunc = originalReturnType
provider.ReturnType = "contracts.Initial"
provider.Imports[atomPackage("")] = ""
provider.Imports[atomPackage("contracts")] = ""
provider.Imports[modePkg] = ""
// Extract and add gRPC service package import
if originalReturnType != "" && strings.Contains(originalReturnType, ".") {
// Extract package alias from gRPC register function name (e.g., "userv1" from "userv1.RegisterUserServiceServer")
if pkgAlias := getTypePkgName(originalReturnType); pkgAlias != "" {
// Look for this package in the original file's imports
if importPath, exists := imports[pkgAlias]; exists {
// Use the exact import path from the original file
provider.Imports[importPath] = pkgAlias
} else {
// Fallback: try to infer the common pattern
if moduleName != "" {
// Common pattern: {module}/pkg/proto/{service}/v1
servicePkg := moduleName + "/pkg/proto/" + strings.ToLower(pkgAlias)
provider.Imports[servicePkg] = pkgAlias
}
}
}
}
provider.InjectParams["__grpc"] = InjectParam{
Star: "*",
Type: "Grpc",

View File

@@ -311,7 +311,23 @@ func Parse(source string) []Provider {
provider.Imports[modePkg] = ""
provider.ProviderGroup = "atom.GroupInitial"
provider.GrpcRegisterFunc = provider.ReturnType
// Handle gRPC register function correctly
if providerDoc.ReturnType != "" && strings.Contains(providerDoc.ReturnType, ".") {
// User specified a complete register function name, like userv1.RegisterUserServiceServer
provider.GrpcRegisterFunc = providerDoc.ReturnType
// Extract package information and add import
if pkgAlias := getTypePkgName(providerDoc.ReturnType); pkgAlias != "" {
if importPkg, ok := imports[pkgAlias]; ok {
provider.Imports[importPkg] = pkgAlias
}
}
} else {
// Generate default gRPC register function name
// Example: UserService -> RegisterUserServiceServer
provider.GrpcRegisterFunc = "Register" + strings.TrimPrefix(provider.ReturnType, "*") + "Server"
}
provider.ReturnType = "contracts.Initial"
provider.InjectParams["__grpc"] = InjectParam{