diff --git a/cmd/gen_model.go b/cmd/gen_model.go index b0161f4..e8c6f72 100644 --- a/cmd/gen_model.go +++ b/cmd/gen_model.go @@ -6,6 +6,8 @@ import ( "regexp" "strings" + astModel "go.ipao.vip/atomctl/pkg/ast/model" + "github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/generator/postgres" "github.com/go-jet/jet/v2/generator/template" @@ -18,7 +20,6 @@ import ( "github.com/spf13/cobra" "github.com/spf13/viper" "go.ipao.vip/atomctl/pkg/ast/model" - astModel "go.ipao.vip/atomctl/pkg/ast/model" pgDatabase "go.ipao.vip/atomctl/pkg/postgres" "go.ipao.vip/atomctl/pkg/utils/gomod" ) @@ -99,6 +100,14 @@ func commandGenModelE(cmd *cobra.Command, args []string) error { UseModel( template. DefaultModel(). + UseEnum(func(meta metadata.Enum) template.EnumModel { + enum := template.DefaultEnumModel(meta) + if lo.Contains(transformer.Ignores.Jet, meta.Name) { + enum.Skip = true + log.Infof("Skip enum %s", meta.Name) + } + return enum + }). UseTable(func(table metadata.Table) template.TableModel { tbl := template.DefaultTableModel(table) if lo.Contains(transformer.Ignores.Jet, table.Name) { @@ -165,5 +174,9 @@ func commandGenModelE(cmd *cobra.Command, args []string) error { if err := os.Rename(dataPath, "database/schemas"); err != nil { return err } - return astModel.Generate(generatedTables, transformer) + + if err := astModel.Generate(generatedTables, transformer); err != nil { + return err + } + return nil } diff --git a/pkg/ast/model/generage.go b/pkg/ast/model/generage.go index a49de52..a14cfbe 100644 --- a/pkg/ast/model/generage.go +++ b/pkg/ast/model/generage.go @@ -4,12 +4,12 @@ import ( _ "embed" "fmt" "html/template" - "log" "os" "path/filepath" "strings" "github.com/samber/lo" + log "github.com/sirupsen/logrus" "go.ipao.vip/atomctl/pkg/utils/gomod" ) @@ -19,8 +19,8 @@ var tableTpl string //go:embed table_test.go.tpl var tableTestTpl string -//go:embed models.gen.go.tpl -var modelTpl string +//go:embed provider.gen.go.tpl +var providerTplStr string type TableModelParam struct { PkgName string @@ -29,11 +29,50 @@ type TableModelParam struct { } func Generate(tables []string, transformer Transformer) error { - baseDir := "app/models" + baseDir := "app/model" + modelDir := "database/schemas/public/model" + // move database/schemas/public/model files to app/model + + // remove all files in app/model with ext .gen.go + files, err := os.ReadDir(baseDir) + if err != nil { + return err + } + + for _, file := range files { + if strings.HasSuffix(file.Name(), ".gen.go") { + if err := os.RemoveAll(filepath.Join(baseDir, file.Name())); err != nil { + return err + } + } + } + + // move files remove ext .go to .gen.go + files, err = os.ReadDir(modelDir) + if err != nil { + return err + } + + for _, file := range files { + // get filename without ext + name := strings.TrimSuffix(file.Name(), filepath.Ext(file.Name())) + + from := filepath.Join(modelDir, file.Name()) + to := filepath.Join(baseDir, name+".gen.go") + log.Infof("Move %s to %s", from, to) + if err := os.Rename(from, to); err != nil { + return err + } + } + + // remove database/schemas/public/model + if err := os.RemoveAll(modelDir); err != nil { + return err + } tableTpl := template.Must(template.New("model").Parse(string(tableTpl))) tableTestTpl := template.Must(template.New("model").Parse(string(tableTestTpl))) - modelTpl := template.Must(template.New("modelGen").Parse(string(modelTpl))) + providerTpl := template.Must(template.New("modelGen").Parse(string(providerTplStr))) items := []TableModelParam{} for _, table := range tables { @@ -86,51 +125,53 @@ func Generate(tables []string, transformer Transformer) error { } } - // 遍历 baseDir 下的所有文件,将不在 tables 中的文件名(不带扩展名)加入 - files, err := os.ReadDir(baseDir) + // 渲染总的 provider 文件 + providerFile := fmt.Sprintf("%s/provider.gen.go", baseDir) + os.Remove(providerFile) + fd, err := os.Create(providerFile) if err != nil { - return fmt.Errorf("遍历目录 %s 失败: %w", baseDir, err) - } - for _, file := range files { - if file.IsDir() { - continue - } - name := file.Name() - if strings.HasSuffix(name, ".gen.go") { - continue - } - - if strings.HasSuffix(name, "_test.go") { - continue - } - - baseName := strings.TrimSuffix(name, filepath.Ext(name)) - if lo.Contains(transformer.Ignores.Model, baseName) { - log.Printf("[WARN] skip model %s\n", baseName) - continue - } - - if !lo.Contains(tables, baseName) { - items = append(items, TableModelParam{ - CamelTable: lo.CamelCase(baseName), - PascalTable: lo.PascalCase(baseName), - }) - } - } - - // 渲染总的 model 文件 - - modelFile := fmt.Sprintf("%s/models.gen.go", baseDir) - os.Remove(modelFile) - fd, err := os.Create(modelFile) - if err != nil { - return fmt.Errorf("failed to create model file %s: %w", baseDir, err) + return fmt.Errorf("failed to create provider file %s: %w", providerFile, err) } defer fd.Close() - if err := modelTpl.Execute(fd, items); err != nil { + if err := providerTpl.Execute(fd, items); err != nil { return fmt.Errorf("failed to render model template: %w", err) } return nil } + +func addProviderComment(filePath string) error { + file, err := os.OpenFile(filePath, os.O_RDWR, 0o644) + if err != nil { + return err + } + defer file.Close() + + content, err := os.ReadFile(filePath) + if err != nil { + return err + } + + if strings.Contains(string(content), "// @provider") { + return nil + } + + // Write this comment to the up line of the type xxx struct + newLines := []string{} + lines := strings.Split(string(content), "\n") + for i, line := range lines { + if strings.Contains(line, "type ") && strings.Contains(line, "struct") { + newLines = append(newLines, "// @provider") + // append rest lines + newLines = append(newLines, lines[i:]...) + break + } + newLines = append(newLines, line) + } + newContent := strings.Join(newLines, "\n") + if _, err := file.WriteAt([]byte(newContent), 0); err != nil { + return err + } + return nil +} diff --git a/pkg/ast/model/models.gen.go.tpl b/pkg/ast/model/models.gen.go.tpl deleted file mode 100644 index 2a22f48..0000000 --- a/pkg/ast/model/models.gen.go.tpl +++ /dev/null @@ -1,30 +0,0 @@ -// Code generated by the atomctl ; DO NOT EDIT. -// Code generated by the atomctl ; DO NOT EDIT. -// Code generated by the atomctl ; DO NOT EDIT. -package models - -import ( - "database/sql" -) - -var db *sql.DB -{{- range . }} -var {{.PascalTable}} *{{.CamelTable}}Model -{{- end }} - -// @provider(model) -type models struct { - db *sql.DB - -{{- range . }} - {{.CamelTable}} *{{.CamelTable}}Model -{{- end }} -} - -func (m *models) Prepare() error { - db = m.db -{{- range . }} - {{.PascalTable}} = m.{{.CamelTable}} -{{- end }} - return nil -} diff --git a/pkg/ast/model/provider.gen.go.tpl b/pkg/ast/model/provider.gen.go.tpl new file mode 100644 index 0000000..f55de68 --- /dev/null +++ b/pkg/ast/model/provider.gen.go.tpl @@ -0,0 +1,53 @@ +// Code generated by the atomctl ; DO NOT EDIT. +// Code generated by the atomctl ; DO NOT EDIT. +// Code generated by the atomctl ; DO NOT EDIT. +package model +import ( + "context" + "database/sql" + + "go.ipao.vip/atom" + "go.ipao.vip/atom/container" + "go.ipao.vip/atom/contracts" + "go.ipao.vip/atom/opt" +) + +var db *sql.DB +{{- range . }} +var {{.PascalTable}}Model *{{.PascalTable}} +{{- end }} + +func Transaction(ctx context.Context) (*sql.Tx, error) { + return db.Begin() +} +func DB() *sql.DB { + return db +} + +func Provide(opts ...opt.Option) error { +{{- range . }} + if err := container.Container.Provide(func() (*{{.PascalTable}}, error) { + obj := &{{.PascalTable}}{} + return obj, nil + }); err != nil { + return err + } +{{ end }} + + if err := container.Container.Provide(func( + _db *sql.DB, +{{- range . }} + {{.CamelTable}} *{{.PascalTable}}, +{{- end }} + ) (contracts.Initial, error) { + db = _db +{{- range . }} + {{.PascalTable}}Model = {{.CamelTable}} +{{- end }} + + return nil, nil + }, atom.GroupInitial); err != nil { + return err + } + return nil +} diff --git a/pkg/ast/model/table.go.tpl b/pkg/ast/model/table.go.tpl index 4435bf7..f14317b 100644 --- a/pkg/ast/model/table.go.tpl +++ b/pkg/ast/model/table.go.tpl @@ -1,14 +1,9 @@ -package models +package model import ( - "github.com/sirupsen/logrus" + log "github.com/sirupsen/logrus" ) -// @provider -type {{.CamelTable}}Model struct { - log *logrus.Entry `inject:"false"` -} -func (m *{{.CamelTable}}Model) Prepare() error { - m.log = logrus.WithField("model", "{{.CamelTable}}Model") - return nil +func (m *{{.PascalTable}}) log() *log.Entry { + return log.WithField("model", "{{.PascalTable}}Model") } \ No newline at end of file diff --git a/pkg/ast/model/table_test.go.tpl b/pkg/ast/model/table_test.go.tpl index ecb9f6e..f8d2301 100644 --- a/pkg/ast/model/table_test.go.tpl +++ b/pkg/ast/model/table_test.go.tpl @@ -1,4 +1,4 @@ -package models +package model import ( "context"