From 2ec0c73ba3afb0b5847fe52be38ea339dfb6f54c Mon Sep 17 00:00:00 2001 From: Rogee Date: Wed, 25 Dec 2024 18:15:27 +0800 Subject: [PATCH] fix: model gen issues --- cmd/gen_model.go | 12 +++++- templates/project/database/database.go.tpl | 36 +++++++++++++++++ .../project/database/fields/common.go.tpl | 32 +++++++++++++++ templates/project/pkg/pg/db.go.tpl | 40 ------------------- 4 files changed, 79 insertions(+), 41 deletions(-) create mode 100644 templates/project/database/fields/common.go.tpl delete mode 100644 templates/project/pkg/pg/db.go.tpl diff --git a/cmd/gen_model.go b/cmd/gen_model.go index a6fde0c..d20c7bc 100644 --- a/cmd/gen_model.go +++ b/cmd/gen_model.go @@ -5,6 +5,7 @@ import ( "strings" pgDatabase "git.ipao.vip/rogeecn/atomctl/pkg/postgres" + "git.ipao.vip/rogeecn/atomctl/pkg/utils/gomod" "github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/generator/postgres" "github.com/go-jet/jet/v2/generator/template" @@ -29,6 +30,10 @@ func CommandGenModel(root *cobra.Command) { } func commandGenModelE(cmd *cobra.Command, args []string) error { + if err := gomod.Parse("go.mod"); err != nil { + return errors.Wrap(err, "parse go.mod") + } + _, dbConf, err := pgDatabase.GetDB(cmd.Flag("config").Value.String()) if err != nil { return errors.Wrap(err, "get db") @@ -91,13 +96,18 @@ func commandGenModelE(cmd *cobra.Command, args []string) error { splits := strings.Split(toType, ".") typeName := splits[len(splits)-1] + pkg := splits[0] + if strings.HasPrefix(pkg, "/") { + pkg = gomod.GetModuleName() + pkg + } + pkgSplits := strings.Split(splits[0], "/") typePkg := pkgSplits[len(pkgSplits)-1] defaultTableModelField = defaultTableModelField. UseType(template.Type{ Name: fmt.Sprintf("%s.%s", typePkg, typeName), - ImportPath: splits[0], + ImportPath: pkg, }) log.Infof("Convert table %s field %s type to : %s", table.Name, column.Name, toType) diff --git a/templates/project/database/database.go.tpl b/templates/project/database/database.go.tpl index 77c3404..b42098e 100644 --- a/templates/project/database/database.go.tpl +++ b/templates/project/database/database.go.tpl @@ -1,8 +1,44 @@ package database import ( + "context" + "database/sql" "embed" + "fmt" + + "github.com/go-jet/jet/v2/qrm" ) //go:embed migrations/* var MigrationFS embed.FS + +type CtxDB struct{} + +func FromContext(ctx context.Context, db *sql.DB) qrm.DB { + if tx, ok := ctx.Value(CtxDB{}).(*sql.Tx); ok { + return tx + } + return db +} + +func TruncateAllTables(ctx context.Context, db *sql.DB, tableName ...string) error { + for _, name := range tableName { + sql := fmt.Sprintf("TRUNCATE TABLE %s RESTART IDENTITY", name) + if _, err := db.ExecContext(ctx, sql); err != nil { + return err + } + } + return nil +} + +func WrapLike(v string) string { + return "%" + v + "%" +} + +func WrapLikeLeft(v string) string { + return "%" + v +} + +func WrapLikeRight(v string) string { + return "%" + v +} diff --git a/templates/project/database/fields/common.go.tpl b/templates/project/database/fields/common.go.tpl new file mode 100644 index 0000000..e790807 --- /dev/null +++ b/templates/project/database/fields/common.go.tpl @@ -0,0 +1,32 @@ +package fields + +import ( + "database/sql/driver" + "encoding/json" + "errors" + + "github.com/samber/lo" +) + +// implement sql.Scanner interface +type field struct{} + +func (x *field) Scan(value interface{}) (err error) { + switch v := value.(type) { + case string: + return json.Unmarshal([]byte(v), &x) + case []byte: + return json.Unmarshal(v, &x) + case *string: + return json.Unmarshal([]byte(*v), &x) + } + return errors.New("Unknown type for ") +} + +func (x field) Value() (driver.Value, error) { + return json.Marshal(x) +} + +func (x field) MustValue() driver.Value { + return lo.Must(json.Marshal(x)) +} diff --git a/templates/project/pkg/pg/db.go.tpl b/templates/project/pkg/pg/db.go.tpl deleted file mode 100644 index 8bda8af..0000000 --- a/templates/project/pkg/pg/db.go.tpl +++ /dev/null @@ -1,40 +0,0 @@ -package db - -import ( - "context" - "database/sql" - "fmt" - - "github.com/go-jet/jet/v2/qrm" -) - -const CtxDB = "__db__tx:" - -func FromContext(ctx context.Context, db *sql.DB) qrm.DB { - if tx, ok := ctx.Value(CtxDB).(*sql.Tx); ok { - return tx - } - return db -} - -func TruncateAllTables(ctx context.Context, db *sql.DB, tableName ...string) error { - for _, name := range tableName { - sql := fmt.Sprintf("TRUNCATE TABLE %s RESTART IDENTITY", name) - if _, err := db.ExecContext(ctx, sql); err != nil { - return err - } - } - return nil -} - -func WrapLike(v string) string { - return "%" + v + "%" -} - -func WrapLikeLeft(v string) string { - return "%" + v -} - -func WrapLikeRight(v string) string { - return "%" + v -}