feat: 增强 gRPC 支持,优化注册函数和导入处理逻辑
This commit is contained in:
@@ -380,6 +380,18 @@ func (pb *ProviderBuilder) resolveImportDependencies(provider *Provider, context
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For gRPC mode, extract and add imports from the original file's imports
|
||||||
|
if provider.Mode == ProviderModeGrpc && provider.GrpcRegisterFunc != "" {
|
||||||
|
// Extract package alias from gRPC register function name (e.g., "userv1" from "userv1.RegisterUserServiceServer")
|
||||||
|
if pkgAlias := getTypePkgName(provider.GrpcRegisterFunc); pkgAlias != "" {
|
||||||
|
// Look for this package in the original file's imports
|
||||||
|
if importResolution, exists := context.ImportContext.FileImports[pkgAlias]; exists {
|
||||||
|
// Add the import from the original file
|
||||||
|
provider.Imports[importResolution.Path] = pkgAlias
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Add mode-specific imports
|
// Add mode-specific imports
|
||||||
modeImports := pb.getModeSpecificImports(provider.Mode)
|
modeImports := pb.getModeSpecificImports(provider.Mode)
|
||||||
for alias, path := range modeImports {
|
for alias, path := range modeImports {
|
||||||
|
|||||||
@@ -368,11 +368,25 @@ func (p *MainParser) applyGrpcCompatibility(provider *Provider) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set return type and register function
|
// Set return type and register function
|
||||||
if provider.GrpcRegisterFunc == "" {
|
// Important: Save the original return type before setting it to contracts.Initial
|
||||||
provider.GrpcRegisterFunc = provider.ReturnType
|
originalReturnType := provider.ReturnType
|
||||||
}
|
|
||||||
provider.ReturnType = "contracts.Initial"
|
provider.ReturnType = "contracts.Initial"
|
||||||
|
|
||||||
|
// Set gRPC register function name if not already set
|
||||||
|
if provider.GrpcRegisterFunc == "" {
|
||||||
|
if originalReturnType != "" && strings.Contains(originalReturnType, ".") {
|
||||||
|
// User specified a complete register function name, like userv1.RegisterUserServiceServer
|
||||||
|
provider.GrpcRegisterFunc = originalReturnType
|
||||||
|
} else {
|
||||||
|
// Generate default gRPC register function name
|
||||||
|
// Example: UserService -> RegisterUserServiceServer
|
||||||
|
provider.GrpcRegisterFunc = "Register" + strings.TrimPrefix(originalReturnType, "*") + "Server"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: Package import handling for gRPC register functions is now done
|
||||||
|
// in resolveImportDependencies to ensure access to original file imports
|
||||||
|
|
||||||
// Add gRPC injection parameter
|
// Add gRPC injection parameter
|
||||||
provider.InjectParams["__grpc"] = InjectParam{
|
provider.InjectParams["__grpc"] = InjectParam{
|
||||||
Star: "*",
|
Star: "*",
|
||||||
|
|||||||
@@ -498,7 +498,7 @@ func (p *GoParser) parseProviderDecl(filePath string, fileNode *ast.File, decl a
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Handle special provider modes
|
// Handle special provider modes
|
||||||
p.handleProviderModes(provider, providerDoc.Mode)
|
p.handleProviderModes(provider, providerDoc.Mode, imports)
|
||||||
|
|
||||||
// Add source location if enabled
|
// Add source location if enabled
|
||||||
if p.config.SourceLocations {
|
if p.config.SourceLocations {
|
||||||
@@ -593,20 +593,41 @@ func (p *GoParser) parseFieldType(expr ast.Expr, imports map[string]string) (sta
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleProviderModes applies special handling for different provider modes
|
// handleProviderModes applies special handling for different provider modes
|
||||||
func (p *GoParser) handleProviderModes(provider *Provider, mode string) {
|
func (p *GoParser) handleProviderModes(provider *Provider, mode string, imports map[string]string) {
|
||||||
moduleName := gomod.GetModuleName()
|
moduleName := gomod.GetModuleName()
|
||||||
|
|
||||||
switch provider.Mode {
|
switch provider.Mode {
|
||||||
case ProviderModeGrpc:
|
case ProviderModeGrpc:
|
||||||
modePkg := moduleName + "/providers/grpc"
|
modePkg := moduleName + "/providers/grpc"
|
||||||
provider.ProviderGroup = "atom.GroupInitial"
|
provider.ProviderGroup = "atom.GroupInitial"
|
||||||
provider.GrpcRegisterFunc = provider.ReturnType
|
// Save the original return type before changing it
|
||||||
|
originalReturnType := provider.ReturnType
|
||||||
|
provider.GrpcRegisterFunc = originalReturnType
|
||||||
provider.ReturnType = "contracts.Initial"
|
provider.ReturnType = "contracts.Initial"
|
||||||
|
|
||||||
provider.Imports[atomPackage("")] = ""
|
provider.Imports[atomPackage("")] = ""
|
||||||
provider.Imports[atomPackage("contracts")] = ""
|
provider.Imports[atomPackage("contracts")] = ""
|
||||||
provider.Imports[modePkg] = ""
|
provider.Imports[modePkg] = ""
|
||||||
|
|
||||||
|
// Extract and add gRPC service package import
|
||||||
|
if originalReturnType != "" && strings.Contains(originalReturnType, ".") {
|
||||||
|
// Extract package alias from gRPC register function name (e.g., "userv1" from "userv1.RegisterUserServiceServer")
|
||||||
|
if pkgAlias := getTypePkgName(originalReturnType); pkgAlias != "" {
|
||||||
|
// Look for this package in the original file's imports
|
||||||
|
if importPath, exists := imports[pkgAlias]; exists {
|
||||||
|
// Use the exact import path from the original file
|
||||||
|
provider.Imports[importPath] = pkgAlias
|
||||||
|
} else {
|
||||||
|
// Fallback: try to infer the common pattern
|
||||||
|
if moduleName != "" {
|
||||||
|
// Common pattern: {module}/pkg/proto/{service}/v1
|
||||||
|
servicePkg := moduleName + "/pkg/proto/" + strings.ToLower(pkgAlias)
|
||||||
|
provider.Imports[servicePkg] = pkgAlias
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
provider.InjectParams["__grpc"] = InjectParam{
|
provider.InjectParams["__grpc"] = InjectParam{
|
||||||
Star: "*",
|
Star: "*",
|
||||||
Type: "Grpc",
|
Type: "Grpc",
|
||||||
|
|||||||
@@ -311,7 +311,23 @@ func Parse(source string) []Provider {
|
|||||||
provider.Imports[modePkg] = ""
|
provider.Imports[modePkg] = ""
|
||||||
|
|
||||||
provider.ProviderGroup = "atom.GroupInitial"
|
provider.ProviderGroup = "atom.GroupInitial"
|
||||||
provider.GrpcRegisterFunc = provider.ReturnType
|
|
||||||
|
// Handle gRPC register function correctly
|
||||||
|
if providerDoc.ReturnType != "" && strings.Contains(providerDoc.ReturnType, ".") {
|
||||||
|
// User specified a complete register function name, like userv1.RegisterUserServiceServer
|
||||||
|
provider.GrpcRegisterFunc = providerDoc.ReturnType
|
||||||
|
// Extract package information and add import
|
||||||
|
if pkgAlias := getTypePkgName(providerDoc.ReturnType); pkgAlias != "" {
|
||||||
|
if importPkg, ok := imports[pkgAlias]; ok {
|
||||||
|
provider.Imports[importPkg] = pkgAlias
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Generate default gRPC register function name
|
||||||
|
// Example: UserService -> RegisterUserServiceServer
|
||||||
|
provider.GrpcRegisterFunc = "Register" + strings.TrimPrefix(provider.ReturnType, "*") + "Server"
|
||||||
|
}
|
||||||
|
|
||||||
provider.ReturnType = "contracts.Initial"
|
provider.ReturnType = "contracts.Initial"
|
||||||
|
|
||||||
provider.InjectParams["__grpc"] = InjectParam{
|
provider.InjectParams["__grpc"] = InjectParam{
|
||||||
|
|||||||
Reference in New Issue
Block a user