feat: add http rate limiting
This commit is contained in:
@@ -15,6 +15,7 @@ type Config struct {
|
||||
BaseURI *string
|
||||
Tls *Tls
|
||||
Cors *Cors
|
||||
RateLimit *RateLimit
|
||||
}
|
||||
|
||||
type Tls struct {
|
||||
@@ -27,6 +28,14 @@ type Cors struct {
|
||||
Whitelist []Whitelist
|
||||
}
|
||||
|
||||
type RateLimit struct {
|
||||
Enabled bool
|
||||
Max int
|
||||
WindowSeconds int
|
||||
Message string
|
||||
SkipPaths []string
|
||||
}
|
||||
|
||||
type Whitelist struct {
|
||||
AllowOrigin string
|
||||
AllowHeaders string
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -16,10 +17,13 @@ import (
|
||||
"github.com/gofiber/fiber/v3/middleware/compress"
|
||||
"github.com/gofiber/fiber/v3/middleware/cors"
|
||||
"github.com/gofiber/fiber/v3/middleware/helmet"
|
||||
"github.com/gofiber/fiber/v3/middleware/limiter"
|
||||
"github.com/gofiber/fiber/v3/middleware/logger"
|
||||
"github.com/gofiber/fiber/v3/middleware/recover"
|
||||
"github.com/gofiber/fiber/v3/middleware/requestid"
|
||||
"github.com/samber/lo"
|
||||
|
||||
"quyun/v2/app/errorx"
|
||||
)
|
||||
|
||||
func DefaultProvider() container.ProviderContainer {
|
||||
@@ -137,8 +141,51 @@ func Provide(opts ...opt.Option) error {
|
||||
TimeZone: "Asia/Shanghai",
|
||||
}))
|
||||
|
||||
// rate limit (enable standard headers; adjust Max via config if needed)
|
||||
// engine.Use(limiter.New(limiter.Config{Max: 0}))
|
||||
// rate limit (by tenant code or IP)
|
||||
if config.RateLimit != nil && config.RateLimit.Enabled {
|
||||
max := config.RateLimit.Max
|
||||
if max <= 0 {
|
||||
max = 120
|
||||
}
|
||||
windowSeconds := config.RateLimit.WindowSeconds
|
||||
if windowSeconds <= 0 {
|
||||
windowSeconds = 60
|
||||
}
|
||||
message := strings.TrimSpace(config.RateLimit.Message)
|
||||
|
||||
skipPrefixes := append([]string{"/healthz", "/readyz"}, config.RateLimit.SkipPaths...)
|
||||
engine.Use(limiter.New(limiter.Config{
|
||||
Max: max,
|
||||
Expiration: time.Duration(windowSeconds) * time.Second,
|
||||
LimitReached: func(c fiber.Ctx) error {
|
||||
appErr := errorx.ErrRateLimitExceeded
|
||||
if message != "" {
|
||||
appErr = appErr.WithMsg(message)
|
||||
}
|
||||
return errorx.SendError(c, appErr)
|
||||
},
|
||||
Next: func(c fiber.Ctx) bool {
|
||||
path := c.Path()
|
||||
for _, prefix := range skipPrefixes {
|
||||
if prefix == "" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(path, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
},
|
||||
KeyGenerator: func(c fiber.Ctx) string {
|
||||
if strings.HasPrefix(c.Path(), "/t/") {
|
||||
if tenantCode := strings.TrimSpace(c.Params("tenantCode")); tenantCode != "" {
|
||||
return "tenant:" + tenantCode
|
||||
}
|
||||
}
|
||||
return c.IP()
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
// static files (Fiber v3 Static helper moved; enable via filesystem middleware later)
|
||||
// if config.StaticRoute != nil && config.StaticPath != nil { ... }
|
||||
|
||||
Reference in New Issue
Block a user