This commit is contained in:
2025-11-14 12:11:44 +08:00
commit 39ebf61572
88 changed files with 9999 additions and 0 deletions

7
internal/cache/doc.go vendored Normal file
View File

@@ -0,0 +1,7 @@
// Package cache defines the disk-backed store responsible for translating hub
// requests into StoragePath/<hub>/<path> files. The store exposes read/write
// primitives with safe semantics (temp file + rename) and surfaces file info
// (size, modtime) for higher layers to implement conditional revalidation.
// Proxy handlers depend on this package to stream cached responses or trigger
// upstream fetches without duplicating filesystem logic.
package cache

240
internal/cache/fs_store.go vendored Normal file
View File

@@ -0,0 +1,240 @@
package cache
import (
"context"
"errors"
"fmt"
"io"
"io/fs"
"os"
"path"
"path/filepath"
"strings"
"sync"
"time"
)
// NewStore 以 basePath 为根目录构建磁盘缓存,整站复用一份实例。
func NewStore(basePath string) (Store, error) {
if basePath == "" {
return nil, errors.New("storage path required")
}
abs, err := filepath.Abs(basePath)
if err != nil {
return nil, fmt.Errorf("resolve storage path: %w", err)
}
if err := os.MkdirAll(abs, 0o755); err != nil {
return nil, fmt.Errorf("create storage path: %w", err)
}
return &fileStore{
basePath: abs,
locks: make(map[string]*entryLock),
}, nil
}
// fileStore 通过 entryLock 避免同一 Locator 并发写入,同时复用 basePath。
type fileStore struct {
basePath string
mu sync.Mutex
locks map[string]*entryLock
}
type entryLock struct {
mu sync.Mutex
refs int
}
func (s *fileStore) Get(ctx context.Context, locator Locator) (*ReadResult, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
filePath, err := s.path(locator)
if err != nil {
return nil, err
}
info, err := os.Stat(filePath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil, ErrNotFound
}
return nil, err
}
if info.IsDir() {
return nil, ErrNotFound
}
f, err := os.Open(filePath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil, ErrNotFound
}
return nil, err
}
entry := Entry{
Locator: locator,
FilePath: filePath,
SizeBytes: info.Size(),
ModTime: info.ModTime(),
}
return &ReadResult{
Entry: entry,
Reader: f,
}, nil
}
func (s *fileStore) Put(ctx context.Context, locator Locator, body io.Reader, opts PutOptions) (*Entry, error) {
unlock, err := s.lockEntry(locator)
if err != nil {
return nil, err
}
defer unlock()
filePath, err := s.path(locator)
if err != nil {
return nil, err
}
if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
return nil, err
}
tempFile, err := os.CreateTemp(filepath.Dir(filePath), ".cache-*")
if err != nil {
return nil, err
}
tempName := tempFile.Name()
written, err := copyWithContext(ctx, tempFile, body)
closeErr := tempFile.Close()
if err == nil {
err = closeErr
}
if err != nil {
os.Remove(tempName)
return nil, err
}
if err := os.Rename(tempName, filePath); err != nil {
os.Remove(tempName)
return nil, err
}
modTime := opts.ModTime
if modTime.IsZero() {
modTime = time.Now().UTC()
}
if err := os.Chtimes(filePath, modTime, modTime); err != nil {
return nil, err
}
entry := Entry{
Locator: locator,
FilePath: filePath,
SizeBytes: written,
ModTime: modTime,
}
return &entry, nil
}
func (s *fileStore) Remove(ctx context.Context, locator Locator) error {
unlock, err := s.lockEntry(locator)
if err != nil {
return err
}
defer unlock()
filePath, err := s.path(locator)
if err != nil {
return err
}
if err := os.Remove(filePath); err != nil && !errors.Is(err, fs.ErrNotExist) {
return err
}
return nil
}
func (s *fileStore) lockEntry(locator Locator) (func(), error) {
key := locatorKey(locator)
s.mu.Lock()
lock := s.locks[key]
if lock == nil {
lock = &entryLock{}
s.locks[key] = lock
}
lock.refs++
s.mu.Unlock()
lock.mu.Lock()
return func() {
lock.mu.Unlock()
s.mu.Lock()
lock.refs--
if lock.refs == 0 {
delete(s.locks, key)
}
s.mu.Unlock()
}, nil
}
func (s *fileStore) path(locator Locator) (string, error) {
if locator.HubName == "" {
return "", errors.New("hub name required")
}
rel := locator.Path
if rel == "" || rel == "/" {
rel = "root"
}
rel = path.Clean("/" + rel)
rel = strings.TrimPrefix(rel, "/")
if rel == "" {
rel = "root"
}
filePath := filepath.Join(s.basePath, locator.HubName, filepath.FromSlash(rel))
if !strings.HasPrefix(filePath, filepath.Join(s.basePath, locator.HubName)) {
return "", errors.New("invalid cache path")
}
return filePath, nil
}
func copyWithContext(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
var copied int64
buf := make([]byte, 32*1024)
for {
if err := ctx.Err(); err != nil {
return copied, err
}
n, err := src.Read(buf)
if n > 0 {
w, wErr := dst.Write(buf[:n])
copied += int64(w)
if wErr != nil {
return copied, wErr
}
if w < n {
return copied, io.ErrShortWrite
}
}
if err != nil {
if errors.Is(err, io.EOF) {
return copied, nil
}
return copied, err
}
}
}
func locatorKey(locator Locator) string {
return locator.HubName + "::" + locator.Path
}

53
internal/cache/store.go vendored Normal file
View File

@@ -0,0 +1,53 @@
package cache
import (
"context"
"errors"
"io"
"time"
)
// Store 负责管理磁盘缓存的读写。磁盘布局遵循:
//
// <StoragePath>/<HubName>/<path> # 实际正文
//
// 每个条目仅由正文文件组成,文件的 ModTime/Size 由文件系统提供。
type Store interface {
// Get 返回一个可流式读取的缓存条目。若不存在则返回 ErrNotFound。
Get(ctx context.Context, locator Locator) (*ReadResult, error)
// Put 将上游响应写入缓存,并产出新的 Entry 描述。实现需通过临时文件 + rename
// 保证写入原子性,并在失败时清理临时文件。可选地根据 opts.ModTime 设置文件时间戳。
Put(ctx context.Context, locator Locator, body io.Reader, opts PutOptions) (*Entry, error)
// Remove 删除正文文件,通常用于上游错误或复合策略清理。
Remove(ctx context.Context, locator Locator) error
}
// PutOptions 控制写入过程中的可选属性。
type PutOptions struct {
ModTime time.Time
}
// Locator 唯一定位一个缓存条目Hub + 相对路径),所有路径均为 URL 路径风格。
type Locator struct {
HubName string
Path string
}
// Entry 表示一次缓存命中结果,包含绝对文件路径及文件信息。
type Entry struct {
Locator Locator `json:"locator"`
FilePath string `json:"file_path"`
SizeBytes int64 `json:"size_bytes"`
ModTime time.Time
}
// ReadResult 组合 Entry 与正文 Reader便于代理层直接将 Body 流式返回。
type ReadResult struct {
Entry Entry
Reader io.ReadSeekCloser
}
// ErrNotFound 表示缓存不存在。
var ErrNotFound = errors.New("cache entry not found")

95
internal/cache/store_test.go vendored Normal file
View File

@@ -0,0 +1,95 @@
package cache
import (
"bytes"
"context"
"io"
"os"
"testing"
"time"
)
func TestStorePutAndGet(t *testing.T) {
store := newTestStore(t)
locator := Locator{HubName: "docker", Path: "/v2/library/sample/manifests/latest"}
modTime := time.Now().Add(-time.Hour).UTC()
payload := []byte("payload")
if _, err := store.Put(context.Background(), locator, bytes.NewReader(payload), PutOptions{ModTime: modTime}); err != nil {
t.Fatalf("put error: %v", err)
}
result, err := store.Get(context.Background(), locator)
if err != nil {
t.Fatalf("get error: %v", err)
}
defer result.Reader.Close()
body, err := io.ReadAll(result.Reader)
if err != nil {
t.Fatalf("read cached body error: %v", err)
}
if string(body) != string(payload) {
t.Fatalf("cached payload mismatch: %s", string(body))
}
if result.Entry.SizeBytes != int64(len(payload)) {
t.Fatalf("size mismatch: %d", result.Entry.SizeBytes)
}
if !result.Entry.ModTime.Equal(modTime) {
t.Fatalf("modtime mismatch: expected %v got %v", modTime, result.Entry.ModTime)
}
}
func TestStoreGetMissing(t *testing.T) {
store := newTestStore(t)
_, err := store.Get(context.Background(), Locator{HubName: "docker", Path: "/missing"})
if err == nil || err != ErrNotFound {
t.Fatalf("expected ErrNotFound, got %v", err)
}
}
func TestStoreRemove(t *testing.T) {
store := newTestStore(t)
locator := Locator{HubName: "docker", Path: "/cache/remove"}
if _, err := store.Put(context.Background(), locator, bytes.NewReader([]byte("data")), PutOptions{}); err != nil {
t.Fatalf("put error: %v", err)
}
if err := store.Remove(context.Background(), locator); err != nil {
t.Fatalf("remove error: %v", err)
}
if _, err := store.Get(context.Background(), locator); err == nil || err != ErrNotFound {
t.Fatalf("expected not found after remove, got %v", err)
}
}
func TestStoreIgnoresDirectories(t *testing.T) {
store := newTestStore(t)
locator := Locator{HubName: "ghcr", Path: "/v2"}
fs, ok := store.(*fileStore)
if !ok {
t.Fatalf("unexpected store type %T", store)
}
filePath, err := fs.path(locator)
if err != nil {
t.Fatalf("path error: %v", err)
}
if err := os.MkdirAll(filePath, 0o755); err != nil {
t.Fatalf("mkdir error: %v", err)
}
if _, err := store.Get(context.Background(), locator); err == nil || err != ErrNotFound {
t.Fatalf("expected ErrNotFound for directory, got %v", err)
}
}
// newTestStore returns a Store backed by a temporary directory.
func newTestStore(t *testing.T) Store {
t.Helper()
store, err := NewStore(t.TempDir())
if err != nil {
t.Fatalf("failed to create store: %v", err)
}
return store
}

View File

@@ -0,0 +1,109 @@
package config
import (
"testing"
"time"
)
func TestLoadWithDefaults(t *testing.T) {
cfgPath := testConfigPath(t, "valid.toml")
cfg, err := Load(cfgPath)
if err != nil {
t.Fatalf("Load 返回错误: %v", err)
}
if cfg.Global.CacheTTL.DurationValue() == 0 {
t.Fatalf("CacheTTL 应该自动填充默认值")
}
if cfg.Global.StoragePath == "" {
t.Fatalf("StoragePath 应该被保留")
}
if cfg.Global.ListenPort == 0 {
t.Fatalf("ListenPort 应当被解析")
}
if cfg.EffectiveCacheTTL(cfg.Hubs[0]) != cfg.Global.CacheTTL.DurationValue() {
t.Fatalf("Hub 未设置 TTL 时应退回全局 TTL")
}
}
func TestValidateRejectsBadHub(t *testing.T) {
cfgPath := testConfigPath(t, "missing.toml")
if _, err := Load(cfgPath); err == nil {
t.Fatalf("不合法的配置应返回错误")
}
}
func TestEffectiveCacheTTLOverrides(t *testing.T) {
cfg := &Config{Global: GlobalConfig{CacheTTL: Duration(time.Hour)}}
hub := HubConfig{CacheTTL: Duration(2 * time.Hour)}
if ttl := cfg.EffectiveCacheTTL(hub); ttl != 2*time.Hour {
t.Fatalf("覆盖 TTL 应该优先生效")
}
}
func TestValidateEnforcesListenPortRange(t *testing.T) {
cfg := validConfig()
cfg.Global.ListenPort = 70000
if err := cfg.Validate(); err == nil {
t.Fatalf("ListenPort 超出范围应当报错")
}
}
func TestHubTypeValidation(t *testing.T) {
testCases := []struct {
name string
hubType string
shouldErr bool
}{
{"docker ok", "docker", false},
{"npm ok", "npm", false},
{"go ok", "go", false},
{"missing type", "", true},
{"unsupported type", "rubygems", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
cfg := validConfig()
cfg.Hubs[0].Type = tc.hubType
err := cfg.Validate()
if tc.shouldErr && err == nil {
t.Fatalf("expected error for type %q", tc.hubType)
}
if !tc.shouldErr && err != nil {
t.Fatalf("unexpected error for type %q: %v", tc.hubType, err)
}
})
}
}
func TestValidateRequiresCredentialPairs(t *testing.T) {
cfg := validConfig()
cfg.Hubs[0].Username = "foo"
if err := cfg.Validate(); err == nil {
t.Fatalf("仅提供 Username 时应报错")
}
}
func validConfig() *Config {
return &Config{
Global: GlobalConfig{
ListenPort: 5000,
StoragePath: "./data",
CacheTTL: Duration(time.Hour),
MaxMemoryCache: 1,
MaxRetries: 1,
InitialBackoff: Duration(time.Second),
UpstreamTimeout: Duration(time.Second),
},
Hubs: []HubConfig{
{
Name: "npm",
Domain: "npm.local",
Type: "npm",
Upstream: "https://registry.npmjs.org",
},
},
}
}

26
internal/config/errors.go Normal file
View File

@@ -0,0 +1,26 @@
package config
import "fmt"
// FieldError 提供字段路径与错误原因,便于 CLI 向用户反馈。
type FieldError struct {
Field string
Reason string
}
func (e FieldError) Error() string {
return fmt.Sprintf("%s: %s", e.Field, e.Reason)
}
// newFieldError 创建包含字段路径与原因的 error便于 CLI 定位。
func newFieldError(field, reason string) error {
return FieldError{Field: field, Reason: reason}
}
// hubField 用于拼接 Hub 级字段路径,方便输出 Hub[xxx].Field 形式。
func hubField(name, field string) string {
if name == "" {
return fmt.Sprintf("Hub[].%s", field)
}
return fmt.Sprintf("Hub[%s].%s", name, field)
}

149
internal/config/loader.go Normal file
View File

@@ -0,0 +1,149 @@
package config
import (
"fmt"
"path/filepath"
"reflect"
"strconv"
"time"
"github.com/mitchellh/mapstructure"
"github.com/spf13/viper"
)
// Load 读取并解析 TOML 配置文件,同时注入默认值与校验逻辑。
func Load(path string) (*Config, error) {
if path == "" {
path = "config.toml"
}
v := viper.New()
v.SetConfigFile(path)
setDefaults(v)
if err := v.ReadInConfig(); err != nil {
return nil, fmt.Errorf("读取配置失败: %w", err)
}
if err := rejectHubLevelPorts(v); err != nil {
return nil, err
}
var cfg Config
if err := v.Unmarshal(&cfg, viper.DecodeHook(durationDecodeHook())); err != nil {
return nil, fmt.Errorf("解析配置失败: %w", err)
}
applyGlobalDefaults(&cfg.Global)
for i := range cfg.Hubs {
applyHubDefaults(&cfg.Hubs[i])
}
if err := cfg.Validate(); err != nil {
return nil, err
}
absStorage, err := filepath.Abs(cfg.Global.StoragePath)
if err != nil {
return nil, fmt.Errorf("无法解析缓存目录: %w", err)
}
cfg.Global.StoragePath = absStorage
return &cfg, nil
}
func setDefaults(v *viper.Viper) {
v.SetDefault("ListenPort", 5000)
v.SetDefault("LogLevel", "info")
v.SetDefault("LogFilePath", "")
v.SetDefault("LogMaxSize", 100)
v.SetDefault("LogMaxBackups", 10)
v.SetDefault("LogCompress", true)
v.SetDefault("StoragePath", "./storage")
v.SetDefault("CacheTTL", 86400)
v.SetDefault("MaxMemoryCacheSize", 256*1024*1024)
v.SetDefault("MaxRetries", 3)
v.SetDefault("InitialBackoff", "1s")
v.SetDefault("UpstreamTimeout", "30s")
}
func applyGlobalDefaults(g *GlobalConfig) {
if g.ListenPort == 0 {
g.ListenPort = 5000
}
if g.CacheTTL.DurationValue() == 0 {
g.CacheTTL = Duration(24 * time.Hour)
}
if g.InitialBackoff.DurationValue() == 0 {
g.InitialBackoff = Duration(time.Second)
}
if g.UpstreamTimeout.DurationValue() == 0 {
g.UpstreamTimeout = Duration(30 * time.Second)
}
}
func applyHubDefaults(h *HubConfig) {
if h.CacheTTL.DurationValue() < 0 {
h.CacheTTL = Duration(0)
}
}
func durationDecodeHook() mapstructure.DecodeHookFunc {
targetType := reflect.TypeOf(Duration(0))
return func(from reflect.Type, to reflect.Type, data interface{}) (interface{}, error) {
if to != targetType {
return data, nil
}
switch v := data.(type) {
case string:
if v == "" {
return Duration(0), nil
}
if parsed, err := time.ParseDuration(v); err == nil {
return Duration(parsed), nil
}
if seconds, err := strconv.ParseFloat(v, 64); err == nil {
return Duration(time.Duration(seconds * float64(time.Second))), nil
}
return nil, fmt.Errorf("无法解析 Duration 字段: %s", v)
case int:
return Duration(time.Duration(v) * time.Second), nil
case int64:
return Duration(time.Duration(v) * time.Second), nil
case float64:
return Duration(time.Duration(v * float64(time.Second))), nil
case time.Duration:
return Duration(v), nil
case Duration:
return v, nil
default:
return nil, fmt.Errorf("不支持的 Duration 类型: %T", v)
}
}
}
func rejectHubLevelPorts(v *viper.Viper) error {
raw := v.Get("Hub")
hubs, ok := raw.([]interface{})
if !ok {
return nil
}
for idx, entry := range hubs {
m, ok := entry.(map[string]interface{})
if !ok {
continue
}
if _, exists := m["Port"]; exists {
name := fmt.Sprintf("#%d", idx)
if rawName, ok := m["Name"].(string); ok && rawName != "" {
name = rawName
}
return newFieldError(hubField(name, "Port"), "字段已弃用,请移除并使用全局 ListenPort")
}
}
return nil
}

View File

@@ -0,0 +1,27 @@
package config
import "testing"
func TestLoadFailsWithMissingFields(t *testing.T) {
if _, err := Load(testConfigPath(t, "missing.toml")); err == nil {
t.Fatalf("缺失字段的配置应返回错误")
}
}
func TestLoadRejectsInvalidDuration(t *testing.T) {
cfg := `
LogLevel = "info"
StoragePath = "./data"
CacheTTL = "boom"
[[Hub]]
Name = "docker"
Domain = "docker.local"
Type = "docker"
Upstream = "https://registry-1.docker.io"
`
path := writeTempConfig(t, cfg)
if _, err := Load(path); err == nil {
t.Fatalf("无效 Duration 应失败")
}
}

View File

@@ -0,0 +1,22 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func testConfigPath(t *testing.T, name string) string {
t.Helper()
return filepath.Join("testdata", name)
}
func writeTempConfig(t *testing.T, content string) string {
t.Helper()
dir := t.TempDir()
path := filepath.Join(dir, "config.toml")
if err := os.WriteFile(path, []byte(content), 0o600); err != nil {
t.Fatalf("写入临时配置失败: %v", err)
}
return path
}

9
internal/config/testdata/missing.toml vendored Normal file
View File

@@ -0,0 +1,9 @@
ListenPort = 0
LogLevel = "info"
StoragePath = "./data"
[[Hub]]
Name = "docker"
Domain = "docker.local"
Upstream = ""
Type = ""

9
internal/config/testdata/valid.toml vendored Normal file
View File

@@ -0,0 +1,9 @@
ListenPort = 5050
LogLevel = "debug"
StoragePath = "./data"
[[Hub]]
Name = "docker"
Domain = "docker.local"
Upstream = "https://registry-1.docker.io"
Type = "docker"

105
internal/config/types.go Normal file
View File

@@ -0,0 +1,105 @@
package config
import (
"fmt"
"strconv"
"strings"
"time"
)
// Duration 提供更灵活的反序列化能力,同时兼容纯秒整数与 Go Duration 字符串。
type Duration time.Duration
// UnmarshalText 使 Viper 可以识别诸如 "30s"、"5m" 或纯数字秒值等配置写法。
func (d *Duration) UnmarshalText(text []byte) error {
raw := strings.TrimSpace(string(text))
if raw == "" {
*d = Duration(0)
return nil
}
if seconds, err := time.ParseDuration(raw); err == nil {
*d = Duration(seconds)
return nil
}
if intVal, err := parseInt(raw); err == nil {
*d = Duration(time.Duration(intVal) * time.Second)
return nil
}
return fmt.Errorf("invalid duration value: %s", raw)
}
// DurationValue 返回真实的 time.Duration便于调用方计算。
func (d Duration) DurationValue() time.Duration {
return time.Duration(d)
}
// parseInt 支持十进制或 0x 前缀的十六进制字符串解析。
func parseInt(value string) (int64, error) {
if strings.HasPrefix(value, "0x") || strings.HasPrefix(value, "0X") {
return strconv.ParseInt(value, 0, 64)
}
return strconv.ParseInt(value, 10, 64)
}
// GlobalConfig 描述全局运行时行为,所有 Hub 共享同一份参数。
type GlobalConfig struct {
ListenPort int `mapstructure:"ListenPort"`
LogLevel string `mapstructure:"LogLevel"`
LogFilePath string `mapstructure:"LogFilePath"`
LogMaxSize int `mapstructure:"LogMaxSize"`
LogMaxBackups int `mapstructure:"LogMaxBackups"`
LogCompress bool `mapstructure:"LogCompress"`
StoragePath string `mapstructure:"StoragePath"`
CacheTTL Duration `mapstructure:"CacheTTL"`
MaxMemoryCache int64 `mapstructure:"MaxMemoryCacheSize"`
MaxRetries int `mapstructure:"MaxRetries"`
InitialBackoff Duration `mapstructure:"InitialBackoff"`
UpstreamTimeout Duration `mapstructure:"UpstreamTimeout"`
}
// HubConfig 决定单个代理实例如何与下游/上游交互。
type HubConfig struct {
Name string `mapstructure:"Name"`
Domain string `mapstructure:"Domain"`
Upstream string `mapstructure:"Upstream"`
Proxy string `mapstructure:"Proxy"`
Type string `mapstructure:"Type"`
Username string `mapstructure:"Username"`
Password string `mapstructure:"Password"`
CacheTTL Duration `mapstructure:"CacheTTL"`
EnableHeadCheck bool `mapstructure:"EnableHeadCheck"`
}
// Config 是 TOML 文件映射的整体结构。
type Config struct {
Global GlobalConfig `mapstructure:",squash"`
Hubs []HubConfig `mapstructure:"Hub"`
}
// HasCredentials 表示当前 Hub 是否配置了完整的上游凭证。
func (h HubConfig) HasCredentials() bool {
return h.Username != "" && h.Password != ""
}
// AuthMode 输出 `credentialed` 或 `anonymous`,供日志字段使用。
func (h HubConfig) AuthMode() string {
if h.HasCredentials() {
return "credentialed"
}
return "anonymous"
}
// CredentialModes 返回所有 Hub 的鉴权模式摘要,例如 secure:credentialed。
func CredentialModes(hubs []HubConfig) []string {
if len(hubs) == 0 {
return nil
}
result := make([]string, len(hubs))
for i, hub := range hubs {
result[i] = fmt.Sprintf("%s:%s", hub.Name, hub.AuthMode())
}
return result
}

View File

@@ -0,0 +1,131 @@
package config
import (
"errors"
"fmt"
"net/url"
"strings"
"time"
)
var supportedHubTypes = map[string]struct{}{
"docker": {},
"npm": {},
"go": {},
}
const supportedHubTypeList = "docker|npm|go"
// Validate 针对语义级别做进一步校验,防止非法配置启动服务。
func (c *Config) Validate() error {
if c == nil {
return errors.New("配置为空")
}
g := c.Global
if g.ListenPort <= 0 || g.ListenPort > 65535 {
return newFieldError("Global.ListenPort", "必须在 1-65535")
}
if g.StoragePath == "" {
return newFieldError("Global.StoragePath", "不能为空")
}
if g.CacheTTL.DurationValue() <= 0 {
return newFieldError("Global.CacheTTL", "必须大于 0")
}
if g.MaxMemoryCache <= 0 {
return newFieldError("Global.MaxMemoryCacheSize", "必须大于 0")
}
if g.MaxRetries < 0 {
return newFieldError("Global.MaxRetries", "不能为负数")
}
if g.InitialBackoff.DurationValue() <= 0 {
return newFieldError("Global.InitialBackoff", "必须大于 0")
}
if g.UpstreamTimeout.DurationValue() <= 0 {
return newFieldError("Global.UpstreamTimeout", "必须大于 0")
}
if len(c.Hubs) == 0 {
return errors.New("至少需要配置一个 Hub")
}
seenNames := map[string]struct{}{}
for i := range c.Hubs {
hub := &c.Hubs[i]
if hub.Name == "" {
return newFieldError("Hub[].Name", "不能为空")
}
if _, exists := seenNames[hub.Name]; exists {
return newFieldError(hubField(hub.Name, "Name"), "重复")
}
seenNames[hub.Name] = struct{}{}
if err := validateDomain(hub.Domain); err != nil {
return fmt.Errorf("%s: %w", hubField(hub.Name, "Domain"), err)
}
normalizedType := strings.ToLower(strings.TrimSpace(hub.Type))
if normalizedType == "" {
return newFieldError(hubField(hub.Name, "Type"), "不能为空")
}
if _, ok := supportedHubTypes[normalizedType]; !ok {
return newFieldError(hubField(hub.Name, "Type"), "仅支持 "+supportedHubTypeList)
}
hub.Type = normalizedType
if (hub.Username == "") != (hub.Password == "") {
return newFieldError(hubField(hub.Name, "Username/Password"), "必须同时提供或同时留空")
}
if err := validateUpstream(hub.Upstream); err != nil {
return fmt.Errorf("%s: %w", hubField(hub.Name, "Upstream"), err)
}
if hub.Proxy != "" {
if err := validateUpstream(hub.Proxy); err != nil {
return fmt.Errorf("%s: %w", hubField(hub.Name, "Proxy"), err)
}
}
}
return nil
}
func validateDomain(domain string) error {
if domain == "" {
return errors.New("Domain 不能为空")
}
if strings.Contains(domain, "/") {
return errors.New("Domain 不允许包含路径")
}
if strings.Contains(domain, " ") {
return errors.New("Domain 不允许包含空格")
}
if strings.HasPrefix(domain, "http") {
return errors.New("Domain 不应包含协议头")
}
return nil
}
func validateUpstream(raw string) error {
if raw == "" {
return errors.New("缺少上游地址")
}
parsed, err := url.Parse(raw)
if err != nil {
return err
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return fmt.Errorf("仅支持 http/https上游: %s", raw)
}
if parsed.Host == "" {
return fmt.Errorf("上游缺少 Host: %s", raw)
}
return nil
}
// EffectiveCacheTTL 返回特定 Hub 生效的 TTL未覆盖时回退至全局值。
func (c *Config) EffectiveCacheTTL(h HubConfig) time.Duration {
if h.CacheTTL.DurationValue() > 0 {
return h.CacheTTL.DurationValue()
}
return c.Global.CacheTTL.DurationValue()
}

View File

@@ -0,0 +1,22 @@
package logging
import "github.com/sirupsen/logrus"
// BaseFields 构建 action + 配置路径等基础字段,便于不同入口复用。
func BaseFields(action, configPath string) logrus.Fields {
return logrus.Fields{
"action": action,
"configPath": configPath,
}
}
// RequestFields 提供 hub/domain/命中状态字段,供代理请求日志复用。
func RequestFields(hub, domain, hubType, authMode string, cacheHit bool) logrus.Fields {
return logrus.Fields{
"hub": hub,
"domain": domain,
"hub_type": hubType,
"auth_mode": authMode,
"cache_hit": cacheHit,
}
}

View File

@@ -0,0 +1,66 @@
package logging
import (
"fmt"
"io"
"os"
"path/filepath"
"time"
"github.com/sirupsen/logrus"
"gopkg.in/natefinch/lumberjack.v2"
"github.com/any-hub/any-hub/internal/config"
)
// InitLogger 根据全局配置初始化 JSON 结构化日志,确保文件/控制台输出一致。
func InitLogger(cfg config.GlobalConfig) (*logrus.Logger, error) {
level, err := logrus.ParseLevel(cfg.LogLevel)
if err != nil {
return nil, fmt.Errorf("无法解析日志级别: %w", err)
}
output, outErr := buildOutput(cfg)
if outErr != nil {
fmt.Fprintf(os.Stderr, "logger_fallback: %v\n", outErr)
}
logger := logrus.New()
logger.SetLevel(level)
logger.SetOutput(output)
logger.SetFormatter(&logrus.JSONFormatter{TimestampFormat: time.RFC3339Nano})
logrus.SetFormatter(logger.Formatter)
logrus.SetOutput(logger.Out)
logrus.SetLevel(logger.GetLevel())
if outErr != nil {
logger.WithFields(logrus.Fields{
"action": "logger_fallback",
"path": cfg.LogFilePath,
}).Warn(outErr.Error())
}
return logger, nil
}
// buildOutput 根据配置创建日志输出 Writer失败时降级到 stdout 并返回错误。
func buildOutput(cfg config.GlobalConfig) (io.Writer, error) {
if cfg.LogFilePath == "" {
return os.Stdout, nil
}
dir := filepath.Dir(cfg.LogFilePath)
if err := os.MkdirAll(dir, 0o755); err != nil {
return os.Stdout, fmt.Errorf("创建日志目录失败: %w", err)
}
rotator := &lumberjack.Logger{
Filename: cfg.LogFilePath,
MaxSize: cfg.LogMaxSize,
MaxBackups: cfg.LogMaxBackups,
Compress: cfg.LogCompress,
LocalTime: true,
}
return rotator, nil
}

View File

@@ -0,0 +1,57 @@
package logging
import (
"os"
"path/filepath"
"testing"
"github.com/any-hub/any-hub/internal/config"
)
func TestConfigureDefaultsToStdout(t *testing.T) {
logger, err := InitLogger(config.GlobalConfig{LogLevel: "info"})
if err != nil {
t.Fatalf("配置失败: %v", err)
}
if logger.Out != os.Stdout {
t.Fatalf("未指定文件时应输出到 stdout")
}
}
func TestInitLoggerFallbackOnPermissionDenied(t *testing.T) {
dir := t.TempDir()
blocked := filepath.Join(dir, "blocked")
if err := os.Mkdir(blocked, 0o755); err != nil {
t.Fatalf("创建目录失败: %v", err)
}
if err := os.Chmod(blocked, 0o000); err != nil {
t.Fatalf("设置目录权限失败: %v", err)
}
t.Cleanup(func() { _ = os.Chmod(blocked, 0o755) })
cfg := config.GlobalConfig{
LogLevel: "info",
LogFilePath: filepath.Join(blocked, "sub", "any-hub.log"),
}
logger, err := InitLogger(cfg)
if err != nil {
t.Fatalf("初始化不应失败: %v", err)
}
if logger.Out != os.Stdout {
t.Fatalf("fallback 时应退回 stdout")
}
}
func TestConfigureCreatesRotatingFile(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "any-hub.log")
cfg := config.GlobalConfig{LogLevel: "debug", LogFilePath: path}
logger, err := InitLogger(cfg)
if err != nil {
t.Fatalf("配置失败: %v", err)
}
logger.Info("test")
if _, err := os.Stat(path); err != nil {
t.Fatalf("预期创建日志文件: %v", err)
}
}

View File

@@ -0,0 +1,68 @@
package proxy
import (
"net/url"
"testing"
"github.com/any-hub/any-hub/internal/config"
"github.com/any-hub/any-hub/internal/server"
)
func TestApplyDockerHubNamespaceFallback(t *testing.T) {
route := dockerHubRoute(t, "https://registry-1.docker.io")
path, changed := applyDockerHubNamespaceFallback(route, "/v2/nginx/manifests/latest")
if !changed {
t.Fatalf("expected fallback to apply")
}
if path != "/v2/library/nginx/manifests/latest" {
t.Fatalf("unexpected normalized path: %s", path)
}
path, changed = applyDockerHubNamespaceFallback(route, "/v2/library/nginx/manifests/latest")
if changed {
t.Fatalf("expected no changes for already-namespaced repo")
}
path, changed = applyDockerHubNamespaceFallback(route, "/v2/rogee/nginx/manifests/latest")
if changed {
t.Fatalf("expected no changes for custom namespace")
}
path, changed = applyDockerHubNamespaceFallback(route, "/v2/_catalog")
if changed {
t.Fatalf("expected no changes for _catalog endpoint")
}
otherRoute := dockerHubRoute(t, "https://registry.example.com")
path, changed = applyDockerHubNamespaceFallback(otherRoute, "/v2/nginx/manifests/latest")
if changed || path != "/v2/nginx/manifests/latest" {
t.Fatalf("expected no changes for non-docker-hub upstream")
}
}
func TestSplitDockerRepoPath(t *testing.T) {
repo, rest, ok := splitDockerRepoPath("/v2/library/nginx/manifests/latest")
if !ok || repo != "library/nginx" || rest != "/manifests/latest" {
t.Fatalf("unexpected split result repo=%s rest=%s ok=%v", repo, rest, ok)
}
if _, _, ok := splitDockerRepoPath("/v2/_catalog"); ok {
t.Fatalf("expected catalog path to be ignored")
}
}
func dockerHubRoute(t *testing.T, upstream string) *server.HubRoute {
t.Helper()
parsed, err := url.Parse(upstream)
if err != nil {
t.Fatalf("invalid upstream: %v", err)
}
return &server.HubRoute{
Config: config.HubConfig{
Name: "docker",
Type: "docker",
},
UpstreamURL: parsed,
}
}

772
internal/proxy/handler.go Normal file
View File

@@ -0,0 +1,772 @@
package proxy
import (
"bytes"
"context"
"crypto/sha1"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"path"
"strings"
"time"
"github.com/gofiber/fiber/v3"
"github.com/sirupsen/logrus"
"github.com/any-hub/any-hub/internal/cache"
"github.com/any-hub/any-hub/internal/logging"
"github.com/any-hub/any-hub/internal/server"
)
// Handler 负责 orchestrate “缓存命中 → revalidate → 回源写缓存” 的全流程,
// 对外暴露 Fiber handler内部复用共享 http.Client 与磁盘缓存。
type Handler struct {
client *http.Client
logger *logrus.Logger
store cache.Store
}
// NewHandler constructs a proxy handler with shared HTTP client/logger/store.
func NewHandler(client *http.Client, logger *logrus.Logger, store cache.Store) *Handler {
return &Handler{
client: client,
logger: logger,
store: store,
}
}
// Handle 执行缓存查找、条件回源和最终 streaming 逻辑,任何阶段出错都会输出结构化日志。
func (h *Handler) Handle(c fiber.Ctx, route *server.HubRoute) error {
started := time.Now()
requestID := server.RequestID(c)
locator := buildLocator(route, c)
policy := determineCachePolicy(route, locator, c.Method())
if err := ensureProxyHubType(route); err != nil {
h.logger.WithField("hub", route.Config.Name).WithError(err).Error("hub_type_unsupported")
return h.writeError(c, fiber.StatusNotImplemented, "hub_type_unsupported")
}
ctx := c.Context()
if ctx == nil {
ctx = context.Background()
}
var cached *cache.ReadResult
if h.store != nil && policy.allowCache {
result, err := h.store.Get(ctx, locator)
switch {
case err == nil:
cached = result
case errors.Is(err, cache.ErrNotFound):
// miss, continue
default:
h.logger.WithError(err).WithField("hub", route.Config.Name).Warn("cache_get_failed")
}
}
if cached != nil {
serve := true
if policy.requireRevalidate {
fresh, err := h.isCacheFresh(c, route, locator, cached.Entry)
if err != nil {
h.logger.WithError(err).WithField("hub", route.Config.Name).Warn("cache_revalidate_failed")
serve = false
} else if !fresh {
serve = false
}
}
if serve {
defer cached.Reader.Close()
return h.serveCache(c, route, cached, requestID, started)
}
cached.Reader.Close()
}
return h.fetchAndStream(c, route, locator, policy, requestID, started, ctx)
}
func (h *Handler) serveCache(c fiber.Ctx, route *server.HubRoute, result *cache.ReadResult, requestID string, started time.Time) error {
if seeker, ok := result.Reader.(io.Seeker); ok {
_, _ = seeker.Seek(0, io.SeekStart)
}
method := c.Method()
contentType := inferCachedContentType(route, result.Entry.Locator)
if contentType != "" {
c.Set("Content-Type", contentType)
} else {
c.Response().Header.Del("Content-Type")
}
length := result.Entry.SizeBytes
if length > 0 {
c.Response().Header.SetContentLength(int(length))
} else {
c.Response().Header.Del("Content-Length")
}
c.Set("X-Any-Hub-Upstream", route.UpstreamURL.String())
c.Set("X-Any-Hub-Cache-Hit", "true")
if requestID != "" {
c.Set("X-Request-ID", requestID)
}
status := fiber.StatusOK
c.Status(status)
if method == http.MethodHead {
result.Reader.Close()
h.logResult(route, route.UpstreamURL.String(), requestID, status, true, started, nil)
return nil
}
_, err := io.Copy(c.Response().BodyWriter(), result.Reader)
result.Reader.Close()
h.logResult(route, route.UpstreamURL.String(), requestID, status, true, started, err)
if err != nil {
return fiber.NewError(fiber.StatusBadGateway, fmt.Sprintf("read cache failed: %v", err))
}
return nil
}
func (h *Handler) fetchAndStream(c fiber.Ctx, route *server.HubRoute, locator cache.Locator, policy cachePolicy, requestID string, started time.Time, ctx context.Context) error {
resp, upstreamURL, err := h.executeRequest(c, route)
if err != nil {
h.logResult(route, upstreamURL.String(), requestID, 0, false, started, err)
return h.writeError(c, fiber.StatusBadGateway, "upstream_failed")
}
resp, upstreamURL, err = h.retryOnAuthFailure(c, route, requestID, started, resp, upstreamURL)
if err != nil {
h.logResult(route, upstreamURL.String(), requestID, 0, false, started, err)
return h.writeError(c, fiber.StatusBadGateway, "upstream_failed")
}
defer resp.Body.Close()
shouldStore := policy.allowStore && h.store != nil && isCacheableStatus(resp.StatusCode) && c.Method() == http.MethodGet
return h.consumeUpstream(c, route, locator, resp, shouldStore, requestID, started, ctx)
}
func (h *Handler) consumeUpstream(c fiber.Ctx, route *server.HubRoute, locator cache.Locator, resp *http.Response, shouldStore bool, requestID string, started time.Time, ctx context.Context) error {
upstreamURL := resp.Request.URL.String()
method := c.Method()
authFailure := isAuthFailure(resp.StatusCode) && route.Config.HasCredentials()
if shouldStore {
return h.cacheAndStream(c, route, locator, resp, requestID, started, ctx, upstreamURL)
}
copyResponseHeaders(c, resp.Header)
c.Set("X-Any-Hub-Upstream", upstreamURL)
c.Set("X-Any-Hub-Cache-Hit", "false")
if requestID != "" {
c.Set("X-Request-ID", requestID)
}
c.Status(resp.StatusCode)
if authFailure {
h.logAuthFailure(route, upstreamURL, requestID, resp.StatusCode)
}
if method == http.MethodHead {
h.logResult(route, upstreamURL, requestID, resp.StatusCode, false, started, nil)
return nil
}
_, err := io.Copy(c.Response().BodyWriter(), resp.Body)
h.logResult(route, upstreamURL, requestID, resp.StatusCode, false, started, err)
if err != nil {
return fiber.NewError(fiber.StatusBadGateway, fmt.Sprintf("proxy stream failed: %v", err))
}
return nil
}
func (h *Handler) cacheAndStream(c fiber.Ctx, route *server.HubRoute, locator cache.Locator, resp *http.Response, requestID string, started time.Time, ctx context.Context, upstreamURL string) error {
copyResponseHeaders(c, resp.Header)
c.Set("X-Any-Hub-Upstream", upstreamURL)
c.Set("X-Any-Hub-Cache-Hit", "false")
if requestID != "" {
c.Set("X-Request-ID", requestID)
}
c.Status(resp.StatusCode)
reader := io.TeeReader(resp.Body, c.Response().BodyWriter())
opts := cache.PutOptions{ModTime: extractModTime(resp.Header)}
entry, err := h.store.Put(ctx, locator, reader, opts)
h.logResult(route, upstreamURL, requestID, resp.StatusCode, false, started, err)
if err != nil {
return fiber.NewError(fiber.StatusBadGateway, fmt.Sprintf("cache_write_failed: %v", err))
}
_ = entry
return nil
}
func (h *Handler) retryOnAuthFailure(c fiber.Ctx, route *server.HubRoute, requestID string, started time.Time, resp *http.Response, upstreamURL *url.URL) (*http.Response, *url.URL, error) {
if !shouldRetryAuth(route, resp.StatusCode) {
return resp, upstreamURL, nil
}
challenge, ok := parseBearerChallenge(resp.Header.Values("Www-Authenticate"))
h.logAuthRetry(route, upstreamURL.String(), requestID, resp.StatusCode)
resp.Body.Close()
if ok {
ctx := c.Context()
if ctx == nil {
ctx = context.Background()
}
token, err := h.fetchBearerToken(ctx, challenge, route)
if err != nil {
return nil, upstreamURL, err
}
authHeader := "Bearer " + token
retryResp, retryURL, err := h.executeRequestWithAuth(c, route, authHeader)
if err != nil {
return nil, upstreamURL, err
}
return retryResp, retryURL, nil
}
retryResp, retryURL, err := h.executeRequest(c, route)
if err != nil {
return nil, upstreamURL, err
}
return retryResp, retryURL, nil
}
func (h *Handler) executeRequest(c fiber.Ctx, route *server.HubRoute) (*http.Response, *url.URL, error) {
return h.executeRequestWithAuth(c, route, "")
}
func (h *Handler) executeRequestWithAuth(c fiber.Ctx, route *server.HubRoute, authHeader string) (*http.Response, *url.URL, error) {
upstreamURL := resolveUpstreamURL(route, route.UpstreamURL, c)
body := bytesReader(c.Body())
req, err := h.buildUpstreamRequest(c, upstreamURL, route, c.Method(), body, authHeader)
if err != nil {
return nil, upstreamURL, err
}
resp, err := h.doRequest(req, route)
return resp, upstreamURL, err
}
func (h *Handler) buildUpstreamRequest(c fiber.Ctx, upstream *url.URL, route *server.HubRoute, method string, body io.Reader, overrideAuth string) (*http.Request, error) {
ctx := c.Context()
if ctx == nil {
ctx = context.Background()
}
if body == nil {
body = http.NoBody
}
req, err := http.NewRequestWithContext(ctx, method, upstream.String(), body)
if err != nil {
return nil, err
}
requestHeaders := fiberHeadersAsHTTP(c)
server.CopyHeaders(req.Header, requestHeaders)
req.Host = upstream.Host
req.Header.Set("Host", upstream.Host)
req.Header.Set("X-Forwarded-Host", c.Hostname())
if ip := c.IP(); ip != "" {
if prior := req.Header.Get("X-Forwarded-For"); prior != "" {
req.Header.Set("X-Forwarded-For", prior+", "+ip)
} else {
req.Header.Set("X-Forwarded-For", ip)
}
}
req.Header.Set("X-Forwarded-Proto", c.Protocol())
req.Header.Set("X-Forwarded-Port", routePort(route))
if overrideAuth != "" {
req.Header.Set("Authorization", overrideAuth)
} else if authHeader := buildCredentialHeader(route.Config.Username, route.Config.Password); authHeader != "" {
req.Header.Set("Authorization", authHeader)
}
return req, nil
}
func (h *Handler) doRequest(req *http.Request, route *server.HubRoute) (*http.Response, error) {
if route.ProxyURL == nil {
return h.client.Do(req)
}
transport := http.Transport{}
if base, ok := h.client.Transport.(*http.Transport); ok && base != nil {
transport = *base.Clone()
}
transport.Proxy = http.ProxyURL(route.ProxyURL)
client := *h.client
client.Transport = &transport
return client.Do(req)
}
func (h *Handler) writeError(c fiber.Ctx, status int, code string) error {
return c.Status(status).JSON(fiber.Map{"error": code})
}
func (h *Handler) logResult(route *server.HubRoute, upstream string, requestID string, status int, cacheHit bool, started time.Time, err error) {
fields := logging.RequestFields(route.Config.Name, route.Config.Domain, route.Config.Type, route.Config.AuthMode(), cacheHit)
fields["action"] = "proxy"
fields["upstream"] = upstream
fields["upstream_status"] = status
fields["elapsed_ms"] = time.Since(started).Milliseconds()
if requestID != "" {
fields["request_id"] = requestID
}
if err != nil {
fields["error"] = err.Error()
h.logger.WithFields(fields).Error("proxy_failed")
return
}
h.logger.WithFields(fields).Info("proxy_complete")
}
func inferCachedContentType(route *server.HubRoute, locator cache.Locator) string {
clean := stripQueryMarker(locator.Path)
switch {
case strings.HasSuffix(clean, ".zip"):
return "application/zip"
case strings.HasSuffix(clean, ".mod"):
return "text/plain"
case strings.HasSuffix(clean, ".info"):
return "application/json"
case strings.HasSuffix(clean, ".tgz"):
return "application/octet-stream"
case strings.HasSuffix(clean, "/@v/list"):
return "text/plain"
}
if route != nil {
switch route.Config.Type {
case "docker":
if strings.Contains(clean, "/manifests/") {
return "application/vnd.docker.distribution.manifest.v2+json"
}
if strings.Contains(clean, "/tags/list") {
return "application/json"
}
if strings.Contains(clean, "/blobs/") {
return "application/octet-stream"
}
case "npm":
if strings.HasSuffix(clean, ".json") {
return "application/json"
}
}
}
return ""
}
func buildLocator(route *server.HubRoute, c fiber.Ctx) cache.Locator {
uri := c.Request().URI()
pathVal := string(uri.Path())
if pathVal == "" {
pathVal = "/"
}
clean := path.Clean("/" + pathVal)
if newPath, ok := applyDockerHubNamespaceFallback(route, clean); ok {
clean = newPath
}
query := uri.QueryString()
if len(query) > 0 {
sum := sha1.Sum(query)
clean = fmt.Sprintf("%s/__qs/%s", clean, hex.EncodeToString(sum[:]))
}
return cache.Locator{
HubName: route.Config.Name,
Path: clean,
}
}
func stripQueryMarker(p string) string {
if idx := strings.Index(p, "/__qs/"); idx >= 0 {
return p[:idx]
}
return p
}
func requestPath(c fiber.Ctx) string {
if c == nil {
return "/"
}
uri := c.Request().URI()
if uri == nil {
return "/"
}
pathVal := string(uri.Path())
if pathVal == "" {
return "/"
}
return pathVal
}
func bytesReader(b []byte) io.Reader {
if len(b) == 0 {
return http.NoBody
}
return bytes.NewReader(b)
}
func resolveUpstreamURL(route *server.HubRoute, base *url.URL, c fiber.Ctx) *url.URL {
uri := c.Request().URI()
pathVal := string(uri.Path())
relative := &url.URL{Path: pathVal, RawPath: pathVal}
if newPath, ok := applyDockerHubNamespaceFallback(route, relative.Path); ok {
relative.Path = newPath
relative.RawPath = newPath
}
if query := string(uri.QueryString()); query != "" {
relative.RawQuery = query
}
return base.ResolveReference(relative)
}
func fiberHeadersAsHTTP(c fiber.Ctx) http.Header {
header := http.Header{}
c.Request().Header.VisitAll(func(key, value []byte) {
header.Add(string(key), string(value))
})
return header
}
func copyResponseHeaders(c fiber.Ctx, headers http.Header) {
for key, values := range headers {
if server.IsHopByHopHeader(key) {
continue
}
for _, value := range values {
c.Set(key, value)
}
}
}
func routePort(route *server.HubRoute) string {
if route == nil || route.ListenPort <= 0 {
return "0"
}
return fmt.Sprintf("%d", route.ListenPort)
}
type cachePolicy struct {
allowCache bool
allowStore bool
requireRevalidate bool
}
func determineCachePolicy(route *server.HubRoute, locator cache.Locator, method string) cachePolicy {
if route == nil || method != http.MethodGet {
return cachePolicy{}
}
policy := cachePolicy{allowCache: true, allowStore: true}
path := stripQueryMarker(locator.Path)
switch route.Config.Type {
case "docker":
if path == "/v2" || path == "v2" || path == "/v2/" {
return cachePolicy{}
}
if strings.Contains(path, "/_catalog") {
return cachePolicy{}
}
if isDockerImmutablePath(path) {
return policy
}
policy.requireRevalidate = true
return policy
case "go":
if strings.Contains(path, "/@v/") && (strings.HasSuffix(path, ".zip") || strings.HasSuffix(path, ".mod") || strings.HasSuffix(path, ".info")) {
return policy
}
policy.requireRevalidate = true
return policy
case "npm":
if strings.Contains(path, "/-/") && strings.HasSuffix(path, ".tgz") {
return policy
}
policy.requireRevalidate = true
return policy
default:
return policy
}
}
func isDockerImmutablePath(path string) bool {
if strings.Contains(path, "/blobs/sha256:") {
return true
}
if strings.Contains(path, "/manifests/sha256:") {
return true
}
return false
}
func isCacheableStatus(status int) bool {
return status == http.StatusOK
}
func (h *Handler) isCacheFresh(c fiber.Ctx, route *server.HubRoute, locator cache.Locator, entry cache.Entry) (bool, error) {
ctx := c.Context()
if ctx == nil {
ctx = context.Background()
}
upstreamURL := resolveUpstreamURL(route, route.UpstreamURL, c)
req, err := h.buildUpstreamRequest(c, upstreamURL, route, http.MethodHead, http.NoBody, "")
if err != nil {
return false, err
}
resp, err := h.doRequest(req, route)
if err != nil {
return false, err
}
defer resp.Body.Close()
switch resp.StatusCode {
case http.StatusOK:
remote := extractModTime(resp.Header)
if !remote.After(entry.ModTime.Add(time.Second)) {
return true, nil
}
return false, nil
case http.StatusNotFound:
if h.store != nil {
_ = h.store.Remove(ctx, locator)
}
return false, nil
default:
return false, nil
}
}
func extractModTime(header http.Header) time.Time {
if last := header.Get("Last-Modified"); last != "" {
if parsed, err := http.ParseTime(last); err == nil {
return parsed.UTC()
}
}
return time.Now().UTC()
}
func applyDockerHubNamespaceFallback(route *server.HubRoute, path string) (string, bool) {
if !isDockerHubRoute(route) {
return path, false
}
repo, rest, ok := splitDockerRepoPath(path)
if !ok || repo == "" {
return path, false
}
if repo == "library" || strings.Contains(repo, "/") {
return path, false
}
normalized := "/v2/library/" + repo + rest
return normalized, true
}
func isDockerHubRoute(route *server.HubRoute) bool {
if route == nil || route.Config.Type != "docker" || route.UpstreamURL == nil {
return false
}
host := strings.ToLower(route.UpstreamURL.Hostname())
switch host {
case "registry-1.docker.io", "docker.io", "index.docker.io":
return true
default:
return false
}
}
func splitDockerRepoPath(path string) (string, string, bool) {
if !strings.HasPrefix(path, "/v2/") {
return "", "", false
}
suffix := strings.TrimPrefix(path, "/v2/")
if suffix == "" || suffix == "/" {
return "", "", false
}
segments := strings.Split(suffix, "/")
var repoSegments []string
for i, seg := range segments {
if seg == "" {
return "", "", false
}
switch seg {
case "manifests", "blobs", "tags", "referrers":
if len(repoSegments) == 0 {
return "", "", false
}
rest := "/" + strings.Join(segments[i:], "/")
return strings.Join(repoSegments, "/"), rest, true
case "_catalog":
return "", "", false
}
repoSegments = append(repoSegments, seg)
}
return "", "", false
}
type bearerChallenge struct {
Realm string
Service string
Scope string
}
func parseBearerChallenge(values []string) (bearerChallenge, bool) {
for _, raw := range values {
raw = strings.TrimSpace(raw)
if raw == "" {
continue
}
if !strings.HasPrefix(strings.ToLower(raw), "bearer ") {
continue
}
params := parseAuthParams(raw[len("Bearer "):])
challenge := bearerChallenge{
Realm: params["realm"],
Service: params["service"],
Scope: params["scope"],
}
if challenge.Realm == "" {
continue
}
return challenge, true
}
return bearerChallenge{}, false
}
func parseAuthParams(input string) map[string]string {
params := make(map[string]string)
parts := strings.Split(input, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
kv := strings.SplitN(part, "=", 2)
if len(kv) != 2 {
continue
}
key := strings.ToLower(strings.TrimSpace(kv[0]))
value := strings.Trim(strings.TrimSpace(kv[1]), `"`)
params[key] = value
}
return params
}
func (h *Handler) fetchBearerToken(ctx context.Context, challenge bearerChallenge, route *server.HubRoute) (string, error) {
if challenge.Realm == "" {
return "", errors.New("bearer realm missing")
}
tokenURL, err := url.Parse(challenge.Realm)
if err != nil {
return "", fmt.Errorf("invalid bearer realm: %w", err)
}
query := tokenURL.Query()
if challenge.Service != "" {
query.Set("service", challenge.Service)
}
if challenge.Scope != "" {
query.Set("scope", challenge.Scope)
}
tokenURL.RawQuery = query.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, tokenURL.String(), nil)
if err != nil {
return "", err
}
if route.Config.Username != "" && route.Config.Password != "" {
req.SetBasicAuth(route.Config.Username, route.Config.Password)
}
resp, err := h.client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return "", fmt.Errorf("token request failed: status=%d body=%s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var tokenResp struct {
Token string `json:"token"`
AccessToken string `json:"access_token"`
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return "", fmt.Errorf("decode token response: %w", err)
}
token := tokenResp.Token
if token == "" {
token = tokenResp.AccessToken
}
if token == "" {
return "", errors.New("token response missing token value")
}
return token, nil
}
func buildCredentialHeader(username, password string) string {
if username == "" || password == "" {
return ""
}
token := username + ":" + password
return "Basic " + base64.StdEncoding.EncodeToString([]byte(token))
}
func shouldRetryAuth(route *server.HubRoute, status int) bool {
return route != nil && route.Config.HasCredentials() && isAuthFailure(status)
}
func isAuthFailure(status int) bool {
return status == http.StatusUnauthorized || status == http.StatusTooManyRequests
}
func (h *Handler) logAuthRetry(route *server.HubRoute, upstream string, requestID string, status int) {
fields := logging.RequestFields(route.Config.Name, route.Config.Domain, route.Config.Type, route.Config.AuthMode(), false)
fields["action"] = "proxy_retry"
fields["upstream"] = upstream
fields["upstream_status"] = status
fields["reason"] = "auth_retry"
if requestID != "" {
fields["request_id"] = requestID
}
h.logger.WithFields(fields).Warn("proxy_auth_retry")
}
func (h *Handler) logAuthFailure(route *server.HubRoute, upstream string, requestID string, status int) {
fields := logging.RequestFields(route.Config.Name, route.Config.Domain, route.Config.Type, route.Config.AuthMode(), false)
fields["action"] = "proxy"
fields["upstream"] = upstream
fields["upstream_status"] = status
fields["error"] = "upstream_auth_failed"
if requestID != "" {
fields["request_id"] = requestID
}
h.logger.WithFields(fields).Error("proxy_auth_failed")
}
func ensureProxyHubType(route *server.HubRoute) error {
switch route.Config.Type {
case "docker":
return nil
case "npm":
return nil
case "go":
return nil
default:
return fmt.Errorf("unsupported hub type: %s", route.Config.Type)
}
}

8
internal/server/doc.go Normal file
View File

@@ -0,0 +1,8 @@
// Package server hosts the Fiber HTTP service, request middleware chain, and
// hub registry glue that wires Host/port resolution into proxy handlers.
// Phase 1 focuses on a single binary that bootstraps Fiber, attaches logging
// and error middlewares, injects the HubRegistry built from config, and exposes
// router constructors that other packages (cmd/any-hub, proxy) can reuse.
// Future phases may extend this package with TLS, metrics endpoints, or admin
// surfaces, so keep exports narrow and accept explicit dependencies.
package server

View File

@@ -0,0 +1,77 @@
package server
import (
"net"
"net/http"
"net/textproto"
"time"
"github.com/any-hub/any-hub/internal/config"
)
// Shared HTTP transport tunings复用长连接并集中配置超时。
var defaultTransport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ForceAttemptHTTP2: true,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
}
// NewUpstreamClient 返回共享 http.Client用于所有上游请求。
func NewUpstreamClient(cfg *config.Config) *http.Client {
timeout := 30 * time.Second
if cfg != nil && cfg.Global.UpstreamTimeout.DurationValue() > 0 {
timeout = cfg.Global.UpstreamTimeout.DurationValue()
}
return &http.Client{
Timeout: timeout,
Transport: defaultTransport.Clone(),
}
}
// hopByHopHeaders 定义 RFC 7230 中禁止代理转发的头部。
var hopByHopHeaders = map[string]struct{}{
"Connection": {},
"Keep-Alive": {},
"Proxy-Authenticate": {},
"Proxy-Authorization": {},
"Te": {},
"Trailer": {},
"Transfer-Encoding": {},
"Upgrade": {},
"Proxy-Connection": {}, // 非标准字段,但部分代理仍使用
}
// CopyHeaders 将 src 中允许透传的头复制到 dst自动忽略 hop-by-hop 字段。
func CopyHeaders(dst, src http.Header) {
for key, values := range src {
if isHopByHopHeader(key) {
continue
}
for _, value := range values {
dst.Add(key, value)
}
}
}
func isHopByHopHeader(key string) bool {
canonical := textproto.CanonicalMIMEHeaderKey(key)
if _, ok := hopByHopHeaders[canonical]; ok {
return true
}
return false
}
// IsHopByHopHeader reports whether the header should be stripped by proxies.
func IsHopByHopHeader(key string) bool {
return isHopByHopHeader(key)
}

View File

@@ -0,0 +1,45 @@
package server
import (
"net/http"
"testing"
"time"
"github.com/any-hub/any-hub/internal/config"
)
func TestNewUpstreamClientUsesConfigTimeout(t *testing.T) {
cfg := &config.Config{
Global: config.GlobalConfig{
UpstreamTimeout: config.Duration(45 * time.Second),
},
}
client := NewUpstreamClient(cfg)
if client.Timeout != 45*time.Second {
t.Fatalf("expected timeout 45s, got %s", client.Timeout)
}
}
func TestCopyHeadersSkipsHopByHop(t *testing.T) {
src := http.Header{}
src.Add("Connection", "keep-alive")
src.Add("Keep-Alive", "timeout=5")
src.Add("X-Test-Header", "1")
src.Add("x-test-header", "2")
dst := http.Header{}
CopyHeaders(dst, src)
if _, exists := dst["Connection"]; exists {
t.Fatalf("connection header should not be copied")
}
if _, exists := dst["Keep-Alive"]; exists {
t.Fatalf("keep-alive header should not be copied")
}
got := dst.Values("X-Test-Header")
if len(got) != 2 {
t.Fatalf("expected 2 values, got %v", got)
}
}

View File

@@ -0,0 +1,152 @@
package server
import (
"errors"
"fmt"
"net"
"net/url"
"strconv"
"strings"
"time"
"github.com/any-hub/any-hub/internal/config"
)
// HubRoute 将 Hub 配置与派生属性(如缓存 TTL、解析后的 Upstream/Proxy URL
// 聚合在一起,供路由/代理层直接复用,避免重复解析配置。
type HubRoute struct {
// Config 是用户在 config.toml 中声明的 Hub 字段副本,避免外部修改。
Config config.HubConfig
// ListenPort 记录当前 CLI 监听端口,方便日志/转发头输出。
ListenPort int
// CacheTTL 是对当前 Hub 生效的 TTL若 Hub 未覆盖则等于全局值。
CacheTTL time.Duration
// UpstreamURL/ProxyURL 在构造 Registry 时提前解析完成,便于后续请求快速复用。
UpstreamURL *url.URL
ProxyURL *url.URL
}
// HubRegistry 提供 Host/Host:port 到 HubRoute 的查询能力,所有 Hub 共享同一个监听端口。
type HubRegistry struct {
routes map[string]*HubRoute
ordered []*HubRoute
}
// NewHubRegistry 根据配置构建 Host/端口映射。调用方应在启动阶段创建一次并复用。
func NewHubRegistry(cfg *config.Config) (*HubRegistry, error) {
if cfg == nil {
return nil, errors.New("config is nil")
}
registry := &HubRegistry{
routes: make(map[string]*HubRoute, len(cfg.Hubs)),
}
if len(cfg.Hubs) == 0 {
return registry, nil
}
for _, hub := range cfg.Hubs {
normalizedHost := normalizeDomain(hub.Domain)
if normalizedHost == "" {
return nil, fmt.Errorf("invalid domain for hub %s", hub.Name)
}
if _, exists := registry.routes[normalizedHost]; exists {
return nil, fmt.Errorf("duplicate domain mapping detected for %s", normalizedHost)
}
route, err := buildHubRoute(cfg, hub)
if err != nil {
return nil, err
}
registry.routes[normalizedHost] = route
registry.ordered = append(registry.ordered, route)
}
return registry, nil
}
// Lookup 根据 Host 或 Host:port 查找 HubRoute。
func (r *HubRegistry) Lookup(host string) (*HubRoute, bool) {
if r == nil {
return nil, false
}
normalizedHost, _ := normalizeHost(host)
if normalizedHost == "" {
return nil, false
}
route, ok := r.routes[normalizedHost]
return route, ok
}
// List 返回当前注册的 HubRoute 列表(按配置定义的顺序),用于调试或 /status 输出。
func (r *HubRegistry) List() []HubRoute {
if r == nil || len(r.ordered) == 0 {
return nil
}
result := make([]HubRoute, len(r.ordered))
for i, route := range r.ordered {
result[i] = *route
}
return result
}
func buildHubRoute(cfg *config.Config, hub config.HubConfig) (*HubRoute, error) {
upstreamURL, err := url.Parse(hub.Upstream)
if err != nil {
return nil, fmt.Errorf("invalid upstream for hub %s: %w", hub.Name, err)
}
var proxyURL *url.URL
if hub.Proxy != "" {
proxyURL, err = url.Parse(hub.Proxy)
if err != nil {
return nil, fmt.Errorf("invalid proxy for hub %s: %w", hub.Name, err)
}
}
return &HubRoute{
Config: hub,
ListenPort: cfg.Global.ListenPort,
CacheTTL: cfg.EffectiveCacheTTL(hub),
UpstreamURL: upstreamURL,
ProxyURL: proxyURL,
}, nil
}
func normalizeDomain(domain string) string {
host, _ := normalizeHost(domain)
return host
}
func normalizeHost(raw string) (string, int) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", 0
}
host := raw
port := 0
if strings.Contains(raw, ":") {
if h, p, err := net.SplitHostPort(raw); err == nil {
host = h
if parsedPort, err := strconv.Atoi(p); err == nil {
port = parsedPort
}
} else if idx := strings.LastIndex(raw, ":"); idx > -1 && strings.Count(raw[idx+1:], ":") == 0 {
if parsedPort, err := strconv.Atoi(raw[idx+1:]); err == nil {
host = raw[:idx]
port = parsedPort
}
}
}
host = strings.TrimSuffix(host, ".")
host = strings.ToLower(host)
return host, port
}

View File

@@ -0,0 +1,120 @@
package server
import (
"testing"
"time"
"github.com/any-hub/any-hub/internal/config"
)
func TestHubRegistryLookupByHost(t *testing.T) {
cfg := &config.Config{
Global: config.GlobalConfig{
ListenPort: 5000,
CacheTTL: config.Duration(2 * time.Hour),
},
Hubs: []config.HubConfig{
{
Name: "docker",
Domain: "docker.hub.local",
Type: "docker",
Upstream: "https://registry-1.docker.io",
EnableHeadCheck: true,
},
{
Name: "npm",
Domain: "npm.hub.local",
Type: "npm",
Upstream: "https://registry.npmjs.org",
CacheTTL: config.Duration(30 * time.Minute),
},
},
}
registry, err := NewHubRegistry(cfg)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
route, ok := registry.Lookup("docker.hub.local")
if !ok {
t.Fatalf("expected docker route")
}
if route.Config.Name != "docker" {
t.Errorf("wrong hub returned: %s", route.Config.Name)
}
if route.CacheTTL != cfg.EffectiveCacheTTL(route.Config) {
t.Errorf("cache ttl mismatch: got %s", route.CacheTTL)
}
if route.UpstreamURL.String() != "https://registry-1.docker.io" {
t.Errorf("unexpected upstream URL: %s", route.UpstreamURL)
}
if route.ProxyURL != nil {
t.Errorf("expected nil proxy")
}
if route.ListenPort != cfg.Global.ListenPort {
t.Fatalf("route listen port mismatch: %d", route.ListenPort)
}
if got := len(registry.List()); got != 2 {
t.Fatalf("expected 2 routes in list, got %d", got)
}
}
func TestHubRegistryParsesHostHeaderPort(t *testing.T) {
cfg := &config.Config{
Global: config.GlobalConfig{
ListenPort: 5000,
CacheTTL: config.Duration(time.Hour),
},
Hubs: []config.HubConfig{
{
Name: "docker",
Domain: "docker.hub.local",
Type: "docker",
Upstream: "https://registry-1.docker.io",
},
},
}
registry, err := NewHubRegistry(cfg)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if _, ok := registry.Lookup("docker.hub.local:6000"); !ok {
t.Fatalf("expected lookup to ignore host header port")
}
}
func TestHubRegistryRejectsDuplicateDomains(t *testing.T) {
cfg := &config.Config{
Global: config.GlobalConfig{
ListenPort: 5000,
CacheTTL: config.Duration(time.Hour),
},
Hubs: []config.HubConfig{
{
Name: "docker",
Domain: "docker.hub.local",
Type: "docker",
Upstream: "https://registry-1.docker.io",
},
{
Name: "docker-alt",
Domain: "docker.hub.local",
Type: "docker",
Upstream: "https://mirror.registry-1.docker.io",
},
},
}
if _, err := NewHubRegistry(cfg); err == nil {
t.Fatalf("expected duplicate domain error")
}
}

163
internal/server/router.go Normal file
View File

@@ -0,0 +1,163 @@
package server
import (
"errors"
"fmt"
"strings"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/recover"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
)
// ProxyHandler describes the component responsible for proxying requests to
// the upstream Hub. It allows injecting fake handlers during tests.
type ProxyHandler interface {
Handle(fiber.Ctx, *HubRoute) error
}
// ProxyHandlerFunc adapts a function to the ProxyHandler interface.
type ProxyHandlerFunc func(fiber.Ctx, *HubRoute) error
// Handle makes ProxyHandlerFunc satisfy ProxyHandler.
func (f ProxyHandlerFunc) Handle(c fiber.Ctx, route *HubRoute) error {
return f(c, route)
}
// AppOptions controls how the Fiber application should behave on a specific port.
type AppOptions struct {
Logger *logrus.Logger
Registry *HubRegistry
Proxy ProxyHandler
ListenPort int
}
const (
contextKeyRoute = "_anyhub_route"
contextKeyRequestID = "_anyhub_request_id"
)
// NewApp builds a Fiber application with Host/port routing middleware and
// structured error handling.
func NewApp(opts AppOptions) (*fiber.App, error) {
if opts.Logger == nil {
return nil, errors.New("logger is required")
}
if opts.Registry == nil {
return nil, errors.New("hub registry is required")
}
if opts.Proxy == nil {
return nil, errors.New("proxy handler is required")
}
if opts.ListenPort <= 0 {
return nil, fmt.Errorf("invalid listen port: %d", opts.ListenPort)
}
app := fiber.New(fiber.Config{
CaseSensitive: true,
})
app.Use(recover.New())
app.Use(requestContextMiddleware(opts))
app.All("/*", func(c fiber.Ctx) error {
route, _ := getRouteFromContext(c)
if route == nil {
return renderHostUnmapped(c, opts.Logger, "", opts.ListenPort)
}
return opts.Proxy.Handle(c, route)
})
return app, nil
}
// requestContextMiddleware 负责生成请求 ID并基于 Host/Host:port 查找 HubRoute。
func requestContextMiddleware(opts AppOptions) fiber.Handler {
return func(c fiber.Ctx) error {
reqID := uuid.NewString()
c.Locals(contextKeyRequestID, reqID)
c.Set("X-Request-ID", reqID)
rawHost := strings.TrimSpace(getHostHeader(c))
route, ok := opts.Registry.Lookup(rawHost)
if !ok {
return renderHostUnmapped(c, opts.Logger, rawHost, opts.ListenPort)
}
if err := ensureRouterHubType(route); err != nil {
return renderTypeUnsupported(c, opts.Logger, route, err)
}
c.Locals(contextKeyRoute, route)
return c.Next()
}
}
func renderHostUnmapped(c fiber.Ctx, logger *logrus.Logger, host string, port int) error {
fields := logrus.Fields{
"action": "host_lookup",
"host": host,
"port": port,
}
logger.WithFields(fields).Warn("host unmapped")
if host != "" {
c.Set("X-Any-Hub-Host", host)
}
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{
"error": "host_unmapped",
})
}
func getHostHeader(c fiber.Ctx) string {
if raw := c.Request().Header.Peek(fiber.HeaderHost); len(raw) > 0 {
return string(raw)
}
return c.Hostname()
}
func getRouteFromContext(c fiber.Ctx) (*HubRoute, bool) {
if value := c.Locals(contextKeyRoute); value != nil {
if route, ok := value.(*HubRoute); ok {
return route, true
}
}
return nil, false
}
// RequestID returns the request identifier stored by the router middleware.
func RequestID(c fiber.Ctx) string {
if value := c.Locals(contextKeyRequestID); value != nil {
if reqID, ok := value.(string); ok {
return reqID
}
}
return ""
}
func ensureRouterHubType(route *HubRoute) error {
switch route.Config.Type {
case "docker":
return nil
case "npm":
return nil
case "go":
return nil
default:
return fmt.Errorf("unsupported hub type: %s", route.Config.Type)
}
}
func renderTypeUnsupported(c fiber.Ctx, logger *logrus.Logger, route *HubRoute, err error) error {
fields := logrus.Fields{
"action": "hub_type_check",
"hub": route.Config.Name,
"hub_type": route.Config.Type,
"error": "hub_type_unsupported",
}
logger.WithFields(fields).Error(err.Error())
return c.Status(fiber.StatusNotImplemented).JSON(fiber.Map{
"error": "hub_type_unsupported",
})
}

View File

@@ -0,0 +1,118 @@
package server
import (
"bytes"
"io"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/sirupsen/logrus"
"github.com/any-hub/any-hub/internal/config"
)
func TestRouterRoutesRequestWhenHostMatches(t *testing.T) {
app := newTestApp(t, 5000)
req := httptest.NewRequest("GET", "http://docker.hub.local/v2/", nil)
req.Host = "docker.hub.local"
req.Header.Set("Host", "docker.hub.local")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("app.Test failed: %v", err)
}
if resp.StatusCode != fiber.StatusNoContent {
body, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 204 status, got %d (body=%s, hostHeader=%s)", resp.StatusCode, string(body), resp.Header.Get("X-Any-Hub-Host"))
}
if app.storage.routeName != "docker" {
t.Fatalf("expected docker route, got %s", app.storage.routeName)
}
if reqID := resp.Header.Get("X-Request-ID"); reqID == "" {
t.Fatalf("expected X-Request-ID header to be set")
}
}
func TestRouterReturns404WhenHostUnknown(t *testing.T) {
app := newTestApp(t, 5000)
req := httptest.NewRequest("GET", "http://unknown.local/v2/", nil)
req.Host = "unknown.local"
req.Header.Set("Host", "unknown.local")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("app.Test failed: %v", err)
}
if resp.StatusCode != fiber.StatusNotFound {
t.Fatalf("expected 404 status, got %d", resp.StatusCode)
}
body, _ := io.ReadAll(resp.Body)
if !bytes.Contains(body, []byte(`"host_unmapped"`)) {
t.Fatalf("expected host_unmapped error, got %s", string(body))
}
}
type testApp struct {
*fiber.App
storage *proxyRecorder
}
func newTestApp(t *testing.T, port int) *testApp {
t.Helper()
cfg := &config.Config{
Global: config.GlobalConfig{
ListenPort: port,
CacheTTL: config.Duration(3600),
},
Hubs: []config.HubConfig{
{
Name: "docker",
Domain: "docker.hub.local",
Type: "docker",
Upstream: "https://registry-1.docker.io",
},
},
}
registry, err := NewHubRegistry(cfg)
if err != nil {
t.Fatalf("failed to create registry: %v", err)
}
if _, ok := registry.Lookup("docker.hub.local"); !ok {
t.Fatalf("registry lookup failed for docker")
}
logger := logrus.New()
logger.SetOutput(io.Discard)
recorder := &proxyRecorder{}
app, err := NewApp(AppOptions{
Logger: logger,
Registry: registry,
Proxy: recorder,
ListenPort: port,
})
if err != nil {
t.Fatalf("failed to create app: %v", err)
}
return &testApp{App: app, storage: recorder}
}
type proxyRecorder struct {
lastRoute *HubRoute
routeName string
}
func (p *proxyRecorder) Handle(c fiber.Ctx, route *HubRoute) error {
p.lastRoute = route
p.routeName = route.Config.Name
return c.SendStatus(fiber.StatusNoContent)
}

View File

@@ -0,0 +1,14 @@
package version
import "fmt"
// Version/Commit 可在构建时通过 -ldflags 注入,默认使用开发占位符。
var (
Version = "0.1.0"
Commit = "dev"
)
// Full 返回便于 CLI 打印的完整版本信息。
func Full() string {
return fmt.Sprintf("any-hub %s (%s)", Version, Commit)
}