modify proxy handler
This commit is contained in:
@@ -1,46 +1,84 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/any-hub/any-hub/internal/logging"
|
||||
"github.com/any-hub/any-hub/internal/server"
|
||||
)
|
||||
|
||||
// Forwarder 根据 HubRoute 的 module_key 选择对应的 ProxyHandler,默认回退到构造时注入的 handler。
|
||||
type Forwarder struct {
|
||||
defaultHandler server.ProxyHandler
|
||||
logger *logrus.Logger
|
||||
}
|
||||
|
||||
// NewForwarder 创建 Forwarder,defaultHandler 不能为空。
|
||||
func NewForwarder(defaultHandler server.ProxyHandler) *Forwarder {
|
||||
return &Forwarder{defaultHandler: defaultHandler}
|
||||
func NewForwarder(defaultHandler server.ProxyHandler, logger *logrus.Logger) *Forwarder {
|
||||
return &Forwarder{
|
||||
defaultHandler: defaultHandler,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
moduleHandlers sync.Map
|
||||
)
|
||||
|
||||
// RegisterModuleHandler 将特定 module_key 映射到 ProxyHandler,重复注册会覆盖旧值。
|
||||
// RegisterModuleHandler is kept for backward compatibility; it panics on invalid input.
|
||||
func RegisterModuleHandler(key string, handler server.ProxyHandler) {
|
||||
normalized := normalizeModuleKey(key)
|
||||
if normalized == "" || handler == nil {
|
||||
return
|
||||
}
|
||||
moduleHandlers.Store(normalized, handler)
|
||||
MustRegisterModule(ModuleRegistration{Key: key, Handler: handler})
|
||||
}
|
||||
|
||||
// Handle 实现 server.ProxyHandler,根据 route.ModuleKey 选择 handler。
|
||||
func (f *Forwarder) Handle(c fiber.Ctx, route *server.HubRoute) error {
|
||||
handler := f.lookup(route)
|
||||
if handler == nil {
|
||||
return fiber.NewError(fiber.StatusInternalServerError, "proxy handler unavailable")
|
||||
return f.respondMissingHandler(c, route)
|
||||
}
|
||||
return f.invokeHandler(c, route, handler)
|
||||
}
|
||||
|
||||
func (f *Forwarder) respondMissingHandler(c fiber.Ctx, route *server.HubRoute) error {
|
||||
f.logModuleError(route, "module_handler_missing", nil)
|
||||
return c.Status(fiber.StatusInternalServerError).
|
||||
JSON(fiber.Map{"error": "module_handler_missing"})
|
||||
}
|
||||
|
||||
func (f *Forwarder) invokeHandler(c fiber.Ctx, route *server.HubRoute, handler server.ProxyHandler) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = f.respondHandlerPanic(c, route, r)
|
||||
}
|
||||
}()
|
||||
return handler.Handle(c, route)
|
||||
}
|
||||
|
||||
func (f *Forwarder) respondHandlerPanic(c fiber.Ctx, route *server.HubRoute, recovered interface{}) error {
|
||||
f.logModuleError(route, "module_handler_panic", fmt.Errorf("panic: %v", recovered))
|
||||
return c.Status(fiber.StatusInternalServerError).
|
||||
JSON(fiber.Map{"error": "module_handler_panic"})
|
||||
}
|
||||
|
||||
func (f *Forwarder) logModuleError(route *server.HubRoute, code string, err error) {
|
||||
if f.logger == nil {
|
||||
return
|
||||
}
|
||||
fields := f.routeFields(route)
|
||||
fields["action"] = "proxy"
|
||||
fields["error"] = code
|
||||
if err != nil {
|
||||
f.logger.WithFields(fields).Error(err.Error())
|
||||
return
|
||||
}
|
||||
f.logger.WithFields(fields).Error("module handler unavailable")
|
||||
}
|
||||
|
||||
func (f *Forwarder) lookup(route *server.HubRoute) server.ProxyHandler {
|
||||
if route != nil {
|
||||
if handler := lookupModuleHandler(route.ModuleKey); handler != nil {
|
||||
@@ -66,3 +104,26 @@ func lookupModuleHandler(key string) server.ProxyHandler {
|
||||
func normalizeModuleKey(key string) string {
|
||||
return strings.ToLower(strings.TrimSpace(key))
|
||||
}
|
||||
|
||||
func (f *Forwarder) routeFields(route *server.HubRoute) logrus.Fields {
|
||||
if route == nil {
|
||||
return logrus.Fields{
|
||||
"hub": "",
|
||||
"domain": "",
|
||||
"hub_type": "",
|
||||
"auth_mode": "",
|
||||
"cache_hit": false,
|
||||
"module_key": "",
|
||||
}
|
||||
}
|
||||
|
||||
return logging.RequestFields(
|
||||
route.Config.Name,
|
||||
route.Config.Domain,
|
||||
route.Config.Type,
|
||||
route.Config.AuthMode(),
|
||||
route.ModuleKey,
|
||||
string(route.RolloutFlag),
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
||||
93
internal/proxy/forwarder_test.go
Normal file
93
internal/proxy/forwarder_test.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/any-hub/any-hub/internal/config"
|
||||
"github.com/any-hub/any-hub/internal/hubmodule/legacy"
|
||||
"github.com/any-hub/any-hub/internal/server"
|
||||
)
|
||||
|
||||
func TestForwarderMissingHandler(t *testing.T) {
|
||||
app := fiber.New()
|
||||
defer app.Shutdown()
|
||||
|
||||
ctx := app.AcquireCtx(new(fasthttp.RequestCtx))
|
||||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
logger := logrus.New()
|
||||
logBuf := &bytes.Buffer{}
|
||||
logger.SetOutput(logBuf)
|
||||
|
||||
forwarder := NewForwarder(nil, logger)
|
||||
route := testRouteWithModule("missing-module")
|
||||
|
||||
if err := forwarder.Handle(ctx, route); err != nil {
|
||||
t.Fatalf("forwarder.Handle returned unexpected error: %v", err)
|
||||
}
|
||||
if status := ctx.Response().StatusCode(); status != fiber.StatusInternalServerError {
|
||||
t.Fatalf("expected 500 for missing handler, got %d", status)
|
||||
}
|
||||
if body := string(ctx.Response().Body()); !strings.Contains(body, "module_handler_missing") {
|
||||
t.Fatalf("expected error body to mention module_handler_missing, got %s", body)
|
||||
}
|
||||
if !strings.Contains(logBuf.String(), "module_handler_missing") {
|
||||
t.Fatalf("expected log to mention module_handler_missing, got %s", logBuf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwarderHandlerPanic(t *testing.T) {
|
||||
const moduleKey = "panic-module"
|
||||
moduleHandlers.Delete(normalizeModuleKey(moduleKey))
|
||||
defer moduleHandlers.Delete(normalizeModuleKey(moduleKey))
|
||||
|
||||
MustRegisterModule(ModuleRegistration{
|
||||
Key: moduleKey,
|
||||
Handler: server.ProxyHandlerFunc(func(fiber.Ctx, *server.HubRoute) error {
|
||||
panic("boom")
|
||||
}),
|
||||
})
|
||||
|
||||
app := fiber.New()
|
||||
defer app.Shutdown()
|
||||
ctx := app.AcquireCtx(new(fasthttp.RequestCtx))
|
||||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
logger := logrus.New()
|
||||
logBuf := &bytes.Buffer{}
|
||||
logger.SetOutput(logBuf)
|
||||
|
||||
forwarder := NewForwarder(nil, logger)
|
||||
route := testRouteWithModule(moduleKey)
|
||||
|
||||
if err := forwarder.Handle(ctx, route); err != nil {
|
||||
t.Fatalf("forwarder.Handle returned unexpected error: %v", err)
|
||||
}
|
||||
if status := ctx.Response().StatusCode(); status != fiber.StatusInternalServerError {
|
||||
t.Fatalf("expected 500 for handler panic, got %d", status)
|
||||
}
|
||||
if body := string(ctx.Response().Body()); !strings.Contains(body, "module_handler_panic") {
|
||||
t.Fatalf("expected error body to mention module_handler_panic, got %s", body)
|
||||
}
|
||||
if !strings.Contains(logBuf.String(), "module_handler_panic") {
|
||||
t.Fatalf("expected log to mention module_handler_panic, got %s", logBuf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func testRouteWithModule(moduleKey string) *server.HubRoute {
|
||||
return &server.HubRoute{
|
||||
Config: config.HubConfig{
|
||||
Name: "test",
|
||||
Domain: "test.local",
|
||||
Type: "custom",
|
||||
},
|
||||
ModuleKey: moduleKey,
|
||||
RolloutFlag: legacy.RolloutModular,
|
||||
}
|
||||
}
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/any-hub/any-hub/internal/cache"
|
||||
"github.com/any-hub/any-hub/internal/hubmodule"
|
||||
"github.com/any-hub/any-hub/internal/logging"
|
||||
"github.com/any-hub/any-hub/internal/proxy/hooks"
|
||||
"github.com/any-hub/any-hub/internal/server"
|
||||
)
|
||||
|
||||
@@ -48,8 +49,10 @@ func NewHandler(client *http.Client, logger *logrus.Logger, store cache.Store) *
|
||||
func (h *Handler) Handle(c fiber.Ctx, route *server.HubRoute) error {
|
||||
started := time.Now()
|
||||
requestID := server.RequestID(c)
|
||||
locator := buildLocator(route, c)
|
||||
policy := determineCachePolicy(route, locator, c.Method())
|
||||
reqCtx := newRequestContext(route, c.Method())
|
||||
moduleHooks, _ := hooks.For(route.ModuleKey)
|
||||
locator := buildLocator(route, c, reqCtx, moduleHooks)
|
||||
policy := determineCachePolicy(route, locator, c.Method(), reqCtx, moduleHooks)
|
||||
strategyWriter := cache.NewStrategyWriter(h.store, route.CacheStrategy)
|
||||
|
||||
if err := ensureProxyHubType(route); err != nil {
|
||||
@@ -106,7 +109,7 @@ func (h *Handler) Handle(c fiber.Ctx, route *server.HubRoute) error {
|
||||
cached.Reader.Close()
|
||||
}
|
||||
|
||||
return h.fetchAndStream(c, route, locator, policy, strategyWriter, requestID, started, ctx)
|
||||
return h.fetchAndStream(c, route, locator, policy, strategyWriter, requestID, started, ctx, reqCtx, moduleHooks)
|
||||
}
|
||||
|
||||
func (h *Handler) serveCache(
|
||||
|
||||
60
internal/proxy/hooks/hooks.go
Normal file
60
internal/proxy/hooks/hooks.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package hooks
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// CachePolicy mirrors the proxy cache policy structure.
|
||||
type CachePolicy struct {
|
||||
AllowCache bool
|
||||
AllowStore bool
|
||||
RequireRevalidate bool
|
||||
}
|
||||
|
||||
// RequestContext exposes route/request details without importing server internals.
|
||||
type RequestContext struct {
|
||||
HubName string
|
||||
Domain string
|
||||
HubType string
|
||||
ModuleKey string
|
||||
RolloutFlag string
|
||||
UpstreamHost string
|
||||
Method string
|
||||
}
|
||||
|
||||
// Hooks describes customization points for module-specific behavior.
|
||||
type Hooks struct {
|
||||
NormalizePath func(ctx *RequestContext, cleanPath string) string
|
||||
ResolveUpstream func(ctx *RequestContext, base *url.URL, cleanPath string, rawQuery []byte) *url.URL
|
||||
RewriteResponse func(ctx *RequestContext, resp *http.Response, cleanPath string) (*http.Response, error)
|
||||
CachePolicy func(ctx *RequestContext, locatorPath string, current CachePolicy) CachePolicy
|
||||
ContentType func(ctx *RequestContext, locatorPath string) string
|
||||
}
|
||||
|
||||
var registry sync.Map
|
||||
|
||||
// Register stores hooks for the given module key.
|
||||
func Register(moduleKey string, hooks Hooks) {
|
||||
key := strings.ToLower(strings.TrimSpace(moduleKey))
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
registry.Store(key, hooks)
|
||||
}
|
||||
|
||||
// For retrieves hooks associated with a module key.
|
||||
func For(moduleKey string) (Hooks, bool) {
|
||||
key := strings.ToLower(strings.TrimSpace(moduleKey))
|
||||
if key == "" {
|
||||
return Hooks{}, false
|
||||
}
|
||||
if value, ok := registry.Load(key); ok {
|
||||
if hooks, ok := value.(Hooks); ok {
|
||||
return hooks, true
|
||||
}
|
||||
}
|
||||
return Hooks{}, false
|
||||
}
|
||||
56
internal/proxy/module_contract.go
Normal file
56
internal/proxy/module_contract.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/any-hub/any-hub/internal/server"
|
||||
)
|
||||
|
||||
// ModuleHandler is the runtime contract each hubmodule must provide to serve requests.
|
||||
// It aligns with server.ProxyHandler so existing handlers remain compatible.
|
||||
type ModuleHandler = server.ProxyHandler
|
||||
|
||||
// ModuleRegistration captures a module_key and its handler for safe registration.
|
||||
// Future registration flows can validate this struct before wiring into the dispatcher.
|
||||
type ModuleRegistration struct {
|
||||
Key string
|
||||
Handler ModuleHandler
|
||||
}
|
||||
|
||||
// ErrModuleHandlerExists indicates a handler has already been registered for the key.
|
||||
var ErrModuleHandlerExists = errors.New("module handler already registered")
|
||||
|
||||
// Validate ensures both key and handler are present before registration.
|
||||
func (r ModuleRegistration) Validate() error {
|
||||
if strings.TrimSpace(r.Key) == "" {
|
||||
return errors.New("module key required")
|
||||
}
|
||||
if r.Handler == nil {
|
||||
return errors.New("module handler required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterModule registers validated metadata/runtime handler pair.
|
||||
func RegisterModule(reg ModuleRegistration) error {
|
||||
if err := reg.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
normalized := normalizeModuleKey(reg.Key)
|
||||
if normalized == "" {
|
||||
return errors.New("module key required")
|
||||
}
|
||||
if _, loaded := moduleHandlers.LoadOrStore(normalized, reg.Handler); loaded {
|
||||
return fmt.Errorf("%w: %s", ErrModuleHandlerExists, normalized)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MustRegisterModule panics when registration fails; suitable for module init().
|
||||
func MustRegisterModule(reg ModuleRegistration) {
|
||||
if err := RegisterModule(reg); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
@@ -62,6 +62,9 @@ func NewApp(opts AppOptions) (*fiber.App, error) {
|
||||
app.Use(requestContextMiddleware(opts))
|
||||
|
||||
app.All("/*", func(c fiber.Ctx) error {
|
||||
if isDiagnosticsPath(string(c.Request().URI().Path())) {
|
||||
return c.Next()
|
||||
}
|
||||
route, _ := getRouteFromContext(c)
|
||||
if route == nil {
|
||||
return renderHostUnmapped(c, opts.Logger, "", opts.ListenPort)
|
||||
@@ -88,9 +91,6 @@ func requestContextMiddleware(opts AppOptions) fiber.Handler {
|
||||
if !ok {
|
||||
return renderHostUnmapped(c, opts.Logger, rawHost, opts.ListenPort)
|
||||
}
|
||||
if err := ensureRouterHubType(route); err != nil {
|
||||
return renderTypeUnsupported(c, opts.Logger, route, err)
|
||||
}
|
||||
|
||||
c.Locals(contextKeyRoute, route)
|
||||
return c.Next()
|
||||
@@ -140,38 +140,6 @@ func RequestID(c fiber.Ctx) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func ensureRouterHubType(route *HubRoute) error {
|
||||
switch route.Config.Type {
|
||||
case "docker":
|
||||
return nil
|
||||
case "npm":
|
||||
return nil
|
||||
case "go":
|
||||
return nil
|
||||
case "pypi":
|
||||
return nil
|
||||
case "composer":
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unsupported hub type: %s", route.Config.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func renderTypeUnsupported(c fiber.Ctx, logger *logrus.Logger, route *HubRoute, err error) error {
|
||||
fields := logrus.Fields{
|
||||
"action": "hub_type_check",
|
||||
"hub": route.Config.Name,
|
||||
"hub_type": route.Config.Type,
|
||||
"module_key": route.ModuleKey,
|
||||
"rollout_flag": string(route.RolloutFlag),
|
||||
"error": "hub_type_unsupported",
|
||||
}
|
||||
logger.WithFields(fields).Error(err.Error())
|
||||
return c.Status(fiber.StatusNotImplemented).JSON(fiber.Map{
|
||||
"error": "hub_type_unsupported",
|
||||
})
|
||||
}
|
||||
|
||||
func isDiagnosticsPath(path string) bool {
|
||||
return strings.HasPrefix(path, "/-/")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user