This commit is contained in:
2025-11-17 15:39:44 +08:00
parent abfa51f12e
commit 1ddda89499
46 changed files with 2185 additions and 751 deletions

View File

@@ -3,6 +3,7 @@ package config
import (
_ "github.com/any-hub/any-hub/internal/hubmodule/composer"
_ "github.com/any-hub/any-hub/internal/hubmodule/docker"
_ "github.com/any-hub/any-hub/internal/hubmodule/golang"
_ "github.com/any-hub/any-hub/internal/hubmodule/legacy"
_ "github.com/any-hub/any-hub/internal/hubmodule/npm"
_ "github.com/any-hub/any-hub/internal/hubmodule/pypi"

View File

@@ -79,7 +79,11 @@ func (c *Config) Validate() error {
moduleKey := strings.ToLower(strings.TrimSpace(hub.Module))
if moduleKey == "" {
moduleKey = hubmodule.DefaultModuleKey()
if _, ok := hubmodule.Resolve(normalizedType); ok && normalizedType != "" {
moduleKey = normalizedType
} else {
moduleKey = hubmodule.DefaultModuleKey()
}
}
if _, ok := hubmodule.Resolve(moduleKey); !ok {
return newFieldError(hubField(hub.Name, "Module"), fmt.Sprintf("未注册模块: %s", moduleKey))

View File

@@ -24,7 +24,7 @@ internal/hubmodule/
2. 填写模块特有逻辑与缓存策略,并确保包含中文注释解释设计。
3. 在模块目录添加 `module_test.go`,使用 `httptest.Server``t.TempDir()` 复现真实流量。
4. 运行 `make modules-test` 验证模块单元测试。
5. 更新 `config.toml` 中对应 `[[Hub]].Module` 字段,验证集成测试后再提交
5. `[[Hub]].Module` 留空时会优先选择与 `Type` 同名的模块,实际迁移时仍建议显式填写,便于 diagnostics 标记 rollout
## 术语
- **Module Key**:模块唯一标识(如 `legacy``npm-tarball`)。

View File

@@ -1,77 +1,136 @@
package proxy
package composer
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"github.com/any-hub/any-hub/internal/server"
"github.com/any-hub/any-hub/internal/proxy/hooks"
)
func (h *Handler) rewriteComposerResponse(route *server.HubRoute, resp *http.Response, path string) (*http.Response, error) {
if resp == nil || route == nil || route.Config.Type != "composer" {
return resp, nil
}
if path == "/packages.json" {
return rewriteComposerRoot(resp, route.Config.Domain)
}
if !isComposerMetadataPath(path) {
return resp, nil
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return resp, err
}
resp.Body.Close()
rewritten, changed, err := rewriteComposerMetadata(body, route.Config.Domain)
if err != nil {
resp.Body = io.NopCloser(bytes.NewReader(body))
return resp, err
}
if !changed {
resp.Body = io.NopCloser(bytes.NewReader(body))
return resp, nil
}
resp.Body = io.NopCloser(bytes.NewReader(rewritten))
resp.ContentLength = int64(len(rewritten))
resp.Header.Set("Content-Length", strconv.Itoa(len(rewritten)))
resp.Header.Set("Content-Type", "application/json")
resp.Header.Del("Content-Encoding")
resp.Header.Del("Etag")
return resp, nil
func init() {
hooks.MustRegister("composer", hooks.Hooks{
NormalizePath: normalizePath,
ResolveUpstream: resolveDistUpstream,
RewriteResponse: rewriteResponse,
CachePolicy: cachePolicy,
ContentType: contentType,
})
}
func rewriteComposerRoot(resp *http.Response, domain string) (*http.Response, error) {
body, err := io.ReadAll(resp.Body)
if err != nil {
return resp, err
func normalizePath(_ *hooks.RequestContext, clean string, rawQuery []byte) (string, []byte) {
if isComposerDistPath(clean) {
return clean, nil
}
resp.Body.Close()
return clean, rawQuery
}
data, changed, err := rewriteComposerRootBody(body, domain)
if err != nil {
resp.Body = io.NopCloser(bytes.NewReader(body))
return resp, err
func resolveDistUpstream(_ *hooks.RequestContext, _ string, clean string, rawQuery []byte) string {
if !isComposerDistPath(clean) {
return ""
}
target, ok := parseComposerDistURL(clean, string(rawQuery))
if !ok {
return ""
}
return target.String()
}
func rewriteResponse(
ctx *hooks.RequestContext,
status int,
headers map[string]string,
body []byte,
path string,
) (int, map[string]string, []byte, error) {
switch {
case path == "/packages.json":
data, changed, err := rewriteComposerRootBody(body, ctx.Domain)
if err != nil {
return status, headers, body, err
}
if !changed {
return status, headers, body, nil
}
outHeaders := ensureJSONHeaders(headers)
return status, outHeaders, data, nil
case isComposerMetadataPath(path):
data, changed, err := rewriteComposerMetadata(body, ctx.Domain)
if err != nil {
return status, headers, body, err
}
if !changed {
return status, headers, body, nil
}
outHeaders := ensureJSONHeaders(headers)
return status, outHeaders, data, nil
default:
return status, headers, body, nil
}
}
func ensureJSONHeaders(headers map[string]string) map[string]string {
if headers == nil {
headers = map[string]string{}
}
headers["Content-Type"] = "application/json"
delete(headers, "Content-Encoding")
delete(headers, "Etag")
return headers
}
func cachePolicy(_ *hooks.RequestContext, locatorPath string, current hooks.CachePolicy) hooks.CachePolicy {
switch {
case isComposerDistPath(locatorPath):
current.AllowCache = true
current.AllowStore = true
current.RequireRevalidate = false
case isComposerMetadataPath(locatorPath):
current.AllowCache = true
current.AllowStore = true
current.RequireRevalidate = true
default:
current.AllowCache = false
current.AllowStore = false
current.RequireRevalidate = false
}
return current
}
func contentType(_ *hooks.RequestContext, locatorPath string) string {
if isComposerMetadataPath(locatorPath) {
return "application/json"
}
return ""
}
func rewriteComposerRootBody(body []byte, domain string) ([]byte, bool, error) {
type root struct {
Packages map[string]string `json:"packages"`
}
var payload root
if err := json.Unmarshal(body, &payload); err != nil {
return nil, false, err
}
if len(payload.Packages) == 0 {
return body, false, nil
}
changed := false
for key, value := range payload.Packages {
rewritten := rewriteComposerAbsolute(domain, value)
if rewritten != value {
payload.Packages[key] = rewritten
changed = true
}
}
if !changed {
resp.Body = io.NopCloser(bytes.NewReader(body))
return resp, nil
return body, false, nil
}
resp.Body = io.NopCloser(bytes.NewReader(data))
resp.ContentLength = int64(len(data))
resp.Header.Set("Content-Length", strconv.Itoa(len(data)))
resp.Header.Set("Content-Type", "application/json")
resp.Header.Del("Content-Encoding")
resp.Header.Del("Etag")
return resp, nil
data, err := json.Marshal(payload)
if err != nil {
return nil, false, err
}
return data, true, nil
}
func rewriteComposerMetadata(body []byte, domain string) ([]byte, bool, error) {
@@ -183,7 +242,7 @@ func rewriteComposerDistURL(domain, original string) string {
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
return original
}
prefix := fmt.Sprintf("/dist/%s/%s", parsed.Scheme, parsed.Host)
prefix := "/dist/" + parsed.Scheme + "/" + parsed.Host
newURL := url.URL{
Scheme: "https",
Host: domain,
@@ -197,6 +256,29 @@ func rewriteComposerDistURL(domain, original string) string {
return newURL.String()
}
func rewriteComposerAbsolute(domain, raw string) string {
if raw == "" {
return raw
}
if strings.HasPrefix(raw, "//") {
return "https://" + domain + strings.TrimPrefix(raw, "//")
}
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
parsed, err := url.Parse(raw)
if err != nil {
return raw
}
parsed.Host = domain
parsed.Scheme = "https"
return parsed.String()
}
pathVal := raw
if !strings.HasPrefix(pathVal, "/") {
pathVal = "/" + pathVal
}
return "https://" + domain + pathVal
}
func isComposerMetadataPath(path string) bool {
switch {
case path == "/packages.json":
@@ -218,74 +300,34 @@ func isComposerDistPath(path string) bool {
return strings.HasPrefix(path, "/dist/")
}
func rewriteComposerAbsolute(domain, raw string) string {
if raw == "" {
return raw
func parseComposerDistURL(path string, rawQuery string) (*url.URL, bool) {
if !strings.HasPrefix(path, "/dist/") {
return nil, false
}
if strings.HasPrefix(raw, "//") {
return "https://" + domain + strings.TrimPrefix(raw, "//")
trimmed := strings.TrimPrefix(path, "/dist/")
parts := strings.SplitN(trimmed, "/", 3)
if len(parts) < 3 {
return nil, false
}
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
parsed, err := url.Parse(raw)
if err != nil {
return raw
}
parsed.Host = domain
parsed.Scheme = "https"
return parsed.String()
scheme := parts[0]
host := parts[1]
rest := parts[2]
if scheme == "" || host == "" {
return nil, false
}
pathVal := raw
if !strings.HasPrefix(pathVal, "/") {
pathVal = "/" + pathVal
if rest == "" {
rest = "/"
} else {
rest = "/" + rest
}
return fmt.Sprintf("https://%s%s", domain, pathVal)
}
func rewriteComposerRootBody(body []byte, domain string) ([]byte, bool, error) {
var root map[string]any
if err := json.Unmarshal(body, &root); err != nil {
return nil, false, err
}
changed := false
for _, key := range []string{"metadata-url", "providers-api", "providers-url", "notify-batch"} {
if raw, ok := root[key].(string); ok && raw != "" {
newVal := rewriteComposerAbsolute(domain, raw)
if newVal != raw {
root[key] = newVal
changed = true
}
}
}
if includes, ok := root["provider-includes"].(map[string]any); ok {
for file, hashVal := range includes {
pathVal := file
if rawPath, ok := hashVal.(map[string]any); ok {
if urlValue, ok := rawPath["url"].(string); ok {
pathVal = urlValue
}
}
newPath := rewriteComposerAbsolute(domain, pathVal)
if newPath != pathVal {
changed = true
}
if rawPath, ok := hashVal.(map[string]any); ok {
rawPath["url"] = newPath
includes[file] = rawPath
} else {
includes[file] = newPath
}
}
}
if !changed {
return body, false, nil
}
data, err := json.Marshal(root)
if err != nil {
return nil, false, err
}
return data, true, nil
target := &url.URL{
Scheme: scheme,
Host: host,
Path: rest,
RawPath: rest,
}
if rawQuery != "" {
target.RawQuery = rawQuery
}
return target, true
}

View File

@@ -0,0 +1,43 @@
package composer
import (
"strings"
"testing"
"github.com/any-hub/any-hub/internal/proxy/hooks"
)
func TestNormalizePathDropsDistQuery(t *testing.T) {
path, raw := normalizePath(nil, "/dist/https/example.com/file.zip", []byte("token=1"))
if raw != nil {
t.Fatalf("expected query to be dropped")
}
if path != "/dist/https/example.com/file.zip" {
t.Fatalf("unexpected path %s", path)
}
}
func TestResolveDistUpstream(t *testing.T) {
url := resolveDistUpstream(nil, "", "/dist/https/example.com/file.zip", []byte("token=1"))
if url != "https://example.com/file.zip?token=1" {
t.Fatalf("unexpected upstream %s", url)
}
}
func TestRewriteResponseUpdatesURLs(t *testing.T) {
ctx := &hooks.RequestContext{Domain: "cache.example"}
body := []byte(`{"packages":{"a/b":{"1.0.0":{"dist":{"url":"https://pkg.example/dist.zip"}}}}}`)
_, headers, rewritten, err := rewriteResponse(ctx, 200, map[string]string{}, body, "/p2/a/b.json")
if err != nil {
t.Fatalf("rewrite failed: %v", err)
}
if string(rewritten) == string(body) {
t.Fatalf("expected rewrite to modify payload")
}
if headers["Content-Type"] != "application/json" {
t.Fatalf("expected json content type")
}
if !strings.Contains(string(rewritten), "https://cache.example/dist/https/pkg.example/dist.zip") {
t.Fatalf("expected rewritten URL, got %s", string(rewritten))
}
}

View File

@@ -0,0 +1,105 @@
package docker
import (
"strings"
"github.com/any-hub/any-hub/internal/proxy/hooks"
)
func init() {
hooks.MustRegister("docker", hooks.Hooks{
NormalizePath: normalizePath,
CachePolicy: cachePolicy,
ContentType: contentType,
})
}
func normalizePath(ctx *hooks.RequestContext, clean string, rawQuery []byte) (string, []byte) {
if !isDockerHubHost(ctx.UpstreamHost) {
return clean, rawQuery
}
repo, rest, ok := splitDockerRepoPath(clean)
if !ok || repo == "" || strings.Contains(repo, "/") || repo == "library" {
return clean, rawQuery
}
return "/v2/library/" + repo + rest, rawQuery
}
func cachePolicy(_ *hooks.RequestContext, locatorPath string, current hooks.CachePolicy) hooks.CachePolicy {
clean := locatorPath
if clean == "/v2" || clean == "v2" || clean == "/v2/" {
return hooks.CachePolicy{}
}
if strings.Contains(clean, "/_catalog") {
return hooks.CachePolicy{}
}
if isDockerImmutablePath(clean) {
current.AllowCache = true
current.AllowStore = true
current.RequireRevalidate = false
return current
}
current.AllowCache = true
current.AllowStore = true
current.RequireRevalidate = true
return current
}
func contentType(_ *hooks.RequestContext, locatorPath string) string {
switch {
case strings.Contains(locatorPath, "/tags/list"):
return "application/json"
case strings.Contains(locatorPath, "/blobs/"):
return "application/octet-stream"
default:
return ""
}
}
func isDockerHubHost(host string) bool {
switch strings.ToLower(host) {
case "registry-1.docker.io", "docker.io", "index.docker.io":
return true
default:
return false
}
}
func splitDockerRepoPath(path string) (string, string, bool) {
if !strings.HasPrefix(path, "/v2/") {
return "", "", false
}
suffix := strings.TrimPrefix(path, "/v2/")
if suffix == "" || suffix == "/" {
return "", "", false
}
segments := strings.Split(suffix, "/")
var repoSegments []string
for i, seg := range segments {
if seg == "" {
return "", "", false
}
switch seg {
case "manifests", "blobs", "tags", "referrers":
if len(repoSegments) == 0 {
return "", "", false
}
rest := "/" + strings.Join(segments[i:], "/")
return strings.Join(repoSegments, "/"), rest, true
case "_catalog":
return "", "", false
}
repoSegments = append(repoSegments, seg)
}
return "", "", false
}
func isDockerImmutablePath(path string) bool {
if strings.Contains(path, "/blobs/sha256:") {
return true
}
if strings.Contains(path, "/manifests/sha256:") {
return true
}
return false
}

View File

@@ -0,0 +1,31 @@
package docker
import (
"testing"
"github.com/any-hub/any-hub/internal/proxy/hooks"
)
func TestNormalizePathAddsLibraryForDockerHub(t *testing.T) {
ctx := &hooks.RequestContext{UpstreamHost: "registry-1.docker.io"}
path, _ := normalizePath(ctx, "/v2/nginx/manifests/latest", nil)
if path != "/v2/library/nginx/manifests/latest" {
t.Fatalf("expected library namespace, got %s", path)
}
path, _ = normalizePath(ctx, "/v2/library/nginx/manifests/latest", nil)
if path != "/v2/library/nginx/manifests/latest" {
t.Fatalf("unexpected rewrite for existing namespace")
}
}
func TestSplitDockerRepoPath(t *testing.T) {
repo, rest, ok := splitDockerRepoPath("/v2/library/nginx/manifests/latest")
if !ok || repo != "library/nginx" || rest != "/manifests/latest" {
t.Fatalf("unexpected split result repo=%s rest=%s ok=%v", repo, rest, ok)
}
if _, _, ok := splitDockerRepoPath("/v2/_catalog"); ok {
t.Fatalf("expected catalog path to be ignored")
}
}

View File

@@ -0,0 +1,27 @@
package golang
import "github.com/any-hub/any-hub/internal/proxy/hooks"
import "strings"
func init() {
hooks.MustRegister("go", hooks.Hooks{
CachePolicy: cachePolicy,
})
}
func cachePolicy(_ *hooks.RequestContext, locatorPath string, current hooks.CachePolicy) hooks.CachePolicy {
if strings.Contains(locatorPath, "/@v/") &&
(strings.HasSuffix(locatorPath, ".zip") ||
strings.HasSuffix(locatorPath, ".mod") ||
strings.HasSuffix(locatorPath, ".info")) {
current.AllowCache = true
current.AllowStore = true
current.RequireRevalidate = false
return current
}
current.AllowCache = true
current.AllowStore = true
current.RequireRevalidate = true
return current
}

View File

@@ -0,0 +1,19 @@
package golang
import (
"testing"
"github.com/any-hub/any-hub/internal/proxy/hooks"
)
func TestCachePolicyForModuleFiles(t *testing.T) {
policy := cachePolicy(nil, "/example/@v/v1.0.0.zip", hooks.CachePolicy{})
if !policy.AllowCache || policy.RequireRevalidate {
t.Fatalf("expected immutable go artifacts to be cacheable without revalidate")
}
policy = cachePolicy(nil, "/example/@latest", hooks.CachePolicy{})
if !policy.RequireRevalidate {
t.Fatalf("expected non-artifacts to require revalidate")
}
}

View File

@@ -0,0 +1,28 @@
package golang
import (
"time"
"github.com/any-hub/any-hub/internal/hubmodule"
)
const goDefaultTTL = 30 * time.Minute
func init() {
hubmodule.MustRegister(hubmodule.ModuleMetadata{
Key: "go",
Description: "Go module proxy with sumdb/cache defaults",
MigrationState: hubmodule.MigrationStateBeta,
SupportedProtocols: []string{
"go",
},
CacheStrategy: hubmodule.CacheStrategyProfile{
TTLHint: goDefaultTTL,
ValidationMode: hubmodule.ValidationModeLastModified,
DiskLayout: "raw_path",
RequiresMetadataFile: false,
SupportsStreamingWrite: true,
},
LocatorRewrite: hubmodule.DefaultLocatorRewrite("go"),
})
}

View File

@@ -3,7 +3,7 @@ package legacy
import "github.com/any-hub/any-hub/internal/hubmodule"
// 模块描述:包装当前共享的代理 + 缓存实现,供未迁移的 Hub 使用。
// 模块描述:包装当前共享的代理 + 缓存实现,供未迁移的 Hub 使用,并在 diagnostics 中标记为 legacy-only
func init() {
hubmodule.MustRegister(hubmodule.ModuleMetadata{
Key: hubmodule.DefaultModuleKey(),

View File

@@ -0,0 +1,26 @@
package npm
import (
"strings"
"github.com/any-hub/any-hub/internal/proxy/hooks"
)
func init() {
hooks.MustRegister("npm", hooks.Hooks{
CachePolicy: cachePolicy,
})
}
func cachePolicy(_ *hooks.RequestContext, locatorPath string, current hooks.CachePolicy) hooks.CachePolicy {
if strings.Contains(locatorPath, "/-/") && strings.HasSuffix(locatorPath, ".tgz") {
current.AllowCache = true
current.AllowStore = true
current.RequireRevalidate = false
return current
}
current.AllowCache = true
current.AllowStore = true
current.RequireRevalidate = true
return current
}

View File

@@ -0,0 +1,22 @@
package npm
import (
"testing"
"github.com/any-hub/any-hub/internal/proxy/hooks"
)
func TestCachePolicyForTarball(t *testing.T) {
policy := cachePolicy(nil, "/pkg/-/pkg-1.0.0.tgz", hooks.CachePolicy{})
if policy.RequireRevalidate {
t.Fatalf("tarball should not require revalidate")
}
if !policy.AllowCache {
t.Fatalf("tarball should allow cache")
}
policy = cachePolicy(nil, "/pkg", hooks.CachePolicy{})
if !policy.RequireRevalidate {
t.Fatalf("metadata should require revalidate")
}
}

View File

@@ -0,0 +1,215 @@
package pypi
import (
"bytes"
"encoding/json"
"net/url"
"strings"
"golang.org/x/net/html"
"github.com/any-hub/any-hub/internal/proxy/hooks"
)
func init() {
hooks.MustRegister("pypi", hooks.Hooks{
NormalizePath: normalizePath,
ResolveUpstream: resolveFilesUpstream,
RewriteResponse: rewriteResponse,
CachePolicy: cachePolicy,
ContentType: contentType,
})
}
func normalizePath(_ *hooks.RequestContext, clean string, rawQuery []byte) (string, []byte) {
if strings.HasPrefix(clean, "/files/") || strings.HasPrefix(clean, "/simple/") {
return ensureSimpleTrailingSlash(clean), rawQuery
}
if isDistributionAsset(clean) {
return clean, rawQuery
}
trimmed := strings.Trim(clean, "/")
if trimmed == "" || strings.HasPrefix(trimmed, "_") {
return clean, rawQuery
}
if !strings.HasSuffix(trimmed, "/") {
trimmed += "/"
}
return "/simple/" + trimmed, rawQuery
}
func ensureSimpleTrailingSlash(path string) string {
if !strings.HasPrefix(path, "/simple/") {
return path
}
if strings.HasSuffix(path, "/") {
return path
}
return path + "/"
}
func resolveFilesUpstream(_ *hooks.RequestContext, baseURL string, clean string, rawQuery []byte) string {
if !strings.HasPrefix(clean, "/files/") {
return ""
}
trimmed := strings.TrimPrefix(clean, "/files/")
parts := strings.SplitN(trimmed, "/", 3)
if len(parts) < 3 {
return ""
}
scheme := parts[0]
host := parts[1]
rest := parts[2]
if scheme == "" || host == "" {
return ""
}
target := url.URL{Scheme: scheme, Host: host, Path: "/" + strings.TrimPrefix(rest, "/")}
if len(rawQuery) > 0 {
target.RawQuery = string(rawQuery)
}
return target.String()
}
func cachePolicy(_ *hooks.RequestContext, locatorPath string, current hooks.CachePolicy) hooks.CachePolicy {
if isDistributionAsset(locatorPath) {
current.AllowCache = true
current.AllowStore = true
current.RequireRevalidate = false
return current
}
current.RequireRevalidate = true
return current
}
func contentType(_ *hooks.RequestContext, locatorPath string) string {
if strings.Contains(locatorPath, "/simple/") {
return "text/html"
}
return ""
}
func rewriteResponse(
ctx *hooks.RequestContext,
status int,
headers map[string]string,
body []byte,
path string,
) (int, map[string]string, []byte, error) {
if !strings.HasPrefix(path, "/simple") && path != "/" {
return status, headers, body, nil
}
domain := ctx.Domain
rewritten, contentType, err := rewritePyPIBody(body, headers["Content-Type"], domain)
if err != nil {
return status, headers, body, err
}
if headers == nil {
headers = map[string]string{}
}
if contentType != "" {
headers["Content-Type"] = contentType
}
delete(headers, "Content-Encoding")
return status, headers, rewritten, nil
}
func rewritePyPIBody(body []byte, contentType string, domain string) ([]byte, string, error) {
lowerCT := strings.ToLower(contentType)
if strings.Contains(lowerCT, "application/vnd.pypi.simple.v1+json") || strings.HasPrefix(strings.TrimSpace(string(body)), "{") {
data := map[string]interface{}{}
if err := json.Unmarshal(body, &data); err != nil {
return body, contentType, err
}
if files, ok := data["files"].([]interface{}); ok {
for _, entry := range files {
if fileMap, ok := entry.(map[string]interface{}); ok {
if urlValue, ok := fileMap["url"].(string); ok {
fileMap["url"] = rewritePyPIFileURL(domain, urlValue)
}
}
}
}
rewriteBytes, err := json.Marshal(data)
if err != nil {
return body, contentType, err
}
return rewriteBytes, "application/vnd.pypi.simple.v1+json", nil
}
rewrittenHTML, err := rewritePyPIHTML(body, domain)
if err != nil {
return body, contentType, err
}
return rewrittenHTML, "text/html; charset=utf-8", nil
}
func rewritePyPIHTML(body []byte, domain string) ([]byte, error) {
node, err := html.Parse(bytes.NewReader(body))
if err != nil {
return nil, err
}
rewriteHTMLNode(node, domain)
var buf bytes.Buffer
if err := html.Render(&buf, node); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func rewriteHTMLNode(n *html.Node, domain string) {
if n.Type == html.ElementNode {
rewriteHTMLAttributes(n, domain)
}
for child := n.FirstChild; child != nil; child = child.NextSibling {
rewriteHTMLNode(child, domain)
}
}
func rewriteHTMLAttributes(n *html.Node, domain string) {
for i, attr := range n.Attr {
switch attr.Key {
case "href", "data-dist-info-metadata", "data-core-metadata":
if strings.HasPrefix(attr.Val, "http://") || strings.HasPrefix(attr.Val, "https://") {
n.Attr[i].Val = rewritePyPIFileURL(domain, attr.Val)
}
}
}
}
func rewritePyPIFileURL(domain, original string) string {
parsed, err := url.Parse(original)
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
return original
}
prefix := "/files/" + parsed.Scheme + "/" + parsed.Host
newURL := url.URL{
Scheme: "https",
Host: domain,
Path: prefix + parsed.Path,
RawQuery: parsed.RawQuery,
Fragment: parsed.Fragment,
}
if raw := parsed.RawPath; raw != "" {
newURL.RawPath = prefix + raw
}
return newURL.String()
}
func isDistributionAsset(path string) bool {
switch {
case strings.HasSuffix(path, ".whl"):
return true
case strings.HasSuffix(path, ".tar.gz"):
return true
case strings.HasSuffix(path, ".tar.bz2"):
return true
case strings.HasSuffix(path, ".tgz"):
return true
case strings.HasSuffix(path, ".zip"):
return true
case strings.HasSuffix(path, ".egg"):
return true
default:
return false
}
}

View File

@@ -0,0 +1,42 @@
package pypi
import (
"strings"
"testing"
"github.com/any-hub/any-hub/internal/proxy/hooks"
)
func TestNormalizePathAddsSimplePrefix(t *testing.T) {
ctx := &hooks.RequestContext{HubType: "pypi"}
path, _ := normalizePath(ctx, "/requests", nil)
if path != "/simple/requests/" {
t.Fatalf("expected /simple prefix, got %s", path)
}
}
func TestResolveFilesUpstream(t *testing.T) {
ctx := &hooks.RequestContext{}
target := resolveFilesUpstream(ctx, "", "/files/https/example.com/pkg.tgz", nil)
if target != "https://example.com/pkg.tgz" {
t.Fatalf("unexpected upstream target: %s", target)
}
}
func TestRewriteResponseAdjustsLinks(t *testing.T) {
ctx := &hooks.RequestContext{Domain: "cache.example"}
body := []byte(`<html><body><a href="https://files.pythonhosted.org/package.whl">link</a></body></html>`)
_, headers, rewritten, err := rewriteResponse(ctx, 200, map[string]string{"Content-Type": "text/html"}, body, "/simple/requests/")
if err != nil {
t.Fatalf("rewrite failed: %v", err)
}
if string(rewritten) == string(body) {
t.Fatalf("expected rewrite to modify HTML")
}
if headers["Content-Type"] == "" {
t.Fatalf("expected content type to be set")
}
if !strings.Contains(string(rewritten), "/files/https/files.pythonhosted.org/package.whl") {
t.Fatalf("expected rewritten link, got %s", string(rewritten))
}
}

View File

@@ -0,0 +1,22 @@
package template
import (
"testing"
"github.com/any-hub/any-hub/internal/proxy/hooks"
)
// This test acts as a usage example for module authors.
func TestExampleHookDefinition(t *testing.T) {
h := hooks.Hooks{
NormalizePath: func(ctx *hooks.RequestContext, clean string, rawQuery []byte) (string, []byte) {
return clean, rawQuery
},
CachePolicy: func(ctx *hooks.RequestContext, path string, current hooks.CachePolicy) hooks.CachePolicy {
current.AllowCache = true
current.AllowStore = true
return current
},
}
_ = h
}

View File

@@ -2,11 +2,11 @@
package template
import "github.com/any-hub/any-hub/internal/hubmodule"
//
// 使用方式:复制整个目录到 internal/hubmodule/<module-key>/ 并替换字段。
// - 将 TemplateModule 重命名为实际模块类型。
// - 在 init() 中调用 hubmodule.MustRegister注册新的 ModuleMetadata。
// - 在模块目录中实现自定义代理/缓存逻辑,然后在 main 中调用 proxy.RegisterModuleHandler
// - 在 init() 中调用 hubmodule.MustRegister 注册新的 ModuleMetadata。
// - 在模块目录中实现自定义 Hook见 hook_example_test.go 中的示例),然后在 main/init 中调用 hooks.MustRegister + proxy.RegisterModule。
//
// 注意:本文件仅示例 metadata 注册写法,不会参与编译。
var _ = hubmodule.ModuleMetadata{}

View File

@@ -0,0 +1,73 @@
package template
import (
"net/http"
"testing"
"github.com/any-hub/any-hub/internal/proxy/hooks"
)
// This test shows a full hook lifecycle that module authors can copy when creating a new hook.
func TestTemplateHookFlow(t *testing.T) {
baseURL := "https://example.com"
ctx := &hooks.RequestContext{
HubName: "demo",
ModuleKey: "template",
}
h := hooks.Hooks{
NormalizePath: func(_ *hooks.RequestContext, clean string, rawQuery []byte) (string, []byte) {
return "/normalized" + clean, rawQuery
},
ResolveUpstream: func(_ *hooks.RequestContext, upstream string, clean string, rawQuery []byte) string {
if len(rawQuery) > 0 {
return upstream + clean + "?" + string(rawQuery)
}
return upstream + clean
},
CachePolicy: func(_ *hooks.RequestContext, path string, current hooks.CachePolicy) hooks.CachePolicy {
current.AllowCache = path != ""
current.AllowStore = true
return current
},
ContentType: func(_ *hooks.RequestContext, path string) string {
if path == "/normalized/index.json" {
return "application/json"
}
return ""
},
RewriteResponse: func(_ *hooks.RequestContext, status int, headers map[string]string, body []byte, _ string) (int, map[string]string, []byte, error) {
if headers == nil {
headers = map[string]string{}
}
headers["X-Demo"] = "ok"
return status, headers, body, nil
},
}
normalized, _ := h.NormalizePath(ctx, "/index.json", nil)
if normalized != "/normalized/index.json" {
t.Fatalf("expected normalized path, got %s", normalized)
}
u := h.ResolveUpstream(ctx, baseURL, normalized, nil)
if u != baseURL+normalized {
t.Fatalf("expected upstream %s, got %s", baseURL+normalized, u)
}
policy := h.CachePolicy(ctx, normalized, hooks.CachePolicy{})
if !policy.AllowCache || !policy.AllowStore {
t.Fatalf("expected policy to allow cache/store, got %#v", policy)
}
status, headers, body, err := h.RewriteResponse(ctx, http.StatusOK, map[string]string{}, []byte("ok"), normalized)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if headers["X-Demo"] != "ok" {
t.Fatalf("expected rewrite to set header, got %s", headers["X-Demo"])
}
if status != http.StatusOK || string(body) != "ok" {
t.Fatalf("expected unchanged status/body, got %d/%s", status, string(body))
}
if ct := h.ContentType(ctx, normalized); ct != "application/json" {
t.Fatalf("expected content type application/json, got %s", ct)
}
}

View File

@@ -11,13 +11,14 @@ func BaseFields(action, configPath string) logrus.Fields {
}
// RequestFields 提供 hub/domain/命中状态字段,供代理请求日志复用。
func RequestFields(hub, domain, hubType, authMode, moduleKey, rolloutFlag string, cacheHit bool) logrus.Fields {
func RequestFields(hub, domain, hubType, authMode, moduleKey, rolloutFlag string, cacheHit bool, legacyOnly bool) logrus.Fields {
return logrus.Fields{
"hub": hub,
"domain": domain,
"hub_type": hubType,
"auth_mode": authMode,
"cache_hit": cacheHit,
"legacy_only": legacyOnly,
"module_key": moduleKey,
"rollout_flag": rolloutFlag,
}

View File

@@ -1,68 +0,0 @@
package proxy
import (
"net/url"
"testing"
"github.com/any-hub/any-hub/internal/config"
"github.com/any-hub/any-hub/internal/server"
)
func TestApplyDockerHubNamespaceFallback(t *testing.T) {
route := dockerHubRoute(t, "https://registry-1.docker.io")
path, changed := applyDockerHubNamespaceFallback(route, "/v2/nginx/manifests/latest")
if !changed {
t.Fatalf("expected fallback to apply")
}
if path != "/v2/library/nginx/manifests/latest" {
t.Fatalf("unexpected normalized path: %s", path)
}
path, changed = applyDockerHubNamespaceFallback(route, "/v2/library/nginx/manifests/latest")
if changed {
t.Fatalf("expected no changes for already-namespaced repo")
}
path, changed = applyDockerHubNamespaceFallback(route, "/v2/rogee/nginx/manifests/latest")
if changed {
t.Fatalf("expected no changes for custom namespace")
}
path, changed = applyDockerHubNamespaceFallback(route, "/v2/_catalog")
if changed {
t.Fatalf("expected no changes for _catalog endpoint")
}
otherRoute := dockerHubRoute(t, "https://registry.example.com")
path, changed = applyDockerHubNamespaceFallback(otherRoute, "/v2/nginx/manifests/latest")
if changed || path != "/v2/nginx/manifests/latest" {
t.Fatalf("expected no changes for non-docker-hub upstream")
}
}
func TestSplitDockerRepoPath(t *testing.T) {
repo, rest, ok := splitDockerRepoPath("/v2/library/nginx/manifests/latest")
if !ok || repo != "library/nginx" || rest != "/manifests/latest" {
t.Fatalf("unexpected split result repo=%s rest=%s ok=%v", repo, rest, ok)
}
if _, _, ok := splitDockerRepoPath("/v2/_catalog"); ok {
t.Fatalf("expected catalog path to be ignored")
}
}
func dockerHubRoute(t *testing.T, upstream string) *server.HubRoute {
t.Helper()
parsed, err := url.Parse(upstream)
if err != nil {
t.Fatalf("invalid upstream: %v", err)
}
return &server.HubRoute{
Config: config.HubConfig{
Name: "docker",
Type: "docker",
},
UpstreamURL: parsed,
}
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/gofiber/fiber/v3"
"github.com/sirupsen/logrus"
"github.com/any-hub/any-hub/internal/hubmodule"
"github.com/any-hub/any-hub/internal/logging"
"github.com/any-hub/any-hub/internal/server"
)
@@ -37,39 +38,48 @@ func RegisterModuleHandler(key string, handler server.ProxyHandler) {
// Handle 实现 server.ProxyHandler根据 route.ModuleKey 选择 handler。
func (f *Forwarder) Handle(c fiber.Ctx, route *server.HubRoute) error {
requestID := server.RequestID(c)
handler := f.lookup(route)
if handler == nil {
return f.respondMissingHandler(c, route)
return f.respondMissingHandler(c, route, requestID)
}
return f.invokeHandler(c, route, handler)
return f.invokeHandler(c, route, handler, requestID)
}
func (f *Forwarder) respondMissingHandler(c fiber.Ctx, route *server.HubRoute) error {
f.logModuleError(route, "module_handler_missing", nil)
func (f *Forwarder) respondMissingHandler(c fiber.Ctx, route *server.HubRoute, requestID string) error {
f.logModuleError(route, "module_handler_missing", nil, requestID)
setRequestIDHeader(c, requestID)
return c.Status(fiber.StatusInternalServerError).
JSON(fiber.Map{"error": "module_handler_missing"})
}
func (f *Forwarder) invokeHandler(c fiber.Ctx, route *server.HubRoute, handler server.ProxyHandler) (err error) {
func (f *Forwarder) invokeHandler(c fiber.Ctx, route *server.HubRoute, handler server.ProxyHandler, requestID string) (err error) {
defer func() {
if r := recover(); r != nil {
err = f.respondHandlerPanic(c, route, r)
err = f.respondHandlerPanic(c, route, r, requestID)
}
}()
return handler.Handle(c, route)
}
func (f *Forwarder) respondHandlerPanic(c fiber.Ctx, route *server.HubRoute, recovered interface{}) error {
f.logModuleError(route, "module_handler_panic", fmt.Errorf("panic: %v", recovered))
func (f *Forwarder) respondHandlerPanic(c fiber.Ctx, route *server.HubRoute, recovered interface{}, requestID string) error {
f.logModuleError(route, "module_handler_panic", fmt.Errorf("panic: %v", recovered), requestID)
setRequestIDHeader(c, requestID)
return c.Status(fiber.StatusInternalServerError).
JSON(fiber.Map{"error": "module_handler_panic"})
}
func (f *Forwarder) logModuleError(route *server.HubRoute, code string, err error) {
func setRequestIDHeader(c fiber.Ctx, requestID string) {
if requestID != "" {
c.Set("X-Request-ID", requestID)
}
}
func (f *Forwarder) logModuleError(route *server.HubRoute, code string, err error, requestID string) {
if f.logger == nil {
return
}
fields := f.routeFields(route)
fields := f.routeFields(route, requestID)
fields["action"] = "proxy"
fields["error"] = code
if err != nil {
@@ -105,7 +115,7 @@ func normalizeModuleKey(key string) string {
return strings.ToLower(strings.TrimSpace(key))
}
func (f *Forwarder) routeFields(route *server.HubRoute) logrus.Fields {
func (f *Forwarder) routeFields(route *server.HubRoute, requestID string) logrus.Fields {
if route == nil {
return logrus.Fields{
"hub": "",
@@ -117,7 +127,7 @@ func (f *Forwarder) routeFields(route *server.HubRoute) logrus.Fields {
}
}
return logging.RequestFields(
fields := logging.RequestFields(
route.Config.Name,
route.Config.Domain,
route.Config.Type,
@@ -125,5 +135,10 @@ func (f *Forwarder) routeFields(route *server.HubRoute) logrus.Fields {
route.ModuleKey,
string(route.RolloutFlag),
false,
route.ModuleKey == hubmodule.DefaultModuleKey(),
)
if requestID != "" {
fields["request_id"] = requestID
}
return fields
}

View File

@@ -14,12 +14,15 @@ import (
"github.com/any-hub/any-hub/internal/server"
)
const requestIDKey = "_anyhub_request_id"
func TestForwarderMissingHandler(t *testing.T) {
app := fiber.New()
defer app.Shutdown()
ctx := app.AcquireCtx(new(fasthttp.RequestCtx))
defer app.ReleaseCtx(ctx)
ctx.Locals(requestIDKey, "missing-req")
logger := logrus.New()
logBuf := &bytes.Buffer{}
@@ -40,6 +43,12 @@ func TestForwarderMissingHandler(t *testing.T) {
if !strings.Contains(logBuf.String(), "module_handler_missing") {
t.Fatalf("expected log to mention module_handler_missing, got %s", logBuf.String())
}
if got := string(ctx.Response().Header.Peek("X-Request-ID")); got != "missing-req" {
t.Fatalf("expected request id header missing-req, got %s", got)
}
if !strings.Contains(logBuf.String(), "missing-req") {
t.Fatalf("expected log to include request id, got %s", logBuf.String())
}
}
func TestForwarderHandlerPanic(t *testing.T) {
@@ -58,6 +67,7 @@ func TestForwarderHandlerPanic(t *testing.T) {
defer app.Shutdown()
ctx := app.AcquireCtx(new(fasthttp.RequestCtx))
defer app.ReleaseCtx(ctx)
ctx.Locals(requestIDKey, "panic-req")
logger := logrus.New()
logBuf := &bytes.Buffer{}
@@ -78,6 +88,12 @@ func TestForwarderHandlerPanic(t *testing.T) {
if !strings.Contains(logBuf.String(), "module_handler_panic") {
t.Fatalf("expected log to mention module_handler_panic, got %s", logBuf.String())
}
if got := string(ctx.Response().Header.Peek("X-Request-ID")); got != "panic-req" {
t.Fatalf("expected request id header panic-req, got %s", got)
}
if !strings.Contains(logBuf.String(), "panic-req") {
t.Fatalf("expected log to include panic request id, got %s", logBuf.String())
}
}
func testRouteWithModule(moduleKey string) *server.HubRoute {

View File

@@ -36,6 +36,14 @@ type Handler struct {
etags sync.Map // key: hub+path, value: etag/digest string
}
type hookState struct {
ctx *hooks.RequestContext
def hooks.Hooks
hasHooks bool
clean string
rawQuery []byte
}
// NewHandler constructs a proxy handler with shared HTTP client/logger/store.
func NewHandler(client *http.Client, logger *logrus.Logger, store cache.Store) *Handler {
return &Handler{
@@ -45,23 +53,58 @@ func NewHandler(client *http.Client, logger *logrus.Logger, store cache.Store) *
}
}
func buildHookContext(route *server.HubRoute, c fiber.Ctx) *hooks.RequestContext {
if route == nil {
return &hooks.RequestContext{Method: c.Method()}
}
baseHost := ""
if route.UpstreamURL != nil {
baseHost = route.UpstreamURL.Host
}
return &hooks.RequestContext{
HubName: route.Config.Name,
Domain: route.Config.Domain,
HubType: route.Config.Type,
ModuleKey: route.ModuleKey,
RolloutFlag: string(route.RolloutFlag),
UpstreamHost: baseHost,
Method: c.Method(),
}
}
func hasHook(def hooks.Hooks) bool {
return def.NormalizePath != nil ||
def.ResolveUpstream != nil ||
def.RewriteResponse != nil ||
def.CachePolicy != nil ||
def.ContentType != nil
}
// Handle 执行缓存查找、条件回源和最终 streaming 逻辑,任何阶段出错都会输出结构化日志。
func (h *Handler) Handle(c fiber.Ctx, route *server.HubRoute) error {
started := time.Now()
requestID := server.RequestID(c)
reqCtx := newRequestContext(route, c.Method())
moduleHooks, _ := hooks.For(route.ModuleKey)
locator := buildLocator(route, c, reqCtx, moduleHooks)
policy := determineCachePolicy(route, locator, c.Method(), reqCtx, moduleHooks)
strategyWriter := cache.NewStrategyWriter(h.store, route.CacheStrategy)
if err := ensureProxyHubType(route); err != nil {
h.logger.WithFields(logrus.Fields{
"hub": route.Config.Name,
"module_key": route.ModuleKey,
}).WithError(err).Error("hub_type_unsupported")
return h.writeError(c, fiber.StatusNotImplemented, "hub_type_unsupported")
hooksDef, ok := hooks.Fetch(route.ModuleKey)
hookCtx := buildHookContext(route, c)
rawQuery := append([]byte(nil), c.Request().URI().QueryString()...)
cleanPath := normalizeRequestPath(route, string(c.Request().URI().Path()))
if hasHook(hooksDef) && hooksDef.NormalizePath != nil {
newPath, newQuery := hooksDef.NormalizePath(hookCtx, cleanPath, rawQuery)
if newPath != "" {
cleanPath = newPath
}
rawQuery = newQuery
}
locator := buildLocator(route, c, cleanPath, rawQuery)
policy := determineCachePolicyWithHook(route, locator, c.Method(), hooksDef, ok, hookCtx)
hookState := hookState{
ctx: hookCtx,
def: hooksDef,
hasHooks: ok && hasHook(hooksDef),
clean: cleanPath,
rawQuery: rawQuery,
}
strategyWriter := cache.NewStrategyWriter(h.store, route.CacheStrategy)
ctx := c.Context()
if ctx == nil {
@@ -89,7 +132,7 @@ func (h *Handler) Handle(c fiber.Ctx, route *server.HubRoute) error {
if strategyWriter.ShouldBypassValidation(cached.Entry) {
serve = true
} else if strategyWriter.SupportsValidation() {
fresh, err := h.isCacheFresh(c, route, locator, cached.Entry)
fresh, err := h.isCacheFresh(c, route, locator, cached.Entry, &hookState)
if err != nil {
h.logger.WithError(err).
WithFields(logrus.Fields{"hub": route.Config.Name, "module_key": route.ModuleKey}).
@@ -104,12 +147,12 @@ func (h *Handler) Handle(c fiber.Ctx, route *server.HubRoute) error {
}
if serve {
defer cached.Reader.Close()
return h.serveCache(c, route, cached, requestID, started)
return h.serveCache(c, route, cached, requestID, started, &hookState)
}
cached.Reader.Close()
}
return h.fetchAndStream(c, route, locator, policy, strategyWriter, requestID, started, ctx, reqCtx, moduleHooks)
return h.fetchAndStream(c, route, locator, policy, strategyWriter, requestID, started, ctx, &hookState)
}
func (h *Handler) serveCache(
@@ -118,6 +161,7 @@ func (h *Handler) serveCache(
result *cache.ReadResult,
requestID string,
started time.Time,
hook *hookState,
) error {
var readSeeker io.ReadSeeker
switch reader := result.Reader.(type) {
@@ -130,44 +174,12 @@ func (h *Handler) serveCache(
method := c.Method()
contentType := inferCachedContentType(route, result.Entry.Locator)
if contentType == "" && shouldSniffDockerManifest(route, result.Entry.Locator) {
contentType := resolveContentType(route, result.Entry.Locator, hook)
if contentType == "" && shouldSniffDockerManifest(result.Entry.Locator) {
if sniffed := sniffDockerManifestContentType(readSeeker); sniffed != "" {
contentType = sniffed
}
}
if route != nil && route.Config.Type == "composer" && isComposerMetadataPath(stripQueryMarker(result.Entry.Locator.Path)) {
body, err := io.ReadAll(result.Reader)
result.Reader.Close()
if err != nil {
return fiber.NewError(fiber.StatusBadGateway, fmt.Sprintf("read cache failed: %v", err))
}
rewritten := body
if stripQueryMarker(result.Entry.Locator.Path) == "/packages.json" {
if data, changed, err := rewriteComposerRootBody(body, route.Config.Domain); err == nil && changed {
rewritten = data
}
} else {
if data, changed, err := rewriteComposerMetadata(body, route.Config.Domain); err == nil && changed {
rewritten = data
}
}
c.Set("Content-Type", "application/json")
c.Set("X-Any-Hub-Upstream", route.UpstreamURL.String())
c.Set("X-Any-Hub-Cache-Hit", "true")
if requestID != "" {
c.Set("X-Request-ID", requestID)
}
c.Status(fiber.StatusOK)
c.Response().Header.SetContentLength(len(rewritten))
_, err = c.Response().BodyWriter().Write(rewritten)
h.logResult(route, route.UpstreamURL.String(), requestID, fiber.StatusOK, true, started, err)
if err != nil {
return fiber.NewError(fiber.StatusBadGateway, fmt.Sprintf("read cache failed: %v", err))
}
return nil
}
if contentType != "" {
c.Set("Content-Type", contentType)
} else {
@@ -214,35 +226,27 @@ func (h *Handler) fetchAndStream(
requestID string,
started time.Time,
ctx context.Context,
hook *hookState,
) error {
resp, upstreamURL, err := h.executeRequest(c, route)
resp, upstreamURL, err := h.executeRequest(c, route, hook)
if err != nil {
h.logResult(route, upstreamURL.String(), requestID, 0, false, started, err)
return h.writeError(c, fiber.StatusBadGateway, "upstream_failed")
}
resp, upstreamURL, err = h.retryOnAuthFailure(c, route, requestID, started, resp, upstreamURL)
resp, upstreamURL, err = h.retryOnAuthFailure(c, route, requestID, started, resp, upstreamURL, hook)
if err != nil {
h.logResult(route, upstreamURL.String(), requestID, 0, false, started, err)
return h.writeError(c, fiber.StatusBadGateway, "upstream_failed")
}
if route.Config.Type == "pypi" {
if rewritten, rewriteErr := h.rewritePyPIResponse(route, resp, requestPath(c)); rewriteErr == nil {
if hook != nil && hook.hasHooks && hook.def.RewriteResponse != nil {
if rewritten, rewriteErr := applyHookRewrite(hook, resp, requestPath(c)); rewriteErr == nil {
resp = rewritten
} else {
h.logger.WithError(rewriteErr).WithFields(logrus.Fields{
"action": "pypi_rewrite",
"action": "hook_rewrite",
"hub": route.Config.Name,
}).Warn("pypi_rewrite_failed")
}
} else if route.Config.Type == "composer" {
if rewritten, rewriteErr := h.rewriteComposerResponse(route, resp, requestPath(c)); rewriteErr == nil {
resp = rewritten
} else {
h.logger.WithError(rewriteErr).WithFields(logrus.Fields{
"action": "composer_rewrite",
"hub": route.Config.Name,
}).Warn("composer_rewrite_failed")
}).Warn("hook_rewrite_failed")
}
}
defer resp.Body.Close()
@@ -252,6 +256,42 @@ func (h *Handler) fetchAndStream(
return h.consumeUpstream(c, route, locator, resp, shouldStore, writer, requestID, started, ctx)
}
func applyHookRewrite(hook *hookState, resp *http.Response, path string) (*http.Response, error) {
if hook == nil || hook.def.RewriteResponse == nil {
return resp, nil
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return nil, err
}
headers := make(map[string]string, len(resp.Header))
for key, values := range resp.Header {
if len(values) > 0 {
headers[key] = values[0]
}
}
status, newHeaders, newBody, rewriteErr := hook.def.RewriteResponse(hook.ctx, resp.StatusCode, headers, body, path)
if rewriteErr != nil {
return nil, rewriteErr
}
if newHeaders == nil {
newHeaders = headers
}
if newBody == nil {
newBody = body
}
cloned := *resp
cloned.StatusCode = status
cloned.Header = make(http.Header, len(newHeaders))
for key, value := range newHeaders {
cloned.Header.Set(key, value)
}
cloned.Body = io.NopCloser(bytes.NewReader(newBody))
cloned.ContentLength = int64(len(newBody))
return &cloned, nil
}
func (h *Handler) consumeUpstream(
c fiber.Ctx,
route *server.HubRoute,
@@ -335,6 +375,7 @@ func (h *Handler) retryOnAuthFailure(
started time.Time,
resp *http.Response,
upstreamURL *url.URL,
hook *hookState,
) (*http.Response, *url.URL, error) {
if !shouldRetryAuth(route, resp.StatusCode) {
return resp, upstreamURL, nil
@@ -354,30 +395,31 @@ func (h *Handler) retryOnAuthFailure(
return nil, upstreamURL, err
}
authHeader := "Bearer " + token
retryResp, retryURL, err := h.executeRequestWithAuth(c, route, authHeader)
retryResp, retryURL, err := h.executeRequestWithAuth(c, route, hook, authHeader)
if err != nil {
return nil, upstreamURL, err
}
return retryResp, retryURL, nil
}
retryResp, retryURL, err := h.executeRequest(c, route)
retryResp, retryURL, err := h.executeRequest(c, route, hook)
if err != nil {
return nil, upstreamURL, err
}
return retryResp, retryURL, nil
}
func (h *Handler) executeRequest(c fiber.Ctx, route *server.HubRoute) (*http.Response, *url.URL, error) {
return h.executeRequestWithAuth(c, route, "")
func (h *Handler) executeRequest(c fiber.Ctx, route *server.HubRoute, hook *hookState) (*http.Response, *url.URL, error) {
return h.executeRequestWithAuth(c, route, hook, "")
}
func (h *Handler) executeRequestWithAuth(
c fiber.Ctx,
route *server.HubRoute,
hook *hookState,
authHeader string,
) (*http.Response, *url.URL, error) {
upstreamURL := resolveUpstreamURL(route, route.UpstreamURL, c)
upstreamURL := resolveUpstreamURL(route, route.UpstreamURL, c, hook)
body := bytesReader(c.Body())
req, err := h.buildUpstreamRequest(c, upstreamURL, route, c.Method(), body, authHeader)
if err != nil {
@@ -469,6 +511,7 @@ func (h *Handler) logResult(
route.ModuleKey,
string(route.RolloutFlag),
cacheHit,
route.ModuleKey == hubmodule.DefaultModuleKey(),
)
fields["action"] = "proxy"
fields["upstream"] = upstream
@@ -506,47 +549,20 @@ func inferCachedContentType(route *server.HubRoute, locator cache.Locator) strin
return "application/x-tar"
}
if route != nil {
switch route.Config.Type {
case "docker":
if strings.Contains(clean, "/manifests/") {
return ""
}
if strings.Contains(clean, "/tags/list") {
return "application/json"
}
if strings.Contains(clean, "/blobs/") {
return "application/octet-stream"
}
case "npm":
if strings.HasSuffix(clean, ".json") {
return "application/json"
}
case "pypi":
if strings.Contains(clean, "/simple/") {
return "text/html"
}
}
}
return ""
}
func buildLocator(route *server.HubRoute, c fiber.Ctx) cache.Locator {
uri := c.Request().URI()
pathVal := string(uri.Path())
clean := normalizeRequestPath(route, pathVal)
if newPath, ok := applyPyPISimpleFallback(route, clean); ok {
clean = newPath
}
if newPath, ok := applyDockerHubNamespaceFallback(route, clean); ok {
clean = newPath
}
query := uri.QueryString()
if route != nil && route.Config.Type == "composer" && isComposerDistPath(clean) {
// composer dist URLs often embed per-request tokens; ignore query for cache key
query = nil
func resolveContentType(route *server.HubRoute, locator cache.Locator, hook *hookState) string {
if hook != nil && hook.hasHooks && hook.def.ContentType != nil {
if ct := hook.def.ContentType(hook.ctx, stripQueryMarker(locator.Path)); ct != "" {
return ct
}
}
return inferCachedContentType(route, locator)
}
func buildLocator(route *server.HubRoute, c fiber.Ctx, clean string, rawQuery []byte) cache.Locator {
query := rawQuery
if len(query) > 0 {
sum := sha1.Sum(query)
clean = fmt.Sprintf("%s/__qs/%s", clean, hex.EncodeToString(sum[:]))
@@ -580,10 +596,7 @@ func stripQueryMarker(p string) string {
return p
}
func shouldSniffDockerManifest(route *server.HubRoute, locator cache.Locator) bool {
if route == nil || route.Config.Type != "docker" {
return false
}
func shouldSniffDockerManifest(locator cache.Locator) bool {
clean := stripQueryMarker(locator.Path)
return strings.Contains(clean, "/manifests/")
}
@@ -631,11 +644,7 @@ func normalizeRequestPath(route *server.HubRoute, raw string) string {
if raw == "" {
raw = "/"
}
hasSlash := strings.HasSuffix(raw, "/")
clean := path.Clean("/" + raw)
if route != nil && route.Config.Type == "pypi" && hasSlash && clean != "/" && !strings.HasSuffix(clean, "/") {
clean += "/"
}
return clean
}
@@ -646,42 +655,28 @@ func bytesReader(b []byte) io.Reader {
return bytes.NewReader(b)
}
func resolveUpstreamURL(route *server.HubRoute, base *url.URL, c fiber.Ctx) *url.URL {
func resolveUpstreamURL(route *server.HubRoute, base *url.URL, c fiber.Ctx, hook *hookState) *url.URL {
uri := c.Request().URI()
pathVal := string(uri.Path())
clean := normalizeRequestPath(route, pathVal)
if newPath, ok := applyPyPISimpleFallback(route, clean); ok {
clean = newPath
}
if newPath, ok := applyDockerHubNamespaceFallback(route, clean); ok {
clean = newPath
}
if route != nil && route.Config.Type == "pypi" && strings.HasPrefix(clean, "/files/") {
trimmed := strings.TrimPrefix(clean, "/files/")
parts := strings.SplitN(trimmed, "/", 3)
if len(parts) >= 3 {
scheme := parts[0]
host := parts[1]
rest := parts[2]
filesBase := &url.URL{Scheme: scheme, Host: host}
if !strings.HasPrefix(rest, "/") {
rest = "/" + rest
}
relative := &url.URL{Path: rest, RawPath: rest}
if query := string(uri.QueryString()); query != "" {
relative.RawQuery = query
}
return filesBase.ResolveReference(relative)
rawQuery := append([]byte(nil), uri.QueryString()...)
clean := normalizeRequestPath(route, string(uri.Path()))
if hook != nil {
if hook.clean != "" {
clean = hook.clean
}
}
if route != nil && route.Config.Type == "composer" && strings.HasPrefix(clean, "/dist/") {
if distTarget, ok := parseComposerDistURL(clean, string(uri.QueryString())); ok {
return distTarget
if hook.rawQuery != nil {
rawQuery = hook.rawQuery
}
if hook.hasHooks && hook.def.ResolveUpstream != nil {
if u := hook.def.ResolveUpstream(hook.ctx, base.String(), clean, rawQuery); u != "" {
if parsed, err := url.Parse(u); err == nil {
return parsed
}
}
}
}
relative := &url.URL{Path: clean, RawPath: clean}
if query := string(uri.QueryString()); query != "" {
relative.RawQuery = query
if len(rawQuery) > 0 {
relative.RawQuery = string(rawQuery)
}
return base.ResolveReference(relative)
}
@@ -718,85 +713,27 @@ type cachePolicy struct {
requireRevalidate bool
}
func determineCachePolicyWithHook(route *server.HubRoute, locator cache.Locator, method string, def hooks.Hooks, enabled bool, ctx *hooks.RequestContext) cachePolicy {
base := determineCachePolicy(route, locator, method)
if !enabled || def.CachePolicy == nil {
return base
}
updated := def.CachePolicy(ctx, locator.Path, hooks.CachePolicy{
AllowCache: base.allowCache,
AllowStore: base.allowStore,
RequireRevalidate: base.requireRevalidate,
})
base.allowCache = updated.AllowCache
base.allowStore = updated.AllowStore
base.requireRevalidate = updated.RequireRevalidate
return base
}
func determineCachePolicy(route *server.HubRoute, locator cache.Locator, method string) cachePolicy {
if route == nil || method != http.MethodGet {
if method != http.MethodGet {
return cachePolicy{}
}
policy := cachePolicy{allowCache: true, allowStore: true}
path := stripQueryMarker(locator.Path)
switch route.Config.Type {
case "docker":
if path == "/v2" || path == "v2" || path == "/v2/" {
return cachePolicy{}
}
if strings.Contains(path, "/_catalog") {
return cachePolicy{}
}
if isDockerImmutablePath(path) {
return policy
}
policy.requireRevalidate = true
return policy
case "go":
if strings.Contains(path, "/@v/") &&
(strings.HasSuffix(path, ".zip") || strings.HasSuffix(path, ".mod") || strings.HasSuffix(path, ".info")) {
return policy
}
policy.requireRevalidate = true
return policy
case "npm":
if strings.Contains(path, "/-/") && strings.HasSuffix(path, ".tgz") {
return policy
}
policy.requireRevalidate = true
return policy
case "pypi":
if isPyPIDistribution(path) {
return policy
}
policy.requireRevalidate = true
return policy
case "composer":
if isComposerDistPath(path) {
return policy
}
if isComposerMetadataPath(path) {
policy.requireRevalidate = true
return policy
}
return cachePolicy{}
default:
return policy
}
}
func isDockerImmutablePath(path string) bool {
if strings.Contains(path, "/blobs/sha256:") {
return true
}
if strings.Contains(path, "/manifests/sha256:") {
return true
}
return false
}
func isPyPIDistribution(path string) bool {
switch {
case strings.HasSuffix(path, ".whl"):
return true
case strings.HasSuffix(path, ".tar.gz"):
return true
case strings.HasSuffix(path, ".tar.bz2"):
return true
case strings.HasSuffix(path, ".tgz"):
return true
case strings.HasSuffix(path, ".zip"):
return true
case strings.HasSuffix(path, ".egg"):
return true
default:
return false
}
return cachePolicy{allowCache: true, allowStore: true}
}
func isCacheableStatus(status int) bool {
@@ -808,13 +745,14 @@ func (h *Handler) isCacheFresh(
route *server.HubRoute,
locator cache.Locator,
entry cache.Entry,
hook *hookState,
) (bool, error) {
ctx := c.Context()
if ctx == nil {
ctx = context.Background()
}
upstreamURL := resolveUpstreamURL(route, route.UpstreamURL, c)
upstreamURL := resolveUpstreamURL(route, route.UpstreamURL, c, hook)
resp, err := h.revalidateRequest(c, route, upstreamURL, locator, "")
if err != nil {
return false, err
@@ -888,113 +826,6 @@ func extractModTime(header http.Header) time.Time {
return time.Now().UTC()
}
func applyDockerHubNamespaceFallback(route *server.HubRoute, path string) (string, bool) {
if !isDockerHubRoute(route) {
return path, false
}
repo, rest, ok := splitDockerRepoPath(path)
if !ok || repo == "" {
return path, false
}
if repo == "library" || strings.Contains(repo, "/") {
return path, false
}
normalized := "/v2/library/" + repo + rest
return normalized, true
}
func isDockerHubRoute(route *server.HubRoute) bool {
if route == nil || route.Config.Type != "docker" || route.UpstreamURL == nil {
return false
}
host := strings.ToLower(route.UpstreamURL.Hostname())
switch host {
case "registry-1.docker.io", "docker.io", "index.docker.io":
return true
default:
return false
}
}
func splitDockerRepoPath(path string) (string, string, bool) {
if !strings.HasPrefix(path, "/v2/") {
return "", "", false
}
suffix := strings.TrimPrefix(path, "/v2/")
if suffix == "" || suffix == "/" {
return "", "", false
}
segments := strings.Split(suffix, "/")
var repoSegments []string
for i, seg := range segments {
if seg == "" {
return "", "", false
}
switch seg {
case "manifests", "blobs", "tags", "referrers":
if len(repoSegments) == 0 {
return "", "", false
}
rest := "/" + strings.Join(segments[i:], "/")
return strings.Join(repoSegments, "/"), rest, true
case "_catalog":
return "", "", false
}
repoSegments = append(repoSegments, seg)
}
return "", "", false
}
func applyPyPISimpleFallback(route *server.HubRoute, path string) (string, bool) {
if route == nil || route.Config.Type != "pypi" {
return path, false
}
if strings.HasPrefix(path, "/simple/") || strings.HasPrefix(path, "/files/") {
return path, false
}
if strings.HasSuffix(path, ".whl") || strings.HasSuffix(path, ".tar.gz") || strings.HasSuffix(path, ".tar.bz2") ||
strings.HasSuffix(path, ".zip") {
return path, false
}
trimmed := strings.Trim(path, "/")
if trimmed == "" || strings.HasPrefix(trimmed, "_") {
return path, false
}
return "/simple/" + trimmed + "/", true
}
func parseComposerDistURL(path string, rawQuery string) (*url.URL, bool) {
if !strings.HasPrefix(path, "/dist/") {
return nil, false
}
trimmed := strings.TrimPrefix(path, "/dist/")
parts := strings.SplitN(trimmed, "/", 3)
if len(parts) < 3 {
return nil, false
}
scheme := parts[0]
host := parts[1]
rest := parts[2]
if scheme == "" || host == "" {
return nil, false
}
if rest == "" {
rest = "/"
} else {
rest = "/" + rest
}
target := &url.URL{
Scheme: scheme,
Host: host,
Path: rest,
RawPath: rest,
}
if rawQuery != "" {
target.RawQuery = rawQuery
}
return target, true
}
type bearerChallenge struct {
Realm string
Service string
@@ -1130,6 +961,7 @@ func (h *Handler) logAuthRetry(route *server.HubRoute, upstream string, requestI
route.ModuleKey,
string(route.RolloutFlag),
false,
route.ModuleKey == hubmodule.DefaultModuleKey(),
)
fields["action"] = "proxy_retry"
fields["upstream"] = upstream
@@ -1150,6 +982,7 @@ func (h *Handler) logAuthFailure(route *server.HubRoute, upstream string, reques
route.ModuleKey,
string(route.RolloutFlag),
false,
route.ModuleKey == hubmodule.DefaultModuleKey(),
)
fields["action"] = "proxy"
fields["upstream"] = upstream
@@ -1204,20 +1037,3 @@ func normalizeETag(value string) string {
}
return strings.Trim(value, "\"")
}
func ensureProxyHubType(route *server.HubRoute) error {
switch route.Config.Type {
case "docker":
return nil
case "npm":
return nil
case "go":
return nil
case "pypi":
return nil
case "composer":
return nil
default:
return fmt.Errorf("unsupported hub type: %s", route.Config.Type)
}
}

View File

@@ -0,0 +1,80 @@
package proxy
import (
"net/url"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/valyala/fasthttp"
"github.com/any-hub/any-hub/internal/cache"
"github.com/any-hub/any-hub/internal/config"
"github.com/any-hub/any-hub/internal/proxy/hooks"
"github.com/any-hub/any-hub/internal/server"
)
func TestResolveUpstreamPrefersHook(t *testing.T) {
app := fiber.New()
defer app.Shutdown()
ctx := app.AcquireCtx(new(fasthttp.RequestCtx))
defer app.ReleaseCtx(ctx)
ctx.Request().SetRequestURI("/original/path?from=req")
base, _ := url.Parse("https://up.example")
route := &server.HubRoute{
Config: config.HubConfig{
Name: "demo",
Type: "custom",
},
UpstreamURL: base,
}
hook := &hookState{
ctx: &hooks.RequestContext{},
def: hooks.Hooks{
NormalizePath: func(_ *hooks.RequestContext, clean string, rawQuery []byte) (string, []byte) {
return clean, rawQuery
},
ResolveUpstream: func(_ *hooks.RequestContext, upstream string, clean string, rawQuery []byte) string {
return upstream + "/hooked"
},
},
hasHooks: true,
clean: "/ignored",
rawQuery: []byte("ignored=1"),
}
target := resolveUpstreamURL(route, base, ctx, hook)
if target.String() != "https://up.example/hooked" {
t.Fatalf("expected hook override, got %s", target.String())
}
}
func TestCachePolicyHookOverrides(t *testing.T) {
route := &server.HubRoute{
Config: config.HubConfig{
Name: "demo",
Type: "npm",
},
}
locator := cacheLocatorForTest("demo", "/a.tgz")
hook := hooks.Hooks{
CachePolicy: func(_ *hooks.RequestContext, _ string, current hooks.CachePolicy) hooks.CachePolicy {
current.AllowCache = false
current.RequireRevalidate = false
return current
},
}
ctx := &hooks.RequestContext{Method: fiber.MethodGet}
policy := determineCachePolicyWithHook(route, locator, fiber.MethodGet, hook, true, ctx)
if policy.allowCache {
t.Fatalf("expected hook to disable cache")
}
if policy.requireRevalidate {
t.Fatalf("expected hook to disable revalidate")
}
}
func cacheLocatorForTest(hub, path string) cache.Locator {
return cache.Locator{HubName: hub, Path: path}
}

View File

@@ -1,12 +1,5 @@
package hooks
import (
"net/http"
"net/url"
"strings"
"sync"
)
// CachePolicy mirrors the proxy cache policy structure.
type CachePolicy struct {
AllowCache bool
@@ -27,34 +20,9 @@ type RequestContext struct {
// Hooks describes customization points for module-specific behavior.
type Hooks struct {
NormalizePath func(ctx *RequestContext, cleanPath string) string
ResolveUpstream func(ctx *RequestContext, base *url.URL, cleanPath string, rawQuery []byte) *url.URL
RewriteResponse func(ctx *RequestContext, resp *http.Response, cleanPath string) (*http.Response, error)
NormalizePath func(ctx *RequestContext, cleanPath string, rawQuery []byte) (string, []byte)
ResolveUpstream func(ctx *RequestContext, baseURL string, path string, rawQuery []byte) string
RewriteResponse func(ctx *RequestContext, status int, headers map[string]string, body []byte, path string) (int, map[string]string, []byte, error)
CachePolicy func(ctx *RequestContext, locatorPath string, current CachePolicy) CachePolicy
ContentType func(ctx *RequestContext, locatorPath string) string
}
var registry sync.Map
// Register stores hooks for the given module key.
func Register(moduleKey string, hooks Hooks) {
key := strings.ToLower(strings.TrimSpace(moduleKey))
if key == "" {
return
}
registry.Store(key, hooks)
}
// For retrieves hooks associated with a module key.
func For(moduleKey string) (Hooks, bool) {
key := strings.ToLower(strings.TrimSpace(moduleKey))
if key == "" {
return Hooks{}, false
}
if value, ok := registry.Load(key); ok {
if hooks, ok := value.(Hooks); ok {
return hooks, true
}
}
return Hooks{}, false
}

View File

@@ -0,0 +1,68 @@
package hooks
import (
"errors"
"strings"
"sync"
)
var registry sync.Map
// ErrDuplicateHook indicates a module key already has hooks registered.
var ErrDuplicateHook = errors.New("hook already registered")
// Register stores hooks for the given module key.
func Register(moduleKey string, hooks Hooks) error {
key := normalizeKey(moduleKey)
if key == "" {
return errors.New("module key required")
}
if _, loaded := registry.LoadOrStore(key, hooks); loaded {
return ErrDuplicateHook
}
return nil
}
// MustRegister panics on registration failure.
func MustRegister(moduleKey string, hooks Hooks) {
if err := Register(moduleKey, hooks); err != nil {
panic(err)
}
}
// Fetch retrieves hooks associated with a module key.
func Fetch(moduleKey string) (Hooks, bool) {
key := normalizeKey(moduleKey)
if key == "" {
return Hooks{}, false
}
if value, ok := registry.Load(key); ok {
if hooks, ok := value.(Hooks); ok {
return hooks, true
}
}
return Hooks{}, false
}
// Status returns hook registration status for a module key.
func Status(moduleKey string) string {
if _, ok := Fetch(moduleKey); ok {
return "registered"
}
return "missing"
}
// Snapshot returns status for a list of module keys.
func Snapshot(keys []string) map[string]string {
out := make(map[string]string, len(keys))
for _, key := range keys {
if normalized := normalizeKey(key); normalized != "" {
out[normalized] = Status(normalized)
}
}
return out
}
func normalizeKey(key string) string {
return strings.ToLower(strings.TrimSpace(key))
}

View File

@@ -0,0 +1,45 @@
package hooks
import (
"sync"
"testing"
)
func TestRegisterAndFetch(t *testing.T) {
registry = sync.Map{}
h := Hooks{ContentType: func(*RequestContext, string) string { return "ok" }}
if err := Register("test", h); err != nil {
t.Fatalf("register failed: %v", err)
}
if _, ok := Fetch("test"); !ok {
t.Fatalf("expected fetch ok")
}
if Status("test") != "registered" {
t.Fatalf("expected registered status")
}
if Status("missing") != "missing" {
t.Fatalf("expected missing status")
}
}
func TestRegisterDuplicate(t *testing.T) {
registry = sync.Map{}
if err := Register("dup", Hooks{}); err != nil {
t.Fatalf("first register failed: %v", err)
}
if err := Register("dup", Hooks{}); err != ErrDuplicateHook {
t.Fatalf("expected ErrDuplicateHook, got %v", err)
}
}
func TestSnapshot(t *testing.T) {
registry = sync.Map{}
_ = Register("a", Hooks{})
snap := Snapshot([]string{"a", "b"})
if snap["a"] != "registered" {
t.Fatalf("expected a registered, got %s", snap["a"])
}
if snap["b"] != "missing" {
t.Fatalf("expected b missing, got %s", snap["b"])
}
}

View File

@@ -1,126 +0,0 @@
package proxy
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"golang.org/x/net/html"
"github.com/any-hub/any-hub/internal/server"
)
func (h *Handler) rewritePyPIResponse(route *server.HubRoute, resp *http.Response, path string) (*http.Response, error) {
if resp == nil {
return resp, nil
}
if !strings.HasPrefix(path, "/simple") && path != "/" {
return resp, nil
}
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return resp, err
}
resp.Body.Close()
rewritten, contentType, err := rewritePyPIBody(bodyBytes, resp.Header.Get("Content-Type"), route.Config.Domain)
if err != nil {
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return resp, err
}
resp.Body = io.NopCloser(bytes.NewReader(rewritten))
resp.ContentLength = int64(len(rewritten))
resp.Header.Set("Content-Length", strconv.Itoa(len(rewritten)))
if contentType != "" {
resp.Header.Set("Content-Type", contentType)
}
resp.Header.Del("Content-Encoding")
return resp, nil
}
func rewritePyPIBody(body []byte, contentType string, domain string) ([]byte, string, error) {
lowerCT := strings.ToLower(contentType)
if strings.Contains(lowerCT, "application/vnd.pypi.simple.v1+json") || strings.HasPrefix(strings.TrimSpace(string(body)), "{") {
data := map[string]interface{}{}
if err := json.Unmarshal(body, &data); err != nil {
return body, contentType, err
}
if files, ok := data["files"].([]interface{}); ok {
for _, entry := range files {
if fileMap, ok := entry.(map[string]interface{}); ok {
if urlValue, ok := fileMap["url"].(string); ok {
fileMap["url"] = rewritePyPIFileURL(domain, urlValue)
}
}
}
}
rewriteBytes, err := json.Marshal(data)
if err != nil {
return body, contentType, err
}
return rewriteBytes, "application/vnd.pypi.simple.v1+json", nil
}
rewrittenHTML, err := rewritePyPIHTML(body, domain)
if err != nil {
return body, contentType, err
}
return rewrittenHTML, "text/html; charset=utf-8", nil
}
func rewritePyPIHTML(body []byte, domain string) ([]byte, error) {
node, err := html.Parse(bytes.NewReader(body))
if err != nil {
return nil, err
}
rewriteHTMLNode(node, domain)
var buf bytes.Buffer
if err := html.Render(&buf, node); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func rewriteHTMLNode(n *html.Node, domain string) {
if n.Type == html.ElementNode {
rewriteHTMLAttributes(n, domain)
}
for child := n.FirstChild; child != nil; child = child.NextSibling {
rewriteHTMLNode(child, domain)
}
}
func rewriteHTMLAttributes(n *html.Node, domain string) {
for i, attr := range n.Attr {
switch attr.Key {
case "href", "data-dist-info-metadata", "data-core-metadata":
if strings.HasPrefix(attr.Val, "http://") || strings.HasPrefix(attr.Val, "https://") {
n.Attr[i].Val = rewritePyPIFileURL(domain, attr.Val)
}
}
}
}
func rewritePyPIFileURL(domain, original string) string {
parsed, err := url.Parse(original)
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
return original
}
prefix := "/files/" + parsed.Scheme + "/" + parsed.Host
newURL := url.URL{
Scheme: "https",
Host: domain,
Path: prefix + parsed.Path,
RawQuery: parsed.RawQuery,
Fragment: parsed.Fragment,
}
if raw := parsed.RawPath; raw != "" {
newURL.RawPath = prefix + raw
}
return newURL.String()
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/gofiber/fiber/v3"
"github.com/any-hub/any-hub/internal/hubmodule"
"github.com/any-hub/any-hub/internal/proxy/hooks"
"github.com/any-hub/any-hub/internal/server"
)
@@ -18,9 +19,11 @@ func RegisterModuleRoutes(app *fiber.App, registry *server.HubRegistry) {
}
app.Get("/-/modules", func(c fiber.Ctx) error {
hookStatus := hooks.Snapshot(hubmodule.Keys())
payload := fiber.Map{
"modules": encodeModules(hubmodule.List()),
"hubs": encodeHubBindings(registry.List()),
"modules": encodeModules(hubmodule.List(), hookStatus),
"hubs": encodeHubBindings(registry.List()),
"hook_registry": hookStatus,
}
return c.JSON(payload)
})
@@ -34,35 +37,39 @@ func RegisterModuleRoutes(app *fiber.App, registry *server.HubRegistry) {
if !ok {
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "module_not_found"})
}
return c.JSON(encodeModule(meta))
encoded := encodeModule(meta)
encoded.HookStatus = hooks.Status(key)
return c.JSON(encoded)
})
}
type modulePayload struct {
Key string `json:"key"`
Description string `json:"description"`
Key string `json:"key"`
Description string `json:"description"`
MigrationState hubmodule.MigrationState `json:"migration_state"`
SupportedProtocols []string `json:"supported_protocols"`
CacheStrategy cacheStrategyPayload `json:"cache_strategy"`
SupportedProtocols []string `json:"supported_protocols"`
CacheStrategy cacheStrategyPayload `json:"cache_strategy"`
HookStatus string `json:"hook_status,omitempty"`
}
type cacheStrategyPayload struct {
TTLSeconds int64 `json:"ttl_seconds"`
ValidationMode string `json:"validation_mode"`
DiskLayout string `json:"disk_layout"`
RequiresMetadataFile bool `json:"requires_metadata_file"`
SupportsStreamingWrite bool `json:"supports_streaming_write"`
TTLSeconds int64 `json:"ttl_seconds"`
ValidationMode string `json:"validation_mode"`
DiskLayout string `json:"disk_layout"`
RequiresMetadataFile bool `json:"requires_metadata_file"`
SupportsStreamingWrite bool `json:"supports_streaming_write"`
}
type hubBindingPayload struct {
HubName string `json:"hub_name"`
ModuleKey string `json:"module_key"`
Domain string `json:"domain"`
Port int `json:"port"`
Rollout string `json:"rollout_flag"`
HubName string `json:"hub_name"`
ModuleKey string `json:"module_key"`
Domain string `json:"domain"`
Port int `json:"port"`
Rollout string `json:"rollout_flag"`
Legacy bool `json:"legacy_only"`
}
func encodeModules(mods []hubmodule.ModuleMetadata) []modulePayload {
func encodeModules(mods []hubmodule.ModuleMetadata, status map[string]string) []modulePayload {
if len(mods) == 0 {
return nil
}
@@ -71,7 +78,11 @@ func encodeModules(mods []hubmodule.ModuleMetadata) []modulePayload {
})
result := make([]modulePayload, 0, len(mods))
for _, meta := range mods {
result = append(result, encodeModule(meta))
item := encodeModule(meta)
if s, ok := status[meta.Key]; ok {
item.HookStatus = s
}
result = append(result, item)
}
return result
}
@@ -84,10 +95,10 @@ func encodeModule(meta hubmodule.ModuleMetadata) modulePayload {
MigrationState: meta.MigrationState,
SupportedProtocols: append([]string(nil), meta.SupportedProtocols...),
CacheStrategy: cacheStrategyPayload{
TTLSeconds: int64(strategy.TTLHint / time.Second),
ValidationMode: string(strategy.ValidationMode),
DiskLayout: strategy.DiskLayout,
RequiresMetadataFile: strategy.RequiresMetadataFile,
TTLSeconds: int64(strategy.TTLHint / time.Second),
ValidationMode: string(strategy.ValidationMode),
DiskLayout: strategy.DiskLayout,
RequiresMetadataFile: strategy.RequiresMetadataFile,
SupportsStreamingWrite: strategy.SupportsStreamingWrite,
},
}
@@ -108,6 +119,7 @@ func encodeHubBindings(routes []server.HubRoute) []hubBindingPayload {
Domain: route.Config.Domain,
Port: route.ListenPort,
Rollout: string(route.RolloutFlag),
Legacy: route.ModuleKey == hubmodule.DefaultModuleKey(),
})
}
return result

View File

@@ -0,0 +1,67 @@
package routes
import (
"testing"
"time"
"github.com/any-hub/any-hub/internal/hubmodule"
"github.com/any-hub/any-hub/internal/proxy/hooks"
)
func TestEncodeModulesAddsHookStatus(t *testing.T) {
modules := []hubmodule.ModuleMetadata{
{
Key: "b",
CacheStrategy: hubmodule.CacheStrategyProfile{
TTLHint: time.Hour,
ValidationMode: hubmodule.ValidationModeNever,
DiskLayout: "flat",
},
},
{
Key: "a",
CacheStrategy: hubmodule.CacheStrategyProfile{
TTLHint: time.Minute,
ValidationMode: hubmodule.ValidationModeNever,
DiskLayout: "flat",
},
},
}
status := map[string]string{"a": "registered"}
encoded := encodeModules(modules, status)
if len(encoded) != 2 {
t.Fatalf("expected 2 modules, got %d", len(encoded))
}
if encoded[0].Key != "a" {
t.Fatalf("expected sorted module key a first, got %s", encoded[0].Key)
}
if encoded[0].HookStatus != "registered" {
t.Fatalf("expected hook status registered for a, got %s", encoded[0].HookStatus)
}
if encoded[1].Key != "b" {
t.Fatalf("expected second module key b, got %s", encoded[1].Key)
}
if encoded[1].HookStatus != "" {
t.Fatalf("expected empty hook status for b, got %s", encoded[1].HookStatus)
}
}
func TestEncodeModuleAddsStatusForDetail(t *testing.T) {
key := "module-routes-test"
_ = hooks.Register(key, hooks.Hooks{})
meta := hubmodule.ModuleMetadata{
Key: key,
CacheStrategy: hubmodule.CacheStrategyProfile{
TTLHint: time.Minute,
ValidationMode: hubmodule.ValidationModeNever,
DiskLayout: "flat",
},
}
payload := encodeModule(meta)
payload.HookStatus = hooks.Status(key)
if payload.HookStatus != "registered" {
t.Fatalf("expected hook status registered, got %s", payload.HookStatus)
}
}