From 1db9badec86e1a7d10276f3e5c22eeb1d788beea Mon Sep 17 00:00:00 2001 From: Rogee Date: Mon, 17 Nov 2025 09:15:35 +0800 Subject: [PATCH] fix: docker cache --- internal/proxy/handler.go | 91 +++++++++++++++++++++- tests/integration/credential_proxy_test.go | 50 ++++++++++++ 2 files changed, 137 insertions(+), 4 deletions(-) diff --git a/internal/proxy/handler.go b/internal/proxy/handler.go index 5cee602..393f77c 100644 --- a/internal/proxy/handler.go +++ b/internal/proxy/handler.go @@ -14,6 +14,7 @@ import ( "net/url" "path" "strings" + "sync" "time" "github.com/gofiber/fiber/v3" @@ -31,6 +32,7 @@ type Handler struct { client *http.Client logger *logrus.Logger store cache.Store + etags sync.Map // key: hub+path, value: etag/digest string } // NewHandler constructs a proxy handler with shared HTTP client/logger/store. @@ -267,6 +269,7 @@ func (h *Handler) cacheAndStream( if err != nil { return fiber.NewError(fiber.StatusBadGateway, fmt.Sprintf("cache_write_failed: %v", err)) } + h.rememberETag(route, locator, resp) _ = entry return nil } @@ -706,19 +709,37 @@ func (h *Handler) isCacheFresh( } upstreamURL := resolveUpstreamURL(route, route.UpstreamURL, c) - req, err := h.buildUpstreamRequest(c, upstreamURL, route, http.MethodHead, http.NoBody, "") + resp, err := h.revalidateRequest(c, route, upstreamURL, locator, "") if err != nil { return false, err } - resp, err := h.doRequest(req, route) - if err != nil { - return false, err + if shouldRetryAuth(route, resp.StatusCode) { + challenge, ok := parseBearerChallenge(resp.Header.Values("Www-Authenticate")) + resp.Body.Close() + + authHeader := "" + if ok { + token, err := h.fetchBearerToken(ctx, challenge, route) + if err != nil { + return false, err + } + authHeader = "Bearer " + token + } + + resp, err = h.revalidateRequest(c, route, upstreamURL, locator, authHeader) + if err != nil { + return false, err + } } + defer resp.Body.Close() switch resp.StatusCode { + case http.StatusNotModified: + return true, nil case http.StatusOK: + h.rememberETag(route, locator, resp) remote := extractModTime(resp.Header) if !remote.After(entry.ModTime.Add(time.Second)) { return true, nil @@ -728,12 +749,30 @@ func (h *Handler) isCacheFresh( if h.store != nil { _ = h.store.Remove(ctx, locator) } + h.forgetETag(route, locator) return false, nil default: return false, nil } } +func (h *Handler) revalidateRequest( + c fiber.Ctx, + route *server.HubRoute, + upstreamURL *url.URL, + locator cache.Locator, + overrideAuth string, +) (*http.Response, error) { + req, err := h.buildUpstreamRequest(c, upstreamURL, route, http.MethodHead, http.NoBody, overrideAuth) + if err != nil { + return nil, err + } + if etag := h.cachedETag(route, locator); etag != "" { + req.Header.Set("If-None-Match", etag) + } + return h.doRequest(req, route) +} + func extractModTime(header http.Header) time.Time { if last := header.Get("Last-Modified"); last != "" { if parsed, err := http.ParseTime(last); err == nil { @@ -984,6 +1023,50 @@ func (h *Handler) logAuthFailure(route *server.HubRoute, upstream string, reques h.logger.WithFields(fields).Error("proxy_auth_failed") } +func (h *Handler) rememberETag(route *server.HubRoute, locator cache.Locator, resp *http.Response) { + if resp == nil { + return + } + etag := resp.Header.Get("Docker-Content-Digest") + if etag == "" { + etag = resp.Header.Get("Etag") + } + etag = normalizeETag(etag) + if etag == "" { + return + } + h.etags.Store(h.locatorKey(route, locator), etag) +} + +func (h *Handler) cachedETag(route *server.HubRoute, locator cache.Locator) string { + if value, ok := h.etags.Load(h.locatorKey(route, locator)); ok { + if etag, ok := value.(string); ok { + return etag + } + } + return "" +} + +func (h *Handler) forgetETag(route *server.HubRoute, locator cache.Locator) { + h.etags.Delete(h.locatorKey(route, locator)) +} + +func (h *Handler) locatorKey(route *server.HubRoute, locator cache.Locator) string { + hub := locator.HubName + if route != nil && route.Config.Name != "" { + hub = route.Config.Name + } + return hub + "::" + locator.Path +} + +func normalizeETag(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "" + } + return strings.Trim(value, "\"") +} + func ensureProxyHubType(route *server.HubRoute) error { switch route.Config.Type { case "docker": diff --git a/tests/integration/credential_proxy_test.go b/tests/integration/credential_proxy_test.go index fa72244..f7bb124 100644 --- a/tests/integration/credential_proxy_test.go +++ b/tests/integration/credential_proxy_test.go @@ -164,6 +164,50 @@ func TestDockerProxyHandlesBearerTokenExchange(t *testing.T) { } } +func TestDockerProxyCachesAfterBearerRevalidation(t *testing.T) { + stub := newDockerBearerStub(t, "ci-user", "ci-pass") + defer stub.Close() + + app := newDockerProxyApp(t, stub) + + req := httptest.NewRequest("GET", "http://docker.hub.local/v2/library/alpine/manifests/latest", nil) + req.Host = "docker.hub.local" + resp, err := app.Test(req) + if err != nil { + t.Fatalf("app.Test failed: %v", err) + } + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 200 after token exchange, got %d (body=%s)", resp.StatusCode, string(body)) + } + if resp.Header.Get("X-Any-Hub-Cache-Hit") != "false" { + t.Fatalf("expected first request to miss cache") + } + resp.Body.Close() + + req2 := httptest.NewRequest("GET", "http://docker.hub.local/v2/library/alpine/manifests/latest", nil) + req2.Host = "docker.hub.local" + resp2, err := app.Test(req2) + if err != nil { + t.Fatalf("app.Test failed: %v", err) + } + if resp2.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp2.Body) + t.Fatalf("expected 200 after cache revalidation, got %d (body=%s)", resp2.StatusCode, string(body)) + } + if resp2.Header.Get("X-Any-Hub-Cache-Hit") != "true" { + t.Fatalf("expected second request to be served from cache") + } + resp2.Body.Close() + + if hits := stub.ManifestHits(); hits != 4 { + t.Fatalf("expected 4 manifest hits (2 GET + 2 HEAD), got %d", hits) + } + if tokens := stub.TokenHits(); tokens != 2 { + t.Fatalf("expected token endpoint to be called twice, got %d", tokens) + } +} + func performCredentialRequest(t *testing.T, app *fiber.App) *http.Response { t.Helper() req := httptest.NewRequest("GET", "http://secure.hub.local/private/data", nil) @@ -427,6 +471,7 @@ type dockerBearerStub struct { tokenAuth string manifestHits int tokenHits int + lastModified time.Time } func newDockerBearerStub(t *testing.T, username, password string) *dockerBearerStub { @@ -436,6 +481,7 @@ func newDockerBearerStub(t *testing.T, username, password string) *dockerBearerS password: password, expectedBasic: "Basic " + base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", username, password))), tokenValue: "test-token", + lastModified: time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC), } mux := http.NewServeMux() @@ -470,7 +516,11 @@ func (s *dockerBearerStub) handleManifest(w http.ResponseWriter, r *http.Request if success { w.Header().Set("Content-Type", "application/json") + w.Header().Set("Last-Modified", s.lastModified.Format(http.TimeFormat)) w.WriteHeader(http.StatusOK) + if r.Method == http.MethodHead { + return + } _, _ = w.Write([]byte(`{"schemaVersion":2}`)) return }