Some checks failed
CI/CD Pipeline / Test (push) Failing after 22m19s
CI/CD Pipeline / Security Scan (push) Failing after 5m57s
CI/CD Pipeline / Build (amd64, darwin) (push) Has been skipped
CI/CD Pipeline / Build (amd64, linux) (push) Has been skipped
CI/CD Pipeline / Build (amd64, windows) (push) Has been skipped
CI/CD Pipeline / Build (arm64, darwin) (push) Has been skipped
CI/CD Pipeline / Build (arm64, linux) (push) Has been skipped
CI/CD Pipeline / Build Docker Image (push) Has been skipped
CI/CD Pipeline / Create Release (push) Has been skipped
468 lines
12 KiB
Go
468 lines
12 KiB
Go
package config
|
||
|
||
import (
|
||
"fmt"
|
||
"os"
|
||
"path/filepath"
|
||
"reflect"
|
||
"strings"
|
||
|
||
"github.com/fsnotify/fsnotify"
|
||
"github.com/spf13/viper"
|
||
"github.com/subconverter-go/internal/logging"
|
||
)
|
||
|
||
// ConfigManager 配置管理器
|
||
// 封装viper功能,提供统一的配置管理接口
|
||
type ConfigManager struct {
|
||
viper *viper.Viper
|
||
config *Configuration
|
||
logger *logging.Logger
|
||
filePath string
|
||
basePath string
|
||
templateManager *TemplateManager
|
||
}
|
||
|
||
// NewConfigManager 创建新的配置管理器
|
||
// 返回初始化好的ConfigManager实例
|
||
func NewConfigManager(configPath string) (*ConfigManager, error) {
|
||
logger, err := logging.NewDefaultLogger()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create logger: %v", err)
|
||
}
|
||
|
||
// 设置base路径
|
||
basePath := "./base"
|
||
if _, err := os.Stat(basePath); os.IsNotExist(err) {
|
||
basePath = "../base"
|
||
}
|
||
|
||
manager := &ConfigManager{
|
||
viper: viper.New(),
|
||
logger: logger,
|
||
basePath: basePath,
|
||
}
|
||
|
||
// 设置默认配置
|
||
manager.setDefaults()
|
||
|
||
// 配置viper
|
||
manager.viper.SetConfigName("config")
|
||
manager.viper.SetConfigType("yaml")
|
||
|
||
// 设置配置文件路径
|
||
if configPath != "" {
|
||
manager.viper.SetConfigFile(configPath)
|
||
manager.filePath = configPath
|
||
} else {
|
||
// 默认配置文件路径
|
||
manager.viper.AddConfigPath(".")
|
||
manager.viper.AddConfigPath("./config")
|
||
manager.viper.AddConfigPath("./configs")
|
||
manager.viper.AddConfigPath("/etc/subconverter-go")
|
||
}
|
||
|
||
// 读取配置文件
|
||
if err := manager.viper.ReadInConfig(); err != nil {
|
||
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
|
||
// 配置文件不存在,使用默认配置
|
||
logger.Info("Configuration file not found, using defaults")
|
||
} else {
|
||
return nil, fmt.Errorf("failed to read config file: %v", err)
|
||
}
|
||
} else {
|
||
logger.Infof("Configuration loaded from: %s", manager.viper.ConfigFileUsed())
|
||
manager.filePath = manager.viper.ConfigFileUsed()
|
||
}
|
||
|
||
// 绑定环境变量
|
||
manager.bindEnvironmentVariables()
|
||
|
||
// 解析配置到结构体
|
||
if err := manager.parseConfig(); err != nil {
|
||
return nil, fmt.Errorf("failed to parse config: %v", err)
|
||
}
|
||
|
||
// 初始化模板管理器
|
||
templateManager, err := NewTemplateManager(basePath, logger)
|
||
if err != nil {
|
||
logger.Warnf("Failed to initialize template manager: %v, template features will be disabled", err)
|
||
} else {
|
||
manager.templateManager = templateManager
|
||
logger.Info("Template manager initialized successfully")
|
||
}
|
||
|
||
return manager, nil
|
||
}
|
||
|
||
// setDefaults 设置默认配置值
|
||
func (cm *ConfigManager) setDefaults() {
|
||
// 服务器默认配置
|
||
cm.viper.SetDefault("server.host", "0.0.0.0")
|
||
cm.viper.SetDefault("server.port", 25500)
|
||
cm.viper.SetDefault("server.read_timeout", 30)
|
||
cm.viper.SetDefault("server.write_timeout", 30)
|
||
cm.viper.SetDefault("server.max_request_size", 10*1024*1024) // 10MB
|
||
|
||
// 日志默认配置
|
||
cm.viper.SetDefault("logging.level", "info")
|
||
cm.viper.SetDefault("logging.format", "json")
|
||
cm.viper.SetDefault("logging.output", "stdout")
|
||
cm.viper.SetDefault("logging.file", "")
|
||
|
||
// 安全默认配置
|
||
cm.viper.SetDefault("security.access_tokens", []string{})
|
||
cm.viper.SetDefault("security.cors_origins", []string{"*"})
|
||
cm.viper.SetDefault("security.rate_limit", 0)
|
||
cm.viper.SetDefault("security.timeout", 60)
|
||
|
||
// 转换默认配置
|
||
cm.viper.SetDefault("conversion.default_target", "clash")
|
||
cm.viper.SetDefault("conversion.default_emoji", false)
|
||
cm.viper.SetDefault("conversion.default_udp", false)
|
||
cm.viper.SetDefault("conversion.max_nodes", 0)
|
||
cm.viper.SetDefault("conversion.cache_timeout", 60)
|
||
cm.viper.SetDefault("conversion.supported_targets", []string{
|
||
"clash", "surge", "quanx", "loon", "surfboard", "v2ray",
|
||
})
|
||
}
|
||
|
||
// bindEnvironmentVariables 绑定环境变量
|
||
func (cm *ConfigManager) bindEnvironmentVariables() {
|
||
// 服务器配置
|
||
cm.viper.BindEnv("server.host", "SUBCONVERTER_HOST")
|
||
cm.viper.BindEnv("server.port", "SUBCONVERTER_PORT")
|
||
cm.viper.BindEnv("server.read_timeout", "SUBCONVERTER_READ_TIMEOUT")
|
||
cm.viper.BindEnv("server.write_timeout", "SUBCONVERTER_WRITE_TIMEOUT")
|
||
cm.viper.BindEnv("server.max_request_size", "SUBCONVERTER_MAX_REQUEST_SIZE")
|
||
|
||
// 日志配置
|
||
cm.viper.BindEnv("logging.level", "SUBCONVERTER_LOG_LEVEL")
|
||
cm.viper.BindEnv("logging.format", "SUBCONVERTER_LOG_FORMAT")
|
||
cm.viper.BindEnv("logging.output", "SUBCONVERTER_LOG_OUTPUT")
|
||
cm.viper.BindEnv("logging.file", "SUBCONVERTER_LOG_FILE")
|
||
|
||
// 安全配置
|
||
cm.viper.BindEnv("security.access_tokens", "SUBCONVERTER_ACCESS_TOKENS")
|
||
cm.viper.BindEnv("security.cors_origins", "SUBCONVERTER_CORS_ORIGINS")
|
||
cm.viper.BindEnv("security.rate_limit", "SUBCONVERTER_RATE_LIMIT")
|
||
cm.viper.BindEnv("security.timeout", "SUBCONVERTER_TIMEOUT")
|
||
|
||
// 转换配置
|
||
cm.viper.BindEnv("conversion.default_target", "SUBCONVERTER_DEFAULT_TARGET")
|
||
cm.viper.BindEnv("conversion.default_emoji", "SUBCONVERTER_DEFAULT_EMOJI")
|
||
cm.viper.BindEnv("conversion.default_udp", "SUBCONVERTER_DEFAULT_UDP")
|
||
cm.viper.BindEnv("conversion.max_nodes", "SUBCONVERTER_MAX_NODES")
|
||
cm.viper.BindEnv("conversion.cache_timeout", "SUBCONVERTER_CACHE_TIMEOUT")
|
||
}
|
||
|
||
// parseConfig 解析配置到结构体
|
||
func (cm *ConfigManager) parseConfig() error {
|
||
cm.config = &Configuration{}
|
||
|
||
// 手动解析配置,因为viper的直接映射可能不适用复杂结构
|
||
if err := cm.viper.Unmarshal(cm.config); err != nil {
|
||
return fmt.Errorf("failed to unmarshal config: %v", err)
|
||
}
|
||
|
||
// 后处理配置
|
||
cm.postProcessConfig()
|
||
|
||
return nil
|
||
}
|
||
|
||
// postProcessConfig 后处理配置
|
||
func (cm *ConfigManager) postProcessConfig() {
|
||
// 处理安全配置
|
||
if cm.config.Security.CorsOrigins == nil {
|
||
cm.config.Security.CorsOrigins = []string{"*"}
|
||
}
|
||
|
||
// 处理转换配置
|
||
if cm.config.Conversion.SupportedTargets == nil {
|
||
cm.config.Conversion.SupportedTargets = []string{
|
||
"clash", "clashr", "surge", "quanx", "loon", "surfboard", "v2ray",
|
||
}
|
||
}
|
||
}
|
||
|
||
// GetConfig 获取配置
|
||
// 返回当前配置的副本
|
||
func (cm *ConfigManager) GetConfig() *Configuration {
|
||
if cm.config == nil {
|
||
return nil
|
||
}
|
||
return cm.config.Clone()
|
||
}
|
||
|
||
// GetServerConfig 获取服务器配置
|
||
func (cm *ConfigManager) GetServerConfig() *ServerConfig {
|
||
if cm.config == nil {
|
||
return nil
|
||
}
|
||
clone := cm.config.Server.Clone()
|
||
return &clone
|
||
}
|
||
|
||
// GetLoggingConfig 获取日志配置
|
||
func (cm *ConfigManager) GetLoggingConfig() *LoggingConfig {
|
||
if cm.config == nil {
|
||
return nil
|
||
}
|
||
clone := cm.config.Logging.Clone()
|
||
return &clone
|
||
}
|
||
|
||
// GetSecurityConfig 获取安全配置
|
||
func (cm *ConfigManager) GetSecurityConfig() *SecurityConfig {
|
||
if cm.config == nil {
|
||
return nil
|
||
}
|
||
clone := cm.config.Security.Clone()
|
||
return &clone
|
||
}
|
||
|
||
// GetConversionConfig 获取转换配置
|
||
func (cm *ConfigManager) GetConversionConfig() *ConversionConfig {
|
||
if cm.config == nil {
|
||
return nil
|
||
}
|
||
clone := cm.config.Conversion.Clone()
|
||
return &clone
|
||
}
|
||
|
||
// UpdateConfig 更新配置
|
||
func (cm *ConfigManager) UpdateConfig(config *Configuration) error {
|
||
if config == nil {
|
||
return fmt.Errorf("config cannot be nil")
|
||
}
|
||
|
||
// 验证配置
|
||
if err := config.Validate(); err != nil {
|
||
return fmt.Errorf("invalid config: %v", err)
|
||
}
|
||
|
||
// 更新配置
|
||
cm.config = config.Clone()
|
||
|
||
// 同步到viper
|
||
cm.syncToViper()
|
||
|
||
// 保存到文件(如果配置了文件路径)
|
||
if cm.filePath != "" {
|
||
if err := cm.SaveConfig(); err != nil {
|
||
cm.logger.WithError(err).Warn("Failed to save config to file")
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// syncToViper 将配置同步到viper
|
||
func (cm *ConfigManager) syncToViper() {
|
||
// 使用反射同步配置
|
||
v := reflect.ValueOf(cm.config).Elem()
|
||
t := v.Type()
|
||
|
||
for i := 0; i < v.NumField(); i++ {
|
||
field := v.Field(i)
|
||
fieldType := t.Field(i)
|
||
|
||
// 获取字段名(考虑yaml标签)
|
||
fieldName := getFieldName(fieldType)
|
||
|
||
// 递归设置嵌套结构
|
||
cm.setViperValue(cm.viper, fieldName, field)
|
||
}
|
||
}
|
||
|
||
// setViperValue 递归设置viper值
|
||
func (cm *ConfigManager) setViperValue(v *viper.Viper, key string, value reflect.Value) {
|
||
switch value.Kind() {
|
||
case reflect.Struct:
|
||
// 处理嵌套结构
|
||
for i := 0; i < value.NumField(); i++ {
|
||
field := value.Field(i)
|
||
fieldType := value.Type().Field(i)
|
||
subKey := key + "." + getFieldName(fieldType)
|
||
cm.setViperValue(v, subKey, field)
|
||
}
|
||
case reflect.Slice, reflect.Array:
|
||
// 处理切片和数组
|
||
if value.Len() > 0 {
|
||
slice := make([]interface{}, value.Len())
|
||
for i := 0; i < value.Len(); i++ {
|
||
slice[i] = value.Index(i).Interface()
|
||
}
|
||
v.Set(key, slice)
|
||
}
|
||
case reflect.Map:
|
||
// 处理map
|
||
mapValue := make(map[string]interface{})
|
||
for _, key := range value.MapKeys() {
|
||
mapValue[key.String()] = value.MapIndex(key).Interface()
|
||
}
|
||
v.Set(key, mapValue)
|
||
default:
|
||
// 处理基本类型
|
||
v.Set(key, value.Interface())
|
||
}
|
||
}
|
||
|
||
// getFieldName 获取字段名
|
||
func getFieldName(field reflect.StructField) string {
|
||
// 优先使用yaml标签
|
||
if yamlTag := field.Tag.Get("yaml"); yamlTag != "" {
|
||
if yamlTag != "-" {
|
||
return strings.Split(yamlTag, ",")[0]
|
||
}
|
||
}
|
||
|
||
// 其次使用json标签
|
||
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
|
||
if jsonTag != "-" {
|
||
return strings.Split(jsonTag, ",")[0]
|
||
}
|
||
}
|
||
|
||
// 最后使用字段名
|
||
return strings.ToLower(field.Name)
|
||
}
|
||
|
||
// SaveConfig 保存配置到文件
|
||
func (cm *ConfigManager) SaveConfig() error {
|
||
if cm.filePath == "" {
|
||
return fmt.Errorf("no config file path configured")
|
||
}
|
||
|
||
// 确保目录存在
|
||
dir := filepath.Dir(cm.filePath)
|
||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||
return fmt.Errorf("failed to create config directory: %v", err)
|
||
}
|
||
|
||
// 设置配置文件路径
|
||
cm.viper.SetConfigFile(cm.filePath)
|
||
|
||
// 保存配置
|
||
if err := cm.viper.WriteConfig(); err != nil {
|
||
return fmt.Errorf("failed to write config: %v", err)
|
||
}
|
||
|
||
cm.logger.Infof("Configuration saved to: %s", cm.filePath)
|
||
return nil
|
||
}
|
||
|
||
// ReloadConfig 从磁盘重新加载配置
|
||
func (cm *ConfigManager) ReloadConfig() error {
|
||
if cm.filePath == "" {
|
||
return fmt.Errorf("no config file path configured")
|
||
}
|
||
|
||
cm.logger.Infof("Reloading configuration from: %s", cm.filePath)
|
||
cm.viper.SetConfigFile(cm.filePath)
|
||
if err := cm.viper.ReadInConfig(); err != nil {
|
||
return fmt.Errorf("failed to read config file: %v", err)
|
||
}
|
||
|
||
if err := cm.parseConfig(); err != nil {
|
||
return fmt.Errorf("failed to parse config: %v", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// WatchConfig 监听配置文件变化
|
||
func (cm *ConfigManager) WatchConfig(callback func(*Configuration)) {
|
||
cm.viper.WatchConfig()
|
||
|
||
cm.viper.OnConfigChange(func(e fsnotify.Event) {
|
||
cm.logger.Infof("Configuration file changed: %s", e.Name)
|
||
|
||
// 重新解析配置
|
||
if err := cm.parseConfig(); err != nil {
|
||
cm.logger.WithError(err).Error("Failed to parse changed config")
|
||
return
|
||
}
|
||
|
||
// 调用回调函数
|
||
if callback != nil {
|
||
callback(cm.config)
|
||
}
|
||
})
|
||
}
|
||
|
||
// GetString 获取字符串配置值
|
||
func (cm *ConfigManager) GetString(key string) string {
|
||
return cm.viper.GetString(key)
|
||
}
|
||
|
||
// GetInt 获取整型配置值
|
||
func (cm *ConfigManager) GetInt(key string) int {
|
||
return cm.viper.GetInt(key)
|
||
}
|
||
|
||
// GetBool 获取布尔配置值
|
||
func (cm *ConfigManager) GetBool(key string) bool {
|
||
return cm.viper.GetBool(key)
|
||
}
|
||
|
||
// GetStringSlice 获取字符串切片配置值
|
||
func (cm *ConfigManager) GetStringSlice(key string) []string {
|
||
return cm.viper.GetStringSlice(key)
|
||
}
|
||
|
||
// IsSet 检查配置是否设置
|
||
func (cm *ConfigManager) IsSet(key string) bool {
|
||
return cm.viper.IsSet(key)
|
||
}
|
||
|
||
// Set 设置配置值
|
||
func (cm *ConfigManager) Set(key string, value interface{}) {
|
||
cm.viper.Set(key, value)
|
||
}
|
||
|
||
// GetFilePath 获取配置文件路径
|
||
func (cm *ConfigManager) GetFilePath() string {
|
||
return cm.filePath
|
||
}
|
||
|
||
// Close 关闭配置管理器
|
||
func (cm *ConfigManager) Close() error {
|
||
// 停止监听配置变化
|
||
cm.viper.OnConfigChange(nil)
|
||
|
||
// 保存配置(如果需要)
|
||
if cm.filePath != "" {
|
||
if err := cm.SaveConfig(); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// GetTemplateManager 获取模板管理器
|
||
func (cm *ConfigManager) GetTemplateManager() *TemplateManager {
|
||
return cm.templateManager
|
||
}
|
||
|
||
// GetBasePath 获取base路径
|
||
func (cm *ConfigManager) GetBasePath() string {
|
||
return cm.basePath
|
||
}
|
||
|
||
// HasTemplateManager 检查是否有模板管理器
|
||
func (cm *ConfigManager) HasTemplateManager() bool {
|
||
return cm.templateManager != nil
|
||
}
|
||
|
||
// ReloadTemplates 重新加载模板
|
||
func (cm *ConfigManager) ReloadTemplates() error {
|
||
if cm.templateManager == nil {
|
||
return fmt.Errorf("template manager not initialized")
|
||
}
|
||
|
||
return cm.templateManager.Reload()
|
||
}
|