init
This commit is contained in:
7
internal/cache/doc.go
vendored
Normal file
7
internal/cache/doc.go
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
// Package cache defines the disk-backed store responsible for translating hub
|
||||
// requests into StoragePath/<hub>/<path> files. The store exposes read/write
|
||||
// primitives with safe semantics (temp file + rename) and surfaces file info
|
||||
// (size, modtime) for higher layers to implement conditional revalidation.
|
||||
// Proxy handlers depend on this package to stream cached responses or trigger
|
||||
// upstream fetches without duplicating filesystem logic.
|
||||
package cache
|
||||
240
internal/cache/fs_store.go
vendored
Normal file
240
internal/cache/fs_store.go
vendored
Normal file
@@ -0,0 +1,240 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NewStore 以 basePath 为根目录构建磁盘缓存,整站复用一份实例。
|
||||
func NewStore(basePath string) (Store, error) {
|
||||
if basePath == "" {
|
||||
return nil, errors.New("storage path required")
|
||||
}
|
||||
|
||||
abs, err := filepath.Abs(basePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve storage path: %w", err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(abs, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create storage path: %w", err)
|
||||
}
|
||||
|
||||
return &fileStore{
|
||||
basePath: abs,
|
||||
locks: make(map[string]*entryLock),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// fileStore 通过 entryLock 避免同一 Locator 并发写入,同时复用 basePath。
|
||||
type fileStore struct {
|
||||
basePath string
|
||||
|
||||
mu sync.Mutex
|
||||
locks map[string]*entryLock
|
||||
}
|
||||
|
||||
type entryLock struct {
|
||||
mu sync.Mutex
|
||||
refs int
|
||||
}
|
||||
|
||||
func (s *fileStore) Get(ctx context.Context, locator Locator) (*ReadResult, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
filePath, err := s.path(locator)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if info.IsDir() {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
f, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
entry := Entry{
|
||||
Locator: locator,
|
||||
FilePath: filePath,
|
||||
SizeBytes: info.Size(),
|
||||
ModTime: info.ModTime(),
|
||||
}
|
||||
|
||||
return &ReadResult{
|
||||
Entry: entry,
|
||||
Reader: f,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *fileStore) Put(ctx context.Context, locator Locator, body io.Reader, opts PutOptions) (*Entry, error) {
|
||||
unlock, err := s.lockEntry(locator)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer unlock()
|
||||
|
||||
filePath, err := s.path(locator)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tempFile, err := os.CreateTemp(filepath.Dir(filePath), ".cache-*")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tempName := tempFile.Name()
|
||||
|
||||
written, err := copyWithContext(ctx, tempFile, body)
|
||||
closeErr := tempFile.Close()
|
||||
if err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
if err != nil {
|
||||
os.Remove(tempName)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := os.Rename(tempName, filePath); err != nil {
|
||||
os.Remove(tempName)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modTime := opts.ModTime
|
||||
if modTime.IsZero() {
|
||||
modTime = time.Now().UTC()
|
||||
}
|
||||
if err := os.Chtimes(filePath, modTime, modTime); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
entry := Entry{
|
||||
Locator: locator,
|
||||
FilePath: filePath,
|
||||
SizeBytes: written,
|
||||
ModTime: modTime,
|
||||
}
|
||||
return &entry, nil
|
||||
}
|
||||
|
||||
func (s *fileStore) Remove(ctx context.Context, locator Locator) error {
|
||||
unlock, err := s.lockEntry(locator)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer unlock()
|
||||
|
||||
filePath, err := s.path(locator)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Remove(filePath); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fileStore) lockEntry(locator Locator) (func(), error) {
|
||||
key := locatorKey(locator)
|
||||
s.mu.Lock()
|
||||
lock := s.locks[key]
|
||||
if lock == nil {
|
||||
lock = &entryLock{}
|
||||
s.locks[key] = lock
|
||||
}
|
||||
lock.refs++
|
||||
s.mu.Unlock()
|
||||
|
||||
lock.mu.Lock()
|
||||
return func() {
|
||||
lock.mu.Unlock()
|
||||
s.mu.Lock()
|
||||
lock.refs--
|
||||
if lock.refs == 0 {
|
||||
delete(s.locks, key)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *fileStore) path(locator Locator) (string, error) {
|
||||
if locator.HubName == "" {
|
||||
return "", errors.New("hub name required")
|
||||
}
|
||||
|
||||
rel := locator.Path
|
||||
if rel == "" || rel == "/" {
|
||||
rel = "root"
|
||||
}
|
||||
rel = path.Clean("/" + rel)
|
||||
rel = strings.TrimPrefix(rel, "/")
|
||||
if rel == "" {
|
||||
rel = "root"
|
||||
}
|
||||
|
||||
filePath := filepath.Join(s.basePath, locator.HubName, filepath.FromSlash(rel))
|
||||
if !strings.HasPrefix(filePath, filepath.Join(s.basePath, locator.HubName)) {
|
||||
return "", errors.New("invalid cache path")
|
||||
}
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
func copyWithContext(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
|
||||
var copied int64
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return copied, err
|
||||
}
|
||||
n, err := src.Read(buf)
|
||||
if n > 0 {
|
||||
w, wErr := dst.Write(buf[:n])
|
||||
copied += int64(w)
|
||||
if wErr != nil {
|
||||
return copied, wErr
|
||||
}
|
||||
if w < n {
|
||||
return copied, io.ErrShortWrite
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return copied, nil
|
||||
}
|
||||
return copied, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func locatorKey(locator Locator) string {
|
||||
return locator.HubName + "::" + locator.Path
|
||||
}
|
||||
53
internal/cache/store.go
vendored
Normal file
53
internal/cache/store.go
vendored
Normal file
@@ -0,0 +1,53 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Store 负责管理磁盘缓存的读写。磁盘布局遵循:
|
||||
//
|
||||
// <StoragePath>/<HubName>/<path> # 实际正文
|
||||
//
|
||||
// 每个条目仅由正文文件组成,文件的 ModTime/Size 由文件系统提供。
|
||||
type Store interface {
|
||||
// Get 返回一个可流式读取的缓存条目。若不存在则返回 ErrNotFound。
|
||||
Get(ctx context.Context, locator Locator) (*ReadResult, error)
|
||||
|
||||
// Put 将上游响应写入缓存,并产出新的 Entry 描述。实现需通过临时文件 + rename
|
||||
// 保证写入原子性,并在失败时清理临时文件。可选地根据 opts.ModTime 设置文件时间戳。
|
||||
Put(ctx context.Context, locator Locator, body io.Reader, opts PutOptions) (*Entry, error)
|
||||
|
||||
// Remove 删除正文文件,通常用于上游错误或复合策略清理。
|
||||
Remove(ctx context.Context, locator Locator) error
|
||||
}
|
||||
|
||||
// PutOptions 控制写入过程中的可选属性。
|
||||
type PutOptions struct {
|
||||
ModTime time.Time
|
||||
}
|
||||
|
||||
// Locator 唯一定位一个缓存条目(Hub + 相对路径),所有路径均为 URL 路径风格。
|
||||
type Locator struct {
|
||||
HubName string
|
||||
Path string
|
||||
}
|
||||
|
||||
// Entry 表示一次缓存命中结果,包含绝对文件路径及文件信息。
|
||||
type Entry struct {
|
||||
Locator Locator `json:"locator"`
|
||||
FilePath string `json:"file_path"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
ModTime time.Time
|
||||
}
|
||||
|
||||
// ReadResult 组合 Entry 与正文 Reader,便于代理层直接将 Body 流式返回。
|
||||
type ReadResult struct {
|
||||
Entry Entry
|
||||
Reader io.ReadSeekCloser
|
||||
}
|
||||
|
||||
// ErrNotFound 表示缓存不存在。
|
||||
var ErrNotFound = errors.New("cache entry not found")
|
||||
95
internal/cache/store_test.go
vendored
Normal file
95
internal/cache/store_test.go
vendored
Normal file
@@ -0,0 +1,95 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestStorePutAndGet(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
locator := Locator{HubName: "docker", Path: "/v2/library/sample/manifests/latest"}
|
||||
|
||||
modTime := time.Now().Add(-time.Hour).UTC()
|
||||
payload := []byte("payload")
|
||||
if _, err := store.Put(context.Background(), locator, bytes.NewReader(payload), PutOptions{ModTime: modTime}); err != nil {
|
||||
t.Fatalf("put error: %v", err)
|
||||
}
|
||||
|
||||
result, err := store.Get(context.Background(), locator)
|
||||
if err != nil {
|
||||
t.Fatalf("get error: %v", err)
|
||||
}
|
||||
defer result.Reader.Close()
|
||||
|
||||
body, err := io.ReadAll(result.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("read cached body error: %v", err)
|
||||
}
|
||||
if string(body) != string(payload) {
|
||||
t.Fatalf("cached payload mismatch: %s", string(body))
|
||||
}
|
||||
if result.Entry.SizeBytes != int64(len(payload)) {
|
||||
t.Fatalf("size mismatch: %d", result.Entry.SizeBytes)
|
||||
}
|
||||
if !result.Entry.ModTime.Equal(modTime) {
|
||||
t.Fatalf("modtime mismatch: expected %v got %v", modTime, result.Entry.ModTime)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreGetMissing(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
_, err := store.Get(context.Background(), Locator{HubName: "docker", Path: "/missing"})
|
||||
if err == nil || err != ErrNotFound {
|
||||
t.Fatalf("expected ErrNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreRemove(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
locator := Locator{HubName: "docker", Path: "/cache/remove"}
|
||||
if _, err := store.Put(context.Background(), locator, bytes.NewReader([]byte("data")), PutOptions{}); err != nil {
|
||||
t.Fatalf("put error: %v", err)
|
||||
}
|
||||
if err := store.Remove(context.Background(), locator); err != nil {
|
||||
t.Fatalf("remove error: %v", err)
|
||||
}
|
||||
if _, err := store.Get(context.Background(), locator); err == nil || err != ErrNotFound {
|
||||
t.Fatalf("expected not found after remove, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreIgnoresDirectories(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
locator := Locator{HubName: "ghcr", Path: "/v2"}
|
||||
|
||||
fs, ok := store.(*fileStore)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected store type %T", store)
|
||||
}
|
||||
|
||||
filePath, err := fs.path(locator)
|
||||
if err != nil {
|
||||
t.Fatalf("path error: %v", err)
|
||||
}
|
||||
if err := os.MkdirAll(filePath, 0o755); err != nil {
|
||||
t.Fatalf("mkdir error: %v", err)
|
||||
}
|
||||
|
||||
if _, err := store.Get(context.Background(), locator); err == nil || err != ErrNotFound {
|
||||
t.Fatalf("expected ErrNotFound for directory, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// newTestStore returns a Store backed by a temporary directory.
|
||||
func newTestStore(t *testing.T) Store {
|
||||
t.Helper()
|
||||
store, err := NewStore(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create store: %v", err)
|
||||
}
|
||||
return store
|
||||
}
|
||||
109
internal/config/config_test.go
Normal file
109
internal/config/config_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLoadWithDefaults(t *testing.T) {
|
||||
cfgPath := testConfigPath(t, "valid.toml")
|
||||
|
||||
cfg, err := Load(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Load 返回错误: %v", err)
|
||||
}
|
||||
if cfg.Global.CacheTTL.DurationValue() == 0 {
|
||||
t.Fatalf("CacheTTL 应该自动填充默认值")
|
||||
}
|
||||
if cfg.Global.StoragePath == "" {
|
||||
t.Fatalf("StoragePath 应该被保留")
|
||||
}
|
||||
if cfg.Global.ListenPort == 0 {
|
||||
t.Fatalf("ListenPort 应当被解析")
|
||||
}
|
||||
if cfg.EffectiveCacheTTL(cfg.Hubs[0]) != cfg.Global.CacheTTL.DurationValue() {
|
||||
t.Fatalf("Hub 未设置 TTL 时应退回全局 TTL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRejectsBadHub(t *testing.T) {
|
||||
cfgPath := testConfigPath(t, "missing.toml")
|
||||
|
||||
if _, err := Load(cfgPath); err == nil {
|
||||
t.Fatalf("不合法的配置应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveCacheTTLOverrides(t *testing.T) {
|
||||
cfg := &Config{Global: GlobalConfig{CacheTTL: Duration(time.Hour)}}
|
||||
hub := HubConfig{CacheTTL: Duration(2 * time.Hour)}
|
||||
if ttl := cfg.EffectiveCacheTTL(hub); ttl != 2*time.Hour {
|
||||
t.Fatalf("覆盖 TTL 应该优先生效")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateEnforcesListenPortRange(t *testing.T) {
|
||||
cfg := validConfig()
|
||||
cfg.Global.ListenPort = 70000
|
||||
if err := cfg.Validate(); err == nil {
|
||||
t.Fatalf("ListenPort 超出范围应当报错")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHubTypeValidation(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
hubType string
|
||||
shouldErr bool
|
||||
}{
|
||||
{"docker ok", "docker", false},
|
||||
{"npm ok", "npm", false},
|
||||
{"go ok", "go", false},
|
||||
{"missing type", "", true},
|
||||
{"unsupported type", "rubygems", true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cfg := validConfig()
|
||||
cfg.Hubs[0].Type = tc.hubType
|
||||
err := cfg.Validate()
|
||||
if tc.shouldErr && err == nil {
|
||||
t.Fatalf("expected error for type %q", tc.hubType)
|
||||
}
|
||||
if !tc.shouldErr && err != nil {
|
||||
t.Fatalf("unexpected error for type %q: %v", tc.hubType, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRequiresCredentialPairs(t *testing.T) {
|
||||
cfg := validConfig()
|
||||
cfg.Hubs[0].Username = "foo"
|
||||
if err := cfg.Validate(); err == nil {
|
||||
t.Fatalf("仅提供 Username 时应报错")
|
||||
}
|
||||
}
|
||||
|
||||
func validConfig() *Config {
|
||||
return &Config{
|
||||
Global: GlobalConfig{
|
||||
ListenPort: 5000,
|
||||
StoragePath: "./data",
|
||||
CacheTTL: Duration(time.Hour),
|
||||
MaxMemoryCache: 1,
|
||||
MaxRetries: 1,
|
||||
InitialBackoff: Duration(time.Second),
|
||||
UpstreamTimeout: Duration(time.Second),
|
||||
},
|
||||
Hubs: []HubConfig{
|
||||
{
|
||||
Name: "npm",
|
||||
Domain: "npm.local",
|
||||
Type: "npm",
|
||||
Upstream: "https://registry.npmjs.org",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
26
internal/config/errors.go
Normal file
26
internal/config/errors.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package config
|
||||
|
||||
import "fmt"
|
||||
|
||||
// FieldError 提供字段路径与错误原因,便于 CLI 向用户反馈。
|
||||
type FieldError struct {
|
||||
Field string
|
||||
Reason string
|
||||
}
|
||||
|
||||
func (e FieldError) Error() string {
|
||||
return fmt.Sprintf("%s: %s", e.Field, e.Reason)
|
||||
}
|
||||
|
||||
// newFieldError 创建包含字段路径与原因的 error,便于 CLI 定位。
|
||||
func newFieldError(field, reason string) error {
|
||||
return FieldError{Field: field, Reason: reason}
|
||||
}
|
||||
|
||||
// hubField 用于拼接 Hub 级字段路径,方便输出 Hub[xxx].Field 形式。
|
||||
func hubField(name, field string) string {
|
||||
if name == "" {
|
||||
return fmt.Sprintf("Hub[].%s", field)
|
||||
}
|
||||
return fmt.Sprintf("Hub[%s].%s", name, field)
|
||||
}
|
||||
149
internal/config/loader.go
Normal file
149
internal/config/loader.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// Load 读取并解析 TOML 配置文件,同时注入默认值与校验逻辑。
|
||||
func Load(path string) (*Config, error) {
|
||||
if path == "" {
|
||||
path = "config.toml"
|
||||
}
|
||||
|
||||
v := viper.New()
|
||||
v.SetConfigFile(path)
|
||||
setDefaults(v)
|
||||
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
return nil, fmt.Errorf("读取配置失败: %w", err)
|
||||
}
|
||||
|
||||
if err := rejectHubLevelPorts(v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := v.Unmarshal(&cfg, viper.DecodeHook(durationDecodeHook())); err != nil {
|
||||
return nil, fmt.Errorf("解析配置失败: %w", err)
|
||||
}
|
||||
|
||||
applyGlobalDefaults(&cfg.Global)
|
||||
for i := range cfg.Hubs {
|
||||
applyHubDefaults(&cfg.Hubs[i])
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
absStorage, err := filepath.Abs(cfg.Global.StoragePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("无法解析缓存目录: %w", err)
|
||||
}
|
||||
cfg.Global.StoragePath = absStorage
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func setDefaults(v *viper.Viper) {
|
||||
v.SetDefault("ListenPort", 5000)
|
||||
v.SetDefault("LogLevel", "info")
|
||||
v.SetDefault("LogFilePath", "")
|
||||
v.SetDefault("LogMaxSize", 100)
|
||||
v.SetDefault("LogMaxBackups", 10)
|
||||
v.SetDefault("LogCompress", true)
|
||||
v.SetDefault("StoragePath", "./storage")
|
||||
v.SetDefault("CacheTTL", 86400)
|
||||
v.SetDefault("MaxMemoryCacheSize", 256*1024*1024)
|
||||
v.SetDefault("MaxRetries", 3)
|
||||
v.SetDefault("InitialBackoff", "1s")
|
||||
v.SetDefault("UpstreamTimeout", "30s")
|
||||
}
|
||||
|
||||
func applyGlobalDefaults(g *GlobalConfig) {
|
||||
if g.ListenPort == 0 {
|
||||
g.ListenPort = 5000
|
||||
}
|
||||
if g.CacheTTL.DurationValue() == 0 {
|
||||
g.CacheTTL = Duration(24 * time.Hour)
|
||||
}
|
||||
if g.InitialBackoff.DurationValue() == 0 {
|
||||
g.InitialBackoff = Duration(time.Second)
|
||||
}
|
||||
if g.UpstreamTimeout.DurationValue() == 0 {
|
||||
g.UpstreamTimeout = Duration(30 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func applyHubDefaults(h *HubConfig) {
|
||||
if h.CacheTTL.DurationValue() < 0 {
|
||||
h.CacheTTL = Duration(0)
|
||||
}
|
||||
}
|
||||
|
||||
func durationDecodeHook() mapstructure.DecodeHookFunc {
|
||||
targetType := reflect.TypeOf(Duration(0))
|
||||
|
||||
return func(from reflect.Type, to reflect.Type, data interface{}) (interface{}, error) {
|
||||
if to != targetType {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
switch v := data.(type) {
|
||||
case string:
|
||||
if v == "" {
|
||||
return Duration(0), nil
|
||||
}
|
||||
if parsed, err := time.ParseDuration(v); err == nil {
|
||||
return Duration(parsed), nil
|
||||
}
|
||||
if seconds, err := strconv.ParseFloat(v, 64); err == nil {
|
||||
return Duration(time.Duration(seconds * float64(time.Second))), nil
|
||||
}
|
||||
return nil, fmt.Errorf("无法解析 Duration 字段: %s", v)
|
||||
case int:
|
||||
return Duration(time.Duration(v) * time.Second), nil
|
||||
case int64:
|
||||
return Duration(time.Duration(v) * time.Second), nil
|
||||
case float64:
|
||||
return Duration(time.Duration(v * float64(time.Second))), nil
|
||||
case time.Duration:
|
||||
return Duration(v), nil
|
||||
case Duration:
|
||||
return v, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的 Duration 类型: %T", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func rejectHubLevelPorts(v *viper.Viper) error {
|
||||
raw := v.Get("Hub")
|
||||
hubs, ok := raw.([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
for idx, entry := range hubs {
|
||||
m, ok := entry.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, exists := m["Port"]; exists {
|
||||
name := fmt.Sprintf("#%d", idx)
|
||||
if rawName, ok := m["Name"].(string); ok && rawName != "" {
|
||||
name = rawName
|
||||
}
|
||||
return newFieldError(hubField(name, "Port"), "字段已弃用,请移除并使用全局 ListenPort")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
27
internal/config/loader_test.go
Normal file
27
internal/config/loader_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package config
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestLoadFailsWithMissingFields(t *testing.T) {
|
||||
if _, err := Load(testConfigPath(t, "missing.toml")); err == nil {
|
||||
t.Fatalf("缺失字段的配置应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadRejectsInvalidDuration(t *testing.T) {
|
||||
cfg := `
|
||||
LogLevel = "info"
|
||||
StoragePath = "./data"
|
||||
CacheTTL = "boom"
|
||||
|
||||
[[Hub]]
|
||||
Name = "docker"
|
||||
Domain = "docker.local"
|
||||
Type = "docker"
|
||||
Upstream = "https://registry-1.docker.io"
|
||||
`
|
||||
path := writeTempConfig(t, cfg)
|
||||
if _, err := Load(path); err == nil {
|
||||
t.Fatalf("无效 Duration 应失败")
|
||||
}
|
||||
}
|
||||
22
internal/config/test_helpers_test.go
Normal file
22
internal/config/test_helpers_test.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func testConfigPath(t *testing.T, name string) string {
|
||||
t.Helper()
|
||||
return filepath.Join("testdata", name)
|
||||
}
|
||||
|
||||
func writeTempConfig(t *testing.T, content string) string {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.toml")
|
||||
if err := os.WriteFile(path, []byte(content), 0o600); err != nil {
|
||||
t.Fatalf("写入临时配置失败: %v", err)
|
||||
}
|
||||
return path
|
||||
}
|
||||
9
internal/config/testdata/missing.toml
vendored
Normal file
9
internal/config/testdata/missing.toml
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
ListenPort = 0
|
||||
LogLevel = "info"
|
||||
StoragePath = "./data"
|
||||
|
||||
[[Hub]]
|
||||
Name = "docker"
|
||||
Domain = "docker.local"
|
||||
Upstream = ""
|
||||
Type = ""
|
||||
9
internal/config/testdata/valid.toml
vendored
Normal file
9
internal/config/testdata/valid.toml
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
ListenPort = 5050
|
||||
LogLevel = "debug"
|
||||
StoragePath = "./data"
|
||||
|
||||
[[Hub]]
|
||||
Name = "docker"
|
||||
Domain = "docker.local"
|
||||
Upstream = "https://registry-1.docker.io"
|
||||
Type = "docker"
|
||||
105
internal/config/types.go
Normal file
105
internal/config/types.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Duration 提供更灵活的反序列化能力,同时兼容纯秒整数与 Go Duration 字符串。
|
||||
type Duration time.Duration
|
||||
|
||||
// UnmarshalText 使 Viper 可以识别诸如 "30s"、"5m" 或纯数字秒值等配置写法。
|
||||
func (d *Duration) UnmarshalText(text []byte) error {
|
||||
raw := strings.TrimSpace(string(text))
|
||||
if raw == "" {
|
||||
*d = Duration(0)
|
||||
return nil
|
||||
}
|
||||
|
||||
if seconds, err := time.ParseDuration(raw); err == nil {
|
||||
*d = Duration(seconds)
|
||||
return nil
|
||||
}
|
||||
|
||||
if intVal, err := parseInt(raw); err == nil {
|
||||
*d = Duration(time.Duration(intVal) * time.Second)
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid duration value: %s", raw)
|
||||
}
|
||||
|
||||
// DurationValue 返回真实的 time.Duration,便于调用方计算。
|
||||
func (d Duration) DurationValue() time.Duration {
|
||||
return time.Duration(d)
|
||||
}
|
||||
|
||||
// parseInt 支持十进制或 0x 前缀的十六进制字符串解析。
|
||||
func parseInt(value string) (int64, error) {
|
||||
if strings.HasPrefix(value, "0x") || strings.HasPrefix(value, "0X") {
|
||||
return strconv.ParseInt(value, 0, 64)
|
||||
}
|
||||
return strconv.ParseInt(value, 10, 64)
|
||||
}
|
||||
|
||||
// GlobalConfig 描述全局运行时行为,所有 Hub 共享同一份参数。
|
||||
type GlobalConfig struct {
|
||||
ListenPort int `mapstructure:"ListenPort"`
|
||||
LogLevel string `mapstructure:"LogLevel"`
|
||||
LogFilePath string `mapstructure:"LogFilePath"`
|
||||
LogMaxSize int `mapstructure:"LogMaxSize"`
|
||||
LogMaxBackups int `mapstructure:"LogMaxBackups"`
|
||||
LogCompress bool `mapstructure:"LogCompress"`
|
||||
StoragePath string `mapstructure:"StoragePath"`
|
||||
CacheTTL Duration `mapstructure:"CacheTTL"`
|
||||
MaxMemoryCache int64 `mapstructure:"MaxMemoryCacheSize"`
|
||||
MaxRetries int `mapstructure:"MaxRetries"`
|
||||
InitialBackoff Duration `mapstructure:"InitialBackoff"`
|
||||
UpstreamTimeout Duration `mapstructure:"UpstreamTimeout"`
|
||||
}
|
||||
|
||||
// HubConfig 决定单个代理实例如何与下游/上游交互。
|
||||
type HubConfig struct {
|
||||
Name string `mapstructure:"Name"`
|
||||
Domain string `mapstructure:"Domain"`
|
||||
Upstream string `mapstructure:"Upstream"`
|
||||
Proxy string `mapstructure:"Proxy"`
|
||||
Type string `mapstructure:"Type"`
|
||||
Username string `mapstructure:"Username"`
|
||||
Password string `mapstructure:"Password"`
|
||||
CacheTTL Duration `mapstructure:"CacheTTL"`
|
||||
EnableHeadCheck bool `mapstructure:"EnableHeadCheck"`
|
||||
}
|
||||
|
||||
// Config 是 TOML 文件映射的整体结构。
|
||||
type Config struct {
|
||||
Global GlobalConfig `mapstructure:",squash"`
|
||||
Hubs []HubConfig `mapstructure:"Hub"`
|
||||
}
|
||||
|
||||
// HasCredentials 表示当前 Hub 是否配置了完整的上游凭证。
|
||||
func (h HubConfig) HasCredentials() bool {
|
||||
return h.Username != "" && h.Password != ""
|
||||
}
|
||||
|
||||
// AuthMode 输出 `credentialed` 或 `anonymous`,供日志字段使用。
|
||||
func (h HubConfig) AuthMode() string {
|
||||
if h.HasCredentials() {
|
||||
return "credentialed"
|
||||
}
|
||||
return "anonymous"
|
||||
}
|
||||
|
||||
// CredentialModes 返回所有 Hub 的鉴权模式摘要,例如 secure:credentialed。
|
||||
func CredentialModes(hubs []HubConfig) []string {
|
||||
if len(hubs) == 0 {
|
||||
return nil
|
||||
}
|
||||
result := make([]string, len(hubs))
|
||||
for i, hub := range hubs {
|
||||
result[i] = fmt.Sprintf("%s:%s", hub.Name, hub.AuthMode())
|
||||
}
|
||||
return result
|
||||
}
|
||||
131
internal/config/validation.go
Normal file
131
internal/config/validation.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var supportedHubTypes = map[string]struct{}{
|
||||
"docker": {},
|
||||
"npm": {},
|
||||
"go": {},
|
||||
}
|
||||
|
||||
const supportedHubTypeList = "docker|npm|go"
|
||||
|
||||
// Validate 针对语义级别做进一步校验,防止非法配置启动服务。
|
||||
func (c *Config) Validate() error {
|
||||
if c == nil {
|
||||
return errors.New("配置为空")
|
||||
}
|
||||
|
||||
g := c.Global
|
||||
if g.ListenPort <= 0 || g.ListenPort > 65535 {
|
||||
return newFieldError("Global.ListenPort", "必须在 1-65535")
|
||||
}
|
||||
if g.StoragePath == "" {
|
||||
return newFieldError("Global.StoragePath", "不能为空")
|
||||
}
|
||||
if g.CacheTTL.DurationValue() <= 0 {
|
||||
return newFieldError("Global.CacheTTL", "必须大于 0")
|
||||
}
|
||||
if g.MaxMemoryCache <= 0 {
|
||||
return newFieldError("Global.MaxMemoryCacheSize", "必须大于 0")
|
||||
}
|
||||
if g.MaxRetries < 0 {
|
||||
return newFieldError("Global.MaxRetries", "不能为负数")
|
||||
}
|
||||
if g.InitialBackoff.DurationValue() <= 0 {
|
||||
return newFieldError("Global.InitialBackoff", "必须大于 0")
|
||||
}
|
||||
if g.UpstreamTimeout.DurationValue() <= 0 {
|
||||
return newFieldError("Global.UpstreamTimeout", "必须大于 0")
|
||||
}
|
||||
|
||||
if len(c.Hubs) == 0 {
|
||||
return errors.New("至少需要配置一个 Hub")
|
||||
}
|
||||
|
||||
seenNames := map[string]struct{}{}
|
||||
for i := range c.Hubs {
|
||||
hub := &c.Hubs[i]
|
||||
if hub.Name == "" {
|
||||
return newFieldError("Hub[].Name", "不能为空")
|
||||
}
|
||||
if _, exists := seenNames[hub.Name]; exists {
|
||||
return newFieldError(hubField(hub.Name, "Name"), "重复")
|
||||
}
|
||||
seenNames[hub.Name] = struct{}{}
|
||||
|
||||
if err := validateDomain(hub.Domain); err != nil {
|
||||
return fmt.Errorf("%s: %w", hubField(hub.Name, "Domain"), err)
|
||||
}
|
||||
|
||||
normalizedType := strings.ToLower(strings.TrimSpace(hub.Type))
|
||||
if normalizedType == "" {
|
||||
return newFieldError(hubField(hub.Name, "Type"), "不能为空")
|
||||
}
|
||||
if _, ok := supportedHubTypes[normalizedType]; !ok {
|
||||
return newFieldError(hubField(hub.Name, "Type"), "仅支持 "+supportedHubTypeList)
|
||||
}
|
||||
hub.Type = normalizedType
|
||||
|
||||
if (hub.Username == "") != (hub.Password == "") {
|
||||
return newFieldError(hubField(hub.Name, "Username/Password"), "必须同时提供或同时留空")
|
||||
}
|
||||
if err := validateUpstream(hub.Upstream); err != nil {
|
||||
return fmt.Errorf("%s: %w", hubField(hub.Name, "Upstream"), err)
|
||||
}
|
||||
if hub.Proxy != "" {
|
||||
if err := validateUpstream(hub.Proxy); err != nil {
|
||||
return fmt.Errorf("%s: %w", hubField(hub.Name, "Proxy"), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateDomain(domain string) error {
|
||||
if domain == "" {
|
||||
return errors.New("Domain 不能为空")
|
||||
}
|
||||
if strings.Contains(domain, "/") {
|
||||
return errors.New("Domain 不允许包含路径")
|
||||
}
|
||||
if strings.Contains(domain, " ") {
|
||||
return errors.New("Domain 不允许包含空格")
|
||||
}
|
||||
if strings.HasPrefix(domain, "http") {
|
||||
return errors.New("Domain 不应包含协议头")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateUpstream(raw string) error {
|
||||
if raw == "" {
|
||||
return errors.New("缺少上游地址")
|
||||
}
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
||||
return fmt.Errorf("仅支持 http/https,上游: %s", raw)
|
||||
}
|
||||
if parsed.Host == "" {
|
||||
return fmt.Errorf("上游缺少 Host: %s", raw)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EffectiveCacheTTL 返回特定 Hub 生效的 TTL,未覆盖时回退至全局值。
|
||||
func (c *Config) EffectiveCacheTTL(h HubConfig) time.Duration {
|
||||
if h.CacheTTL.DurationValue() > 0 {
|
||||
return h.CacheTTL.DurationValue()
|
||||
}
|
||||
return c.Global.CacheTTL.DurationValue()
|
||||
}
|
||||
22
internal/logging/fields.go
Normal file
22
internal/logging/fields.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package logging
|
||||
|
||||
import "github.com/sirupsen/logrus"
|
||||
|
||||
// BaseFields 构建 action + 配置路径等基础字段,便于不同入口复用。
|
||||
func BaseFields(action, configPath string) logrus.Fields {
|
||||
return logrus.Fields{
|
||||
"action": action,
|
||||
"configPath": configPath,
|
||||
}
|
||||
}
|
||||
|
||||
// RequestFields 提供 hub/domain/命中状态字段,供代理请求日志复用。
|
||||
func RequestFields(hub, domain, hubType, authMode string, cacheHit bool) logrus.Fields {
|
||||
return logrus.Fields{
|
||||
"hub": hub,
|
||||
"domain": domain,
|
||||
"hub_type": hubType,
|
||||
"auth_mode": authMode,
|
||||
"cache_hit": cacheHit,
|
||||
}
|
||||
}
|
||||
66
internal/logging/logger.go
Normal file
66
internal/logging/logger.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
|
||||
"github.com/any-hub/any-hub/internal/config"
|
||||
)
|
||||
|
||||
// InitLogger 根据全局配置初始化 JSON 结构化日志,确保文件/控制台输出一致。
|
||||
func InitLogger(cfg config.GlobalConfig) (*logrus.Logger, error) {
|
||||
level, err := logrus.ParseLevel(cfg.LogLevel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("无法解析日志级别: %w", err)
|
||||
}
|
||||
|
||||
output, outErr := buildOutput(cfg)
|
||||
if outErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "logger_fallback: %v\n", outErr)
|
||||
}
|
||||
|
||||
logger := logrus.New()
|
||||
logger.SetLevel(level)
|
||||
logger.SetOutput(output)
|
||||
logger.SetFormatter(&logrus.JSONFormatter{TimestampFormat: time.RFC3339Nano})
|
||||
|
||||
logrus.SetFormatter(logger.Formatter)
|
||||
logrus.SetOutput(logger.Out)
|
||||
logrus.SetLevel(logger.GetLevel())
|
||||
|
||||
if outErr != nil {
|
||||
logger.WithFields(logrus.Fields{
|
||||
"action": "logger_fallback",
|
||||
"path": cfg.LogFilePath,
|
||||
}).Warn(outErr.Error())
|
||||
}
|
||||
|
||||
return logger, nil
|
||||
}
|
||||
|
||||
// buildOutput 根据配置创建日志输出 Writer;失败时降级到 stdout 并返回错误。
|
||||
func buildOutput(cfg config.GlobalConfig) (io.Writer, error) {
|
||||
if cfg.LogFilePath == "" {
|
||||
return os.Stdout, nil
|
||||
}
|
||||
|
||||
dir := filepath.Dir(cfg.LogFilePath)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return os.Stdout, fmt.Errorf("创建日志目录失败: %w", err)
|
||||
}
|
||||
|
||||
rotator := &lumberjack.Logger{
|
||||
Filename: cfg.LogFilePath,
|
||||
MaxSize: cfg.LogMaxSize,
|
||||
MaxBackups: cfg.LogMaxBackups,
|
||||
Compress: cfg.LogCompress,
|
||||
LocalTime: true,
|
||||
}
|
||||
return rotator, nil
|
||||
}
|
||||
57
internal/logging/logger_test.go
Normal file
57
internal/logging/logger_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/any-hub/any-hub/internal/config"
|
||||
)
|
||||
|
||||
func TestConfigureDefaultsToStdout(t *testing.T) {
|
||||
logger, err := InitLogger(config.GlobalConfig{LogLevel: "info"})
|
||||
if err != nil {
|
||||
t.Fatalf("配置失败: %v", err)
|
||||
}
|
||||
if logger.Out != os.Stdout {
|
||||
t.Fatalf("未指定文件时应输出到 stdout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitLoggerFallbackOnPermissionDenied(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
blocked := filepath.Join(dir, "blocked")
|
||||
if err := os.Mkdir(blocked, 0o755); err != nil {
|
||||
t.Fatalf("创建目录失败: %v", err)
|
||||
}
|
||||
if err := os.Chmod(blocked, 0o000); err != nil {
|
||||
t.Fatalf("设置目录权限失败: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = os.Chmod(blocked, 0o755) })
|
||||
|
||||
cfg := config.GlobalConfig{
|
||||
LogLevel: "info",
|
||||
LogFilePath: filepath.Join(blocked, "sub", "any-hub.log"),
|
||||
}
|
||||
logger, err := InitLogger(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("初始化不应失败: %v", err)
|
||||
}
|
||||
if logger.Out != os.Stdout {
|
||||
t.Fatalf("fallback 时应退回 stdout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureCreatesRotatingFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "any-hub.log")
|
||||
cfg := config.GlobalConfig{LogLevel: "debug", LogFilePath: path}
|
||||
logger, err := InitLogger(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("配置失败: %v", err)
|
||||
}
|
||||
logger.Info("test")
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
t.Fatalf("预期创建日志文件: %v", err)
|
||||
}
|
||||
}
|
||||
68
internal/proxy/docker_path_test.go
Normal file
68
internal/proxy/docker_path_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/any-hub/any-hub/internal/config"
|
||||
"github.com/any-hub/any-hub/internal/server"
|
||||
)
|
||||
|
||||
func TestApplyDockerHubNamespaceFallback(t *testing.T) {
|
||||
route := dockerHubRoute(t, "https://registry-1.docker.io")
|
||||
|
||||
path, changed := applyDockerHubNamespaceFallback(route, "/v2/nginx/manifests/latest")
|
||||
if !changed {
|
||||
t.Fatalf("expected fallback to apply")
|
||||
}
|
||||
if path != "/v2/library/nginx/manifests/latest" {
|
||||
t.Fatalf("unexpected normalized path: %s", path)
|
||||
}
|
||||
|
||||
path, changed = applyDockerHubNamespaceFallback(route, "/v2/library/nginx/manifests/latest")
|
||||
if changed {
|
||||
t.Fatalf("expected no changes for already-namespaced repo")
|
||||
}
|
||||
|
||||
path, changed = applyDockerHubNamespaceFallback(route, "/v2/rogee/nginx/manifests/latest")
|
||||
if changed {
|
||||
t.Fatalf("expected no changes for custom namespace")
|
||||
}
|
||||
|
||||
path, changed = applyDockerHubNamespaceFallback(route, "/v2/_catalog")
|
||||
if changed {
|
||||
t.Fatalf("expected no changes for _catalog endpoint")
|
||||
}
|
||||
|
||||
otherRoute := dockerHubRoute(t, "https://registry.example.com")
|
||||
path, changed = applyDockerHubNamespaceFallback(otherRoute, "/v2/nginx/manifests/latest")
|
||||
if changed || path != "/v2/nginx/manifests/latest" {
|
||||
t.Fatalf("expected no changes for non-docker-hub upstream")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitDockerRepoPath(t *testing.T) {
|
||||
repo, rest, ok := splitDockerRepoPath("/v2/library/nginx/manifests/latest")
|
||||
if !ok || repo != "library/nginx" || rest != "/manifests/latest" {
|
||||
t.Fatalf("unexpected split result repo=%s rest=%s ok=%v", repo, rest, ok)
|
||||
}
|
||||
|
||||
if _, _, ok := splitDockerRepoPath("/v2/_catalog"); ok {
|
||||
t.Fatalf("expected catalog path to be ignored")
|
||||
}
|
||||
}
|
||||
|
||||
func dockerHubRoute(t *testing.T, upstream string) *server.HubRoute {
|
||||
t.Helper()
|
||||
parsed, err := url.Parse(upstream)
|
||||
if err != nil {
|
||||
t.Fatalf("invalid upstream: %v", err)
|
||||
}
|
||||
return &server.HubRoute{
|
||||
Config: config.HubConfig{
|
||||
Name: "docker",
|
||||
Type: "docker",
|
||||
},
|
||||
UpstreamURL: parsed,
|
||||
}
|
||||
}
|
||||
772
internal/proxy/handler.go
Normal file
772
internal/proxy/handler.go
Normal file
@@ -0,0 +1,772 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/any-hub/any-hub/internal/cache"
|
||||
"github.com/any-hub/any-hub/internal/logging"
|
||||
"github.com/any-hub/any-hub/internal/server"
|
||||
)
|
||||
|
||||
// Handler 负责 orchestrate “缓存命中 → revalidate → 回源写缓存” 的全流程,
|
||||
// 对外暴露 Fiber handler,内部复用共享 http.Client 与磁盘缓存。
|
||||
type Handler struct {
|
||||
client *http.Client
|
||||
logger *logrus.Logger
|
||||
store cache.Store
|
||||
}
|
||||
|
||||
// NewHandler constructs a proxy handler with shared HTTP client/logger/store.
|
||||
func NewHandler(client *http.Client, logger *logrus.Logger, store cache.Store) *Handler {
|
||||
return &Handler{
|
||||
client: client,
|
||||
logger: logger,
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
// Handle 执行缓存查找、条件回源和最终 streaming 逻辑,任何阶段出错都会输出结构化日志。
|
||||
func (h *Handler) Handle(c fiber.Ctx, route *server.HubRoute) error {
|
||||
started := time.Now()
|
||||
requestID := server.RequestID(c)
|
||||
locator := buildLocator(route, c)
|
||||
policy := determineCachePolicy(route, locator, c.Method())
|
||||
|
||||
if err := ensureProxyHubType(route); err != nil {
|
||||
h.logger.WithField("hub", route.Config.Name).WithError(err).Error("hub_type_unsupported")
|
||||
return h.writeError(c, fiber.StatusNotImplemented, "hub_type_unsupported")
|
||||
}
|
||||
|
||||
ctx := c.Context()
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
var cached *cache.ReadResult
|
||||
if h.store != nil && policy.allowCache {
|
||||
result, err := h.store.Get(ctx, locator)
|
||||
switch {
|
||||
case err == nil:
|
||||
cached = result
|
||||
case errors.Is(err, cache.ErrNotFound):
|
||||
// miss, continue
|
||||
default:
|
||||
h.logger.WithError(err).WithField("hub", route.Config.Name).Warn("cache_get_failed")
|
||||
}
|
||||
}
|
||||
|
||||
if cached != nil {
|
||||
serve := true
|
||||
if policy.requireRevalidate {
|
||||
fresh, err := h.isCacheFresh(c, route, locator, cached.Entry)
|
||||
if err != nil {
|
||||
h.logger.WithError(err).WithField("hub", route.Config.Name).Warn("cache_revalidate_failed")
|
||||
serve = false
|
||||
} else if !fresh {
|
||||
serve = false
|
||||
}
|
||||
}
|
||||
if serve {
|
||||
defer cached.Reader.Close()
|
||||
return h.serveCache(c, route, cached, requestID, started)
|
||||
}
|
||||
cached.Reader.Close()
|
||||
}
|
||||
|
||||
return h.fetchAndStream(c, route, locator, policy, requestID, started, ctx)
|
||||
}
|
||||
|
||||
func (h *Handler) serveCache(c fiber.Ctx, route *server.HubRoute, result *cache.ReadResult, requestID string, started time.Time) error {
|
||||
if seeker, ok := result.Reader.(io.Seeker); ok {
|
||||
_, _ = seeker.Seek(0, io.SeekStart)
|
||||
}
|
||||
|
||||
method := c.Method()
|
||||
|
||||
contentType := inferCachedContentType(route, result.Entry.Locator)
|
||||
if contentType != "" {
|
||||
c.Set("Content-Type", contentType)
|
||||
} else {
|
||||
c.Response().Header.Del("Content-Type")
|
||||
}
|
||||
|
||||
length := result.Entry.SizeBytes
|
||||
if length > 0 {
|
||||
c.Response().Header.SetContentLength(int(length))
|
||||
} else {
|
||||
c.Response().Header.Del("Content-Length")
|
||||
}
|
||||
|
||||
c.Set("X-Any-Hub-Upstream", route.UpstreamURL.String())
|
||||
c.Set("X-Any-Hub-Cache-Hit", "true")
|
||||
if requestID != "" {
|
||||
c.Set("X-Request-ID", requestID)
|
||||
}
|
||||
|
||||
status := fiber.StatusOK
|
||||
c.Status(status)
|
||||
|
||||
if method == http.MethodHead {
|
||||
result.Reader.Close()
|
||||
h.logResult(route, route.UpstreamURL.String(), requestID, status, true, started, nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := io.Copy(c.Response().BodyWriter(), result.Reader)
|
||||
result.Reader.Close()
|
||||
h.logResult(route, route.UpstreamURL.String(), requestID, status, true, started, err)
|
||||
if err != nil {
|
||||
return fiber.NewError(fiber.StatusBadGateway, fmt.Sprintf("read cache failed: %v", err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) fetchAndStream(c fiber.Ctx, route *server.HubRoute, locator cache.Locator, policy cachePolicy, requestID string, started time.Time, ctx context.Context) error {
|
||||
resp, upstreamURL, err := h.executeRequest(c, route)
|
||||
if err != nil {
|
||||
h.logResult(route, upstreamURL.String(), requestID, 0, false, started, err)
|
||||
return h.writeError(c, fiber.StatusBadGateway, "upstream_failed")
|
||||
}
|
||||
|
||||
resp, upstreamURL, err = h.retryOnAuthFailure(c, route, requestID, started, resp, upstreamURL)
|
||||
if err != nil {
|
||||
h.logResult(route, upstreamURL.String(), requestID, 0, false, started, err)
|
||||
return h.writeError(c, fiber.StatusBadGateway, "upstream_failed")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
shouldStore := policy.allowStore && h.store != nil && isCacheableStatus(resp.StatusCode) && c.Method() == http.MethodGet
|
||||
return h.consumeUpstream(c, route, locator, resp, shouldStore, requestID, started, ctx)
|
||||
}
|
||||
|
||||
func (h *Handler) consumeUpstream(c fiber.Ctx, route *server.HubRoute, locator cache.Locator, resp *http.Response, shouldStore bool, requestID string, started time.Time, ctx context.Context) error {
|
||||
upstreamURL := resp.Request.URL.String()
|
||||
method := c.Method()
|
||||
authFailure := isAuthFailure(resp.StatusCode) && route.Config.HasCredentials()
|
||||
|
||||
if shouldStore {
|
||||
return h.cacheAndStream(c, route, locator, resp, requestID, started, ctx, upstreamURL)
|
||||
}
|
||||
|
||||
copyResponseHeaders(c, resp.Header)
|
||||
c.Set("X-Any-Hub-Upstream", upstreamURL)
|
||||
c.Set("X-Any-Hub-Cache-Hit", "false")
|
||||
if requestID != "" {
|
||||
c.Set("X-Request-ID", requestID)
|
||||
}
|
||||
c.Status(resp.StatusCode)
|
||||
|
||||
if authFailure {
|
||||
h.logAuthFailure(route, upstreamURL, requestID, resp.StatusCode)
|
||||
}
|
||||
|
||||
if method == http.MethodHead {
|
||||
h.logResult(route, upstreamURL, requestID, resp.StatusCode, false, started, nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := io.Copy(c.Response().BodyWriter(), resp.Body)
|
||||
h.logResult(route, upstreamURL, requestID, resp.StatusCode, false, started, err)
|
||||
if err != nil {
|
||||
return fiber.NewError(fiber.StatusBadGateway, fmt.Sprintf("proxy stream failed: %v", err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) cacheAndStream(c fiber.Ctx, route *server.HubRoute, locator cache.Locator, resp *http.Response, requestID string, started time.Time, ctx context.Context, upstreamURL string) error {
|
||||
copyResponseHeaders(c, resp.Header)
|
||||
c.Set("X-Any-Hub-Upstream", upstreamURL)
|
||||
c.Set("X-Any-Hub-Cache-Hit", "false")
|
||||
if requestID != "" {
|
||||
c.Set("X-Request-ID", requestID)
|
||||
}
|
||||
c.Status(resp.StatusCode)
|
||||
|
||||
reader := io.TeeReader(resp.Body, c.Response().BodyWriter())
|
||||
|
||||
opts := cache.PutOptions{ModTime: extractModTime(resp.Header)}
|
||||
entry, err := h.store.Put(ctx, locator, reader, opts)
|
||||
h.logResult(route, upstreamURL, requestID, resp.StatusCode, false, started, err)
|
||||
if err != nil {
|
||||
return fiber.NewError(fiber.StatusBadGateway, fmt.Sprintf("cache_write_failed: %v", err))
|
||||
}
|
||||
_ = entry
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) retryOnAuthFailure(c fiber.Ctx, route *server.HubRoute, requestID string, started time.Time, resp *http.Response, upstreamURL *url.URL) (*http.Response, *url.URL, error) {
|
||||
if !shouldRetryAuth(route, resp.StatusCode) {
|
||||
return resp, upstreamURL, nil
|
||||
}
|
||||
|
||||
challenge, ok := parseBearerChallenge(resp.Header.Values("Www-Authenticate"))
|
||||
h.logAuthRetry(route, upstreamURL.String(), requestID, resp.StatusCode)
|
||||
resp.Body.Close()
|
||||
|
||||
if ok {
|
||||
ctx := c.Context()
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
token, err := h.fetchBearerToken(ctx, challenge, route)
|
||||
if err != nil {
|
||||
return nil, upstreamURL, err
|
||||
}
|
||||
authHeader := "Bearer " + token
|
||||
retryResp, retryURL, err := h.executeRequestWithAuth(c, route, authHeader)
|
||||
if err != nil {
|
||||
return nil, upstreamURL, err
|
||||
}
|
||||
return retryResp, retryURL, nil
|
||||
}
|
||||
|
||||
retryResp, retryURL, err := h.executeRequest(c, route)
|
||||
if err != nil {
|
||||
return nil, upstreamURL, err
|
||||
}
|
||||
return retryResp, retryURL, nil
|
||||
}
|
||||
|
||||
func (h *Handler) executeRequest(c fiber.Ctx, route *server.HubRoute) (*http.Response, *url.URL, error) {
|
||||
return h.executeRequestWithAuth(c, route, "")
|
||||
}
|
||||
|
||||
func (h *Handler) executeRequestWithAuth(c fiber.Ctx, route *server.HubRoute, authHeader string) (*http.Response, *url.URL, error) {
|
||||
upstreamURL := resolveUpstreamURL(route, route.UpstreamURL, c)
|
||||
body := bytesReader(c.Body())
|
||||
req, err := h.buildUpstreamRequest(c, upstreamURL, route, c.Method(), body, authHeader)
|
||||
if err != nil {
|
||||
return nil, upstreamURL, err
|
||||
}
|
||||
|
||||
resp, err := h.doRequest(req, route)
|
||||
return resp, upstreamURL, err
|
||||
}
|
||||
|
||||
func (h *Handler) buildUpstreamRequest(c fiber.Ctx, upstream *url.URL, route *server.HubRoute, method string, body io.Reader, overrideAuth string) (*http.Request, error) {
|
||||
ctx := c.Context()
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if body == nil {
|
||||
body = http.NoBody
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, upstream.String(), body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
requestHeaders := fiberHeadersAsHTTP(c)
|
||||
server.CopyHeaders(req.Header, requestHeaders)
|
||||
req.Host = upstream.Host
|
||||
req.Header.Set("Host", upstream.Host)
|
||||
req.Header.Set("X-Forwarded-Host", c.Hostname())
|
||||
if ip := c.IP(); ip != "" {
|
||||
if prior := req.Header.Get("X-Forwarded-For"); prior != "" {
|
||||
req.Header.Set("X-Forwarded-For", prior+", "+ip)
|
||||
} else {
|
||||
req.Header.Set("X-Forwarded-For", ip)
|
||||
}
|
||||
}
|
||||
req.Header.Set("X-Forwarded-Proto", c.Protocol())
|
||||
req.Header.Set("X-Forwarded-Port", routePort(route))
|
||||
|
||||
if overrideAuth != "" {
|
||||
req.Header.Set("Authorization", overrideAuth)
|
||||
} else if authHeader := buildCredentialHeader(route.Config.Username, route.Config.Password); authHeader != "" {
|
||||
req.Header.Set("Authorization", authHeader)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (h *Handler) doRequest(req *http.Request, route *server.HubRoute) (*http.Response, error) {
|
||||
if route.ProxyURL == nil {
|
||||
return h.client.Do(req)
|
||||
}
|
||||
transport := http.Transport{}
|
||||
if base, ok := h.client.Transport.(*http.Transport); ok && base != nil {
|
||||
transport = *base.Clone()
|
||||
}
|
||||
transport.Proxy = http.ProxyURL(route.ProxyURL)
|
||||
client := *h.client
|
||||
client.Transport = &transport
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (h *Handler) writeError(c fiber.Ctx, status int, code string) error {
|
||||
return c.Status(status).JSON(fiber.Map{"error": code})
|
||||
}
|
||||
|
||||
func (h *Handler) logResult(route *server.HubRoute, upstream string, requestID string, status int, cacheHit bool, started time.Time, err error) {
|
||||
fields := logging.RequestFields(route.Config.Name, route.Config.Domain, route.Config.Type, route.Config.AuthMode(), cacheHit)
|
||||
fields["action"] = "proxy"
|
||||
fields["upstream"] = upstream
|
||||
fields["upstream_status"] = status
|
||||
fields["elapsed_ms"] = time.Since(started).Milliseconds()
|
||||
if requestID != "" {
|
||||
fields["request_id"] = requestID
|
||||
}
|
||||
if err != nil {
|
||||
fields["error"] = err.Error()
|
||||
h.logger.WithFields(fields).Error("proxy_failed")
|
||||
return
|
||||
}
|
||||
h.logger.WithFields(fields).Info("proxy_complete")
|
||||
}
|
||||
|
||||
func inferCachedContentType(route *server.HubRoute, locator cache.Locator) string {
|
||||
clean := stripQueryMarker(locator.Path)
|
||||
switch {
|
||||
case strings.HasSuffix(clean, ".zip"):
|
||||
return "application/zip"
|
||||
case strings.HasSuffix(clean, ".mod"):
|
||||
return "text/plain"
|
||||
case strings.HasSuffix(clean, ".info"):
|
||||
return "application/json"
|
||||
case strings.HasSuffix(clean, ".tgz"):
|
||||
return "application/octet-stream"
|
||||
case strings.HasSuffix(clean, "/@v/list"):
|
||||
return "text/plain"
|
||||
}
|
||||
|
||||
if route != nil {
|
||||
switch route.Config.Type {
|
||||
case "docker":
|
||||
if strings.Contains(clean, "/manifests/") {
|
||||
return "application/vnd.docker.distribution.manifest.v2+json"
|
||||
}
|
||||
if strings.Contains(clean, "/tags/list") {
|
||||
return "application/json"
|
||||
}
|
||||
if strings.Contains(clean, "/blobs/") {
|
||||
return "application/octet-stream"
|
||||
}
|
||||
case "npm":
|
||||
if strings.HasSuffix(clean, ".json") {
|
||||
return "application/json"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func buildLocator(route *server.HubRoute, c fiber.Ctx) cache.Locator {
|
||||
uri := c.Request().URI()
|
||||
pathVal := string(uri.Path())
|
||||
if pathVal == "" {
|
||||
pathVal = "/"
|
||||
}
|
||||
clean := path.Clean("/" + pathVal)
|
||||
if newPath, ok := applyDockerHubNamespaceFallback(route, clean); ok {
|
||||
clean = newPath
|
||||
}
|
||||
query := uri.QueryString()
|
||||
if len(query) > 0 {
|
||||
sum := sha1.Sum(query)
|
||||
clean = fmt.Sprintf("%s/__qs/%s", clean, hex.EncodeToString(sum[:]))
|
||||
}
|
||||
return cache.Locator{
|
||||
HubName: route.Config.Name,
|
||||
Path: clean,
|
||||
}
|
||||
}
|
||||
|
||||
func stripQueryMarker(p string) string {
|
||||
if idx := strings.Index(p, "/__qs/"); idx >= 0 {
|
||||
return p[:idx]
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func requestPath(c fiber.Ctx) string {
|
||||
if c == nil {
|
||||
return "/"
|
||||
}
|
||||
uri := c.Request().URI()
|
||||
if uri == nil {
|
||||
return "/"
|
||||
}
|
||||
pathVal := string(uri.Path())
|
||||
if pathVal == "" {
|
||||
return "/"
|
||||
}
|
||||
return pathVal
|
||||
}
|
||||
|
||||
func bytesReader(b []byte) io.Reader {
|
||||
if len(b) == 0 {
|
||||
return http.NoBody
|
||||
}
|
||||
return bytes.NewReader(b)
|
||||
}
|
||||
|
||||
func resolveUpstreamURL(route *server.HubRoute, base *url.URL, c fiber.Ctx) *url.URL {
|
||||
uri := c.Request().URI()
|
||||
pathVal := string(uri.Path())
|
||||
relative := &url.URL{Path: pathVal, RawPath: pathVal}
|
||||
if newPath, ok := applyDockerHubNamespaceFallback(route, relative.Path); ok {
|
||||
relative.Path = newPath
|
||||
relative.RawPath = newPath
|
||||
}
|
||||
if query := string(uri.QueryString()); query != "" {
|
||||
relative.RawQuery = query
|
||||
}
|
||||
return base.ResolveReference(relative)
|
||||
}
|
||||
|
||||
func fiberHeadersAsHTTP(c fiber.Ctx) http.Header {
|
||||
header := http.Header{}
|
||||
c.Request().Header.VisitAll(func(key, value []byte) {
|
||||
header.Add(string(key), string(value))
|
||||
})
|
||||
return header
|
||||
}
|
||||
|
||||
func copyResponseHeaders(c fiber.Ctx, headers http.Header) {
|
||||
for key, values := range headers {
|
||||
if server.IsHopByHopHeader(key) {
|
||||
continue
|
||||
}
|
||||
for _, value := range values {
|
||||
c.Set(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func routePort(route *server.HubRoute) string {
|
||||
if route == nil || route.ListenPort <= 0 {
|
||||
return "0"
|
||||
}
|
||||
return fmt.Sprintf("%d", route.ListenPort)
|
||||
}
|
||||
|
||||
type cachePolicy struct {
|
||||
allowCache bool
|
||||
allowStore bool
|
||||
requireRevalidate bool
|
||||
}
|
||||
|
||||
func determineCachePolicy(route *server.HubRoute, locator cache.Locator, method string) cachePolicy {
|
||||
if route == nil || method != http.MethodGet {
|
||||
return cachePolicy{}
|
||||
}
|
||||
policy := cachePolicy{allowCache: true, allowStore: true}
|
||||
path := stripQueryMarker(locator.Path)
|
||||
switch route.Config.Type {
|
||||
case "docker":
|
||||
if path == "/v2" || path == "v2" || path == "/v2/" {
|
||||
return cachePolicy{}
|
||||
}
|
||||
if strings.Contains(path, "/_catalog") {
|
||||
return cachePolicy{}
|
||||
}
|
||||
if isDockerImmutablePath(path) {
|
||||
return policy
|
||||
}
|
||||
policy.requireRevalidate = true
|
||||
return policy
|
||||
case "go":
|
||||
if strings.Contains(path, "/@v/") && (strings.HasSuffix(path, ".zip") || strings.HasSuffix(path, ".mod") || strings.HasSuffix(path, ".info")) {
|
||||
return policy
|
||||
}
|
||||
policy.requireRevalidate = true
|
||||
return policy
|
||||
case "npm":
|
||||
if strings.Contains(path, "/-/") && strings.HasSuffix(path, ".tgz") {
|
||||
return policy
|
||||
}
|
||||
policy.requireRevalidate = true
|
||||
return policy
|
||||
default:
|
||||
return policy
|
||||
}
|
||||
}
|
||||
|
||||
func isDockerImmutablePath(path string) bool {
|
||||
if strings.Contains(path, "/blobs/sha256:") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(path, "/manifests/sha256:") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isCacheableStatus(status int) bool {
|
||||
return status == http.StatusOK
|
||||
}
|
||||
|
||||
func (h *Handler) isCacheFresh(c fiber.Ctx, route *server.HubRoute, locator cache.Locator, entry cache.Entry) (bool, error) {
|
||||
ctx := c.Context()
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
upstreamURL := resolveUpstreamURL(route, route.UpstreamURL, c)
|
||||
req, err := h.buildUpstreamRequest(c, upstreamURL, route, http.MethodHead, http.NoBody, "")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
resp, err := h.doRequest(req, route)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusOK:
|
||||
remote := extractModTime(resp.Header)
|
||||
if !remote.After(entry.ModTime.Add(time.Second)) {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
case http.StatusNotFound:
|
||||
if h.store != nil {
|
||||
_ = h.store.Remove(ctx, locator)
|
||||
}
|
||||
return false, nil
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func extractModTime(header http.Header) time.Time {
|
||||
if last := header.Get("Last-Modified"); last != "" {
|
||||
if parsed, err := http.ParseTime(last); err == nil {
|
||||
return parsed.UTC()
|
||||
}
|
||||
}
|
||||
return time.Now().UTC()
|
||||
}
|
||||
|
||||
func applyDockerHubNamespaceFallback(route *server.HubRoute, path string) (string, bool) {
|
||||
if !isDockerHubRoute(route) {
|
||||
return path, false
|
||||
}
|
||||
repo, rest, ok := splitDockerRepoPath(path)
|
||||
if !ok || repo == "" {
|
||||
return path, false
|
||||
}
|
||||
if repo == "library" || strings.Contains(repo, "/") {
|
||||
return path, false
|
||||
}
|
||||
normalized := "/v2/library/" + repo + rest
|
||||
return normalized, true
|
||||
}
|
||||
|
||||
func isDockerHubRoute(route *server.HubRoute) bool {
|
||||
if route == nil || route.Config.Type != "docker" || route.UpstreamURL == nil {
|
||||
return false
|
||||
}
|
||||
host := strings.ToLower(route.UpstreamURL.Hostname())
|
||||
switch host {
|
||||
case "registry-1.docker.io", "docker.io", "index.docker.io":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func splitDockerRepoPath(path string) (string, string, bool) {
|
||||
if !strings.HasPrefix(path, "/v2/") {
|
||||
return "", "", false
|
||||
}
|
||||
suffix := strings.TrimPrefix(path, "/v2/")
|
||||
if suffix == "" || suffix == "/" {
|
||||
return "", "", false
|
||||
}
|
||||
segments := strings.Split(suffix, "/")
|
||||
var repoSegments []string
|
||||
for i, seg := range segments {
|
||||
if seg == "" {
|
||||
return "", "", false
|
||||
}
|
||||
switch seg {
|
||||
case "manifests", "blobs", "tags", "referrers":
|
||||
if len(repoSegments) == 0 {
|
||||
return "", "", false
|
||||
}
|
||||
rest := "/" + strings.Join(segments[i:], "/")
|
||||
return strings.Join(repoSegments, "/"), rest, true
|
||||
case "_catalog":
|
||||
return "", "", false
|
||||
}
|
||||
repoSegments = append(repoSegments, seg)
|
||||
}
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
type bearerChallenge struct {
|
||||
Realm string
|
||||
Service string
|
||||
Scope string
|
||||
}
|
||||
|
||||
func parseBearerChallenge(values []string) (bearerChallenge, bool) {
|
||||
for _, raw := range values {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
if !strings.HasPrefix(strings.ToLower(raw), "bearer ") {
|
||||
continue
|
||||
}
|
||||
params := parseAuthParams(raw[len("Bearer "):])
|
||||
challenge := bearerChallenge{
|
||||
Realm: params["realm"],
|
||||
Service: params["service"],
|
||||
Scope: params["scope"],
|
||||
}
|
||||
if challenge.Realm == "" {
|
||||
continue
|
||||
}
|
||||
return challenge, true
|
||||
}
|
||||
return bearerChallenge{}, false
|
||||
}
|
||||
|
||||
func parseAuthParams(input string) map[string]string {
|
||||
params := make(map[string]string)
|
||||
parts := strings.Split(input, ",")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
kv := strings.SplitN(part, "=", 2)
|
||||
if len(kv) != 2 {
|
||||
continue
|
||||
}
|
||||
key := strings.ToLower(strings.TrimSpace(kv[0]))
|
||||
value := strings.Trim(strings.TrimSpace(kv[1]), `"`)
|
||||
params[key] = value
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
func (h *Handler) fetchBearerToken(ctx context.Context, challenge bearerChallenge, route *server.HubRoute) (string, error) {
|
||||
if challenge.Realm == "" {
|
||||
return "", errors.New("bearer realm missing")
|
||||
}
|
||||
tokenURL, err := url.Parse(challenge.Realm)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid bearer realm: %w", err)
|
||||
}
|
||||
query := tokenURL.Query()
|
||||
if challenge.Service != "" {
|
||||
query.Set("service", challenge.Service)
|
||||
}
|
||||
if challenge.Scope != "" {
|
||||
query.Set("scope", challenge.Scope)
|
||||
}
|
||||
tokenURL.RawQuery = query.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, tokenURL.String(), nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if route.Config.Username != "" && route.Config.Password != "" {
|
||||
req.SetBasicAuth(route.Config.Username, route.Config.Password)
|
||||
}
|
||||
|
||||
resp, err := h.client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return "", fmt.Errorf("token request failed: status=%d body=%s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
Token string `json:"token"`
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return "", fmt.Errorf("decode token response: %w", err)
|
||||
}
|
||||
|
||||
token := tokenResp.Token
|
||||
if token == "" {
|
||||
token = tokenResp.AccessToken
|
||||
}
|
||||
if token == "" {
|
||||
return "", errors.New("token response missing token value")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func buildCredentialHeader(username, password string) string {
|
||||
if username == "" || password == "" {
|
||||
return ""
|
||||
}
|
||||
token := username + ":" + password
|
||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(token))
|
||||
}
|
||||
|
||||
func shouldRetryAuth(route *server.HubRoute, status int) bool {
|
||||
return route != nil && route.Config.HasCredentials() && isAuthFailure(status)
|
||||
}
|
||||
|
||||
func isAuthFailure(status int) bool {
|
||||
return status == http.StatusUnauthorized || status == http.StatusTooManyRequests
|
||||
}
|
||||
|
||||
func (h *Handler) logAuthRetry(route *server.HubRoute, upstream string, requestID string, status int) {
|
||||
fields := logging.RequestFields(route.Config.Name, route.Config.Domain, route.Config.Type, route.Config.AuthMode(), false)
|
||||
fields["action"] = "proxy_retry"
|
||||
fields["upstream"] = upstream
|
||||
fields["upstream_status"] = status
|
||||
fields["reason"] = "auth_retry"
|
||||
if requestID != "" {
|
||||
fields["request_id"] = requestID
|
||||
}
|
||||
h.logger.WithFields(fields).Warn("proxy_auth_retry")
|
||||
}
|
||||
|
||||
func (h *Handler) logAuthFailure(route *server.HubRoute, upstream string, requestID string, status int) {
|
||||
fields := logging.RequestFields(route.Config.Name, route.Config.Domain, route.Config.Type, route.Config.AuthMode(), false)
|
||||
fields["action"] = "proxy"
|
||||
fields["upstream"] = upstream
|
||||
fields["upstream_status"] = status
|
||||
fields["error"] = "upstream_auth_failed"
|
||||
if requestID != "" {
|
||||
fields["request_id"] = requestID
|
||||
}
|
||||
h.logger.WithFields(fields).Error("proxy_auth_failed")
|
||||
}
|
||||
|
||||
func ensureProxyHubType(route *server.HubRoute) error {
|
||||
switch route.Config.Type {
|
||||
case "docker":
|
||||
return nil
|
||||
case "npm":
|
||||
return nil
|
||||
case "go":
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unsupported hub type: %s", route.Config.Type)
|
||||
}
|
||||
}
|
||||
8
internal/server/doc.go
Normal file
8
internal/server/doc.go
Normal file
@@ -0,0 +1,8 @@
|
||||
// Package server hosts the Fiber HTTP service, request middleware chain, and
|
||||
// hub registry glue that wires Host/port resolution into proxy handlers.
|
||||
// Phase 1 focuses on a single binary that bootstraps Fiber, attaches logging
|
||||
// and error middlewares, injects the HubRegistry built from config, and exposes
|
||||
// router constructors that other packages (cmd/any-hub, proxy) can reuse.
|
||||
// Future phases may extend this package with TLS, metrics endpoints, or admin
|
||||
// surfaces, so keep exports narrow and accept explicit dependencies.
|
||||
package server
|
||||
77
internal/server/http_client.go
Normal file
77
internal/server/http_client.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"time"
|
||||
|
||||
"github.com/any-hub/any-hub/internal/config"
|
||||
)
|
||||
|
||||
// Shared HTTP transport tunings,复用长连接并集中配置超时。
|
||||
var defaultTransport = &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
}
|
||||
|
||||
// NewUpstreamClient 返回共享 http.Client,用于所有上游请求。
|
||||
func NewUpstreamClient(cfg *config.Config) *http.Client {
|
||||
timeout := 30 * time.Second
|
||||
if cfg != nil && cfg.Global.UpstreamTimeout.DurationValue() > 0 {
|
||||
timeout = cfg.Global.UpstreamTimeout.DurationValue()
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: defaultTransport.Clone(),
|
||||
}
|
||||
}
|
||||
|
||||
// hopByHopHeaders 定义 RFC 7230 中禁止代理转发的头部。
|
||||
var hopByHopHeaders = map[string]struct{}{
|
||||
"Connection": {},
|
||||
"Keep-Alive": {},
|
||||
"Proxy-Authenticate": {},
|
||||
"Proxy-Authorization": {},
|
||||
"Te": {},
|
||||
"Trailer": {},
|
||||
"Transfer-Encoding": {},
|
||||
"Upgrade": {},
|
||||
"Proxy-Connection": {}, // 非标准字段,但部分代理仍使用
|
||||
}
|
||||
|
||||
// CopyHeaders 将 src 中允许透传的头复制到 dst,自动忽略 hop-by-hop 字段。
|
||||
func CopyHeaders(dst, src http.Header) {
|
||||
for key, values := range src {
|
||||
if isHopByHopHeader(key) {
|
||||
continue
|
||||
}
|
||||
for _, value := range values {
|
||||
dst.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isHopByHopHeader(key string) bool {
|
||||
canonical := textproto.CanonicalMIMEHeaderKey(key)
|
||||
if _, ok := hopByHopHeaders[canonical]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsHopByHopHeader reports whether the header should be stripped by proxies.
|
||||
func IsHopByHopHeader(key string) bool {
|
||||
return isHopByHopHeader(key)
|
||||
}
|
||||
45
internal/server/http_client_test.go
Normal file
45
internal/server/http_client_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/any-hub/any-hub/internal/config"
|
||||
)
|
||||
|
||||
func TestNewUpstreamClientUsesConfigTimeout(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Global: config.GlobalConfig{
|
||||
UpstreamTimeout: config.Duration(45 * time.Second),
|
||||
},
|
||||
}
|
||||
|
||||
client := NewUpstreamClient(cfg)
|
||||
if client.Timeout != 45*time.Second {
|
||||
t.Fatalf("expected timeout 45s, got %s", client.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyHeadersSkipsHopByHop(t *testing.T) {
|
||||
src := http.Header{}
|
||||
src.Add("Connection", "keep-alive")
|
||||
src.Add("Keep-Alive", "timeout=5")
|
||||
src.Add("X-Test-Header", "1")
|
||||
src.Add("x-test-header", "2")
|
||||
|
||||
dst := http.Header{}
|
||||
CopyHeaders(dst, src)
|
||||
|
||||
if _, exists := dst["Connection"]; exists {
|
||||
t.Fatalf("connection header should not be copied")
|
||||
}
|
||||
if _, exists := dst["Keep-Alive"]; exists {
|
||||
t.Fatalf("keep-alive header should not be copied")
|
||||
}
|
||||
|
||||
got := dst.Values("X-Test-Header")
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected 2 values, got %v", got)
|
||||
}
|
||||
}
|
||||
152
internal/server/hub_registry.go
Normal file
152
internal/server/hub_registry.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/any-hub/any-hub/internal/config"
|
||||
)
|
||||
|
||||
// HubRoute 将 Hub 配置与派生属性(如缓存 TTL、解析后的 Upstream/Proxy URL)
|
||||
// 聚合在一起,供路由/代理层直接复用,避免重复解析配置。
|
||||
type HubRoute struct {
|
||||
// Config 是用户在 config.toml 中声明的 Hub 字段副本,避免外部修改。
|
||||
Config config.HubConfig
|
||||
// ListenPort 记录当前 CLI 监听端口,方便日志/转发头输出。
|
||||
ListenPort int
|
||||
// CacheTTL 是对当前 Hub 生效的 TTL,若 Hub 未覆盖则等于全局值。
|
||||
CacheTTL time.Duration
|
||||
// UpstreamURL/ProxyURL 在构造 Registry 时提前解析完成,便于后续请求快速复用。
|
||||
UpstreamURL *url.URL
|
||||
ProxyURL *url.URL
|
||||
}
|
||||
|
||||
// HubRegistry 提供 Host/Host:port 到 HubRoute 的查询能力,所有 Hub 共享同一个监听端口。
|
||||
type HubRegistry struct {
|
||||
routes map[string]*HubRoute
|
||||
ordered []*HubRoute
|
||||
}
|
||||
|
||||
// NewHubRegistry 根据配置构建 Host/端口映射。调用方应在启动阶段创建一次并复用。
|
||||
func NewHubRegistry(cfg *config.Config) (*HubRegistry, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is nil")
|
||||
}
|
||||
|
||||
registry := &HubRegistry{
|
||||
routes: make(map[string]*HubRoute, len(cfg.Hubs)),
|
||||
}
|
||||
|
||||
if len(cfg.Hubs) == 0 {
|
||||
return registry, nil
|
||||
}
|
||||
|
||||
for _, hub := range cfg.Hubs {
|
||||
normalizedHost := normalizeDomain(hub.Domain)
|
||||
if normalizedHost == "" {
|
||||
return nil, fmt.Errorf("invalid domain for hub %s", hub.Name)
|
||||
}
|
||||
if _, exists := registry.routes[normalizedHost]; exists {
|
||||
return nil, fmt.Errorf("duplicate domain mapping detected for %s", normalizedHost)
|
||||
}
|
||||
|
||||
route, err := buildHubRoute(cfg, hub)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
registry.routes[normalizedHost] = route
|
||||
registry.ordered = append(registry.ordered, route)
|
||||
}
|
||||
|
||||
return registry, nil
|
||||
}
|
||||
|
||||
// Lookup 根据 Host 或 Host:port 查找 HubRoute。
|
||||
func (r *HubRegistry) Lookup(host string) (*HubRoute, bool) {
|
||||
if r == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
normalizedHost, _ := normalizeHost(host)
|
||||
if normalizedHost == "" {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
route, ok := r.routes[normalizedHost]
|
||||
return route, ok
|
||||
}
|
||||
|
||||
// List 返回当前注册的 HubRoute 列表(按配置定义的顺序),用于调试或 /status 输出。
|
||||
func (r *HubRegistry) List() []HubRoute {
|
||||
if r == nil || len(r.ordered) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make([]HubRoute, len(r.ordered))
|
||||
for i, route := range r.ordered {
|
||||
result[i] = *route
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func buildHubRoute(cfg *config.Config, hub config.HubConfig) (*HubRoute, error) {
|
||||
upstreamURL, err := url.Parse(hub.Upstream)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid upstream for hub %s: %w", hub.Name, err)
|
||||
}
|
||||
|
||||
var proxyURL *url.URL
|
||||
if hub.Proxy != "" {
|
||||
proxyURL, err = url.Parse(hub.Proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid proxy for hub %s: %w", hub.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return &HubRoute{
|
||||
Config: hub,
|
||||
ListenPort: cfg.Global.ListenPort,
|
||||
CacheTTL: cfg.EffectiveCacheTTL(hub),
|
||||
UpstreamURL: upstreamURL,
|
||||
ProxyURL: proxyURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeDomain(domain string) string {
|
||||
host, _ := normalizeHost(domain)
|
||||
return host
|
||||
}
|
||||
|
||||
func normalizeHost(raw string) (string, int) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
host := raw
|
||||
port := 0
|
||||
|
||||
if strings.Contains(raw, ":") {
|
||||
if h, p, err := net.SplitHostPort(raw); err == nil {
|
||||
host = h
|
||||
if parsedPort, err := strconv.Atoi(p); err == nil {
|
||||
port = parsedPort
|
||||
}
|
||||
} else if idx := strings.LastIndex(raw, ":"); idx > -1 && strings.Count(raw[idx+1:], ":") == 0 {
|
||||
if parsedPort, err := strconv.Atoi(raw[idx+1:]); err == nil {
|
||||
host = raw[:idx]
|
||||
port = parsedPort
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
host = strings.TrimSuffix(host, ".")
|
||||
host = strings.ToLower(host)
|
||||
return host, port
|
||||
}
|
||||
120
internal/server/hub_registry_test.go
Normal file
120
internal/server/hub_registry_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/any-hub/any-hub/internal/config"
|
||||
)
|
||||
|
||||
func TestHubRegistryLookupByHost(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Global: config.GlobalConfig{
|
||||
ListenPort: 5000,
|
||||
CacheTTL: config.Duration(2 * time.Hour),
|
||||
},
|
||||
Hubs: []config.HubConfig{
|
||||
{
|
||||
Name: "docker",
|
||||
Domain: "docker.hub.local",
|
||||
Type: "docker",
|
||||
Upstream: "https://registry-1.docker.io",
|
||||
EnableHeadCheck: true,
|
||||
},
|
||||
{
|
||||
Name: "npm",
|
||||
Domain: "npm.hub.local",
|
||||
Type: "npm",
|
||||
Upstream: "https://registry.npmjs.org",
|
||||
CacheTTL: config.Duration(30 * time.Minute),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
registry, err := NewHubRegistry(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
route, ok := registry.Lookup("docker.hub.local")
|
||||
if !ok {
|
||||
t.Fatalf("expected docker route")
|
||||
}
|
||||
|
||||
if route.Config.Name != "docker" {
|
||||
t.Errorf("wrong hub returned: %s", route.Config.Name)
|
||||
}
|
||||
|
||||
if route.CacheTTL != cfg.EffectiveCacheTTL(route.Config) {
|
||||
t.Errorf("cache ttl mismatch: got %s", route.CacheTTL)
|
||||
}
|
||||
|
||||
if route.UpstreamURL.String() != "https://registry-1.docker.io" {
|
||||
t.Errorf("unexpected upstream URL: %s", route.UpstreamURL)
|
||||
}
|
||||
|
||||
if route.ProxyURL != nil {
|
||||
t.Errorf("expected nil proxy")
|
||||
}
|
||||
|
||||
if route.ListenPort != cfg.Global.ListenPort {
|
||||
t.Fatalf("route listen port mismatch: %d", route.ListenPort)
|
||||
}
|
||||
|
||||
if got := len(registry.List()); got != 2 {
|
||||
t.Fatalf("expected 2 routes in list, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHubRegistryParsesHostHeaderPort(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Global: config.GlobalConfig{
|
||||
ListenPort: 5000,
|
||||
CacheTTL: config.Duration(time.Hour),
|
||||
},
|
||||
Hubs: []config.HubConfig{
|
||||
{
|
||||
Name: "docker",
|
||||
Domain: "docker.hub.local",
|
||||
Type: "docker",
|
||||
Upstream: "https://registry-1.docker.io",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
registry, err := NewHubRegistry(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := registry.Lookup("docker.hub.local:6000"); !ok {
|
||||
t.Fatalf("expected lookup to ignore host header port")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHubRegistryRejectsDuplicateDomains(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Global: config.GlobalConfig{
|
||||
ListenPort: 5000,
|
||||
CacheTTL: config.Duration(time.Hour),
|
||||
},
|
||||
Hubs: []config.HubConfig{
|
||||
{
|
||||
Name: "docker",
|
||||
Domain: "docker.hub.local",
|
||||
Type: "docker",
|
||||
Upstream: "https://registry-1.docker.io",
|
||||
},
|
||||
{
|
||||
Name: "docker-alt",
|
||||
Domain: "docker.hub.local",
|
||||
Type: "docker",
|
||||
Upstream: "https://mirror.registry-1.docker.io",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if _, err := NewHubRegistry(cfg); err == nil {
|
||||
t.Fatalf("expected duplicate domain error")
|
||||
}
|
||||
}
|
||||
163
internal/server/router.go
Normal file
163
internal/server/router.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/gofiber/fiber/v3/middleware/recover"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ProxyHandler describes the component responsible for proxying requests to
|
||||
// the upstream Hub. It allows injecting fake handlers during tests.
|
||||
type ProxyHandler interface {
|
||||
Handle(fiber.Ctx, *HubRoute) error
|
||||
}
|
||||
|
||||
// ProxyHandlerFunc adapts a function to the ProxyHandler interface.
|
||||
type ProxyHandlerFunc func(fiber.Ctx, *HubRoute) error
|
||||
|
||||
// Handle makes ProxyHandlerFunc satisfy ProxyHandler.
|
||||
func (f ProxyHandlerFunc) Handle(c fiber.Ctx, route *HubRoute) error {
|
||||
return f(c, route)
|
||||
}
|
||||
|
||||
// AppOptions controls how the Fiber application should behave on a specific port.
|
||||
type AppOptions struct {
|
||||
Logger *logrus.Logger
|
||||
Registry *HubRegistry
|
||||
Proxy ProxyHandler
|
||||
ListenPort int
|
||||
}
|
||||
|
||||
const (
|
||||
contextKeyRoute = "_anyhub_route"
|
||||
contextKeyRequestID = "_anyhub_request_id"
|
||||
)
|
||||
|
||||
// NewApp builds a Fiber application with Host/port routing middleware and
|
||||
// structured error handling.
|
||||
func NewApp(opts AppOptions) (*fiber.App, error) {
|
||||
if opts.Logger == nil {
|
||||
return nil, errors.New("logger is required")
|
||||
}
|
||||
if opts.Registry == nil {
|
||||
return nil, errors.New("hub registry is required")
|
||||
}
|
||||
if opts.Proxy == nil {
|
||||
return nil, errors.New("proxy handler is required")
|
||||
}
|
||||
if opts.ListenPort <= 0 {
|
||||
return nil, fmt.Errorf("invalid listen port: %d", opts.ListenPort)
|
||||
}
|
||||
|
||||
app := fiber.New(fiber.Config{
|
||||
CaseSensitive: true,
|
||||
})
|
||||
|
||||
app.Use(recover.New())
|
||||
app.Use(requestContextMiddleware(opts))
|
||||
|
||||
app.All("/*", func(c fiber.Ctx) error {
|
||||
route, _ := getRouteFromContext(c)
|
||||
if route == nil {
|
||||
return renderHostUnmapped(c, opts.Logger, "", opts.ListenPort)
|
||||
}
|
||||
return opts.Proxy.Handle(c, route)
|
||||
})
|
||||
|
||||
return app, nil
|
||||
}
|
||||
|
||||
// requestContextMiddleware 负责生成请求 ID,并基于 Host/Host:port 查找 HubRoute。
|
||||
func requestContextMiddleware(opts AppOptions) fiber.Handler {
|
||||
return func(c fiber.Ctx) error {
|
||||
reqID := uuid.NewString()
|
||||
c.Locals(contextKeyRequestID, reqID)
|
||||
c.Set("X-Request-ID", reqID)
|
||||
|
||||
rawHost := strings.TrimSpace(getHostHeader(c))
|
||||
route, ok := opts.Registry.Lookup(rawHost)
|
||||
if !ok {
|
||||
return renderHostUnmapped(c, opts.Logger, rawHost, opts.ListenPort)
|
||||
}
|
||||
if err := ensureRouterHubType(route); err != nil {
|
||||
return renderTypeUnsupported(c, opts.Logger, route, err)
|
||||
}
|
||||
|
||||
c.Locals(contextKeyRoute, route)
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func renderHostUnmapped(c fiber.Ctx, logger *logrus.Logger, host string, port int) error {
|
||||
fields := logrus.Fields{
|
||||
"action": "host_lookup",
|
||||
"host": host,
|
||||
"port": port,
|
||||
}
|
||||
logger.WithFields(fields).Warn("host unmapped")
|
||||
|
||||
if host != "" {
|
||||
c.Set("X-Any-Hub-Host", host)
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{
|
||||
"error": "host_unmapped",
|
||||
})
|
||||
}
|
||||
|
||||
func getHostHeader(c fiber.Ctx) string {
|
||||
if raw := c.Request().Header.Peek(fiber.HeaderHost); len(raw) > 0 {
|
||||
return string(raw)
|
||||
}
|
||||
return c.Hostname()
|
||||
}
|
||||
|
||||
func getRouteFromContext(c fiber.Ctx) (*HubRoute, bool) {
|
||||
if value := c.Locals(contextKeyRoute); value != nil {
|
||||
if route, ok := value.(*HubRoute); ok {
|
||||
return route, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// RequestID returns the request identifier stored by the router middleware.
|
||||
func RequestID(c fiber.Ctx) string {
|
||||
if value := c.Locals(contextKeyRequestID); value != nil {
|
||||
if reqID, ok := value.(string); ok {
|
||||
return reqID
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func ensureRouterHubType(route *HubRoute) error {
|
||||
switch route.Config.Type {
|
||||
case "docker":
|
||||
return nil
|
||||
case "npm":
|
||||
return nil
|
||||
case "go":
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unsupported hub type: %s", route.Config.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func renderTypeUnsupported(c fiber.Ctx, logger *logrus.Logger, route *HubRoute, err error) error {
|
||||
fields := logrus.Fields{
|
||||
"action": "hub_type_check",
|
||||
"hub": route.Config.Name,
|
||||
"hub_type": route.Config.Type,
|
||||
"error": "hub_type_unsupported",
|
||||
}
|
||||
logger.WithFields(fields).Error(err.Error())
|
||||
return c.Status(fiber.StatusNotImplemented).JSON(fiber.Map{
|
||||
"error": "hub_type_unsupported",
|
||||
})
|
||||
}
|
||||
118
internal/server/router_test.go
Normal file
118
internal/server/router_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/any-hub/any-hub/internal/config"
|
||||
)
|
||||
|
||||
func TestRouterRoutesRequestWhenHostMatches(t *testing.T) {
|
||||
app := newTestApp(t, 5000)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://docker.hub.local/v2/", nil)
|
||||
req.Host = "docker.hub.local"
|
||||
req.Header.Set("Host", "docker.hub.local")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("app.Test failed: %v", err)
|
||||
}
|
||||
if resp.StatusCode != fiber.StatusNoContent {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("expected 204 status, got %d (body=%s, hostHeader=%s)", resp.StatusCode, string(body), resp.Header.Get("X-Any-Hub-Host"))
|
||||
}
|
||||
|
||||
if app.storage.routeName != "docker" {
|
||||
t.Fatalf("expected docker route, got %s", app.storage.routeName)
|
||||
}
|
||||
|
||||
if reqID := resp.Header.Get("X-Request-ID"); reqID == "" {
|
||||
t.Fatalf("expected X-Request-ID header to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterReturns404WhenHostUnknown(t *testing.T) {
|
||||
app := newTestApp(t, 5000)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://unknown.local/v2/", nil)
|
||||
req.Host = "unknown.local"
|
||||
req.Header.Set("Host", "unknown.local")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("app.Test failed: %v", err)
|
||||
}
|
||||
if resp.StatusCode != fiber.StatusNotFound {
|
||||
t.Fatalf("expected 404 status, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if !bytes.Contains(body, []byte(`"host_unmapped"`)) {
|
||||
t.Fatalf("expected host_unmapped error, got %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
type testApp struct {
|
||||
*fiber.App
|
||||
storage *proxyRecorder
|
||||
}
|
||||
|
||||
func newTestApp(t *testing.T, port int) *testApp {
|
||||
t.Helper()
|
||||
|
||||
cfg := &config.Config{
|
||||
Global: config.GlobalConfig{
|
||||
ListenPort: port,
|
||||
CacheTTL: config.Duration(3600),
|
||||
},
|
||||
Hubs: []config.HubConfig{
|
||||
{
|
||||
Name: "docker",
|
||||
Domain: "docker.hub.local",
|
||||
Type: "docker",
|
||||
Upstream: "https://registry-1.docker.io",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
registry, err := NewHubRegistry(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create registry: %v", err)
|
||||
}
|
||||
if _, ok := registry.Lookup("docker.hub.local"); !ok {
|
||||
t.Fatalf("registry lookup failed for docker")
|
||||
}
|
||||
|
||||
logger := logrus.New()
|
||||
logger.SetOutput(io.Discard)
|
||||
|
||||
recorder := &proxyRecorder{}
|
||||
app, err := NewApp(AppOptions{
|
||||
Logger: logger,
|
||||
Registry: registry,
|
||||
Proxy: recorder,
|
||||
ListenPort: port,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create app: %v", err)
|
||||
}
|
||||
|
||||
return &testApp{App: app, storage: recorder}
|
||||
}
|
||||
|
||||
type proxyRecorder struct {
|
||||
lastRoute *HubRoute
|
||||
routeName string
|
||||
}
|
||||
|
||||
func (p *proxyRecorder) Handle(c fiber.Ctx, route *HubRoute) error {
|
||||
p.lastRoute = route
|
||||
p.routeName = route.Config.Name
|
||||
return c.SendStatus(fiber.StatusNoContent)
|
||||
}
|
||||
14
internal/version/version.go
Normal file
14
internal/version/version.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package version
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Version/Commit 可在构建时通过 -ldflags 注入,默认使用开发占位符。
|
||||
var (
|
||||
Version = "0.1.0"
|
||||
Commit = "dev"
|
||||
)
|
||||
|
||||
// Full 返回便于 CLI 打印的完整版本信息。
|
||||
func Full() string {
|
||||
return fmt.Sprintf("any-hub %s (%s)", Version, Commit)
|
||||
}
|
||||
Reference in New Issue
Block a user