Files
atomctl/cmd/gen_provider.go
2024-12-19 17:34:12 +08:00

437 lines
8.9 KiB
Go

package cmd
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"io/fs"
"math/rand"
"os"
"path/filepath"
"strings"
"text/template"
"git.ipao.vip/rogeecn/atomctl/pkg/utils/gomod"
"github.com/samber/lo"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.org/x/tools/imports"
)
func getTypePkgName(typ string) string {
if strings.Contains(typ, ".") {
return strings.Split(typ, ".")[0]
}
return ""
}
func CommandGenProvider(root *cobra.Command) {
cmd := &cobra.Command{
Use: "provider",
Short: "Generate providers",
Long: `
// @provider
// @provider:[except|only] [returnType] [group]
// when except add tag: inject:"false"
// when only add tag: inject:"true"
`,
RunE: commandGenProviderE,
}
root.AddCommand(cmd)
}
var scalarTypes = []string{
"float32",
"float64",
"int",
"int8",
"int16",
"int32",
"int64",
"uint",
"uint8",
"uint16",
"uint32",
"uint64",
"bool",
"uintptr",
"complex64",
"complex128",
}
func commandGenProviderE(cmd *cobra.Command, args []string) error {
var err error
var path string
if len(args) > 0 {
path = args[0]
} else {
path, err = os.Getwd()
if err != nil {
return err
}
}
path, _ = filepath.Abs(path)
err = gomod.Parse(filepath.Join(path, "go.mod"))
if err != nil {
return err
}
providers := []Provider{}
// if path is file, then get the dir
log.Infof("generate providers for dir: %s", path)
// travel controller to find all controller objects
_ = filepath.WalkDir(path, func(filepath string, d fs.DirEntry, err error) error {
if d.IsDir() {
return nil
}
if !strings.HasSuffix(filepath, ".go") {
return nil
}
if strings.HasSuffix(filepath, "_test.go") {
return nil
}
providers = append(providers, astParseProviders(filepath)...)
return nil
})
// generate files
groups := lo.GroupBy(providers, func(item Provider) string {
return item.ProviderFile
})
for file, conf := range groups {
if err := renderFile(file, conf); err != nil {
return err
}
}
return nil
}
type InjectParam struct {
Star string
Type string
Package string
PackageAlias string
}
type Provider struct {
StructName string
ReturnType string
ProviderGroup string
NeedPrepareFunc bool
InjectParams map[string]InjectParam
Imports map[string]string
PkgName string
ProviderFile string
}
func astParseProviders(source string) []Provider {
if strings.HasSuffix(source, "_test.go") {
return []Provider{}
}
if strings.HasSuffix(source, "/provider.go") {
return []Provider{}
}
providers := []Provider{}
fset := token.NewFileSet()
node, err := parser.ParseFile(fset, source, nil, parser.ParseComments)
if err != nil {
log.Error("ERR: ", err)
return nil
}
imports := make(map[string]string)
for _, imp := range node.Imports {
name := ""
pkgPath := strings.Trim(imp.Path.Value, "\"")
if imp.Name != nil {
// 如果有显式指定包名,直接使用
name = imp.Name.Name
} else {
// 尝试从go.mod中获取真实包名
name = gomod.GetPackageModuleName(pkgPath)
}
// 处理匿名导入的情况
if name == "_" {
name = gomod.GetPackageModuleName(pkgPath)
// 处理重名
if _, ok := imports[name]; ok {
name = fmt.Sprintf("%s%d", name, rand.Intn(100))
}
}
imports[name] = pkgPath
}
// 再去遍历 struct 的方法去
for _, decl := range node.Decls {
provider := Provider{
InjectParams: make(map[string]InjectParam),
Imports: make(map[string]string),
}
decl, ok := decl.(*ast.GenDecl)
if !ok {
continue
}
if len(decl.Specs) == 0 {
continue
}
declType, ok := decl.Specs[0].(*ast.TypeSpec)
if !ok {
continue
}
// 必须包含注释 // @provider:only/except
if decl.Doc == nil {
continue
}
if len(decl.Doc.List) == 0 {
continue
}
structType, ok := declType.Type.(*ast.StructType)
if !ok {
continue
}
provider.StructName = declType.Name.Name
docMark := strings.TrimLeft(decl.Doc.List[len(decl.Doc.List)-1].Text, "/ \t")
if !strings.HasPrefix(docMark, "@provider") {
continue
}
mode, returnType, group := parseDoc(docMark)
if group != "" {
provider.ProviderGroup = group
}
// log.Infof("mode: %s, returnType: %s, group: %s", mode, returnType, group)
if returnType == "#" {
provider.ReturnType = "*" + provider.StructName
} else {
provider.ReturnType = returnType
}
onlyMode := mode == "only"
exceptMode := mode == "except"
log.Infof("[%s] %s => ONLY: %+v, EXCEPT: %+v, Type: %s, Group: %s", source, declType.Name.Name, onlyMode, exceptMode, provider.ReturnType, provider.ProviderGroup)
for _, field := range structType.Fields.List {
if field.Names == nil {
continue
}
if field.Tag != nil {
provider.NeedPrepareFunc = true
}
if onlyMode {
if field.Tag == nil || !strings.Contains(field.Tag.Value, `inject:"true"`) {
continue
}
}
if exceptMode {
if field.Tag != nil && strings.Contains(field.Tag.Value, `inject:"false"`) {
continue
}
}
var star string
var pkg string
var pkgAlias string
var typ string
switch field.Type.(type) {
case *ast.Ident:
typ = field.Type.(*ast.Ident).Name
case *ast.StarExpr:
star = "*"
paramsType := field.Type.(*ast.StarExpr)
switch paramsType.X.(type) {
case *ast.SelectorExpr:
X := paramsType.X.(*ast.SelectorExpr)
pkgAlias = X.X.(*ast.Ident).Name
p, ok := imports[pkgAlias]
if !ok {
continue
}
pkg = p
typ = X.Sel.Name
default:
typ = paramsType.X.(*ast.Ident).Name
}
case *ast.SelectorExpr:
pkgAlias = field.Type.(*ast.SelectorExpr).X.(*ast.Ident).Name
p, ok := imports[pkgAlias]
if !ok {
continue
}
pkg = p
typ = field.Type.(*ast.SelectorExpr).Sel.Name
}
if lo.Contains(scalarTypes, typ) {
continue
}
for _, name := range field.Names {
provider.InjectParams[name.Name] = InjectParam{
Star: star,
Type: typ,
Package: pkg,
PackageAlias: pkgAlias,
}
}
if importPkg, ok := imports[pkgAlias]; ok {
provider.Imports[importPkg] = pkgAlias
}
}
if pkgAlias := getTypePkgName(provider.ReturnType); pkgAlias != "" {
if importPkg, ok := imports[pkgAlias]; ok {
provider.Imports[importPkg] = pkgAlias
}
}
if pkgAlias := getTypePkgName(provider.ProviderGroup); pkgAlias != "" {
if importPkg, ok := imports[pkgAlias]; ok {
provider.Imports[importPkg] = pkgAlias
}
}
provider.PkgName = node.Name.Name
provider.ProviderFile = filepath.Join(filepath.Dir(source), "provider.gen.go")
providers = append(providers, provider)
}
return providers
}
func parseDoc(doc string) (string, string, string) {
// @provider:[except|only] [returnType] [group]
doc = strings.TrimLeft(doc[len("@provider"):], ":")
if !strings.HasPrefix(doc, "except") && !strings.HasPrefix(doc, "only") {
doc = "except " + doc
}
doc = strings.ReplaceAll(doc, "\t", " ")
cmds := strings.Split(doc, " ")
cmds = lo.Filter(cmds, func(item string, idx int) bool {
return strings.TrimSpace(item) != ""
})
if len(cmds) == 0 {
return "except", "#", ""
}
if len(cmds) == 1 {
return cmds[0], "#", ""
}
if len(cmds) == 2 {
return cmds[0], cmds[1], ""
}
return cmds[0], cmds[1], cmds[2]
}
func renderFile(filename string, conf []Provider) error {
defer func() {
result, err := imports.Process(filename, nil, nil)
if err == nil {
os.WriteFile(filename, result, os.ModePerm)
}
}()
imports := map[string]string{
"git.ipao.vip/rogeecn/atom/container": "",
"git.ipao.vip/rogeecn/atom/utils/opt": "",
}
lo.ForEach(conf, func(item Provider, _ int) {
for k, v := range item.Imports {
// 如果是当前包的引用,直接使用包名
if strings.HasSuffix(k, "/"+v) {
v = ""
}
imports[k] = v
}
})
tmpl := `package {{.PkgName}}
import (
{{- range $pkg, $alias := .Imports }}
{{- if eq $alias "" }}
"{{$pkg}}"
{{- else }}
{{$alias}} "{{$pkg}}"
{{- end }}
{{- end }}
)
func Provide(opts ...opt.Option) error {
{{- range .Providers }}
if err := container.Container.Provide(func(
{{- range $key, $param := .InjectParams }}
{{$key}} {{$param.Star}}{{if eq $param.Package ""}}{{$param.Type}}{{else}}{{$param.PackageAlias}}.{{$param.Type}}{{end}},
{{- end }}
) ({{.ReturnType}}, error) {
obj := &{{.StructName}}{
{{- range $key, $param := .InjectParams }}
{{$key}}: {{$key}},
{{- end }}
}
{{- if .NeedPrepareFunc }}
if err := obj.Prepare(); err != nil {
return nil, err
}
{{- end }}
return obj, nil
}{{if .ProviderGroup}}, {{.ProviderGroup}}{{end}}); err != nil {
return err
}
{{- end }}
return nil
}
`
t := template.Must(template.New("provider").Parse(tmpl))
data := struct {
PkgName string
Imports map[string]string
Providers []Provider
}{
PkgName: conf[0].PkgName,
Imports: imports,
Providers: conf,
}
fd, err := os.OpenFile(filename, os.O_CREATE|os.O_TRUNC|os.O_RDWR, os.ModePerm)
if err != nil {
return err
}
defer fd.Close()
return t.Execute(fd, data)
}