From 6c6d5273ac49564f3045c139da8beead1b9bc462 Mon Sep 17 00:00:00 2001 From: Rogee Date: Mon, 17 Nov 2025 09:18:46 +0800 Subject: [PATCH] complete docker cache --- internal/proxy/handler.go | 48 ++++++++++++++++++++-- tests/integration/credential_proxy_test.go | 13 +++++- 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/internal/proxy/handler.go b/internal/proxy/handler.go index 393f77c..3300ad8 100644 --- a/internal/proxy/handler.go +++ b/internal/proxy/handler.go @@ -116,13 +116,23 @@ func (h *Handler) serveCache( requestID string, started time.Time, ) error { - if seeker, ok := result.Reader.(io.Seeker); ok { - _, _ = seeker.Seek(0, io.SeekStart) + var readSeeker io.ReadSeeker + switch reader := result.Reader.(type) { + case io.ReadSeeker: + readSeeker = reader + _, _ = readSeeker.Seek(0, io.SeekStart) + case io.Seeker: + _, _ = reader.Seek(0, io.SeekStart) } method := c.Method() contentType := inferCachedContentType(route, result.Entry.Locator) + if contentType == "" && shouldSniffDockerManifest(route, result.Entry.Locator) { + if sniffed := sniffDockerManifestContentType(readSeeker); sniffed != "" { + contentType = sniffed + } + } if contentType != "" { c.Set("Content-Type", contentType) } else { @@ -454,7 +464,7 @@ func inferCachedContentType(route *server.HubRoute, locator cache.Locator) strin switch route.Config.Type { case "docker": if strings.Contains(clean, "/manifests/") { - return "application/vnd.docker.distribution.manifest.v2+json" + return "" } if strings.Contains(clean, "/tags/list") { return "application/json" @@ -520,6 +530,38 @@ func stripQueryMarker(p string) string { return p } +func shouldSniffDockerManifest(route *server.HubRoute, locator cache.Locator) bool { + if route == nil || route.Config.Type != "docker" { + return false + } + clean := stripQueryMarker(locator.Path) + return strings.Contains(clean, "/manifests/") +} + +func sniffDockerManifestContentType(reader io.ReadSeeker) string { + if reader == nil { + return "" + } + const maxInspectBytes = 512 * 1024 + if _, err := reader.Seek(0, io.SeekStart); err != nil { + return "" + } + data, err := io.ReadAll(io.LimitReader(reader, maxInspectBytes)) + if _, seekErr := reader.Seek(0, io.SeekStart); seekErr != nil { + return "" + } + if err != nil && !errors.Is(err, io.EOF) { + return "" + } + var manifest struct { + MediaType string `json:"mediaType"` + } + if err := json.Unmarshal(data, &manifest); err != nil { + return "" + } + return strings.TrimSpace(manifest.MediaType) +} + func requestPath(c fiber.Ctx) string { if c == nil { return "/" diff --git a/tests/integration/credential_proxy_test.go b/tests/integration/credential_proxy_test.go index f7bb124..9343817 100644 --- a/tests/integration/credential_proxy_test.go +++ b/tests/integration/credential_proxy_test.go @@ -131,6 +131,8 @@ func TestCredentialProxy(t *testing.T) { }) } +const dockerManifestContentType = "application/vnd.oci.image.index.v1+json" + func TestDockerProxyHandlesBearerTokenExchange(t *testing.T) { stub := newDockerBearerStub(t, "ci-user", "ci-pass") defer stub.Close() @@ -183,6 +185,9 @@ func TestDockerProxyCachesAfterBearerRevalidation(t *testing.T) { if resp.Header.Get("X-Any-Hub-Cache-Hit") != "false" { t.Fatalf("expected first request to miss cache") } + if resp.Header.Get("Content-Type") != dockerManifestContentType { + t.Fatalf("expected upstream content type %s, got %s", dockerManifestContentType, resp.Header.Get("Content-Type")) + } resp.Body.Close() req2 := httptest.NewRequest("GET", "http://docker.hub.local/v2/library/alpine/manifests/latest", nil) @@ -198,6 +203,9 @@ func TestDockerProxyCachesAfterBearerRevalidation(t *testing.T) { if resp2.Header.Get("X-Any-Hub-Cache-Hit") != "true" { t.Fatalf("expected second request to be served from cache") } + if resp2.Header.Get("Content-Type") != dockerManifestContentType { + t.Fatalf("expected cached content type %s, got %s", dockerManifestContentType, resp2.Header.Get("Content-Type")) + } resp2.Body.Close() if hits := stub.ManifestHits(); hits != 4 { @@ -515,13 +523,14 @@ func (s *dockerBearerStub) handleManifest(w http.ResponseWriter, r *http.Request s.mu.Unlock() if success { - w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Type", dockerManifestContentType) w.Header().Set("Last-Modified", s.lastModified.Format(http.TimeFormat)) w.WriteHeader(http.StatusOK) if r.Method == http.MethodHead { return } - _, _ = w.Write([]byte(`{"schemaVersion":2}`)) + payload := fmt.Sprintf(`{"schemaVersion":2,"mediaType":"%s"}`, dockerManifestContentType) + _, _ = w.Write([]byte(payload)) return }