This commit is contained in:
2025-11-14 12:11:44 +08:00
commit 39ebf61572
88 changed files with 9999 additions and 0 deletions

View File

@@ -0,0 +1,109 @@
package config
import (
"testing"
"time"
)
func TestLoadWithDefaults(t *testing.T) {
cfgPath := testConfigPath(t, "valid.toml")
cfg, err := Load(cfgPath)
if err != nil {
t.Fatalf("Load 返回错误: %v", err)
}
if cfg.Global.CacheTTL.DurationValue() == 0 {
t.Fatalf("CacheTTL 应该自动填充默认值")
}
if cfg.Global.StoragePath == "" {
t.Fatalf("StoragePath 应该被保留")
}
if cfg.Global.ListenPort == 0 {
t.Fatalf("ListenPort 应当被解析")
}
if cfg.EffectiveCacheTTL(cfg.Hubs[0]) != cfg.Global.CacheTTL.DurationValue() {
t.Fatalf("Hub 未设置 TTL 时应退回全局 TTL")
}
}
func TestValidateRejectsBadHub(t *testing.T) {
cfgPath := testConfigPath(t, "missing.toml")
if _, err := Load(cfgPath); err == nil {
t.Fatalf("不合法的配置应返回错误")
}
}
func TestEffectiveCacheTTLOverrides(t *testing.T) {
cfg := &Config{Global: GlobalConfig{CacheTTL: Duration(time.Hour)}}
hub := HubConfig{CacheTTL: Duration(2 * time.Hour)}
if ttl := cfg.EffectiveCacheTTL(hub); ttl != 2*time.Hour {
t.Fatalf("覆盖 TTL 应该优先生效")
}
}
func TestValidateEnforcesListenPortRange(t *testing.T) {
cfg := validConfig()
cfg.Global.ListenPort = 70000
if err := cfg.Validate(); err == nil {
t.Fatalf("ListenPort 超出范围应当报错")
}
}
func TestHubTypeValidation(t *testing.T) {
testCases := []struct {
name string
hubType string
shouldErr bool
}{
{"docker ok", "docker", false},
{"npm ok", "npm", false},
{"go ok", "go", false},
{"missing type", "", true},
{"unsupported type", "rubygems", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
cfg := validConfig()
cfg.Hubs[0].Type = tc.hubType
err := cfg.Validate()
if tc.shouldErr && err == nil {
t.Fatalf("expected error for type %q", tc.hubType)
}
if !tc.shouldErr && err != nil {
t.Fatalf("unexpected error for type %q: %v", tc.hubType, err)
}
})
}
}
func TestValidateRequiresCredentialPairs(t *testing.T) {
cfg := validConfig()
cfg.Hubs[0].Username = "foo"
if err := cfg.Validate(); err == nil {
t.Fatalf("仅提供 Username 时应报错")
}
}
func validConfig() *Config {
return &Config{
Global: GlobalConfig{
ListenPort: 5000,
StoragePath: "./data",
CacheTTL: Duration(time.Hour),
MaxMemoryCache: 1,
MaxRetries: 1,
InitialBackoff: Duration(time.Second),
UpstreamTimeout: Duration(time.Second),
},
Hubs: []HubConfig{
{
Name: "npm",
Domain: "npm.local",
Type: "npm",
Upstream: "https://registry.npmjs.org",
},
},
}
}

26
internal/config/errors.go Normal file
View File

@@ -0,0 +1,26 @@
package config
import "fmt"
// FieldError 提供字段路径与错误原因,便于 CLI 向用户反馈。
type FieldError struct {
Field string
Reason string
}
func (e FieldError) Error() string {
return fmt.Sprintf("%s: %s", e.Field, e.Reason)
}
// newFieldError 创建包含字段路径与原因的 error便于 CLI 定位。
func newFieldError(field, reason string) error {
return FieldError{Field: field, Reason: reason}
}
// hubField 用于拼接 Hub 级字段路径,方便输出 Hub[xxx].Field 形式。
func hubField(name, field string) string {
if name == "" {
return fmt.Sprintf("Hub[].%s", field)
}
return fmt.Sprintf("Hub[%s].%s", name, field)
}

149
internal/config/loader.go Normal file
View File

@@ -0,0 +1,149 @@
package config
import (
"fmt"
"path/filepath"
"reflect"
"strconv"
"time"
"github.com/mitchellh/mapstructure"
"github.com/spf13/viper"
)
// Load 读取并解析 TOML 配置文件,同时注入默认值与校验逻辑。
func Load(path string) (*Config, error) {
if path == "" {
path = "config.toml"
}
v := viper.New()
v.SetConfigFile(path)
setDefaults(v)
if err := v.ReadInConfig(); err != nil {
return nil, fmt.Errorf("读取配置失败: %w", err)
}
if err := rejectHubLevelPorts(v); err != nil {
return nil, err
}
var cfg Config
if err := v.Unmarshal(&cfg, viper.DecodeHook(durationDecodeHook())); err != nil {
return nil, fmt.Errorf("解析配置失败: %w", err)
}
applyGlobalDefaults(&cfg.Global)
for i := range cfg.Hubs {
applyHubDefaults(&cfg.Hubs[i])
}
if err := cfg.Validate(); err != nil {
return nil, err
}
absStorage, err := filepath.Abs(cfg.Global.StoragePath)
if err != nil {
return nil, fmt.Errorf("无法解析缓存目录: %w", err)
}
cfg.Global.StoragePath = absStorage
return &cfg, nil
}
func setDefaults(v *viper.Viper) {
v.SetDefault("ListenPort", 5000)
v.SetDefault("LogLevel", "info")
v.SetDefault("LogFilePath", "")
v.SetDefault("LogMaxSize", 100)
v.SetDefault("LogMaxBackups", 10)
v.SetDefault("LogCompress", true)
v.SetDefault("StoragePath", "./storage")
v.SetDefault("CacheTTL", 86400)
v.SetDefault("MaxMemoryCacheSize", 256*1024*1024)
v.SetDefault("MaxRetries", 3)
v.SetDefault("InitialBackoff", "1s")
v.SetDefault("UpstreamTimeout", "30s")
}
func applyGlobalDefaults(g *GlobalConfig) {
if g.ListenPort == 0 {
g.ListenPort = 5000
}
if g.CacheTTL.DurationValue() == 0 {
g.CacheTTL = Duration(24 * time.Hour)
}
if g.InitialBackoff.DurationValue() == 0 {
g.InitialBackoff = Duration(time.Second)
}
if g.UpstreamTimeout.DurationValue() == 0 {
g.UpstreamTimeout = Duration(30 * time.Second)
}
}
func applyHubDefaults(h *HubConfig) {
if h.CacheTTL.DurationValue() < 0 {
h.CacheTTL = Duration(0)
}
}
func durationDecodeHook() mapstructure.DecodeHookFunc {
targetType := reflect.TypeOf(Duration(0))
return func(from reflect.Type, to reflect.Type, data interface{}) (interface{}, error) {
if to != targetType {
return data, nil
}
switch v := data.(type) {
case string:
if v == "" {
return Duration(0), nil
}
if parsed, err := time.ParseDuration(v); err == nil {
return Duration(parsed), nil
}
if seconds, err := strconv.ParseFloat(v, 64); err == nil {
return Duration(time.Duration(seconds * float64(time.Second))), nil
}
return nil, fmt.Errorf("无法解析 Duration 字段: %s", v)
case int:
return Duration(time.Duration(v) * time.Second), nil
case int64:
return Duration(time.Duration(v) * time.Second), nil
case float64:
return Duration(time.Duration(v * float64(time.Second))), nil
case time.Duration:
return Duration(v), nil
case Duration:
return v, nil
default:
return nil, fmt.Errorf("不支持的 Duration 类型: %T", v)
}
}
}
func rejectHubLevelPorts(v *viper.Viper) error {
raw := v.Get("Hub")
hubs, ok := raw.([]interface{})
if !ok {
return nil
}
for idx, entry := range hubs {
m, ok := entry.(map[string]interface{})
if !ok {
continue
}
if _, exists := m["Port"]; exists {
name := fmt.Sprintf("#%d", idx)
if rawName, ok := m["Name"].(string); ok && rawName != "" {
name = rawName
}
return newFieldError(hubField(name, "Port"), "字段已弃用,请移除并使用全局 ListenPort")
}
}
return nil
}

View File

@@ -0,0 +1,27 @@
package config
import "testing"
func TestLoadFailsWithMissingFields(t *testing.T) {
if _, err := Load(testConfigPath(t, "missing.toml")); err == nil {
t.Fatalf("缺失字段的配置应返回错误")
}
}
func TestLoadRejectsInvalidDuration(t *testing.T) {
cfg := `
LogLevel = "info"
StoragePath = "./data"
CacheTTL = "boom"
[[Hub]]
Name = "docker"
Domain = "docker.local"
Type = "docker"
Upstream = "https://registry-1.docker.io"
`
path := writeTempConfig(t, cfg)
if _, err := Load(path); err == nil {
t.Fatalf("无效 Duration 应失败")
}
}

View File

@@ -0,0 +1,22 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func testConfigPath(t *testing.T, name string) string {
t.Helper()
return filepath.Join("testdata", name)
}
func writeTempConfig(t *testing.T, content string) string {
t.Helper()
dir := t.TempDir()
path := filepath.Join(dir, "config.toml")
if err := os.WriteFile(path, []byte(content), 0o600); err != nil {
t.Fatalf("写入临时配置失败: %v", err)
}
return path
}

9
internal/config/testdata/missing.toml vendored Normal file
View File

@@ -0,0 +1,9 @@
ListenPort = 0
LogLevel = "info"
StoragePath = "./data"
[[Hub]]
Name = "docker"
Domain = "docker.local"
Upstream = ""
Type = ""

9
internal/config/testdata/valid.toml vendored Normal file
View File

@@ -0,0 +1,9 @@
ListenPort = 5050
LogLevel = "debug"
StoragePath = "./data"
[[Hub]]
Name = "docker"
Domain = "docker.local"
Upstream = "https://registry-1.docker.io"
Type = "docker"

105
internal/config/types.go Normal file
View File

@@ -0,0 +1,105 @@
package config
import (
"fmt"
"strconv"
"strings"
"time"
)
// Duration 提供更灵活的反序列化能力,同时兼容纯秒整数与 Go Duration 字符串。
type Duration time.Duration
// UnmarshalText 使 Viper 可以识别诸如 "30s"、"5m" 或纯数字秒值等配置写法。
func (d *Duration) UnmarshalText(text []byte) error {
raw := strings.TrimSpace(string(text))
if raw == "" {
*d = Duration(0)
return nil
}
if seconds, err := time.ParseDuration(raw); err == nil {
*d = Duration(seconds)
return nil
}
if intVal, err := parseInt(raw); err == nil {
*d = Duration(time.Duration(intVal) * time.Second)
return nil
}
return fmt.Errorf("invalid duration value: %s", raw)
}
// DurationValue 返回真实的 time.Duration便于调用方计算。
func (d Duration) DurationValue() time.Duration {
return time.Duration(d)
}
// parseInt 支持十进制或 0x 前缀的十六进制字符串解析。
func parseInt(value string) (int64, error) {
if strings.HasPrefix(value, "0x") || strings.HasPrefix(value, "0X") {
return strconv.ParseInt(value, 0, 64)
}
return strconv.ParseInt(value, 10, 64)
}
// GlobalConfig 描述全局运行时行为,所有 Hub 共享同一份参数。
type GlobalConfig struct {
ListenPort int `mapstructure:"ListenPort"`
LogLevel string `mapstructure:"LogLevel"`
LogFilePath string `mapstructure:"LogFilePath"`
LogMaxSize int `mapstructure:"LogMaxSize"`
LogMaxBackups int `mapstructure:"LogMaxBackups"`
LogCompress bool `mapstructure:"LogCompress"`
StoragePath string `mapstructure:"StoragePath"`
CacheTTL Duration `mapstructure:"CacheTTL"`
MaxMemoryCache int64 `mapstructure:"MaxMemoryCacheSize"`
MaxRetries int `mapstructure:"MaxRetries"`
InitialBackoff Duration `mapstructure:"InitialBackoff"`
UpstreamTimeout Duration `mapstructure:"UpstreamTimeout"`
}
// HubConfig 决定单个代理实例如何与下游/上游交互。
type HubConfig struct {
Name string `mapstructure:"Name"`
Domain string `mapstructure:"Domain"`
Upstream string `mapstructure:"Upstream"`
Proxy string `mapstructure:"Proxy"`
Type string `mapstructure:"Type"`
Username string `mapstructure:"Username"`
Password string `mapstructure:"Password"`
CacheTTL Duration `mapstructure:"CacheTTL"`
EnableHeadCheck bool `mapstructure:"EnableHeadCheck"`
}
// Config 是 TOML 文件映射的整体结构。
type Config struct {
Global GlobalConfig `mapstructure:",squash"`
Hubs []HubConfig `mapstructure:"Hub"`
}
// HasCredentials 表示当前 Hub 是否配置了完整的上游凭证。
func (h HubConfig) HasCredentials() bool {
return h.Username != "" && h.Password != ""
}
// AuthMode 输出 `credentialed` 或 `anonymous`,供日志字段使用。
func (h HubConfig) AuthMode() string {
if h.HasCredentials() {
return "credentialed"
}
return "anonymous"
}
// CredentialModes 返回所有 Hub 的鉴权模式摘要,例如 secure:credentialed。
func CredentialModes(hubs []HubConfig) []string {
if len(hubs) == 0 {
return nil
}
result := make([]string, len(hubs))
for i, hub := range hubs {
result[i] = fmt.Sprintf("%s:%s", hub.Name, hub.AuthMode())
}
return result
}

View File

@@ -0,0 +1,131 @@
package config
import (
"errors"
"fmt"
"net/url"
"strings"
"time"
)
var supportedHubTypes = map[string]struct{}{
"docker": {},
"npm": {},
"go": {},
}
const supportedHubTypeList = "docker|npm|go"
// Validate 针对语义级别做进一步校验,防止非法配置启动服务。
func (c *Config) Validate() error {
if c == nil {
return errors.New("配置为空")
}
g := c.Global
if g.ListenPort <= 0 || g.ListenPort > 65535 {
return newFieldError("Global.ListenPort", "必须在 1-65535")
}
if g.StoragePath == "" {
return newFieldError("Global.StoragePath", "不能为空")
}
if g.CacheTTL.DurationValue() <= 0 {
return newFieldError("Global.CacheTTL", "必须大于 0")
}
if g.MaxMemoryCache <= 0 {
return newFieldError("Global.MaxMemoryCacheSize", "必须大于 0")
}
if g.MaxRetries < 0 {
return newFieldError("Global.MaxRetries", "不能为负数")
}
if g.InitialBackoff.DurationValue() <= 0 {
return newFieldError("Global.InitialBackoff", "必须大于 0")
}
if g.UpstreamTimeout.DurationValue() <= 0 {
return newFieldError("Global.UpstreamTimeout", "必须大于 0")
}
if len(c.Hubs) == 0 {
return errors.New("至少需要配置一个 Hub")
}
seenNames := map[string]struct{}{}
for i := range c.Hubs {
hub := &c.Hubs[i]
if hub.Name == "" {
return newFieldError("Hub[].Name", "不能为空")
}
if _, exists := seenNames[hub.Name]; exists {
return newFieldError(hubField(hub.Name, "Name"), "重复")
}
seenNames[hub.Name] = struct{}{}
if err := validateDomain(hub.Domain); err != nil {
return fmt.Errorf("%s: %w", hubField(hub.Name, "Domain"), err)
}
normalizedType := strings.ToLower(strings.TrimSpace(hub.Type))
if normalizedType == "" {
return newFieldError(hubField(hub.Name, "Type"), "不能为空")
}
if _, ok := supportedHubTypes[normalizedType]; !ok {
return newFieldError(hubField(hub.Name, "Type"), "仅支持 "+supportedHubTypeList)
}
hub.Type = normalizedType
if (hub.Username == "") != (hub.Password == "") {
return newFieldError(hubField(hub.Name, "Username/Password"), "必须同时提供或同时留空")
}
if err := validateUpstream(hub.Upstream); err != nil {
return fmt.Errorf("%s: %w", hubField(hub.Name, "Upstream"), err)
}
if hub.Proxy != "" {
if err := validateUpstream(hub.Proxy); err != nil {
return fmt.Errorf("%s: %w", hubField(hub.Name, "Proxy"), err)
}
}
}
return nil
}
func validateDomain(domain string) error {
if domain == "" {
return errors.New("Domain 不能为空")
}
if strings.Contains(domain, "/") {
return errors.New("Domain 不允许包含路径")
}
if strings.Contains(domain, " ") {
return errors.New("Domain 不允许包含空格")
}
if strings.HasPrefix(domain, "http") {
return errors.New("Domain 不应包含协议头")
}
return nil
}
func validateUpstream(raw string) error {
if raw == "" {
return errors.New("缺少上游地址")
}
parsed, err := url.Parse(raw)
if err != nil {
return err
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return fmt.Errorf("仅支持 http/https上游: %s", raw)
}
if parsed.Host == "" {
return fmt.Errorf("上游缺少 Host: %s", raw)
}
return nil
}
// EffectiveCacheTTL 返回特定 Hub 生效的 TTL未覆盖时回退至全局值。
func (c *Config) EffectiveCacheTTL(h HubConfig) time.Duration {
if h.CacheTTL.DurationValue() > 0 {
return h.CacheTTL.DurationValue()
}
return c.Global.CacheTTL.DurationValue()
}