diff --git a/cmd_root.go b/cmd_root.go index e9c803e..ac502fd 100644 --- a/cmd_root.go +++ b/cmd_root.go @@ -1,8 +1,6 @@ package atom import ( - "fmt" - "git.ipao.vip/rogeecn/atom/config" "git.ipao.vip/rogeecn/atom/container" "github.com/pkg/errors" @@ -10,8 +8,6 @@ import ( "go.uber.org/dig" ) -var cfgFile string - var ( GroupInitialName = "initials" GroupRoutesName = "routes" @@ -26,7 +22,7 @@ var ( GroupQueue = dig.Group(GroupQueueName) ) -func Serve(providers container.Providers, opts ...Option) error { +func Serve(opts ...Option) error { rootCmd := &cobra.Command{Use: "app"} for _, opt := range opts { opt(rootCmd) @@ -40,21 +36,15 @@ func Serve(providers container.Providers, opts ...Option) error { return err }) - defaultCfgFile := fmt.Sprintf(".%s.toml", rootCmd.Use) - rootCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "", "config file path, lookup in dir: $HOME, $PWD, /etc, /usr/local/etc, filename: "+defaultCfgFile) - - rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error { - return LoadProviders(cfgFile, rootCmd.Use, providers) - } + rootCmd.PersistentFlags().StringP("config", "c", "config.toml", "config file") return rootCmd.Execute() } -func LoadProviders(cfgFile, appName string, providers container.Providers) error { - // parse config files - configure, err := config.Load(cfgFile, appName) +func LoadProviders(configFile string, providers container.Providers) error { + configure, err := config.Load(configFile) if err != nil { - return errors.Wrapf(err, "load config file: %s", cfgFile) + return errors.Wrapf(err, "load config file: %s", configFile) } if err := providers.Provide(configure); err != nil { @@ -70,6 +60,24 @@ var ( AppVersion string ) +func Providers(providers container.Providers) Option { + return func(cmd *cobra.Command) { + cmd.PreRunE = func(cmd *cobra.Command, args []string) error { + return LoadProviders(cmd.Flag("config").Value.String(), providers) + } + } +} + +func Command(opt ...Option) Option { + return func(parentCmd *cobra.Command) { + cmd := &cobra.Command{} + for _, o := range opt { + o(cmd) + } + parentCmd.AddCommand(cmd) + } +} + func Version(ver string) Option { return func(cmd *cobra.Command) { cmd.Version = ver diff --git a/config/config.go b/config/config.go index f4c432c..ca7c190 100644 --- a/config/config.go +++ b/config/config.go @@ -2,7 +2,6 @@ package config import ( "log" - "os" "path/filepath" "github.com/pkg/errors" @@ -10,38 +9,23 @@ import ( "github.com/spf13/viper" ) -func Load(file, app string) (*viper.Viper, error) { +func Load(file string) (*viper.Viper, error) { v := viper.NewWithOptions(viper.KeyDelimiter("_")) v.AutomaticEnv() - if file == "" { + ext := filepath.Ext(file) + if ext == "" { v.SetConfigType("toml") - v.SetConfigName(app + ".toml") - - paths := []string{"."} - // execute path - execPath, err := os.Executable() - if err == nil { - paths = append(paths, filepath.Dir(execPath)) - } - - // home path - homePath, err := os.UserHomeDir() - if err == nil { - paths = append(paths, homePath, homePath+"/"+app, homePath+"/.config", homePath+"/.config/"+app) - } - paths = append(paths, "/etc", "/etc/"+app, "/usr/local/etc", "/usr/local/etc/"+app) - - log.Println("try load config from paths:", paths) - for _, path := range paths { - v.AddConfigPath(path) - } + v.SetConfigFile(file) } else { + v.SetConfigType(ext[1:]) v.SetConfigFile(file) } + v.AddConfigPath(".") + err := v.ReadInConfig() - log.Println("use config file:", v.ConfigFileUsed()) + log.Println("config file:", v.ConfigFileUsed()) if err != nil { return nil, errors.Wrap(err, "config file read error") }