feat: 增强 gRPC 支持,优化注册函数和导入处理逻辑
This commit is contained in:
@@ -380,6 +380,18 @@ func (pb *ProviderBuilder) resolveImportDependencies(provider *Provider, context
|
||||
}
|
||||
}
|
||||
|
||||
// For gRPC mode, extract and add imports from the original file's imports
|
||||
if provider.Mode == ProviderModeGrpc && provider.GrpcRegisterFunc != "" {
|
||||
// Extract package alias from gRPC register function name (e.g., "userv1" from "userv1.RegisterUserServiceServer")
|
||||
if pkgAlias := getTypePkgName(provider.GrpcRegisterFunc); pkgAlias != "" {
|
||||
// Look for this package in the original file's imports
|
||||
if importResolution, exists := context.ImportContext.FileImports[pkgAlias]; exists {
|
||||
// Add the import from the original file
|
||||
provider.Imports[importResolution.Path] = pkgAlias
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add mode-specific imports
|
||||
modeImports := pb.getModeSpecificImports(provider.Mode)
|
||||
for alias, path := range modeImports {
|
||||
|
||||
@@ -368,11 +368,25 @@ func (p *MainParser) applyGrpcCompatibility(provider *Provider) {
|
||||
}
|
||||
|
||||
// Set return type and register function
|
||||
if provider.GrpcRegisterFunc == "" {
|
||||
provider.GrpcRegisterFunc = provider.ReturnType
|
||||
}
|
||||
// Important: Save the original return type before setting it to contracts.Initial
|
||||
originalReturnType := provider.ReturnType
|
||||
provider.ReturnType = "contracts.Initial"
|
||||
|
||||
// Set gRPC register function name if not already set
|
||||
if provider.GrpcRegisterFunc == "" {
|
||||
if originalReturnType != "" && strings.Contains(originalReturnType, ".") {
|
||||
// User specified a complete register function name, like userv1.RegisterUserServiceServer
|
||||
provider.GrpcRegisterFunc = originalReturnType
|
||||
} else {
|
||||
// Generate default gRPC register function name
|
||||
// Example: UserService -> RegisterUserServiceServer
|
||||
provider.GrpcRegisterFunc = "Register" + strings.TrimPrefix(originalReturnType, "*") + "Server"
|
||||
}
|
||||
}
|
||||
|
||||
// Note: Package import handling for gRPC register functions is now done
|
||||
// in resolveImportDependencies to ensure access to original file imports
|
||||
|
||||
// Add gRPC injection parameter
|
||||
provider.InjectParams["__grpc"] = InjectParam{
|
||||
Star: "*",
|
||||
|
||||
@@ -498,7 +498,7 @@ func (p *GoParser) parseProviderDecl(filePath string, fileNode *ast.File, decl a
|
||||
}
|
||||
|
||||
// Handle special provider modes
|
||||
p.handleProviderModes(provider, providerDoc.Mode)
|
||||
p.handleProviderModes(provider, providerDoc.Mode, imports)
|
||||
|
||||
// Add source location if enabled
|
||||
if p.config.SourceLocations {
|
||||
@@ -593,20 +593,41 @@ func (p *GoParser) parseFieldType(expr ast.Expr, imports map[string]string) (sta
|
||||
}
|
||||
|
||||
// handleProviderModes applies special handling for different provider modes
|
||||
func (p *GoParser) handleProviderModes(provider *Provider, mode string) {
|
||||
func (p *GoParser) handleProviderModes(provider *Provider, mode string, imports map[string]string) {
|
||||
moduleName := gomod.GetModuleName()
|
||||
|
||||
switch provider.Mode {
|
||||
case ProviderModeGrpc:
|
||||
modePkg := moduleName + "/providers/grpc"
|
||||
provider.ProviderGroup = "atom.GroupInitial"
|
||||
provider.GrpcRegisterFunc = provider.ReturnType
|
||||
// Save the original return type before changing it
|
||||
originalReturnType := provider.ReturnType
|
||||
provider.GrpcRegisterFunc = originalReturnType
|
||||
provider.ReturnType = "contracts.Initial"
|
||||
|
||||
provider.Imports[atomPackage("")] = ""
|
||||
provider.Imports[atomPackage("contracts")] = ""
|
||||
provider.Imports[modePkg] = ""
|
||||
|
||||
// Extract and add gRPC service package import
|
||||
if originalReturnType != "" && strings.Contains(originalReturnType, ".") {
|
||||
// Extract package alias from gRPC register function name (e.g., "userv1" from "userv1.RegisterUserServiceServer")
|
||||
if pkgAlias := getTypePkgName(originalReturnType); pkgAlias != "" {
|
||||
// Look for this package in the original file's imports
|
||||
if importPath, exists := imports[pkgAlias]; exists {
|
||||
// Use the exact import path from the original file
|
||||
provider.Imports[importPath] = pkgAlias
|
||||
} else {
|
||||
// Fallback: try to infer the common pattern
|
||||
if moduleName != "" {
|
||||
// Common pattern: {module}/pkg/proto/{service}/v1
|
||||
servicePkg := moduleName + "/pkg/proto/" + strings.ToLower(pkgAlias)
|
||||
provider.Imports[servicePkg] = pkgAlias
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
provider.InjectParams["__grpc"] = InjectParam{
|
||||
Star: "*",
|
||||
Type: "Grpc",
|
||||
|
||||
@@ -311,7 +311,23 @@ func Parse(source string) []Provider {
|
||||
provider.Imports[modePkg] = ""
|
||||
|
||||
provider.ProviderGroup = "atom.GroupInitial"
|
||||
provider.GrpcRegisterFunc = provider.ReturnType
|
||||
|
||||
// Handle gRPC register function correctly
|
||||
if providerDoc.ReturnType != "" && strings.Contains(providerDoc.ReturnType, ".") {
|
||||
// User specified a complete register function name, like userv1.RegisterUserServiceServer
|
||||
provider.GrpcRegisterFunc = providerDoc.ReturnType
|
||||
// Extract package information and add import
|
||||
if pkgAlias := getTypePkgName(providerDoc.ReturnType); pkgAlias != "" {
|
||||
if importPkg, ok := imports[pkgAlias]; ok {
|
||||
provider.Imports[importPkg] = pkgAlias
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Generate default gRPC register function name
|
||||
// Example: UserService -> RegisterUserServiceServer
|
||||
provider.GrpcRegisterFunc = "Register" + strings.TrimPrefix(provider.ReturnType, "*") + "Server"
|
||||
}
|
||||
|
||||
provider.ReturnType = "contracts.Initial"
|
||||
|
||||
provider.InjectParams["__grpc"] = InjectParam{
|
||||
|
||||
Reference in New Issue
Block a user