From 34c05d52a29c8c152f49c2553f26964ffc5cb68a Mon Sep 17 00:00:00 2001 From: rogeecn Date: Sat, 22 Mar 2025 19:27:42 +0800 Subject: [PATCH] fix: model --- pkg/ast/model/generage.go | 32 +++++++++++++++++++++--- pkg/ast/model/table.go.tpl | 5 ++++ pkg/ast/model/table_test.go.tpl | 43 +++++++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 3 deletions(-) create mode 100644 pkg/ast/model/table_test.go.tpl diff --git a/pkg/ast/model/generage.go b/pkg/ast/model/generage.go index ab955a8..a49de52 100644 --- a/pkg/ast/model/generage.go +++ b/pkg/ast/model/generage.go @@ -10,15 +10,20 @@ import ( "strings" "github.com/samber/lo" + "go.ipao.vip/atomctl/pkg/utils/gomod" ) //go:embed table.go.tpl var tableTpl string +//go:embed table_test.go.tpl +var tableTestTpl string + //go:embed models.gen.go.tpl var modelTpl string type TableModelParam struct { + PkgName string CamelTable string // user PascalTable string // User } @@ -27,6 +32,7 @@ func Generate(tables []string, transformer Transformer) error { baseDir := "app/models" tableTpl := template.Must(template.New("model").Parse(string(tableTpl))) + tableTestTpl := template.Must(template.New("model").Parse(string(tableTestTpl))) modelTpl := template.Must(template.New("modelGen").Parse(string(modelTpl))) items := []TableModelParam{} @@ -36,10 +42,12 @@ func Generate(tables []string, transformer Transformer) error { continue } - items = append(items, TableModelParam{ + tableInfo := TableModelParam{ CamelTable: lo.CamelCase(table), PascalTable: lo.PascalCase(table), - }) + PkgName: gomod.GetModuleName(), + } + items = append(items, tableInfo) modelFile := fmt.Sprintf("%s/%s.go", baseDir, table) // 如果 modelFile 已存在,则跳过 @@ -55,9 +63,27 @@ func Generate(tables []string, transformer Transformer) error { } defer fd.Close() - if err := tableTpl.Execute(fd, map[string]string{"CamelTable": lo.CamelCase(table)}); err != nil { + if err := tableTpl.Execute(fd, tableInfo); err != nil { return fmt.Errorf("failed to render model template: %w", err) } + + modelTestFile := fmt.Sprintf("%s/%s_test.go", baseDir, table) + // 如果 modelTestFile 已存在,则跳过 + if _, err := os.Stat(modelTestFile); err == nil { + fmt.Printf("Model test file %s already exists. Skipping...\n", modelTestFile) + continue + } + + // 如果 modelTestFile 不存在,则创建 + fd, err = os.Create(modelTestFile) + if err != nil { + return fmt.Errorf("failed to create model test file %s: %w", modelTestFile, err) + } + defer fd.Close() + + if err := tableTestTpl.Execute(fd, tableInfo); err != nil { + return fmt.Errorf("failed to render model test template: %w", err) + } } // 遍历 baseDir 下的所有文件,将不在 tables 中的文件名(不带扩展名)加入 diff --git a/pkg/ast/model/table.go.tpl b/pkg/ast/model/table.go.tpl index 637fb8d..4435bf7 100644 --- a/pkg/ast/model/table.go.tpl +++ b/pkg/ast/model/table.go.tpl @@ -1,9 +1,14 @@ package models +import ( + "github.com/sirupsen/logrus" +) // @provider type {{.CamelTable}}Model struct { + log *logrus.Entry `inject:"false"` } func (m *{{.CamelTable}}Model) Prepare() error { + m.log = logrus.WithField("model", "{{.CamelTable}}Model") return nil } \ No newline at end of file diff --git a/pkg/ast/model/table_test.go.tpl b/pkg/ast/model/table_test.go.tpl new file mode 100644 index 0000000..ecb9f6e --- /dev/null +++ b/pkg/ast/model/table_test.go.tpl @@ -0,0 +1,43 @@ +package models + +import ( + "context" + "testing" + + "{{ .PkgName }}/app/service/testx" + "{{ .PkgName }}/database" + "{{ .PkgName }}/database/schemas/public/table" + + . "github.com/smartystreets/goconvey/convey" + "go.ipao.vip/atom/contracts" + + // . "github.com/go-jet/jet/v2/postgres" + "github.com/stretchr/testify/suite" + "go.uber.org/dig" +) + +type {{ .PascalTable }}InjectParams struct { + dig.In + Initials []contracts.Initial `group:"initials"` +} + +type {{ .PascalTable }}TestSuite struct { + suite.Suite + + {{ .PascalTable }}InjectParams +} + +func Test_{{ .PascalTable }}(t *testing.T) { + providers := testx.Default().With(Provide) + testx.Serve(providers, t, func(params {{ .PascalTable }}InjectParams) { + suite.Run(t, &{{ .PascalTable }}TestSuite{ + {{ .PascalTable }}InjectParams: params, + }) + }) +} + +func (s *{{ .PascalTable }}TestSuite) Test_Demo() { + Convey("Test_Demo", s.T(), func() { + database.Truncate(context.Background(), db, table.{{ .PascalTable }}.TableName()) + }) +}