feat: add http rate limiting

This commit is contained in:
2026-01-17 09:47:49 +08:00
parent 4f2b8ea3ad
commit c399a65d83
4 changed files with 68 additions and 16 deletions

View File

@@ -15,6 +15,7 @@ type Config struct {
BaseURI *string
Tls *Tls
Cors *Cors
RateLimit *RateLimit
}
type Tls struct {
@@ -27,6 +28,14 @@ type Cors struct {
Whitelist []Whitelist
}
type RateLimit struct {
Enabled bool
Max int
WindowSeconds int
Message string
SkipPaths []string
}
type Whitelist struct {
AllowOrigin string
AllowHeaders string

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net"
"runtime/debug"
"strings"
"time"
log "github.com/sirupsen/logrus"
@@ -16,10 +17,13 @@ import (
"github.com/gofiber/fiber/v3/middleware/compress"
"github.com/gofiber/fiber/v3/middleware/cors"
"github.com/gofiber/fiber/v3/middleware/helmet"
"github.com/gofiber/fiber/v3/middleware/limiter"
"github.com/gofiber/fiber/v3/middleware/logger"
"github.com/gofiber/fiber/v3/middleware/recover"
"github.com/gofiber/fiber/v3/middleware/requestid"
"github.com/samber/lo"
"quyun/v2/app/errorx"
)
func DefaultProvider() container.ProviderContainer {
@@ -137,8 +141,51 @@ func Provide(opts ...opt.Option) error {
TimeZone: "Asia/Shanghai",
}))
// rate limit (enable standard headers; adjust Max via config if needed)
// engine.Use(limiter.New(limiter.Config{Max: 0}))
// rate limit (by tenant code or IP)
if config.RateLimit != nil && config.RateLimit.Enabled {
max := config.RateLimit.Max
if max <= 0 {
max = 120
}
windowSeconds := config.RateLimit.WindowSeconds
if windowSeconds <= 0 {
windowSeconds = 60
}
message := strings.TrimSpace(config.RateLimit.Message)
skipPrefixes := append([]string{"/healthz", "/readyz"}, config.RateLimit.SkipPaths...)
engine.Use(limiter.New(limiter.Config{
Max: max,
Expiration: time.Duration(windowSeconds) * time.Second,
LimitReached: func(c fiber.Ctx) error {
appErr := errorx.ErrRateLimitExceeded
if message != "" {
appErr = appErr.WithMsg(message)
}
return errorx.SendError(c, appErr)
},
Next: func(c fiber.Ctx) bool {
path := c.Path()
for _, prefix := range skipPrefixes {
if prefix == "" {
continue
}
if strings.HasPrefix(path, prefix) {
return true
}
}
return false
},
KeyGenerator: func(c fiber.Ctx) string {
if strings.HasPrefix(c.Path(), "/t/") {
if tenantCode := strings.TrimSpace(c.Params("tenantCode")); tenantCode != "" {
return "tenant:" + tenantCode
}
}
return c.IP()
},
}))
}
// static files (Fiber v3 Static helper moved; enable via filesystem middleware later)
// if config.StaticRoute != nil && config.StaticPath != nil { ... }