diff --git a/cmd/gen_model.go b/cmd/gen_model.go index c9aa2e8..e54d23f 100644 --- a/cmd/gen_model.go +++ b/cmd/gen_model.go @@ -1,16 +1,14 @@ package cmd import ( - "fmt" - - "github.com/pkg/errors" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "github.com/spf13/viper" - "go.ipao.vip/atomctl/v2/pkg/utils/gomod" - "go.ipao.vip/gen" - "gorm.io/driver/postgres" - "gorm.io/gorm" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + apg "go.ipao.vip/atomctl/v2/pkg/postgres" + "go.ipao.vip/atomctl/v2/pkg/utils/gomod" + "go.ipao.vip/gen" + "gorm.io/driver/postgres" + "gorm.io/gorm" ) func CommandGenModel(root *cobra.Command) { @@ -38,83 +36,31 @@ func CommandGenModel(root *cobra.Command) { } func commandGenModelE(cmd *cobra.Command, args []string) error { - if err := gomod.Parse("go.mod"); err != nil { - return errors.Wrap(err, "parse go.mod") - } + if err := gomod.Parse("go.mod"); err != nil { + return errors.Wrap(err, "parse go.mod") + } - cfgFile := cmd.Flag("config").Value.String() - if cfgFile == "" { - cfgFile = "config.toml" - } + cfgFile := cmd.Flag("config").Value.String() + if cfgFile == "" { + cfgFile = "config.toml" + } - v := viper.New() - v.SetConfigType("toml") - v.SetConfigFile(cfgFile) - v.AddConfigPath(".") - if err := v.ReadInConfig(); err != nil { - return errors.Wrap(err, "read config") - } + sqlDB, conf, err := apg.GetDB(cfgFile) + if err != nil { + return errors.Wrap(err, "load database config") + } + defer sqlDB.Close() - var dbc modelDBConfig - if err := v.UnmarshalKey("Database", &dbc); err != nil { - return errors.Wrap(err, "unmarshal Database config") - } - dsn := dbc.DSN() + dsn := conf.DSN() + log.Infof("parsed DSN: %s (schema=%s)", dsn, conf.Schema) - log.Infof("parsed DSN: %s (schema=%s)", dsn, dbc.Schema) - - db, err := gorm.Open(postgres.New(postgres.Config{DSN: dsn})) - if err != nil { - return errors.Wrapf(err, "open database with dsn: %s", dsn) - } + db, err := gorm.Open(postgres.New(postgres.Config{DSN: dsn})) + if err != nil { + return errors.Wrapf(err, "open database with dsn: %s", dsn) + } // 默认同包同目录生成到 ./database gen.GenerateWithDefault(db, "./database/.transform.yaml") - return nil -} - -// local config and helpers - -type modelDBConfig struct { - Username string - Password string - Database string - Schema string - Host string - Port uint - SslMode string - TimeZone string -} - -func (c *modelDBConfig) applyDefaults() { - if c.Username == "" { - c.Username = "postgres" - } - if c.SslMode == "" { - c.SslMode = "disable" - } - if c.TimeZone == "" { - c.TimeZone = "Asia/Shanghai" - } - if c.Port == 0 { - c.Port = 5432 - } - if c.Schema == "" { - c.Schema = "public" - } -} - -func (c *modelDBConfig) DSN() string { - c.applyDefaults() - return fmt.Sprintf( - "host=%s user=%s password=%s dbname=%s port=%d sslmode=%s TimeZone=%s", - c.Host, - c.Username, - c.Password, - c.Database, - c.Port, - c.SslMode, - c.TimeZone, - ) + return nil }