diff --git a/backend/config.full.toml b/backend/config.full.toml index a782cdc..99c273e 100644 --- a/backend/config.full.toml +++ b/backend/config.full.toml @@ -45,6 +45,13 @@ Port = 8080 # HTTP 监听端口 # WindowSeconds = 60 # 窗口大小(秒) # Message = "Too Many Requests" # SkipPaths = ["/healthz", "/readyz"] +# +# [Http.RateLimit.Redis] +# Addrs = ["127.0.0.1:6379"] +# Username = "" +# Password = "" +# DB = 0 +# Prefix = "rl:" # ========================= # Connection Multiplexer (providers/cmux) # 用于同端口同时暴露 HTTP + gRPC:cmux -> 分发到 Http/Grpc diff --git a/backend/providers/http/config.go b/backend/providers/http/config.go index 8a91ed1..47eb761 100644 --- a/backend/providers/http/config.go +++ b/backend/providers/http/config.go @@ -34,6 +34,15 @@ type RateLimit struct { WindowSeconds int Message string SkipPaths []string + Redis *RateLimitRedis +} + +type RateLimitRedis struct { + Addrs []string + Username string + Password string + DB int + Prefix string } type Whitelist struct { diff --git a/backend/providers/http/engine.go b/backend/providers/http/engine.go index 9453451..749e6af 100644 --- a/backend/providers/http/engine.go +++ b/backend/providers/http/engine.go @@ -153,10 +153,21 @@ func Provide(opts ...opt.Option) error { } message := strings.TrimSpace(config.RateLimit.Message) + var limiterStorage fiber.Storage + if config.RateLimit.Redis != nil { + storage, err := newRedisLimiterStorage(config.RateLimit.Redis) + if err != nil { + return nil, err + } + limiterStorage = storage + container.AddCloseAble(func() { _ = storage.Close() }) + } + skipPrefixes := append([]string{"/healthz", "/readyz"}, config.RateLimit.SkipPaths...) engine.Use(limiter.New(limiter.Config{ Max: max, Expiration: time.Duration(windowSeconds) * time.Second, + Storage: limiterStorage, LimitReached: func(c fiber.Ctx) error { appErr := errorx.ErrRateLimitExceeded if message != "" { diff --git a/backend/providers/http/limiter_storage_redis.go b/backend/providers/http/limiter_storage_redis.go new file mode 100644 index 0000000..4d49db9 --- /dev/null +++ b/backend/providers/http/limiter_storage_redis.go @@ -0,0 +1,107 @@ +package http + +import ( + "context" + "errors" + "strings" + "time" + + "github.com/gofiber/fiber/v3" + "github.com/redis/go-redis/v9" +) + +type redisLimiterStorage struct { + client redis.UniversalClient + prefix string +} + +func newRedisLimiterStorage(config *RateLimitRedis) (fiber.Storage, error) { + if config == nil { + return nil, errors.New("rate limit redis config is nil") + } + if len(config.Addrs) == 0 { + return nil, errors.New("rate limit redis addrs is empty") + } + + client := redis.NewUniversalClient(&redis.UniversalOptions{ + Addrs: config.Addrs, + Username: config.Username, + Password: config.Password, + DB: config.DB, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + if err := client.Ping(ctx).Err(); err != nil { + _ = client.Close() + return nil, err + } + + prefix := strings.TrimSpace(config.Prefix) + return &redisLimiterStorage{ + client: client, + prefix: prefix, + }, nil +} + +func (s *redisLimiterStorage) GetWithContext(ctx context.Context, key string) ([]byte, error) { + if s == nil || key == "" { + return nil, nil + } + val, err := s.client.Get(ctx, s.key(key)).Bytes() + if errors.Is(err, redis.Nil) { + return nil, nil + } + return val, err +} + +func (s *redisLimiterStorage) Get(key string) ([]byte, error) { + return s.GetWithContext(context.Background(), key) +} + +func (s *redisLimiterStorage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { + if s == nil || key == "" || len(val) == 0 { + return nil + } + return s.client.Set(ctx, s.key(key), val, exp).Err() +} + +func (s *redisLimiterStorage) Set(key string, val []byte, exp time.Duration) error { + return s.SetWithContext(context.Background(), key, val, exp) +} + +func (s *redisLimiterStorage) DeleteWithContext(ctx context.Context, key string) error { + if s == nil || key == "" { + return nil + } + return s.client.Del(ctx, s.key(key)).Err() +} + +func (s *redisLimiterStorage) Delete(key string) error { + return s.DeleteWithContext(context.Background(), key) +} + +func (s *redisLimiterStorage) ResetWithContext(ctx context.Context) error { + if s == nil { + return nil + } + return s.client.FlushDB(ctx).Err() +} + +func (s *redisLimiterStorage) Reset() error { + return s.ResetWithContext(context.Background()) +} + +func (s *redisLimiterStorage) Close() error { + if s == nil { + return nil + } + return s.client.Close() +} + +func (s *redisLimiterStorage) key(raw string) string { + if s.prefix == "" { + return raw + } + return s.prefix + raw +}