From 467cabe23861e44cccccdea8b0c173c705fb0c24 Mon Sep 17 00:00:00 2001 From: Rogee Date: Sat, 15 Nov 2025 23:03:02 +0800 Subject: [PATCH] fix pypi --- internal/proxy/handler.go | 213 ++++++++++++++++---- internal/proxy/pypi_rewrite.go | 126 ++++++++++++ specs/004-modular-proxy-cache/quickstart.md | 2 +- tests/integration/pypi_proxy_test.go | 26 ++- 4 files changed, 320 insertions(+), 47 deletions(-) create mode 100644 internal/proxy/pypi_rewrite.go diff --git a/internal/proxy/handler.go b/internal/proxy/handler.go index 0502242..5cee602 100644 --- a/internal/proxy/handler.go +++ b/internal/proxy/handler.go @@ -107,7 +107,13 @@ func (h *Handler) Handle(c fiber.Ctx, route *server.HubRoute) error { return h.fetchAndStream(c, route, locator, policy, strategyWriter, requestID, started, ctx) } -func (h *Handler) serveCache(c fiber.Ctx, route *server.HubRoute, result *cache.ReadResult, requestID string, started time.Time) error { +func (h *Handler) serveCache( + c fiber.Ctx, + route *server.HubRoute, + result *cache.ReadResult, + requestID string, + started time.Time, +) error { if seeker, ok := result.Reader.(io.Seeker); ok { _, _ = seeker.Seek(0, io.SeekStart) } @@ -152,7 +158,16 @@ func (h *Handler) serveCache(c fiber.Ctx, route *server.HubRoute, result *cache. return nil } -func (h *Handler) fetchAndStream(c fiber.Ctx, route *server.HubRoute, locator cache.Locator, policy cachePolicy, writer cache.StrategyWriter, requestID string, started time.Time, ctx context.Context) error { +func (h *Handler) fetchAndStream( + c fiber.Ctx, + route *server.HubRoute, + locator cache.Locator, + policy cachePolicy, + writer cache.StrategyWriter, + requestID string, + started time.Time, + ctx context.Context, +) error { resp, upstreamURL, err := h.executeRequest(c, route) if err != nil { h.logResult(route, upstreamURL.String(), requestID, 0, false, started, err) @@ -164,13 +179,34 @@ func (h *Handler) fetchAndStream(c fiber.Ctx, route *server.HubRoute, locator ca h.logResult(route, upstreamURL.String(), requestID, 0, false, started, err) return h.writeError(c, fiber.StatusBadGateway, "upstream_failed") } + if route.Config.Type == "pypi" { + if rewritten, rewriteErr := h.rewritePyPIResponse(route, resp, requestPath(c)); rewriteErr == nil { + resp = rewritten + } else { + h.logger.WithError(rewriteErr).WithFields(logrus.Fields{ + "action": "pypi_rewrite", + "hub": route.Config.Name, + }).Warn("pypi_rewrite_failed") + } + } defer resp.Body.Close() - shouldStore := policy.allowStore && writer.Enabled() && isCacheableStatus(resp.StatusCode) && c.Method() == http.MethodGet + shouldStore := policy.allowStore && writer.Enabled() && isCacheableStatus(resp.StatusCode) && + c.Method() == http.MethodGet return h.consumeUpstream(c, route, locator, resp, shouldStore, writer, requestID, started, ctx) } -func (h *Handler) consumeUpstream(c fiber.Ctx, route *server.HubRoute, locator cache.Locator, resp *http.Response, shouldStore bool, writer cache.StrategyWriter, requestID string, started time.Time, ctx context.Context) error { +func (h *Handler) consumeUpstream( + c fiber.Ctx, + route *server.HubRoute, + locator cache.Locator, + resp *http.Response, + shouldStore bool, + writer cache.StrategyWriter, + requestID string, + started time.Time, + ctx context.Context, +) error { upstreamURL := resp.Request.URL.String() method := c.Method() authFailure := isAuthFailure(resp.StatusCode) && route.Config.HasCredentials() @@ -204,7 +240,17 @@ func (h *Handler) consumeUpstream(c fiber.Ctx, route *server.HubRoute, locator c return nil } -func (h *Handler) cacheAndStream(c fiber.Ctx, route *server.HubRoute, locator cache.Locator, resp *http.Response, writer cache.StrategyWriter, requestID string, started time.Time, ctx context.Context, upstreamURL string) error { +func (h *Handler) cacheAndStream( + c fiber.Ctx, + route *server.HubRoute, + locator cache.Locator, + resp *http.Response, + writer cache.StrategyWriter, + requestID string, + started time.Time, + ctx context.Context, + upstreamURL string, +) error { copyResponseHeaders(c, resp.Header) c.Set("X-Any-Hub-Upstream", upstreamURL) c.Set("X-Any-Hub-Cache-Hit", "false") @@ -225,7 +271,14 @@ func (h *Handler) cacheAndStream(c fiber.Ctx, route *server.HubRoute, locator ca return nil } -func (h *Handler) retryOnAuthFailure(c fiber.Ctx, route *server.HubRoute, requestID string, started time.Time, resp *http.Response, upstreamURL *url.URL) (*http.Response, *url.URL, error) { +func (h *Handler) retryOnAuthFailure( + c fiber.Ctx, + route *server.HubRoute, + requestID string, + started time.Time, + resp *http.Response, + upstreamURL *url.URL, +) (*http.Response, *url.URL, error) { if !shouldRetryAuth(route, resp.StatusCode) { return resp, upstreamURL, nil } @@ -262,7 +315,11 @@ func (h *Handler) executeRequest(c fiber.Ctx, route *server.HubRoute) (*http.Res return h.executeRequestWithAuth(c, route, "") } -func (h *Handler) executeRequestWithAuth(c fiber.Ctx, route *server.HubRoute, authHeader string) (*http.Response, *url.URL, error) { +func (h *Handler) executeRequestWithAuth( + c fiber.Ctx, + route *server.HubRoute, + authHeader string, +) (*http.Response, *url.URL, error) { upstreamURL := resolveUpstreamURL(route, route.UpstreamURL, c) body := bytesReader(c.Body()) req, err := h.buildUpstreamRequest(c, upstreamURL, route, c.Method(), body, authHeader) @@ -274,7 +331,14 @@ func (h *Handler) executeRequestWithAuth(c fiber.Ctx, route *server.HubRoute, au return resp, upstreamURL, err } -func (h *Handler) buildUpstreamRequest(c fiber.Ctx, upstream *url.URL, route *server.HubRoute, method string, body io.Reader, overrideAuth string) (*http.Request, error) { +func (h *Handler) buildUpstreamRequest( + c fiber.Ctx, + upstream *url.URL, + route *server.HubRoute, + method string, + body io.Reader, + overrideAuth string, +) (*http.Request, error) { ctx := c.Context() if ctx == nil { ctx = context.Background() @@ -331,8 +395,24 @@ func (h *Handler) writeError(c fiber.Ctx, status int, code string) error { return c.Status(status).JSON(fiber.Map{"error": code}) } -func (h *Handler) logResult(route *server.HubRoute, upstream string, requestID string, status int, cacheHit bool, started time.Time, err error) { - fields := logging.RequestFields(route.Config.Name, route.Config.Domain, route.Config.Type, route.Config.AuthMode(), route.ModuleKey, string(route.RolloutFlag), cacheHit) +func (h *Handler) logResult( + route *server.HubRoute, + upstream string, + requestID string, + status int, + cacheHit bool, + started time.Time, + err error, +) { + fields := logging.RequestFields( + route.Config.Name, + route.Config.Domain, + route.Config.Type, + route.Config.AuthMode(), + route.ModuleKey, + string(route.RolloutFlag), + cacheHit, + ) fields["action"] = "proxy" fields["upstream"] = upstream fields["upstream_status"] = status @@ -481,6 +561,24 @@ func resolveUpstreamURL(route *server.HubRoute, base *url.URL, c fiber.Ctx) *url if newPath, ok := applyDockerHubNamespaceFallback(route, clean); ok { clean = newPath } + if route != nil && route.Config.Type == "pypi" && strings.HasPrefix(clean, "/files/") { + trimmed := strings.TrimPrefix(clean, "/files/") + parts := strings.SplitN(trimmed, "/", 3) + if len(parts) >= 3 { + scheme := parts[0] + host := parts[1] + rest := parts[2] + filesBase := &url.URL{Scheme: scheme, Host: host} + if !strings.HasPrefix(rest, "/") { + rest = "/" + rest + } + relative := &url.URL{Path: rest, RawPath: rest} + if query := string(uri.QueryString()); query != "" { + relative.RawQuery = query + } + return filesBase.ResolveReference(relative) + } + } relative := &url.URL{Path: clean, RawPath: clean} if query := string(uri.QueryString()); query != "" { relative.RawQuery = query @@ -515,8 +613,8 @@ func routePort(route *server.HubRoute) string { } type cachePolicy struct { - allowCache bool - allowStore bool + allowCache bool + allowStore bool requireRevalidate bool } @@ -527,10 +625,10 @@ func determineCachePolicy(route *server.HubRoute, locator cache.Locator, method policy := cachePolicy{allowCache: true, allowStore: true} path := stripQueryMarker(locator.Path) switch route.Config.Type { -case "docker": - if path == "/v2" || path == "v2" || path == "/v2/" { - return cachePolicy{} - } + case "docker": + if path == "/v2" || path == "v2" || path == "/v2/" { + return cachePolicy{} + } if strings.Contains(path, "/_catalog") { return cachePolicy{} } @@ -539,27 +637,28 @@ case "docker": } policy.requireRevalidate = true return policy -case "go": - if strings.Contains(path, "/@v/") && (strings.HasSuffix(path, ".zip") || strings.HasSuffix(path, ".mod") || strings.HasSuffix(path, ".info")) { + case "go": + if strings.Contains(path, "/@v/") && + (strings.HasSuffix(path, ".zip") || strings.HasSuffix(path, ".mod") || strings.HasSuffix(path, ".info")) { + return policy + } + policy.requireRevalidate = true + return policy + case "npm": + if strings.Contains(path, "/-/") && strings.HasSuffix(path, ".tgz") { + return policy + } + policy.requireRevalidate = true + return policy + case "pypi": + if isPyPIDistribution(path) { + return policy + } + policy.requireRevalidate = true + return policy + default: return policy } - policy.requireRevalidate = true - return policy -case "npm": - if strings.Contains(path, "/-/") && strings.HasSuffix(path, ".tgz") { - return policy - } - policy.requireRevalidate = true - return policy -case "pypi": - if isPyPIDistribution(path) { - return policy - } - policy.requireRevalidate = true - return policy -default: - return policy -} } func isDockerImmutablePath(path string) bool { @@ -595,7 +694,12 @@ func isCacheableStatus(status int) bool { return status == http.StatusOK } -func (h *Handler) isCacheFresh(c fiber.Ctx, route *server.HubRoute, locator cache.Locator, entry cache.Entry) (bool, error) { +func (h *Handler) isCacheFresh( + c fiber.Ctx, + route *server.HubRoute, + locator cache.Locator, + entry cache.Entry, +) (bool, error) { ctx := c.Context() if ctx == nil { ctx = context.Background() @@ -700,10 +804,11 @@ func applyPyPISimpleFallback(route *server.HubRoute, path string) (string, bool) if route == nil || route.Config.Type != "pypi" { return path, false } - if strings.HasPrefix(path, "/simple/") || strings.HasPrefix(path, "/packages/") { + if strings.HasPrefix(path, "/simple/") || strings.HasPrefix(path, "/files/") { return path, false } - if strings.HasSuffix(path, ".whl") || strings.HasSuffix(path, ".tar.gz") || strings.HasSuffix(path, ".tar.bz2") || strings.HasSuffix(path, ".zip") { + if strings.HasSuffix(path, ".whl") || strings.HasSuffix(path, ".tar.gz") || strings.HasSuffix(path, ".tar.bz2") || + strings.HasSuffix(path, ".zip") { return path, false } trimmed := strings.Trim(path, "/") @@ -761,7 +866,11 @@ func parseAuthParams(input string) map[string]string { return params } -func (h *Handler) fetchBearerToken(ctx context.Context, challenge bearerChallenge, route *server.HubRoute) (string, error) { +func (h *Handler) fetchBearerToken( + ctx context.Context, + challenge bearerChallenge, + route *server.HubRoute, +) (string, error) { if challenge.Realm == "" { return "", errors.New("bearer realm missing") } @@ -794,7 +903,11 @@ func (h *Handler) fetchBearerToken(ctx context.Context, challenge bearerChalleng if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) - return "", fmt.Errorf("token request failed: status=%d body=%s", resp.StatusCode, strings.TrimSpace(string(body))) + return "", fmt.Errorf( + "token request failed: status=%d body=%s", + resp.StatusCode, + strings.TrimSpace(string(body)), + ) } var tokenResp struct { @@ -832,7 +945,15 @@ func isAuthFailure(status int) bool { } func (h *Handler) logAuthRetry(route *server.HubRoute, upstream string, requestID string, status int) { - fields := logging.RequestFields(route.Config.Name, route.Config.Domain, route.Config.Type, route.Config.AuthMode(), route.ModuleKey, string(route.RolloutFlag), false) + fields := logging.RequestFields( + route.Config.Name, + route.Config.Domain, + route.Config.Type, + route.Config.AuthMode(), + route.ModuleKey, + string(route.RolloutFlag), + false, + ) fields["action"] = "proxy_retry" fields["upstream"] = upstream fields["upstream_status"] = status @@ -844,7 +965,15 @@ func (h *Handler) logAuthRetry(route *server.HubRoute, upstream string, requestI } func (h *Handler) logAuthFailure(route *server.HubRoute, upstream string, requestID string, status int) { - fields := logging.RequestFields(route.Config.Name, route.Config.Domain, route.Config.Type, route.Config.AuthMode(), route.ModuleKey, string(route.RolloutFlag), false) + fields := logging.RequestFields( + route.Config.Name, + route.Config.Domain, + route.Config.Type, + route.Config.AuthMode(), + route.ModuleKey, + string(route.RolloutFlag), + false, + ) fields["action"] = "proxy" fields["upstream"] = upstream fields["upstream_status"] = status diff --git a/internal/proxy/pypi_rewrite.go b/internal/proxy/pypi_rewrite.go new file mode 100644 index 0000000..47c6456 --- /dev/null +++ b/internal/proxy/pypi_rewrite.go @@ -0,0 +1,126 @@ +package proxy + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/url" + "strconv" + "strings" + + "golang.org/x/net/html" + + "github.com/any-hub/any-hub/internal/server" +) + +func (h *Handler) rewritePyPIResponse(route *server.HubRoute, resp *http.Response, path string) (*http.Response, error) { + if resp == nil { + return resp, nil + } + if !strings.HasPrefix(path, "/simple") && path != "/" { + return resp, nil + } + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return resp, err + } + resp.Body.Close() + + rewritten, contentType, err := rewritePyPIBody(bodyBytes, resp.Header.Get("Content-Type"), route.Config.Domain) + if err != nil { + resp.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + return resp, err + } + + resp.Body = io.NopCloser(bytes.NewReader(rewritten)) + resp.ContentLength = int64(len(rewritten)) + resp.Header.Set("Content-Length", strconv.Itoa(len(rewritten))) + if contentType != "" { + resp.Header.Set("Content-Type", contentType) + } + resp.Header.Del("Content-Encoding") + return resp, nil +} + +func rewritePyPIBody(body []byte, contentType string, domain string) ([]byte, string, error) { + lowerCT := strings.ToLower(contentType) + if strings.Contains(lowerCT, "application/vnd.pypi.simple.v1+json") || strings.HasPrefix(strings.TrimSpace(string(body)), "{") { + data := map[string]interface{}{} + if err := json.Unmarshal(body, &data); err != nil { + return body, contentType, err + } + if files, ok := data["files"].([]interface{}); ok { + for _, entry := range files { + if fileMap, ok := entry.(map[string]interface{}); ok { + if urlValue, ok := fileMap["url"].(string); ok { + fileMap["url"] = rewritePyPIFileURL(domain, urlValue) + } + } + } + } + rewriteBytes, err := json.Marshal(data) + if err != nil { + return body, contentType, err + } + return rewriteBytes, "application/vnd.pypi.simple.v1+json", nil + } + + rewrittenHTML, err := rewritePyPIHTML(body, domain) + if err != nil { + return body, contentType, err + } + return rewrittenHTML, "text/html; charset=utf-8", nil +} + +func rewritePyPIHTML(body []byte, domain string) ([]byte, error) { + node, err := html.Parse(bytes.NewReader(body)) + if err != nil { + return nil, err + } + rewriteHTMLNode(node, domain) + var buf bytes.Buffer + if err := html.Render(&buf, node); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func rewriteHTMLNode(n *html.Node, domain string) { + if n.Type == html.ElementNode { + rewriteHTMLAttributes(n, domain) + } + for child := n.FirstChild; child != nil; child = child.NextSibling { + rewriteHTMLNode(child, domain) + } +} + +func rewriteHTMLAttributes(n *html.Node, domain string) { + for i, attr := range n.Attr { + switch attr.Key { + case "href", "data-dist-info-metadata", "data-core-metadata": + if strings.HasPrefix(attr.Val, "http://") || strings.HasPrefix(attr.Val, "https://") { + n.Attr[i].Val = rewritePyPIFileURL(domain, attr.Val) + } + } + } +} + +func rewritePyPIFileURL(domain, original string) string { + parsed, err := url.Parse(original) + if err != nil || parsed.Scheme == "" || parsed.Host == "" { + return original + } + prefix := "/files/" + parsed.Scheme + "/" + parsed.Host + newURL := url.URL{ + Scheme: "https", + Host: domain, + Path: prefix + parsed.Path, + RawQuery: parsed.RawQuery, + Fragment: parsed.Fragment, + } + if raw := parsed.RawPath; raw != "" { + newURL.RawPath = prefix + raw + } + return newURL.String() +} diff --git a/specs/004-modular-proxy-cache/quickstart.md b/specs/004-modular-proxy-cache/quickstart.md index 6f0b828..11c4b96 100644 --- a/specs/004-modular-proxy-cache/quickstart.md +++ b/specs/004-modular-proxy-cache/quickstart.md @@ -20,7 +20,7 @@ ## 4. Run and Verify 1. Start the binary: `go run ./cmd/any-hub --config ./config.toml`. 2. Use `curl -H "Host: " http://127.0.0.1:/` to produce traffic, then hit `curl http://127.0.0.1:/-/modules` and confirm the hub binding points to your module with the expected `rollout_flag`. -3. Inspect `./storage//` to confirm the cached files mirror the upstream path (no suffix). When a path also has child entries (e.g., `/pkg` metadata plus `/pkg/-/...` tarballs), the metadata payload is stored in a `__content` file under that directory so both artifacts can coexist. Verify TTL overrides are propagated. +3. Inspect `./storage//` to confirm the cached files mirror the upstream path (no suffix). When a path also has child entries (e.g., `/pkg` metadata plus `/pkg/-/...` tarballs), the metadata payload is stored in a `__content` file under that directory so both artifacts can coexist. PyPI Simple responses rewrite distribution links to `/files///` so that wheels/tarballs are fetched through the proxy and cached alongside the HTML/JSON index. Verify TTL overrides are propagated. 4. Monitor `logs/any-hub.log` (or the sample `logs/module_migration_sample.log`) to verify each entry exposes `module_key` + `rollout_flag`. Example: ```json {"action":"proxy","hub":"testhub","module_key":"testhub","rollout_flag":"dual","cache_hit":false,"upstream_status":200} diff --git a/tests/integration/pypi_proxy_test.go b/tests/integration/pypi_proxy_test.go index c5ece8e..e8a3fbc 100644 --- a/tests/integration/pypi_proxy_test.go +++ b/tests/integration/pypi_proxy_test.go @@ -2,10 +2,13 @@ package integration import ( "context" + "fmt" "io" "net" "net/http" "net/http/httptest" + "net/url" + "strings" "sync" "testing" "time" @@ -83,7 +86,11 @@ func TestPyPICachePolicies(t *testing.T) { if resp.Header.Get("X-Any-Hub-Cache-Hit") != "false" { t.Fatalf("expected miss for first simple request") } + body, _ := io.ReadAll(resp.Body) resp.Body.Close() + if !strings.Contains(string(body), "/files/") { + t.Fatalf("simple response should rewrite file links, got %s", string(body)) + } resp2 := doRequest(simplePath) if resp2.Header.Get("X-Any-Hub-Cache-Hit") != "true" { @@ -109,7 +116,12 @@ func TestPyPICachePolicies(t *testing.T) { t.Fatalf("expected second HEAD before refresh, got %d", stub.simpleHeadHits) } - wheelPath := "/packages/foo/foo-1.0-py3-none-any.whl" + wheelURL := fmt.Sprintf("%s/packages/foo/foo-1.0-py3-none-any.whl", stub.URL) + parsedWheel, err := url.Parse(wheelURL) + if err != nil { + t.Fatalf("wheel url parse: %v", err) + } + wheelPath := fmt.Sprintf("/files/%s/%s%s", parsedWheel.Scheme, parsedWheel.Host, parsedWheel.Path) respWheel := doRequest(wheelPath) if respWheel.StatusCode != fiber.StatusOK { t.Fatalf("expected 200 for wheel, got %d", respWheel.StatusCode) @@ -151,19 +163,20 @@ type pypiStub struct { simpleBody []byte wheelBody []byte lastSimpleMod string + wheelPath string } func newPyPIStub(t *testing.T) *pypiStub { t.Helper() stub := &pypiStub{ - simpleBody: []byte("ok"), + wheelPath: "/packages/foo/foo-1.0-py3-none-any.whl", wheelBody: []byte("wheel-bytes"), lastSimpleMod: time.Now().UTC().Format(http.TimeFormat), } mux := http.NewServeMux() mux.HandleFunc("/simple/pkg/", stub.handleSimple) - mux.HandleFunc("/packages/foo/foo-1.0-py3-none-any.whl", stub.handleWheel) + mux.HandleFunc(stub.wheelPath, stub.handleWheel) listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -174,6 +187,7 @@ func newPyPIStub(t *testing.T) *pypiStub { stub.server = server stub.listener = listener stub.URL = "http://" + listener.Addr().String() + stub.simpleBody = stub.defaultSimpleHTML() go func() { _ = server.Serve(listener) @@ -224,9 +238,13 @@ func (s *pypiStub) handleWheel(w http.ResponseWriter, r *http.Request) { func (s *pypiStub) UpdateSimple(body []byte) { s.mu.Lock() - defer s.mu.Unlock() s.simpleBody = append([]byte(nil), body...) s.lastSimpleMod = time.Now().UTC().Format(http.TimeFormat) + s.mu.Unlock() +} + +func (s *pypiStub) defaultSimpleHTML() []byte { + return []byte(fmt.Sprintf(`wheel`, s.URL, s.wheelPath)) } func (s *pypiStub) Close() {