update
This commit is contained in:
@@ -1,291 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/any-hub/any-hub/internal/server"
|
||||
)
|
||||
|
||||
func (h *Handler) rewriteComposerResponse(route *server.HubRoute, resp *http.Response, path string) (*http.Response, error) {
|
||||
if resp == nil || route == nil || route.Config.Type != "composer" {
|
||||
return resp, nil
|
||||
}
|
||||
if path == "/packages.json" {
|
||||
return rewriteComposerRoot(resp, route.Config.Domain)
|
||||
}
|
||||
if !isComposerMetadataPath(path) {
|
||||
return resp, nil
|
||||
}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
rewritten, changed, err := rewriteComposerMetadata(body, route.Config.Domain)
|
||||
if err != nil {
|
||||
resp.Body = io.NopCloser(bytes.NewReader(body))
|
||||
return resp, err
|
||||
}
|
||||
if !changed {
|
||||
resp.Body = io.NopCloser(bytes.NewReader(body))
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewReader(rewritten))
|
||||
resp.ContentLength = int64(len(rewritten))
|
||||
resp.Header.Set("Content-Length", strconv.Itoa(len(rewritten)))
|
||||
resp.Header.Set("Content-Type", "application/json")
|
||||
resp.Header.Del("Content-Encoding")
|
||||
resp.Header.Del("Etag")
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func rewriteComposerRoot(resp *http.Response, domain string) (*http.Response, error) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
data, changed, err := rewriteComposerRootBody(body, domain)
|
||||
if err != nil {
|
||||
resp.Body = io.NopCloser(bytes.NewReader(body))
|
||||
return resp, err
|
||||
}
|
||||
if !changed {
|
||||
resp.Body = io.NopCloser(bytes.NewReader(body))
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewReader(data))
|
||||
resp.ContentLength = int64(len(data))
|
||||
resp.Header.Set("Content-Length", strconv.Itoa(len(data)))
|
||||
resp.Header.Set("Content-Type", "application/json")
|
||||
resp.Header.Del("Content-Encoding")
|
||||
resp.Header.Del("Etag")
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func rewriteComposerMetadata(body []byte, domain string) ([]byte, bool, error) {
|
||||
type packagesRoot struct {
|
||||
Packages map[string]json.RawMessage `json:"packages"`
|
||||
}
|
||||
var root packagesRoot
|
||||
if err := json.Unmarshal(body, &root); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if len(root.Packages) == 0 {
|
||||
return body, false, nil
|
||||
}
|
||||
|
||||
changed := false
|
||||
for name, raw := range root.Packages {
|
||||
updated, rewritten, err := rewriteComposerPackagesPayload(raw, domain, name)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if rewritten {
|
||||
root.Packages[name] = updated
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if !changed {
|
||||
return body, false, nil
|
||||
}
|
||||
data, err := json.Marshal(root)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return data, true, nil
|
||||
}
|
||||
|
||||
func rewriteComposerPackagesPayload(raw json.RawMessage, domain string, packageName string) (json.RawMessage, bool, error) {
|
||||
var asArray []map[string]any
|
||||
if err := json.Unmarshal(raw, &asArray); err == nil {
|
||||
rewrote := rewriteComposerVersionSlice(asArray, domain, packageName)
|
||||
if !rewrote {
|
||||
return raw, false, nil
|
||||
}
|
||||
data, err := json.Marshal(asArray)
|
||||
return data, true, err
|
||||
}
|
||||
|
||||
var asMap map[string]map[string]any
|
||||
if err := json.Unmarshal(raw, &asMap); err == nil {
|
||||
rewrote := rewriteComposerVersionMap(asMap, domain, packageName)
|
||||
if !rewrote {
|
||||
return raw, false, nil
|
||||
}
|
||||
data, err := json.Marshal(asMap)
|
||||
return data, true, err
|
||||
}
|
||||
|
||||
return raw, false, nil
|
||||
}
|
||||
|
||||
func rewriteComposerVersionSlice(items []map[string]any, domain string, packageName string) bool {
|
||||
changed := false
|
||||
for _, entry := range items {
|
||||
if rewriteComposerVersion(entry, domain, packageName) {
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
func rewriteComposerVersionMap(items map[string]map[string]any, domain string, packageName string) bool {
|
||||
changed := false
|
||||
for _, entry := range items {
|
||||
if rewriteComposerVersion(entry, domain, packageName) {
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
func rewriteComposerVersion(entry map[string]any, domain string, packageName string) bool {
|
||||
if entry == nil {
|
||||
return false
|
||||
}
|
||||
changed := false
|
||||
if packageName != "" {
|
||||
if name, _ := entry["name"].(string); strings.TrimSpace(name) == "" {
|
||||
entry["name"] = packageName
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
distVal, ok := entry["dist"].(map[string]any)
|
||||
if !ok {
|
||||
return changed
|
||||
}
|
||||
urlValue, ok := distVal["url"].(string)
|
||||
if !ok || urlValue == "" {
|
||||
return changed
|
||||
}
|
||||
rewritten := rewriteComposerDistURL(domain, urlValue)
|
||||
if rewritten == urlValue {
|
||||
return changed
|
||||
}
|
||||
distVal["url"] = rewritten
|
||||
return true
|
||||
}
|
||||
|
||||
func rewriteComposerDistURL(domain, original string) string {
|
||||
parsed, err := url.Parse(original)
|
||||
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
|
||||
return original
|
||||
}
|
||||
prefix := fmt.Sprintf("/dist/%s/%s", parsed.Scheme, parsed.Host)
|
||||
newURL := url.URL{
|
||||
Scheme: "https",
|
||||
Host: domain,
|
||||
Path: prefix + parsed.Path,
|
||||
RawQuery: parsed.RawQuery,
|
||||
Fragment: parsed.Fragment,
|
||||
}
|
||||
if raw := parsed.RawPath; raw != "" {
|
||||
newURL.RawPath = prefix + raw
|
||||
}
|
||||
return newURL.String()
|
||||
}
|
||||
|
||||
func isComposerMetadataPath(path string) bool {
|
||||
switch {
|
||||
case path == "/packages.json":
|
||||
return true
|
||||
case strings.HasPrefix(path, "/p2/"):
|
||||
return true
|
||||
case strings.HasPrefix(path, "/p/"):
|
||||
return true
|
||||
case strings.HasPrefix(path, "/provider-"):
|
||||
return true
|
||||
case strings.HasPrefix(path, "/providers/"):
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isComposerDistPath(path string) bool {
|
||||
return strings.HasPrefix(path, "/dist/")
|
||||
}
|
||||
|
||||
func rewriteComposerAbsolute(domain, raw string) string {
|
||||
if raw == "" {
|
||||
return raw
|
||||
}
|
||||
if strings.HasPrefix(raw, "//") {
|
||||
return "https://" + domain + strings.TrimPrefix(raw, "//")
|
||||
}
|
||||
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return raw
|
||||
}
|
||||
parsed.Host = domain
|
||||
parsed.Scheme = "https"
|
||||
return parsed.String()
|
||||
}
|
||||
pathVal := raw
|
||||
if !strings.HasPrefix(pathVal, "/") {
|
||||
pathVal = "/" + pathVal
|
||||
}
|
||||
return fmt.Sprintf("https://%s%s", domain, pathVal)
|
||||
}
|
||||
|
||||
func rewriteComposerRootBody(body []byte, domain string) ([]byte, bool, error) {
|
||||
var root map[string]any
|
||||
if err := json.Unmarshal(body, &root); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
changed := false
|
||||
for _, key := range []string{"metadata-url", "providers-api", "providers-url", "notify-batch"} {
|
||||
if raw, ok := root[key].(string); ok && raw != "" {
|
||||
newVal := rewriteComposerAbsolute(domain, raw)
|
||||
if newVal != raw {
|
||||
root[key] = newVal
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if includes, ok := root["provider-includes"].(map[string]any); ok {
|
||||
for file, hashVal := range includes {
|
||||
pathVal := file
|
||||
if rawPath, ok := hashVal.(map[string]any); ok {
|
||||
if urlValue, ok := rawPath["url"].(string); ok {
|
||||
pathVal = urlValue
|
||||
}
|
||||
}
|
||||
newPath := rewriteComposerAbsolute(domain, pathVal)
|
||||
if newPath != pathVal {
|
||||
changed = true
|
||||
}
|
||||
if rawPath, ok := hashVal.(map[string]any); ok {
|
||||
rawPath["url"] = newPath
|
||||
includes[file] = rawPath
|
||||
} else {
|
||||
includes[file] = newPath
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !changed {
|
||||
return body, false, nil
|
||||
}
|
||||
|
||||
data, err := json.Marshal(root)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return data, true, nil
|
||||
}
|
||||
@@ -1,68 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/any-hub/any-hub/internal/config"
|
||||
"github.com/any-hub/any-hub/internal/server"
|
||||
)
|
||||
|
||||
func TestApplyDockerHubNamespaceFallback(t *testing.T) {
|
||||
route := dockerHubRoute(t, "https://registry-1.docker.io")
|
||||
|
||||
path, changed := applyDockerHubNamespaceFallback(route, "/v2/nginx/manifests/latest")
|
||||
if !changed {
|
||||
t.Fatalf("expected fallback to apply")
|
||||
}
|
||||
if path != "/v2/library/nginx/manifests/latest" {
|
||||
t.Fatalf("unexpected normalized path: %s", path)
|
||||
}
|
||||
|
||||
path, changed = applyDockerHubNamespaceFallback(route, "/v2/library/nginx/manifests/latest")
|
||||
if changed {
|
||||
t.Fatalf("expected no changes for already-namespaced repo")
|
||||
}
|
||||
|
||||
path, changed = applyDockerHubNamespaceFallback(route, "/v2/rogee/nginx/manifests/latest")
|
||||
if changed {
|
||||
t.Fatalf("expected no changes for custom namespace")
|
||||
}
|
||||
|
||||
path, changed = applyDockerHubNamespaceFallback(route, "/v2/_catalog")
|
||||
if changed {
|
||||
t.Fatalf("expected no changes for _catalog endpoint")
|
||||
}
|
||||
|
||||
otherRoute := dockerHubRoute(t, "https://registry.example.com")
|
||||
path, changed = applyDockerHubNamespaceFallback(otherRoute, "/v2/nginx/manifests/latest")
|
||||
if changed || path != "/v2/nginx/manifests/latest" {
|
||||
t.Fatalf("expected no changes for non-docker-hub upstream")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitDockerRepoPath(t *testing.T) {
|
||||
repo, rest, ok := splitDockerRepoPath("/v2/library/nginx/manifests/latest")
|
||||
if !ok || repo != "library/nginx" || rest != "/manifests/latest" {
|
||||
t.Fatalf("unexpected split result repo=%s rest=%s ok=%v", repo, rest, ok)
|
||||
}
|
||||
|
||||
if _, _, ok := splitDockerRepoPath("/v2/_catalog"); ok {
|
||||
t.Fatalf("expected catalog path to be ignored")
|
||||
}
|
||||
}
|
||||
|
||||
func dockerHubRoute(t *testing.T, upstream string) *server.HubRoute {
|
||||
t.Helper()
|
||||
parsed, err := url.Parse(upstream)
|
||||
if err != nil {
|
||||
t.Fatalf("invalid upstream: %v", err)
|
||||
}
|
||||
return &server.HubRoute{
|
||||
Config: config.HubConfig{
|
||||
Name: "docker",
|
||||
Type: "docker",
|
||||
},
|
||||
UpstreamURL: parsed,
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/any-hub/any-hub/internal/hubmodule"
|
||||
"github.com/any-hub/any-hub/internal/logging"
|
||||
"github.com/any-hub/any-hub/internal/server"
|
||||
)
|
||||
@@ -37,39 +38,48 @@ func RegisterModuleHandler(key string, handler server.ProxyHandler) {
|
||||
|
||||
// Handle 实现 server.ProxyHandler,根据 route.ModuleKey 选择 handler。
|
||||
func (f *Forwarder) Handle(c fiber.Ctx, route *server.HubRoute) error {
|
||||
requestID := server.RequestID(c)
|
||||
handler := f.lookup(route)
|
||||
if handler == nil {
|
||||
return f.respondMissingHandler(c, route)
|
||||
return f.respondMissingHandler(c, route, requestID)
|
||||
}
|
||||
return f.invokeHandler(c, route, handler)
|
||||
return f.invokeHandler(c, route, handler, requestID)
|
||||
}
|
||||
|
||||
func (f *Forwarder) respondMissingHandler(c fiber.Ctx, route *server.HubRoute) error {
|
||||
f.logModuleError(route, "module_handler_missing", nil)
|
||||
func (f *Forwarder) respondMissingHandler(c fiber.Ctx, route *server.HubRoute, requestID string) error {
|
||||
f.logModuleError(route, "module_handler_missing", nil, requestID)
|
||||
setRequestIDHeader(c, requestID)
|
||||
return c.Status(fiber.StatusInternalServerError).
|
||||
JSON(fiber.Map{"error": "module_handler_missing"})
|
||||
}
|
||||
|
||||
func (f *Forwarder) invokeHandler(c fiber.Ctx, route *server.HubRoute, handler server.ProxyHandler) (err error) {
|
||||
func (f *Forwarder) invokeHandler(c fiber.Ctx, route *server.HubRoute, handler server.ProxyHandler, requestID string) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = f.respondHandlerPanic(c, route, r)
|
||||
err = f.respondHandlerPanic(c, route, r, requestID)
|
||||
}
|
||||
}()
|
||||
return handler.Handle(c, route)
|
||||
}
|
||||
|
||||
func (f *Forwarder) respondHandlerPanic(c fiber.Ctx, route *server.HubRoute, recovered interface{}) error {
|
||||
f.logModuleError(route, "module_handler_panic", fmt.Errorf("panic: %v", recovered))
|
||||
func (f *Forwarder) respondHandlerPanic(c fiber.Ctx, route *server.HubRoute, recovered interface{}, requestID string) error {
|
||||
f.logModuleError(route, "module_handler_panic", fmt.Errorf("panic: %v", recovered), requestID)
|
||||
setRequestIDHeader(c, requestID)
|
||||
return c.Status(fiber.StatusInternalServerError).
|
||||
JSON(fiber.Map{"error": "module_handler_panic"})
|
||||
}
|
||||
|
||||
func (f *Forwarder) logModuleError(route *server.HubRoute, code string, err error) {
|
||||
func setRequestIDHeader(c fiber.Ctx, requestID string) {
|
||||
if requestID != "" {
|
||||
c.Set("X-Request-ID", requestID)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Forwarder) logModuleError(route *server.HubRoute, code string, err error, requestID string) {
|
||||
if f.logger == nil {
|
||||
return
|
||||
}
|
||||
fields := f.routeFields(route)
|
||||
fields := f.routeFields(route, requestID)
|
||||
fields["action"] = "proxy"
|
||||
fields["error"] = code
|
||||
if err != nil {
|
||||
@@ -105,7 +115,7 @@ func normalizeModuleKey(key string) string {
|
||||
return strings.ToLower(strings.TrimSpace(key))
|
||||
}
|
||||
|
||||
func (f *Forwarder) routeFields(route *server.HubRoute) logrus.Fields {
|
||||
func (f *Forwarder) routeFields(route *server.HubRoute, requestID string) logrus.Fields {
|
||||
if route == nil {
|
||||
return logrus.Fields{
|
||||
"hub": "",
|
||||
@@ -117,7 +127,7 @@ func (f *Forwarder) routeFields(route *server.HubRoute) logrus.Fields {
|
||||
}
|
||||
}
|
||||
|
||||
return logging.RequestFields(
|
||||
fields := logging.RequestFields(
|
||||
route.Config.Name,
|
||||
route.Config.Domain,
|
||||
route.Config.Type,
|
||||
@@ -125,5 +135,10 @@ func (f *Forwarder) routeFields(route *server.HubRoute) logrus.Fields {
|
||||
route.ModuleKey,
|
||||
string(route.RolloutFlag),
|
||||
false,
|
||||
route.ModuleKey == hubmodule.DefaultModuleKey(),
|
||||
)
|
||||
if requestID != "" {
|
||||
fields["request_id"] = requestID
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
@@ -14,12 +14,15 @@ import (
|
||||
"github.com/any-hub/any-hub/internal/server"
|
||||
)
|
||||
|
||||
const requestIDKey = "_anyhub_request_id"
|
||||
|
||||
func TestForwarderMissingHandler(t *testing.T) {
|
||||
app := fiber.New()
|
||||
defer app.Shutdown()
|
||||
|
||||
ctx := app.AcquireCtx(new(fasthttp.RequestCtx))
|
||||
defer app.ReleaseCtx(ctx)
|
||||
ctx.Locals(requestIDKey, "missing-req")
|
||||
|
||||
logger := logrus.New()
|
||||
logBuf := &bytes.Buffer{}
|
||||
@@ -40,6 +43,12 @@ func TestForwarderMissingHandler(t *testing.T) {
|
||||
if !strings.Contains(logBuf.String(), "module_handler_missing") {
|
||||
t.Fatalf("expected log to mention module_handler_missing, got %s", logBuf.String())
|
||||
}
|
||||
if got := string(ctx.Response().Header.Peek("X-Request-ID")); got != "missing-req" {
|
||||
t.Fatalf("expected request id header missing-req, got %s", got)
|
||||
}
|
||||
if !strings.Contains(logBuf.String(), "missing-req") {
|
||||
t.Fatalf("expected log to include request id, got %s", logBuf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwarderHandlerPanic(t *testing.T) {
|
||||
@@ -58,6 +67,7 @@ func TestForwarderHandlerPanic(t *testing.T) {
|
||||
defer app.Shutdown()
|
||||
ctx := app.AcquireCtx(new(fasthttp.RequestCtx))
|
||||
defer app.ReleaseCtx(ctx)
|
||||
ctx.Locals(requestIDKey, "panic-req")
|
||||
|
||||
logger := logrus.New()
|
||||
logBuf := &bytes.Buffer{}
|
||||
@@ -78,6 +88,12 @@ func TestForwarderHandlerPanic(t *testing.T) {
|
||||
if !strings.Contains(logBuf.String(), "module_handler_panic") {
|
||||
t.Fatalf("expected log to mention module_handler_panic, got %s", logBuf.String())
|
||||
}
|
||||
if got := string(ctx.Response().Header.Peek("X-Request-ID")); got != "panic-req" {
|
||||
t.Fatalf("expected request id header panic-req, got %s", got)
|
||||
}
|
||||
if !strings.Contains(logBuf.String(), "panic-req") {
|
||||
t.Fatalf("expected log to include panic request id, got %s", logBuf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func testRouteWithModule(moduleKey string) *server.HubRoute {
|
||||
|
||||
@@ -36,6 +36,14 @@ type Handler struct {
|
||||
etags sync.Map // key: hub+path, value: etag/digest string
|
||||
}
|
||||
|
||||
type hookState struct {
|
||||
ctx *hooks.RequestContext
|
||||
def hooks.Hooks
|
||||
hasHooks bool
|
||||
clean string
|
||||
rawQuery []byte
|
||||
}
|
||||
|
||||
// NewHandler constructs a proxy handler with shared HTTP client/logger/store.
|
||||
func NewHandler(client *http.Client, logger *logrus.Logger, store cache.Store) *Handler {
|
||||
return &Handler{
|
||||
@@ -45,23 +53,58 @@ func NewHandler(client *http.Client, logger *logrus.Logger, store cache.Store) *
|
||||
}
|
||||
}
|
||||
|
||||
func buildHookContext(route *server.HubRoute, c fiber.Ctx) *hooks.RequestContext {
|
||||
if route == nil {
|
||||
return &hooks.RequestContext{Method: c.Method()}
|
||||
}
|
||||
baseHost := ""
|
||||
if route.UpstreamURL != nil {
|
||||
baseHost = route.UpstreamURL.Host
|
||||
}
|
||||
return &hooks.RequestContext{
|
||||
HubName: route.Config.Name,
|
||||
Domain: route.Config.Domain,
|
||||
HubType: route.Config.Type,
|
||||
ModuleKey: route.ModuleKey,
|
||||
RolloutFlag: string(route.RolloutFlag),
|
||||
UpstreamHost: baseHost,
|
||||
Method: c.Method(),
|
||||
}
|
||||
}
|
||||
|
||||
func hasHook(def hooks.Hooks) bool {
|
||||
return def.NormalizePath != nil ||
|
||||
def.ResolveUpstream != nil ||
|
||||
def.RewriteResponse != nil ||
|
||||
def.CachePolicy != nil ||
|
||||
def.ContentType != nil
|
||||
}
|
||||
|
||||
// Handle 执行缓存查找、条件回源和最终 streaming 逻辑,任何阶段出错都会输出结构化日志。
|
||||
func (h *Handler) Handle(c fiber.Ctx, route *server.HubRoute) error {
|
||||
started := time.Now()
|
||||
requestID := server.RequestID(c)
|
||||
reqCtx := newRequestContext(route, c.Method())
|
||||
moduleHooks, _ := hooks.For(route.ModuleKey)
|
||||
locator := buildLocator(route, c, reqCtx, moduleHooks)
|
||||
policy := determineCachePolicy(route, locator, c.Method(), reqCtx, moduleHooks)
|
||||
strategyWriter := cache.NewStrategyWriter(h.store, route.CacheStrategy)
|
||||
|
||||
if err := ensureProxyHubType(route); err != nil {
|
||||
h.logger.WithFields(logrus.Fields{
|
||||
"hub": route.Config.Name,
|
||||
"module_key": route.ModuleKey,
|
||||
}).WithError(err).Error("hub_type_unsupported")
|
||||
return h.writeError(c, fiber.StatusNotImplemented, "hub_type_unsupported")
|
||||
hooksDef, ok := hooks.Fetch(route.ModuleKey)
|
||||
hookCtx := buildHookContext(route, c)
|
||||
rawQuery := append([]byte(nil), c.Request().URI().QueryString()...)
|
||||
cleanPath := normalizeRequestPath(route, string(c.Request().URI().Path()))
|
||||
if hasHook(hooksDef) && hooksDef.NormalizePath != nil {
|
||||
newPath, newQuery := hooksDef.NormalizePath(hookCtx, cleanPath, rawQuery)
|
||||
if newPath != "" {
|
||||
cleanPath = newPath
|
||||
}
|
||||
rawQuery = newQuery
|
||||
}
|
||||
locator := buildLocator(route, c, cleanPath, rawQuery)
|
||||
policy := determineCachePolicyWithHook(route, locator, c.Method(), hooksDef, ok, hookCtx)
|
||||
hookState := hookState{
|
||||
ctx: hookCtx,
|
||||
def: hooksDef,
|
||||
hasHooks: ok && hasHook(hooksDef),
|
||||
clean: cleanPath,
|
||||
rawQuery: rawQuery,
|
||||
}
|
||||
strategyWriter := cache.NewStrategyWriter(h.store, route.CacheStrategy)
|
||||
|
||||
ctx := c.Context()
|
||||
if ctx == nil {
|
||||
@@ -89,7 +132,7 @@ func (h *Handler) Handle(c fiber.Ctx, route *server.HubRoute) error {
|
||||
if strategyWriter.ShouldBypassValidation(cached.Entry) {
|
||||
serve = true
|
||||
} else if strategyWriter.SupportsValidation() {
|
||||
fresh, err := h.isCacheFresh(c, route, locator, cached.Entry)
|
||||
fresh, err := h.isCacheFresh(c, route, locator, cached.Entry, &hookState)
|
||||
if err != nil {
|
||||
h.logger.WithError(err).
|
||||
WithFields(logrus.Fields{"hub": route.Config.Name, "module_key": route.ModuleKey}).
|
||||
@@ -104,12 +147,12 @@ func (h *Handler) Handle(c fiber.Ctx, route *server.HubRoute) error {
|
||||
}
|
||||
if serve {
|
||||
defer cached.Reader.Close()
|
||||
return h.serveCache(c, route, cached, requestID, started)
|
||||
return h.serveCache(c, route, cached, requestID, started, &hookState)
|
||||
}
|
||||
cached.Reader.Close()
|
||||
}
|
||||
|
||||
return h.fetchAndStream(c, route, locator, policy, strategyWriter, requestID, started, ctx, reqCtx, moduleHooks)
|
||||
return h.fetchAndStream(c, route, locator, policy, strategyWriter, requestID, started, ctx, &hookState)
|
||||
}
|
||||
|
||||
func (h *Handler) serveCache(
|
||||
@@ -118,6 +161,7 @@ func (h *Handler) serveCache(
|
||||
result *cache.ReadResult,
|
||||
requestID string,
|
||||
started time.Time,
|
||||
hook *hookState,
|
||||
) error {
|
||||
var readSeeker io.ReadSeeker
|
||||
switch reader := result.Reader.(type) {
|
||||
@@ -130,44 +174,12 @@ func (h *Handler) serveCache(
|
||||
|
||||
method := c.Method()
|
||||
|
||||
contentType := inferCachedContentType(route, result.Entry.Locator)
|
||||
if contentType == "" && shouldSniffDockerManifest(route, result.Entry.Locator) {
|
||||
contentType := resolveContentType(route, result.Entry.Locator, hook)
|
||||
if contentType == "" && shouldSniffDockerManifest(result.Entry.Locator) {
|
||||
if sniffed := sniffDockerManifestContentType(readSeeker); sniffed != "" {
|
||||
contentType = sniffed
|
||||
}
|
||||
}
|
||||
if route != nil && route.Config.Type == "composer" && isComposerMetadataPath(stripQueryMarker(result.Entry.Locator.Path)) {
|
||||
body, err := io.ReadAll(result.Reader)
|
||||
result.Reader.Close()
|
||||
if err != nil {
|
||||
return fiber.NewError(fiber.StatusBadGateway, fmt.Sprintf("read cache failed: %v", err))
|
||||
}
|
||||
rewritten := body
|
||||
if stripQueryMarker(result.Entry.Locator.Path) == "/packages.json" {
|
||||
if data, changed, err := rewriteComposerRootBody(body, route.Config.Domain); err == nil && changed {
|
||||
rewritten = data
|
||||
}
|
||||
} else {
|
||||
if data, changed, err := rewriteComposerMetadata(body, route.Config.Domain); err == nil && changed {
|
||||
rewritten = data
|
||||
}
|
||||
}
|
||||
|
||||
c.Set("Content-Type", "application/json")
|
||||
c.Set("X-Any-Hub-Upstream", route.UpstreamURL.String())
|
||||
c.Set("X-Any-Hub-Cache-Hit", "true")
|
||||
if requestID != "" {
|
||||
c.Set("X-Request-ID", requestID)
|
||||
}
|
||||
c.Status(fiber.StatusOK)
|
||||
c.Response().Header.SetContentLength(len(rewritten))
|
||||
_, err = c.Response().BodyWriter().Write(rewritten)
|
||||
h.logResult(route, route.UpstreamURL.String(), requestID, fiber.StatusOK, true, started, err)
|
||||
if err != nil {
|
||||
return fiber.NewError(fiber.StatusBadGateway, fmt.Sprintf("read cache failed: %v", err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if contentType != "" {
|
||||
c.Set("Content-Type", contentType)
|
||||
} else {
|
||||
@@ -214,35 +226,27 @@ func (h *Handler) fetchAndStream(
|
||||
requestID string,
|
||||
started time.Time,
|
||||
ctx context.Context,
|
||||
hook *hookState,
|
||||
) error {
|
||||
resp, upstreamURL, err := h.executeRequest(c, route)
|
||||
resp, upstreamURL, err := h.executeRequest(c, route, hook)
|
||||
if err != nil {
|
||||
h.logResult(route, upstreamURL.String(), requestID, 0, false, started, err)
|
||||
return h.writeError(c, fiber.StatusBadGateway, "upstream_failed")
|
||||
}
|
||||
|
||||
resp, upstreamURL, err = h.retryOnAuthFailure(c, route, requestID, started, resp, upstreamURL)
|
||||
resp, upstreamURL, err = h.retryOnAuthFailure(c, route, requestID, started, resp, upstreamURL, hook)
|
||||
if err != nil {
|
||||
h.logResult(route, upstreamURL.String(), requestID, 0, false, started, err)
|
||||
return h.writeError(c, fiber.StatusBadGateway, "upstream_failed")
|
||||
}
|
||||
if route.Config.Type == "pypi" {
|
||||
if rewritten, rewriteErr := h.rewritePyPIResponse(route, resp, requestPath(c)); rewriteErr == nil {
|
||||
if hook != nil && hook.hasHooks && hook.def.RewriteResponse != nil {
|
||||
if rewritten, rewriteErr := applyHookRewrite(hook, resp, requestPath(c)); rewriteErr == nil {
|
||||
resp = rewritten
|
||||
} else {
|
||||
h.logger.WithError(rewriteErr).WithFields(logrus.Fields{
|
||||
"action": "pypi_rewrite",
|
||||
"action": "hook_rewrite",
|
||||
"hub": route.Config.Name,
|
||||
}).Warn("pypi_rewrite_failed")
|
||||
}
|
||||
} else if route.Config.Type == "composer" {
|
||||
if rewritten, rewriteErr := h.rewriteComposerResponse(route, resp, requestPath(c)); rewriteErr == nil {
|
||||
resp = rewritten
|
||||
} else {
|
||||
h.logger.WithError(rewriteErr).WithFields(logrus.Fields{
|
||||
"action": "composer_rewrite",
|
||||
"hub": route.Config.Name,
|
||||
}).Warn("composer_rewrite_failed")
|
||||
}).Warn("hook_rewrite_failed")
|
||||
}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
@@ -252,6 +256,42 @@ func (h *Handler) fetchAndStream(
|
||||
return h.consumeUpstream(c, route, locator, resp, shouldStore, writer, requestID, started, ctx)
|
||||
}
|
||||
|
||||
func applyHookRewrite(hook *hookState, resp *http.Response, path string) (*http.Response, error) {
|
||||
if hook == nil || hook.def.RewriteResponse == nil {
|
||||
return resp, nil
|
||||
}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
headers := make(map[string]string, len(resp.Header))
|
||||
for key, values := range resp.Header {
|
||||
if len(values) > 0 {
|
||||
headers[key] = values[0]
|
||||
}
|
||||
}
|
||||
status, newHeaders, newBody, rewriteErr := hook.def.RewriteResponse(hook.ctx, resp.StatusCode, headers, body, path)
|
||||
if rewriteErr != nil {
|
||||
return nil, rewriteErr
|
||||
}
|
||||
if newHeaders == nil {
|
||||
newHeaders = headers
|
||||
}
|
||||
if newBody == nil {
|
||||
newBody = body
|
||||
}
|
||||
cloned := *resp
|
||||
cloned.StatusCode = status
|
||||
cloned.Header = make(http.Header, len(newHeaders))
|
||||
for key, value := range newHeaders {
|
||||
cloned.Header.Set(key, value)
|
||||
}
|
||||
cloned.Body = io.NopCloser(bytes.NewReader(newBody))
|
||||
cloned.ContentLength = int64(len(newBody))
|
||||
return &cloned, nil
|
||||
}
|
||||
|
||||
func (h *Handler) consumeUpstream(
|
||||
c fiber.Ctx,
|
||||
route *server.HubRoute,
|
||||
@@ -335,6 +375,7 @@ func (h *Handler) retryOnAuthFailure(
|
||||
started time.Time,
|
||||
resp *http.Response,
|
||||
upstreamURL *url.URL,
|
||||
hook *hookState,
|
||||
) (*http.Response, *url.URL, error) {
|
||||
if !shouldRetryAuth(route, resp.StatusCode) {
|
||||
return resp, upstreamURL, nil
|
||||
@@ -354,30 +395,31 @@ func (h *Handler) retryOnAuthFailure(
|
||||
return nil, upstreamURL, err
|
||||
}
|
||||
authHeader := "Bearer " + token
|
||||
retryResp, retryURL, err := h.executeRequestWithAuth(c, route, authHeader)
|
||||
retryResp, retryURL, err := h.executeRequestWithAuth(c, route, hook, authHeader)
|
||||
if err != nil {
|
||||
return nil, upstreamURL, err
|
||||
}
|
||||
return retryResp, retryURL, nil
|
||||
}
|
||||
|
||||
retryResp, retryURL, err := h.executeRequest(c, route)
|
||||
retryResp, retryURL, err := h.executeRequest(c, route, hook)
|
||||
if err != nil {
|
||||
return nil, upstreamURL, err
|
||||
}
|
||||
return retryResp, retryURL, nil
|
||||
}
|
||||
|
||||
func (h *Handler) executeRequest(c fiber.Ctx, route *server.HubRoute) (*http.Response, *url.URL, error) {
|
||||
return h.executeRequestWithAuth(c, route, "")
|
||||
func (h *Handler) executeRequest(c fiber.Ctx, route *server.HubRoute, hook *hookState) (*http.Response, *url.URL, error) {
|
||||
return h.executeRequestWithAuth(c, route, hook, "")
|
||||
}
|
||||
|
||||
func (h *Handler) executeRequestWithAuth(
|
||||
c fiber.Ctx,
|
||||
route *server.HubRoute,
|
||||
hook *hookState,
|
||||
authHeader string,
|
||||
) (*http.Response, *url.URL, error) {
|
||||
upstreamURL := resolveUpstreamURL(route, route.UpstreamURL, c)
|
||||
upstreamURL := resolveUpstreamURL(route, route.UpstreamURL, c, hook)
|
||||
body := bytesReader(c.Body())
|
||||
req, err := h.buildUpstreamRequest(c, upstreamURL, route, c.Method(), body, authHeader)
|
||||
if err != nil {
|
||||
@@ -469,6 +511,7 @@ func (h *Handler) logResult(
|
||||
route.ModuleKey,
|
||||
string(route.RolloutFlag),
|
||||
cacheHit,
|
||||
route.ModuleKey == hubmodule.DefaultModuleKey(),
|
||||
)
|
||||
fields["action"] = "proxy"
|
||||
fields["upstream"] = upstream
|
||||
@@ -506,47 +549,20 @@ func inferCachedContentType(route *server.HubRoute, locator cache.Locator) strin
|
||||
return "application/x-tar"
|
||||
}
|
||||
|
||||
if route != nil {
|
||||
switch route.Config.Type {
|
||||
case "docker":
|
||||
if strings.Contains(clean, "/manifests/") {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(clean, "/tags/list") {
|
||||
return "application/json"
|
||||
}
|
||||
if strings.Contains(clean, "/blobs/") {
|
||||
return "application/octet-stream"
|
||||
}
|
||||
case "npm":
|
||||
if strings.HasSuffix(clean, ".json") {
|
||||
return "application/json"
|
||||
}
|
||||
case "pypi":
|
||||
if strings.Contains(clean, "/simple/") {
|
||||
return "text/html"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func buildLocator(route *server.HubRoute, c fiber.Ctx) cache.Locator {
|
||||
uri := c.Request().URI()
|
||||
pathVal := string(uri.Path())
|
||||
clean := normalizeRequestPath(route, pathVal)
|
||||
if newPath, ok := applyPyPISimpleFallback(route, clean); ok {
|
||||
clean = newPath
|
||||
}
|
||||
if newPath, ok := applyDockerHubNamespaceFallback(route, clean); ok {
|
||||
clean = newPath
|
||||
}
|
||||
query := uri.QueryString()
|
||||
if route != nil && route.Config.Type == "composer" && isComposerDistPath(clean) {
|
||||
// composer dist URLs often embed per-request tokens; ignore query for cache key
|
||||
query = nil
|
||||
func resolveContentType(route *server.HubRoute, locator cache.Locator, hook *hookState) string {
|
||||
if hook != nil && hook.hasHooks && hook.def.ContentType != nil {
|
||||
if ct := hook.def.ContentType(hook.ctx, stripQueryMarker(locator.Path)); ct != "" {
|
||||
return ct
|
||||
}
|
||||
}
|
||||
return inferCachedContentType(route, locator)
|
||||
}
|
||||
|
||||
func buildLocator(route *server.HubRoute, c fiber.Ctx, clean string, rawQuery []byte) cache.Locator {
|
||||
query := rawQuery
|
||||
if len(query) > 0 {
|
||||
sum := sha1.Sum(query)
|
||||
clean = fmt.Sprintf("%s/__qs/%s", clean, hex.EncodeToString(sum[:]))
|
||||
@@ -580,10 +596,7 @@ func stripQueryMarker(p string) string {
|
||||
return p
|
||||
}
|
||||
|
||||
func shouldSniffDockerManifest(route *server.HubRoute, locator cache.Locator) bool {
|
||||
if route == nil || route.Config.Type != "docker" {
|
||||
return false
|
||||
}
|
||||
func shouldSniffDockerManifest(locator cache.Locator) bool {
|
||||
clean := stripQueryMarker(locator.Path)
|
||||
return strings.Contains(clean, "/manifests/")
|
||||
}
|
||||
@@ -631,11 +644,7 @@ func normalizeRequestPath(route *server.HubRoute, raw string) string {
|
||||
if raw == "" {
|
||||
raw = "/"
|
||||
}
|
||||
hasSlash := strings.HasSuffix(raw, "/")
|
||||
clean := path.Clean("/" + raw)
|
||||
if route != nil && route.Config.Type == "pypi" && hasSlash && clean != "/" && !strings.HasSuffix(clean, "/") {
|
||||
clean += "/"
|
||||
}
|
||||
return clean
|
||||
}
|
||||
|
||||
@@ -646,42 +655,28 @@ func bytesReader(b []byte) io.Reader {
|
||||
return bytes.NewReader(b)
|
||||
}
|
||||
|
||||
func resolveUpstreamURL(route *server.HubRoute, base *url.URL, c fiber.Ctx) *url.URL {
|
||||
func resolveUpstreamURL(route *server.HubRoute, base *url.URL, c fiber.Ctx, hook *hookState) *url.URL {
|
||||
uri := c.Request().URI()
|
||||
pathVal := string(uri.Path())
|
||||
clean := normalizeRequestPath(route, pathVal)
|
||||
if newPath, ok := applyPyPISimpleFallback(route, clean); ok {
|
||||
clean = newPath
|
||||
}
|
||||
if newPath, ok := applyDockerHubNamespaceFallback(route, clean); ok {
|
||||
clean = newPath
|
||||
}
|
||||
if route != nil && route.Config.Type == "pypi" && strings.HasPrefix(clean, "/files/") {
|
||||
trimmed := strings.TrimPrefix(clean, "/files/")
|
||||
parts := strings.SplitN(trimmed, "/", 3)
|
||||
if len(parts) >= 3 {
|
||||
scheme := parts[0]
|
||||
host := parts[1]
|
||||
rest := parts[2]
|
||||
filesBase := &url.URL{Scheme: scheme, Host: host}
|
||||
if !strings.HasPrefix(rest, "/") {
|
||||
rest = "/" + rest
|
||||
}
|
||||
relative := &url.URL{Path: rest, RawPath: rest}
|
||||
if query := string(uri.QueryString()); query != "" {
|
||||
relative.RawQuery = query
|
||||
}
|
||||
return filesBase.ResolveReference(relative)
|
||||
rawQuery := append([]byte(nil), uri.QueryString()...)
|
||||
clean := normalizeRequestPath(route, string(uri.Path()))
|
||||
if hook != nil {
|
||||
if hook.clean != "" {
|
||||
clean = hook.clean
|
||||
}
|
||||
}
|
||||
if route != nil && route.Config.Type == "composer" && strings.HasPrefix(clean, "/dist/") {
|
||||
if distTarget, ok := parseComposerDistURL(clean, string(uri.QueryString())); ok {
|
||||
return distTarget
|
||||
if hook.rawQuery != nil {
|
||||
rawQuery = hook.rawQuery
|
||||
}
|
||||
if hook.hasHooks && hook.def.ResolveUpstream != nil {
|
||||
if u := hook.def.ResolveUpstream(hook.ctx, base.String(), clean, rawQuery); u != "" {
|
||||
if parsed, err := url.Parse(u); err == nil {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
relative := &url.URL{Path: clean, RawPath: clean}
|
||||
if query := string(uri.QueryString()); query != "" {
|
||||
relative.RawQuery = query
|
||||
if len(rawQuery) > 0 {
|
||||
relative.RawQuery = string(rawQuery)
|
||||
}
|
||||
return base.ResolveReference(relative)
|
||||
}
|
||||
@@ -718,85 +713,27 @@ type cachePolicy struct {
|
||||
requireRevalidate bool
|
||||
}
|
||||
|
||||
func determineCachePolicyWithHook(route *server.HubRoute, locator cache.Locator, method string, def hooks.Hooks, enabled bool, ctx *hooks.RequestContext) cachePolicy {
|
||||
base := determineCachePolicy(route, locator, method)
|
||||
if !enabled || def.CachePolicy == nil {
|
||||
return base
|
||||
}
|
||||
updated := def.CachePolicy(ctx, locator.Path, hooks.CachePolicy{
|
||||
AllowCache: base.allowCache,
|
||||
AllowStore: base.allowStore,
|
||||
RequireRevalidate: base.requireRevalidate,
|
||||
})
|
||||
base.allowCache = updated.AllowCache
|
||||
base.allowStore = updated.AllowStore
|
||||
base.requireRevalidate = updated.RequireRevalidate
|
||||
return base
|
||||
}
|
||||
|
||||
func determineCachePolicy(route *server.HubRoute, locator cache.Locator, method string) cachePolicy {
|
||||
if route == nil || method != http.MethodGet {
|
||||
if method != http.MethodGet {
|
||||
return cachePolicy{}
|
||||
}
|
||||
policy := cachePolicy{allowCache: true, allowStore: true}
|
||||
path := stripQueryMarker(locator.Path)
|
||||
switch route.Config.Type {
|
||||
case "docker":
|
||||
if path == "/v2" || path == "v2" || path == "/v2/" {
|
||||
return cachePolicy{}
|
||||
}
|
||||
if strings.Contains(path, "/_catalog") {
|
||||
return cachePolicy{}
|
||||
}
|
||||
if isDockerImmutablePath(path) {
|
||||
return policy
|
||||
}
|
||||
policy.requireRevalidate = true
|
||||
return policy
|
||||
case "go":
|
||||
if strings.Contains(path, "/@v/") &&
|
||||
(strings.HasSuffix(path, ".zip") || strings.HasSuffix(path, ".mod") || strings.HasSuffix(path, ".info")) {
|
||||
return policy
|
||||
}
|
||||
policy.requireRevalidate = true
|
||||
return policy
|
||||
case "npm":
|
||||
if strings.Contains(path, "/-/") && strings.HasSuffix(path, ".tgz") {
|
||||
return policy
|
||||
}
|
||||
policy.requireRevalidate = true
|
||||
return policy
|
||||
case "pypi":
|
||||
if isPyPIDistribution(path) {
|
||||
return policy
|
||||
}
|
||||
policy.requireRevalidate = true
|
||||
return policy
|
||||
case "composer":
|
||||
if isComposerDistPath(path) {
|
||||
return policy
|
||||
}
|
||||
if isComposerMetadataPath(path) {
|
||||
policy.requireRevalidate = true
|
||||
return policy
|
||||
}
|
||||
return cachePolicy{}
|
||||
default:
|
||||
return policy
|
||||
}
|
||||
}
|
||||
|
||||
func isDockerImmutablePath(path string) bool {
|
||||
if strings.Contains(path, "/blobs/sha256:") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(path, "/manifests/sha256:") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isPyPIDistribution(path string) bool {
|
||||
switch {
|
||||
case strings.HasSuffix(path, ".whl"):
|
||||
return true
|
||||
case strings.HasSuffix(path, ".tar.gz"):
|
||||
return true
|
||||
case strings.HasSuffix(path, ".tar.bz2"):
|
||||
return true
|
||||
case strings.HasSuffix(path, ".tgz"):
|
||||
return true
|
||||
case strings.HasSuffix(path, ".zip"):
|
||||
return true
|
||||
case strings.HasSuffix(path, ".egg"):
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return cachePolicy{allowCache: true, allowStore: true}
|
||||
}
|
||||
|
||||
func isCacheableStatus(status int) bool {
|
||||
@@ -808,13 +745,14 @@ func (h *Handler) isCacheFresh(
|
||||
route *server.HubRoute,
|
||||
locator cache.Locator,
|
||||
entry cache.Entry,
|
||||
hook *hookState,
|
||||
) (bool, error) {
|
||||
ctx := c.Context()
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
upstreamURL := resolveUpstreamURL(route, route.UpstreamURL, c)
|
||||
upstreamURL := resolveUpstreamURL(route, route.UpstreamURL, c, hook)
|
||||
resp, err := h.revalidateRequest(c, route, upstreamURL, locator, "")
|
||||
if err != nil {
|
||||
return false, err
|
||||
@@ -888,113 +826,6 @@ func extractModTime(header http.Header) time.Time {
|
||||
return time.Now().UTC()
|
||||
}
|
||||
|
||||
func applyDockerHubNamespaceFallback(route *server.HubRoute, path string) (string, bool) {
|
||||
if !isDockerHubRoute(route) {
|
||||
return path, false
|
||||
}
|
||||
repo, rest, ok := splitDockerRepoPath(path)
|
||||
if !ok || repo == "" {
|
||||
return path, false
|
||||
}
|
||||
if repo == "library" || strings.Contains(repo, "/") {
|
||||
return path, false
|
||||
}
|
||||
normalized := "/v2/library/" + repo + rest
|
||||
return normalized, true
|
||||
}
|
||||
|
||||
func isDockerHubRoute(route *server.HubRoute) bool {
|
||||
if route == nil || route.Config.Type != "docker" || route.UpstreamURL == nil {
|
||||
return false
|
||||
}
|
||||
host := strings.ToLower(route.UpstreamURL.Hostname())
|
||||
switch host {
|
||||
case "registry-1.docker.io", "docker.io", "index.docker.io":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func splitDockerRepoPath(path string) (string, string, bool) {
|
||||
if !strings.HasPrefix(path, "/v2/") {
|
||||
return "", "", false
|
||||
}
|
||||
suffix := strings.TrimPrefix(path, "/v2/")
|
||||
if suffix == "" || suffix == "/" {
|
||||
return "", "", false
|
||||
}
|
||||
segments := strings.Split(suffix, "/")
|
||||
var repoSegments []string
|
||||
for i, seg := range segments {
|
||||
if seg == "" {
|
||||
return "", "", false
|
||||
}
|
||||
switch seg {
|
||||
case "manifests", "blobs", "tags", "referrers":
|
||||
if len(repoSegments) == 0 {
|
||||
return "", "", false
|
||||
}
|
||||
rest := "/" + strings.Join(segments[i:], "/")
|
||||
return strings.Join(repoSegments, "/"), rest, true
|
||||
case "_catalog":
|
||||
return "", "", false
|
||||
}
|
||||
repoSegments = append(repoSegments, seg)
|
||||
}
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
func applyPyPISimpleFallback(route *server.HubRoute, path string) (string, bool) {
|
||||
if route == nil || route.Config.Type != "pypi" {
|
||||
return path, false
|
||||
}
|
||||
if strings.HasPrefix(path, "/simple/") || strings.HasPrefix(path, "/files/") {
|
||||
return path, false
|
||||
}
|
||||
if strings.HasSuffix(path, ".whl") || strings.HasSuffix(path, ".tar.gz") || strings.HasSuffix(path, ".tar.bz2") ||
|
||||
strings.HasSuffix(path, ".zip") {
|
||||
return path, false
|
||||
}
|
||||
trimmed := strings.Trim(path, "/")
|
||||
if trimmed == "" || strings.HasPrefix(trimmed, "_") {
|
||||
return path, false
|
||||
}
|
||||
return "/simple/" + trimmed + "/", true
|
||||
}
|
||||
|
||||
func parseComposerDistURL(path string, rawQuery string) (*url.URL, bool) {
|
||||
if !strings.HasPrefix(path, "/dist/") {
|
||||
return nil, false
|
||||
}
|
||||
trimmed := strings.TrimPrefix(path, "/dist/")
|
||||
parts := strings.SplitN(trimmed, "/", 3)
|
||||
if len(parts) < 3 {
|
||||
return nil, false
|
||||
}
|
||||
scheme := parts[0]
|
||||
host := parts[1]
|
||||
rest := parts[2]
|
||||
if scheme == "" || host == "" {
|
||||
return nil, false
|
||||
}
|
||||
if rest == "" {
|
||||
rest = "/"
|
||||
} else {
|
||||
rest = "/" + rest
|
||||
}
|
||||
target := &url.URL{
|
||||
Scheme: scheme,
|
||||
Host: host,
|
||||
Path: rest,
|
||||
RawPath: rest,
|
||||
}
|
||||
if rawQuery != "" {
|
||||
target.RawQuery = rawQuery
|
||||
}
|
||||
return target, true
|
||||
}
|
||||
|
||||
type bearerChallenge struct {
|
||||
Realm string
|
||||
Service string
|
||||
@@ -1130,6 +961,7 @@ func (h *Handler) logAuthRetry(route *server.HubRoute, upstream string, requestI
|
||||
route.ModuleKey,
|
||||
string(route.RolloutFlag),
|
||||
false,
|
||||
route.ModuleKey == hubmodule.DefaultModuleKey(),
|
||||
)
|
||||
fields["action"] = "proxy_retry"
|
||||
fields["upstream"] = upstream
|
||||
@@ -1150,6 +982,7 @@ func (h *Handler) logAuthFailure(route *server.HubRoute, upstream string, reques
|
||||
route.ModuleKey,
|
||||
string(route.RolloutFlag),
|
||||
false,
|
||||
route.ModuleKey == hubmodule.DefaultModuleKey(),
|
||||
)
|
||||
fields["action"] = "proxy"
|
||||
fields["upstream"] = upstream
|
||||
@@ -1204,20 +1037,3 @@ func normalizeETag(value string) string {
|
||||
}
|
||||
return strings.Trim(value, "\"")
|
||||
}
|
||||
|
||||
func ensureProxyHubType(route *server.HubRoute) error {
|
||||
switch route.Config.Type {
|
||||
case "docker":
|
||||
return nil
|
||||
case "npm":
|
||||
return nil
|
||||
case "go":
|
||||
return nil
|
||||
case "pypi":
|
||||
return nil
|
||||
case "composer":
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unsupported hub type: %s", route.Config.Type)
|
||||
}
|
||||
}
|
||||
|
||||
80
internal/proxy/handler_hook_test.go
Normal file
80
internal/proxy/handler_hook_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/any-hub/any-hub/internal/cache"
|
||||
"github.com/any-hub/any-hub/internal/config"
|
||||
"github.com/any-hub/any-hub/internal/proxy/hooks"
|
||||
"github.com/any-hub/any-hub/internal/server"
|
||||
)
|
||||
|
||||
func TestResolveUpstreamPrefersHook(t *testing.T) {
|
||||
app := fiber.New()
|
||||
defer app.Shutdown()
|
||||
|
||||
ctx := app.AcquireCtx(new(fasthttp.RequestCtx))
|
||||
defer app.ReleaseCtx(ctx)
|
||||
ctx.Request().SetRequestURI("/original/path?from=req")
|
||||
|
||||
base, _ := url.Parse("https://up.example")
|
||||
route := &server.HubRoute{
|
||||
Config: config.HubConfig{
|
||||
Name: "demo",
|
||||
Type: "custom",
|
||||
},
|
||||
UpstreamURL: base,
|
||||
}
|
||||
hook := &hookState{
|
||||
ctx: &hooks.RequestContext{},
|
||||
def: hooks.Hooks{
|
||||
NormalizePath: func(_ *hooks.RequestContext, clean string, rawQuery []byte) (string, []byte) {
|
||||
return clean, rawQuery
|
||||
},
|
||||
ResolveUpstream: func(_ *hooks.RequestContext, upstream string, clean string, rawQuery []byte) string {
|
||||
return upstream + "/hooked"
|
||||
},
|
||||
},
|
||||
hasHooks: true,
|
||||
clean: "/ignored",
|
||||
rawQuery: []byte("ignored=1"),
|
||||
}
|
||||
|
||||
target := resolveUpstreamURL(route, base, ctx, hook)
|
||||
if target.String() != "https://up.example/hooked" {
|
||||
t.Fatalf("expected hook override, got %s", target.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachePolicyHookOverrides(t *testing.T) {
|
||||
route := &server.HubRoute{
|
||||
Config: config.HubConfig{
|
||||
Name: "demo",
|
||||
Type: "npm",
|
||||
},
|
||||
}
|
||||
locator := cacheLocatorForTest("demo", "/a.tgz")
|
||||
hook := hooks.Hooks{
|
||||
CachePolicy: func(_ *hooks.RequestContext, _ string, current hooks.CachePolicy) hooks.CachePolicy {
|
||||
current.AllowCache = false
|
||||
current.RequireRevalidate = false
|
||||
return current
|
||||
},
|
||||
}
|
||||
ctx := &hooks.RequestContext{Method: fiber.MethodGet}
|
||||
policy := determineCachePolicyWithHook(route, locator, fiber.MethodGet, hook, true, ctx)
|
||||
if policy.allowCache {
|
||||
t.Fatalf("expected hook to disable cache")
|
||||
}
|
||||
if policy.requireRevalidate {
|
||||
t.Fatalf("expected hook to disable revalidate")
|
||||
}
|
||||
}
|
||||
|
||||
func cacheLocatorForTest(hub, path string) cache.Locator {
|
||||
return cache.Locator{HubName: hub, Path: path}
|
||||
}
|
||||
@@ -1,12 +1,5 @@
|
||||
package hooks
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// CachePolicy mirrors the proxy cache policy structure.
|
||||
type CachePolicy struct {
|
||||
AllowCache bool
|
||||
@@ -27,34 +20,9 @@ type RequestContext struct {
|
||||
|
||||
// Hooks describes customization points for module-specific behavior.
|
||||
type Hooks struct {
|
||||
NormalizePath func(ctx *RequestContext, cleanPath string) string
|
||||
ResolveUpstream func(ctx *RequestContext, base *url.URL, cleanPath string, rawQuery []byte) *url.URL
|
||||
RewriteResponse func(ctx *RequestContext, resp *http.Response, cleanPath string) (*http.Response, error)
|
||||
NormalizePath func(ctx *RequestContext, cleanPath string, rawQuery []byte) (string, []byte)
|
||||
ResolveUpstream func(ctx *RequestContext, baseURL string, path string, rawQuery []byte) string
|
||||
RewriteResponse func(ctx *RequestContext, status int, headers map[string]string, body []byte, path string) (int, map[string]string, []byte, error)
|
||||
CachePolicy func(ctx *RequestContext, locatorPath string, current CachePolicy) CachePolicy
|
||||
ContentType func(ctx *RequestContext, locatorPath string) string
|
||||
}
|
||||
|
||||
var registry sync.Map
|
||||
|
||||
// Register stores hooks for the given module key.
|
||||
func Register(moduleKey string, hooks Hooks) {
|
||||
key := strings.ToLower(strings.TrimSpace(moduleKey))
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
registry.Store(key, hooks)
|
||||
}
|
||||
|
||||
// For retrieves hooks associated with a module key.
|
||||
func For(moduleKey string) (Hooks, bool) {
|
||||
key := strings.ToLower(strings.TrimSpace(moduleKey))
|
||||
if key == "" {
|
||||
return Hooks{}, false
|
||||
}
|
||||
if value, ok := registry.Load(key); ok {
|
||||
if hooks, ok := value.(Hooks); ok {
|
||||
return hooks, true
|
||||
}
|
||||
}
|
||||
return Hooks{}, false
|
||||
}
|
||||
|
||||
68
internal/proxy/hooks/registry.go
Normal file
68
internal/proxy/hooks/registry.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package hooks
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var registry sync.Map
|
||||
|
||||
// ErrDuplicateHook indicates a module key already has hooks registered.
|
||||
var ErrDuplicateHook = errors.New("hook already registered")
|
||||
|
||||
// Register stores hooks for the given module key.
|
||||
func Register(moduleKey string, hooks Hooks) error {
|
||||
key := normalizeKey(moduleKey)
|
||||
if key == "" {
|
||||
return errors.New("module key required")
|
||||
}
|
||||
if _, loaded := registry.LoadOrStore(key, hooks); loaded {
|
||||
return ErrDuplicateHook
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MustRegister panics on registration failure.
|
||||
func MustRegister(moduleKey string, hooks Hooks) {
|
||||
if err := Register(moduleKey, hooks); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch retrieves hooks associated with a module key.
|
||||
func Fetch(moduleKey string) (Hooks, bool) {
|
||||
key := normalizeKey(moduleKey)
|
||||
if key == "" {
|
||||
return Hooks{}, false
|
||||
}
|
||||
if value, ok := registry.Load(key); ok {
|
||||
if hooks, ok := value.(Hooks); ok {
|
||||
return hooks, true
|
||||
}
|
||||
}
|
||||
return Hooks{}, false
|
||||
}
|
||||
|
||||
// Status returns hook registration status for a module key.
|
||||
func Status(moduleKey string) string {
|
||||
if _, ok := Fetch(moduleKey); ok {
|
||||
return "registered"
|
||||
}
|
||||
return "missing"
|
||||
}
|
||||
|
||||
// Snapshot returns status for a list of module keys.
|
||||
func Snapshot(keys []string) map[string]string {
|
||||
out := make(map[string]string, len(keys))
|
||||
for _, key := range keys {
|
||||
if normalized := normalizeKey(key); normalized != "" {
|
||||
out[normalized] = Status(normalized)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeKey(key string) string {
|
||||
return strings.ToLower(strings.TrimSpace(key))
|
||||
}
|
||||
45
internal/proxy/hooks/registry_test.go
Normal file
45
internal/proxy/hooks/registry_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package hooks
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRegisterAndFetch(t *testing.T) {
|
||||
registry = sync.Map{}
|
||||
h := Hooks{ContentType: func(*RequestContext, string) string { return "ok" }}
|
||||
if err := Register("test", h); err != nil {
|
||||
t.Fatalf("register failed: %v", err)
|
||||
}
|
||||
if _, ok := Fetch("test"); !ok {
|
||||
t.Fatalf("expected fetch ok")
|
||||
}
|
||||
if Status("test") != "registered" {
|
||||
t.Fatalf("expected registered status")
|
||||
}
|
||||
if Status("missing") != "missing" {
|
||||
t.Fatalf("expected missing status")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterDuplicate(t *testing.T) {
|
||||
registry = sync.Map{}
|
||||
if err := Register("dup", Hooks{}); err != nil {
|
||||
t.Fatalf("first register failed: %v", err)
|
||||
}
|
||||
if err := Register("dup", Hooks{}); err != ErrDuplicateHook {
|
||||
t.Fatalf("expected ErrDuplicateHook, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSnapshot(t *testing.T) {
|
||||
registry = sync.Map{}
|
||||
_ = Register("a", Hooks{})
|
||||
snap := Snapshot([]string{"a", "b"})
|
||||
if snap["a"] != "registered" {
|
||||
t.Fatalf("expected a registered, got %s", snap["a"])
|
||||
}
|
||||
if snap["b"] != "missing" {
|
||||
t.Fatalf("expected b missing, got %s", snap["b"])
|
||||
}
|
||||
}
|
||||
@@ -1,126 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/html"
|
||||
|
||||
"github.com/any-hub/any-hub/internal/server"
|
||||
)
|
||||
|
||||
func (h *Handler) rewritePyPIResponse(route *server.HubRoute, resp *http.Response, path string) (*http.Response, error) {
|
||||
if resp == nil {
|
||||
return resp, nil
|
||||
}
|
||||
if !strings.HasPrefix(path, "/simple") && path != "/" {
|
||||
return resp, nil
|
||||
}
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
rewritten, contentType, err := rewritePyPIBody(bodyBytes, resp.Header.Get("Content-Type"), route.Config.Domain)
|
||||
if err != nil {
|
||||
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
return resp, err
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewReader(rewritten))
|
||||
resp.ContentLength = int64(len(rewritten))
|
||||
resp.Header.Set("Content-Length", strconv.Itoa(len(rewritten)))
|
||||
if contentType != "" {
|
||||
resp.Header.Set("Content-Type", contentType)
|
||||
}
|
||||
resp.Header.Del("Content-Encoding")
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func rewritePyPIBody(body []byte, contentType string, domain string) ([]byte, string, error) {
|
||||
lowerCT := strings.ToLower(contentType)
|
||||
if strings.Contains(lowerCT, "application/vnd.pypi.simple.v1+json") || strings.HasPrefix(strings.TrimSpace(string(body)), "{") {
|
||||
data := map[string]interface{}{}
|
||||
if err := json.Unmarshal(body, &data); err != nil {
|
||||
return body, contentType, err
|
||||
}
|
||||
if files, ok := data["files"].([]interface{}); ok {
|
||||
for _, entry := range files {
|
||||
if fileMap, ok := entry.(map[string]interface{}); ok {
|
||||
if urlValue, ok := fileMap["url"].(string); ok {
|
||||
fileMap["url"] = rewritePyPIFileURL(domain, urlValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
rewriteBytes, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return body, contentType, err
|
||||
}
|
||||
return rewriteBytes, "application/vnd.pypi.simple.v1+json", nil
|
||||
}
|
||||
|
||||
rewrittenHTML, err := rewritePyPIHTML(body, domain)
|
||||
if err != nil {
|
||||
return body, contentType, err
|
||||
}
|
||||
return rewrittenHTML, "text/html; charset=utf-8", nil
|
||||
}
|
||||
|
||||
func rewritePyPIHTML(body []byte, domain string) ([]byte, error) {
|
||||
node, err := html.Parse(bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rewriteHTMLNode(node, domain)
|
||||
var buf bytes.Buffer
|
||||
if err := html.Render(&buf, node); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func rewriteHTMLNode(n *html.Node, domain string) {
|
||||
if n.Type == html.ElementNode {
|
||||
rewriteHTMLAttributes(n, domain)
|
||||
}
|
||||
for child := n.FirstChild; child != nil; child = child.NextSibling {
|
||||
rewriteHTMLNode(child, domain)
|
||||
}
|
||||
}
|
||||
|
||||
func rewriteHTMLAttributes(n *html.Node, domain string) {
|
||||
for i, attr := range n.Attr {
|
||||
switch attr.Key {
|
||||
case "href", "data-dist-info-metadata", "data-core-metadata":
|
||||
if strings.HasPrefix(attr.Val, "http://") || strings.HasPrefix(attr.Val, "https://") {
|
||||
n.Attr[i].Val = rewritePyPIFileURL(domain, attr.Val)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func rewritePyPIFileURL(domain, original string) string {
|
||||
parsed, err := url.Parse(original)
|
||||
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
|
||||
return original
|
||||
}
|
||||
prefix := "/files/" + parsed.Scheme + "/" + parsed.Host
|
||||
newURL := url.URL{
|
||||
Scheme: "https",
|
||||
Host: domain,
|
||||
Path: prefix + parsed.Path,
|
||||
RawQuery: parsed.RawQuery,
|
||||
Fragment: parsed.Fragment,
|
||||
}
|
||||
if raw := parsed.RawPath; raw != "" {
|
||||
newURL.RawPath = prefix + raw
|
||||
}
|
||||
return newURL.String()
|
||||
}
|
||||
Reference in New Issue
Block a user