feat: update render provider
This commit is contained in:
@@ -1,31 +1,18 @@
|
|||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"go/ast"
|
|
||||||
"go/parser"
|
|
||||||
"go/token"
|
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"math/rand"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
|
||||||
|
|
||||||
|
"git.ipao.vip/rogeecn/atomctl/pkg/ast/provider"
|
||||||
"git.ipao.vip/rogeecn/atomctl/pkg/utils/gomod"
|
"git.ipao.vip/rogeecn/atomctl/pkg/utils/gomod"
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"golang.org/x/tools/imports"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func getTypePkgName(typ string) string {
|
|
||||||
if strings.Contains(typ, ".") {
|
|
||||||
return strings.Split(typ, ".")[0]
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func CommandGenProvider(root *cobra.Command) {
|
func CommandGenProvider(root *cobra.Command) {
|
||||||
cmd := &cobra.Command{
|
cmd := &cobra.Command{
|
||||||
Use: "provider",
|
Use: "provider",
|
||||||
@@ -43,25 +30,6 @@ func CommandGenProvider(root *cobra.Command) {
|
|||||||
root.AddCommand(cmd)
|
root.AddCommand(cmd)
|
||||||
}
|
}
|
||||||
|
|
||||||
var scalarTypes = []string{
|
|
||||||
"float32",
|
|
||||||
"float64",
|
|
||||||
"int",
|
|
||||||
"int8",
|
|
||||||
"int16",
|
|
||||||
"int32",
|
|
||||||
"int64",
|
|
||||||
"uint",
|
|
||||||
"uint8",
|
|
||||||
"uint16",
|
|
||||||
"uint32",
|
|
||||||
"uint64",
|
|
||||||
"bool",
|
|
||||||
"uintptr",
|
|
||||||
"complex64",
|
|
||||||
"complex128",
|
|
||||||
}
|
|
||||||
|
|
||||||
func commandGenProviderE(cmd *cobra.Command, args []string) error {
|
func commandGenProviderE(cmd *cobra.Command, args []string) error {
|
||||||
var err error
|
var err error
|
||||||
var path string
|
var path string
|
||||||
@@ -81,7 +49,7 @@ func commandGenProviderE(cmd *cobra.Command, args []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
providers := []Provider{}
|
providers := []provider.Provider{}
|
||||||
|
|
||||||
// if path is file, then get the dir
|
// if path is file, then get the dir
|
||||||
log.Infof("generate providers for dir: %s", path)
|
log.Infof("generate providers for dir: %s", path)
|
||||||
@@ -99,346 +67,19 @@ func commandGenProviderE(cmd *cobra.Command, args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
providers = append(providers, astParseProviders(filepath)...)
|
providers = append(providers, provider.Parse(filepath)...)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
// generate files
|
// generate files
|
||||||
groups := lo.GroupBy(providers, func(item Provider) string {
|
groups := lo.GroupBy(providers, func(item provider.Provider) string {
|
||||||
return item.ProviderFile
|
return item.ProviderFile
|
||||||
})
|
})
|
||||||
|
|
||||||
for file, conf := range groups {
|
for file, conf := range groups {
|
||||||
if err := renderFile(file, conf); err != nil {
|
if err := provider.Render(file, conf); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type InjectParam struct {
|
|
||||||
Star string
|
|
||||||
Type string
|
|
||||||
Package string
|
|
||||||
PackageAlias string
|
|
||||||
}
|
|
||||||
type Provider struct {
|
|
||||||
StructName string
|
|
||||||
ReturnType string
|
|
||||||
ProviderGroup string
|
|
||||||
NeedPrepareFunc bool
|
|
||||||
InjectParams map[string]InjectParam
|
|
||||||
Imports map[string]string
|
|
||||||
PkgName string
|
|
||||||
ProviderFile string
|
|
||||||
}
|
|
||||||
|
|
||||||
func astParseProviders(source string) []Provider {
|
|
||||||
if strings.HasSuffix(source, "_test.go") {
|
|
||||||
return []Provider{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasSuffix(source, "/provider.go") {
|
|
||||||
return []Provider{}
|
|
||||||
}
|
|
||||||
|
|
||||||
providers := []Provider{}
|
|
||||||
|
|
||||||
fset := token.NewFileSet()
|
|
||||||
node, err := parser.ParseFile(fset, source, nil, parser.ParseComments)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("ERR: ", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
imports := make(map[string]string)
|
|
||||||
for _, imp := range node.Imports {
|
|
||||||
name := ""
|
|
||||||
pkgPath := strings.Trim(imp.Path.Value, "\"")
|
|
||||||
|
|
||||||
if imp.Name != nil {
|
|
||||||
// 如果有显式指定包名,直接使用
|
|
||||||
name = imp.Name.Name
|
|
||||||
} else {
|
|
||||||
// 尝试从go.mod中获取真实包名
|
|
||||||
name = gomod.GetPackageModuleName(pkgPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 处理匿名导入的情况
|
|
||||||
if name == "_" {
|
|
||||||
name = gomod.GetPackageModuleName(pkgPath)
|
|
||||||
|
|
||||||
// 处理重名
|
|
||||||
if _, ok := imports[name]; ok {
|
|
||||||
name = fmt.Sprintf("%s%d", name, rand.Intn(100))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
imports[name] = pkgPath
|
|
||||||
}
|
|
||||||
|
|
||||||
// 再去遍历 struct 的方法去
|
|
||||||
for _, decl := range node.Decls {
|
|
||||||
provider := Provider{
|
|
||||||
InjectParams: make(map[string]InjectParam),
|
|
||||||
Imports: make(map[string]string),
|
|
||||||
}
|
|
||||||
|
|
||||||
decl, ok := decl.(*ast.GenDecl)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(decl.Specs) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
declType, ok := decl.Specs[0].(*ast.TypeSpec)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 必须包含注释 // @provider:only/except
|
|
||||||
if decl.Doc == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(decl.Doc.List) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
structType, ok := declType.Type.(*ast.StructType)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
provider.StructName = declType.Name.Name
|
|
||||||
|
|
||||||
docMark := strings.TrimLeft(decl.Doc.List[len(decl.Doc.List)-1].Text, "/ \t")
|
|
||||||
if !strings.HasPrefix(docMark, "@provider") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
mode, returnType, group := parseDoc(docMark)
|
|
||||||
if group != "" {
|
|
||||||
provider.ProviderGroup = group
|
|
||||||
}
|
|
||||||
// log.Infof("mode: %s, returnType: %s, group: %s", mode, returnType, group)
|
|
||||||
|
|
||||||
if returnType == "#" {
|
|
||||||
provider.ReturnType = "*" + provider.StructName
|
|
||||||
} else {
|
|
||||||
provider.ReturnType = returnType
|
|
||||||
}
|
|
||||||
onlyMode := mode == "only"
|
|
||||||
exceptMode := mode == "except"
|
|
||||||
log.Infof("[%s] %s => ONLY: %+v, EXCEPT: %+v, Type: %s, Group: %s", source, declType.Name.Name, onlyMode, exceptMode, provider.ReturnType, provider.ProviderGroup)
|
|
||||||
|
|
||||||
for _, field := range structType.Fields.List {
|
|
||||||
if field.Names == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if field.Tag != nil {
|
|
||||||
provider.NeedPrepareFunc = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if onlyMode {
|
|
||||||
if field.Tag == nil || !strings.Contains(field.Tag.Value, `inject:"true"`) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if exceptMode {
|
|
||||||
if field.Tag != nil && strings.Contains(field.Tag.Value, `inject:"false"`) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var star string
|
|
||||||
var pkg string
|
|
||||||
var pkgAlias string
|
|
||||||
var typ string
|
|
||||||
switch field.Type.(type) {
|
|
||||||
case *ast.Ident:
|
|
||||||
typ = field.Type.(*ast.Ident).Name
|
|
||||||
case *ast.StarExpr:
|
|
||||||
star = "*"
|
|
||||||
paramsType := field.Type.(*ast.StarExpr)
|
|
||||||
switch paramsType.X.(type) {
|
|
||||||
case *ast.SelectorExpr:
|
|
||||||
X := paramsType.X.(*ast.SelectorExpr)
|
|
||||||
|
|
||||||
pkgAlias = X.X.(*ast.Ident).Name
|
|
||||||
p, ok := imports[pkgAlias]
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
pkg = p
|
|
||||||
|
|
||||||
typ = X.Sel.Name
|
|
||||||
default:
|
|
||||||
typ = paramsType.X.(*ast.Ident).Name
|
|
||||||
}
|
|
||||||
case *ast.SelectorExpr:
|
|
||||||
pkgAlias = field.Type.(*ast.SelectorExpr).X.(*ast.Ident).Name
|
|
||||||
p, ok := imports[pkgAlias]
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
pkg = p
|
|
||||||
typ = field.Type.(*ast.SelectorExpr).Sel.Name
|
|
||||||
}
|
|
||||||
|
|
||||||
if lo.Contains(scalarTypes, typ) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, name := range field.Names {
|
|
||||||
provider.InjectParams[name.Name] = InjectParam{
|
|
||||||
Star: star,
|
|
||||||
Type: typ,
|
|
||||||
Package: pkg,
|
|
||||||
PackageAlias: pkgAlias,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if importPkg, ok := imports[pkgAlias]; ok {
|
|
||||||
provider.Imports[importPkg] = pkgAlias
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if pkgAlias := getTypePkgName(provider.ReturnType); pkgAlias != "" {
|
|
||||||
if importPkg, ok := imports[pkgAlias]; ok {
|
|
||||||
provider.Imports[importPkg] = pkgAlias
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if pkgAlias := getTypePkgName(provider.ProviderGroup); pkgAlias != "" {
|
|
||||||
if importPkg, ok := imports[pkgAlias]; ok {
|
|
||||||
provider.Imports[importPkg] = pkgAlias
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
provider.PkgName = node.Name.Name
|
|
||||||
provider.ProviderFile = filepath.Join(filepath.Dir(source), "provider.gen.go")
|
|
||||||
|
|
||||||
providers = append(providers, provider)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return providers
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseDoc(doc string) (string, string, string) {
|
|
||||||
// @provider:[except|only] [returnType] [group]
|
|
||||||
doc = strings.TrimLeft(doc[len("@provider"):], ":")
|
|
||||||
if !strings.HasPrefix(doc, "except") && !strings.HasPrefix(doc, "only") {
|
|
||||||
doc = "except " + doc
|
|
||||||
}
|
|
||||||
|
|
||||||
doc = strings.ReplaceAll(doc, "\t", " ")
|
|
||||||
cmds := strings.Split(doc, " ")
|
|
||||||
cmds = lo.Filter(cmds, func(item string, idx int) bool {
|
|
||||||
return strings.TrimSpace(item) != ""
|
|
||||||
})
|
|
||||||
|
|
||||||
if len(cmds) == 0 {
|
|
||||||
return "except", "#", ""
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(cmds) == 1 {
|
|
||||||
return cmds[0], "#", ""
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(cmds) == 2 {
|
|
||||||
return cmds[0], cmds[1], ""
|
|
||||||
}
|
|
||||||
|
|
||||||
return cmds[0], cmds[1], cmds[2]
|
|
||||||
}
|
|
||||||
|
|
||||||
func renderFile(filename string, conf []Provider) error {
|
|
||||||
defer func() {
|
|
||||||
result, err := imports.Process(filename, nil, nil)
|
|
||||||
if err == nil {
|
|
||||||
os.WriteFile(filename, result, os.ModePerm)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
imports := map[string]string{
|
|
||||||
"git.ipao.vip/rogeecn/atom/container": "",
|
|
||||||
"git.ipao.vip/rogeecn/atom/utils/opt": "",
|
|
||||||
}
|
|
||||||
lo.ForEach(conf, func(item Provider, _ int) {
|
|
||||||
for k, v := range item.Imports {
|
|
||||||
// 如果是当前包的引用,直接使用包名
|
|
||||||
if strings.HasSuffix(k, "/"+v) {
|
|
||||||
imports[k] = ""
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if gomod.GetPackageModuleName(k) == v {
|
|
||||||
imports[k] = ""
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
imports[k] = v
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
tmpl := `package {{.PkgName}}
|
|
||||||
|
|
||||||
import (
|
|
||||||
{{- range $pkg, $alias := .Imports }}
|
|
||||||
{{- if eq $alias "" }}
|
|
||||||
"{{$pkg}}"
|
|
||||||
{{- else }}
|
|
||||||
{{$alias}} "{{$pkg}}"
|
|
||||||
{{- end }}
|
|
||||||
{{- end }}
|
|
||||||
)
|
|
||||||
|
|
||||||
func Provide(opts ...opt.Option) error {
|
|
||||||
{{- range .Providers }}
|
|
||||||
if err := container.Container.Provide(func(
|
|
||||||
{{- range $key, $param := .InjectParams }}
|
|
||||||
{{$key}} {{$param.Star}}{{if eq $param.Package ""}}{{$param.Type}}{{else}}{{$param.PackageAlias}}.{{$param.Type}}{{end}},
|
|
||||||
{{- end }}
|
|
||||||
) ({{.ReturnType}}, error) {
|
|
||||||
obj := &{{.StructName}}{
|
|
||||||
{{- range $key, $param := .InjectParams }}
|
|
||||||
{{$key}}: {{$key}},
|
|
||||||
{{- end }}
|
|
||||||
}
|
|
||||||
{{- if .NeedPrepareFunc }}
|
|
||||||
if err := obj.Prepare(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
{{- end }}
|
|
||||||
return obj, nil
|
|
||||||
}{{if .ProviderGroup}}, {{.ProviderGroup}}{{end}}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
{{- end }}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
`
|
|
||||||
|
|
||||||
t := template.Must(template.New("provider").Parse(tmpl))
|
|
||||||
|
|
||||||
data := struct {
|
|
||||||
PkgName string
|
|
||||||
Imports map[string]string
|
|
||||||
Providers []Provider
|
|
||||||
}{
|
|
||||||
PkgName: conf[0].PkgName,
|
|
||||||
Imports: imports,
|
|
||||||
Providers: conf,
|
|
||||||
}
|
|
||||||
|
|
||||||
fd, err := os.OpenFile(filename, os.O_CREATE|os.O_TRUNC|os.O_RDWR, os.ModePerm)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer fd.Close()
|
|
||||||
|
|
||||||
return t.Execute(fd, data)
|
|
||||||
}
|
|
||||||
|
|||||||
280
pkg/ast/provider/provider.go
Normal file
280
pkg/ast/provider/provider.go
Normal file
@@ -0,0 +1,280 @@
|
|||||||
|
package provider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"go/ast"
|
||||||
|
"go/parser"
|
||||||
|
"go/token"
|
||||||
|
"math/rand"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.ipao.vip/rogeecn/atomctl/pkg/utils/gomod"
|
||||||
|
"github.com/samber/lo"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getTypePkgName(typ string) string {
|
||||||
|
if strings.Contains(typ, ".") {
|
||||||
|
return strings.Split(typ, ".")[0]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var scalarTypes = []string{
|
||||||
|
"float32",
|
||||||
|
"float64",
|
||||||
|
"int",
|
||||||
|
"int8",
|
||||||
|
"int16",
|
||||||
|
"int32",
|
||||||
|
"int64",
|
||||||
|
"uint",
|
||||||
|
"uint8",
|
||||||
|
"uint16",
|
||||||
|
"uint32",
|
||||||
|
"uint64",
|
||||||
|
"bool",
|
||||||
|
"uintptr",
|
||||||
|
"complex64",
|
||||||
|
"complex128",
|
||||||
|
}
|
||||||
|
|
||||||
|
type InjectParam struct {
|
||||||
|
Star string
|
||||||
|
Type string
|
||||||
|
Package string
|
||||||
|
PackageAlias string
|
||||||
|
}
|
||||||
|
type Provider struct {
|
||||||
|
StructName string
|
||||||
|
ReturnType string
|
||||||
|
ProviderGroup string
|
||||||
|
NeedPrepareFunc bool
|
||||||
|
InjectParams map[string]InjectParam
|
||||||
|
Imports map[string]string
|
||||||
|
PkgName string
|
||||||
|
ProviderFile string
|
||||||
|
}
|
||||||
|
|
||||||
|
func Parse(source string) []Provider {
|
||||||
|
if strings.HasSuffix(source, "_test.go") {
|
||||||
|
return []Provider{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasSuffix(source, "/provider.go") {
|
||||||
|
return []Provider{}
|
||||||
|
}
|
||||||
|
|
||||||
|
providers := []Provider{}
|
||||||
|
|
||||||
|
fset := token.NewFileSet()
|
||||||
|
node, err := parser.ParseFile(fset, source, nil, parser.ParseComments)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("ERR: ", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
imports := make(map[string]string)
|
||||||
|
for _, imp := range node.Imports {
|
||||||
|
name := ""
|
||||||
|
pkgPath := strings.Trim(imp.Path.Value, "\"")
|
||||||
|
|
||||||
|
if imp.Name != nil {
|
||||||
|
// 如果有显式指定包名,直接使用
|
||||||
|
name = imp.Name.Name
|
||||||
|
} else {
|
||||||
|
// 尝试从go.mod中获取真实包名
|
||||||
|
name = gomod.GetPackageModuleName(pkgPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理匿名导入的情况
|
||||||
|
if name == "_" {
|
||||||
|
name = gomod.GetPackageModuleName(pkgPath)
|
||||||
|
|
||||||
|
// 处理重名
|
||||||
|
if _, ok := imports[name]; ok {
|
||||||
|
name = fmt.Sprintf("%s%d", name, rand.Intn(100))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
imports[name] = pkgPath
|
||||||
|
}
|
||||||
|
|
||||||
|
// 再去遍历 struct 的方法去
|
||||||
|
for _, decl := range node.Decls {
|
||||||
|
provider := Provider{
|
||||||
|
InjectParams: make(map[string]InjectParam),
|
||||||
|
Imports: make(map[string]string),
|
||||||
|
}
|
||||||
|
|
||||||
|
decl, ok := decl.(*ast.GenDecl)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(decl.Specs) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
declType, ok := decl.Specs[0].(*ast.TypeSpec)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 必须包含注释 // @provider:only/except
|
||||||
|
if decl.Doc == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(decl.Doc.List) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
structType, ok := declType.Type.(*ast.StructType)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
provider.StructName = declType.Name.Name
|
||||||
|
|
||||||
|
docMark := strings.TrimLeft(decl.Doc.List[len(decl.Doc.List)-1].Text, "/ \t")
|
||||||
|
if !strings.HasPrefix(docMark, "@provider") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
mode, returnType, group := parseDoc(docMark)
|
||||||
|
if group != "" {
|
||||||
|
provider.ProviderGroup = group
|
||||||
|
}
|
||||||
|
// log.Infof("mode: %s, returnType: %s, group: %s", mode, returnType, group)
|
||||||
|
|
||||||
|
if returnType == "#" {
|
||||||
|
provider.ReturnType = "*" + provider.StructName
|
||||||
|
} else {
|
||||||
|
provider.ReturnType = returnType
|
||||||
|
}
|
||||||
|
onlyMode := mode == "only"
|
||||||
|
exceptMode := mode == "except"
|
||||||
|
log.Infof("[%s] %s => ONLY: %+v, EXCEPT: %+v, Type: %s, Group: %s", source, declType.Name.Name, onlyMode, exceptMode, provider.ReturnType, provider.ProviderGroup)
|
||||||
|
|
||||||
|
for _, field := range structType.Fields.List {
|
||||||
|
if field.Names == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if field.Tag != nil {
|
||||||
|
provider.NeedPrepareFunc = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if onlyMode {
|
||||||
|
if field.Tag == nil || !strings.Contains(field.Tag.Value, `inject:"true"`) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if exceptMode {
|
||||||
|
if field.Tag != nil && strings.Contains(field.Tag.Value, `inject:"false"`) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var star string
|
||||||
|
var pkg string
|
||||||
|
var pkgAlias string
|
||||||
|
var typ string
|
||||||
|
switch field.Type.(type) {
|
||||||
|
case *ast.Ident:
|
||||||
|
typ = field.Type.(*ast.Ident).Name
|
||||||
|
case *ast.StarExpr:
|
||||||
|
star = "*"
|
||||||
|
paramsType := field.Type.(*ast.StarExpr)
|
||||||
|
switch paramsType.X.(type) {
|
||||||
|
case *ast.SelectorExpr:
|
||||||
|
X := paramsType.X.(*ast.SelectorExpr)
|
||||||
|
|
||||||
|
pkgAlias = X.X.(*ast.Ident).Name
|
||||||
|
p, ok := imports[pkgAlias]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pkg = p
|
||||||
|
|
||||||
|
typ = X.Sel.Name
|
||||||
|
default:
|
||||||
|
typ = paramsType.X.(*ast.Ident).Name
|
||||||
|
}
|
||||||
|
case *ast.SelectorExpr:
|
||||||
|
pkgAlias = field.Type.(*ast.SelectorExpr).X.(*ast.Ident).Name
|
||||||
|
p, ok := imports[pkgAlias]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pkg = p
|
||||||
|
typ = field.Type.(*ast.SelectorExpr).Sel.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
if lo.Contains(scalarTypes, typ) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range field.Names {
|
||||||
|
provider.InjectParams[name.Name] = InjectParam{
|
||||||
|
Star: star,
|
||||||
|
Type: typ,
|
||||||
|
Package: pkg,
|
||||||
|
PackageAlias: pkgAlias,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if importPkg, ok := imports[pkgAlias]; ok {
|
||||||
|
provider.Imports[importPkg] = pkgAlias
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if pkgAlias := getTypePkgName(provider.ReturnType); pkgAlias != "" {
|
||||||
|
if importPkg, ok := imports[pkgAlias]; ok {
|
||||||
|
provider.Imports[importPkg] = pkgAlias
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if pkgAlias := getTypePkgName(provider.ProviderGroup); pkgAlias != "" {
|
||||||
|
if importPkg, ok := imports[pkgAlias]; ok {
|
||||||
|
provider.Imports[importPkg] = pkgAlias
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
provider.PkgName = node.Name.Name
|
||||||
|
provider.ProviderFile = filepath.Join(filepath.Dir(source), "provider.gen.go")
|
||||||
|
|
||||||
|
providers = append(providers, provider)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return providers
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseDoc(doc string) (string, string, string) {
|
||||||
|
// @provider:[except|only] [returnType] [group]
|
||||||
|
doc = strings.TrimLeft(doc[len("@provider"):], ":")
|
||||||
|
if !strings.HasPrefix(doc, "except") && !strings.HasPrefix(doc, "only") {
|
||||||
|
doc = "except " + doc
|
||||||
|
}
|
||||||
|
|
||||||
|
doc = strings.ReplaceAll(doc, "\t", " ")
|
||||||
|
cmds := strings.Split(doc, " ")
|
||||||
|
cmds = lo.Filter(cmds, func(item string, idx int) bool {
|
||||||
|
return strings.TrimSpace(item) != ""
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(cmds) == 0 {
|
||||||
|
return "except", "#", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cmds) == 1 {
|
||||||
|
return cmds[0], "#", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cmds) == 2 {
|
||||||
|
return cmds[0], cmds[1], ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return cmds[0], cmds[1], cmds[2]
|
||||||
|
}
|
||||||
36
pkg/ast/provider/provider.go.tpl
Normal file
36
pkg/ast/provider/provider.go.tpl
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package {{.PkgName}}
|
||||||
|
|
||||||
|
import (
|
||||||
|
{{- range $pkg, $alias := .Imports }}
|
||||||
|
{{- if eq $alias "" }}
|
||||||
|
"{{$pkg}}"
|
||||||
|
{{- else }}
|
||||||
|
{{$alias}} "{{$pkg}}"
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
||||||
|
)
|
||||||
|
|
||||||
|
func Provide(opts ...opt.Option) error {
|
||||||
|
{{- range .Providers }}
|
||||||
|
if err := container.Container.Provide(func(
|
||||||
|
{{- range $key, $param := .InjectParams }}
|
||||||
|
{{$key}} {{$param.Star}}{{if eq $param.Package ""}}{{$param.Type}}{{else}}{{$param.PackageAlias}}.{{$param.Type}}{{end}},
|
||||||
|
{{- end }}
|
||||||
|
) ({{.ReturnType}}, error) {
|
||||||
|
obj := &{{.StructName}}{
|
||||||
|
{{- range $key, $param := .InjectParams }}
|
||||||
|
{{$key}}: {{$key}},
|
||||||
|
{{- end }}
|
||||||
|
}
|
||||||
|
{{- if .NeedPrepareFunc }}
|
||||||
|
if err := obj.Prepare(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
{{- end }}
|
||||||
|
return obj, nil
|
||||||
|
}{{if .ProviderGroup}}, {{.ProviderGroup}}{{end}}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
{{- end }}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
65
pkg/ast/provider/render.go
Normal file
65
pkg/ast/provider/render.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
package provider
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "embed"
|
||||||
|
"html/template"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.ipao.vip/rogeecn/atomctl/pkg/utils/gomod"
|
||||||
|
"github.com/samber/lo"
|
||||||
|
"golang.org/x/tools/imports"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed provider.go.tpl
|
||||||
|
var providerTpl string
|
||||||
|
|
||||||
|
func Render(filename string, conf []Provider) error {
|
||||||
|
defer func() {
|
||||||
|
result, err := imports.Process(filename, nil, nil)
|
||||||
|
if err == nil {
|
||||||
|
os.WriteFile(filename, result, os.ModePerm)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
imports := map[string]string{
|
||||||
|
"git.ipao.vip/rogeecn/atom/container": "",
|
||||||
|
"git.ipao.vip/rogeecn/atom/utils/opt": "",
|
||||||
|
}
|
||||||
|
lo.ForEach(conf, func(item Provider, _ int) {
|
||||||
|
for k, v := range item.Imports {
|
||||||
|
// 如果是当前包的引用,直接使用包名
|
||||||
|
if strings.HasSuffix(k, "/"+v) {
|
||||||
|
imports[k] = ""
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if gomod.GetPackageModuleName(k) == v {
|
||||||
|
imports[k] = ""
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
imports[k] = v
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t := template.Must(template.New("provider").Parse(providerTpl))
|
||||||
|
|
||||||
|
data := struct {
|
||||||
|
PkgName string
|
||||||
|
Imports map[string]string
|
||||||
|
Providers []Provider
|
||||||
|
}{
|
||||||
|
PkgName: conf[0].PkgName,
|
||||||
|
Imports: imports,
|
||||||
|
Providers: conf,
|
||||||
|
}
|
||||||
|
|
||||||
|
fd, err := os.OpenFile(filename, os.O_CREATE|os.O_TRUNC|os.O_RDWR, os.ModePerm)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer fd.Close()
|
||||||
|
|
||||||
|
return t.Execute(fd, data)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user