334 lines
8.6 KiB
Go
334 lines
8.6 KiB
Go
package http
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"runtime/debug"
|
|
"strings"
|
|
"time"
|
|
|
|
"quyun/v2/app/errorx"
|
|
"quyun/v2/providers/storage"
|
|
|
|
logrus "github.com/sirupsen/logrus"
|
|
"go.ipao.vip/atom/container"
|
|
"go.ipao.vip/atom/opt"
|
|
"go.uber.org/dig"
|
|
|
|
"github.com/gofiber/fiber/v3"
|
|
"github.com/gofiber/fiber/v3/middleware/compress"
|
|
"github.com/gofiber/fiber/v3/middleware/cors"
|
|
"github.com/gofiber/fiber/v3/middleware/helmet"
|
|
"github.com/gofiber/fiber/v3/middleware/limiter"
|
|
"github.com/gofiber/fiber/v3/middleware/logger"
|
|
"github.com/gofiber/fiber/v3/middleware/recover"
|
|
"github.com/gofiber/fiber/v3/middleware/requestid"
|
|
"github.com/samber/lo"
|
|
)
|
|
|
|
func DefaultProvider() container.ProviderContainer {
|
|
return container.ProviderContainer{
|
|
Provider: Provide,
|
|
Options: []opt.Option{
|
|
opt.Prefix(DefaultPrefix),
|
|
},
|
|
}
|
|
}
|
|
|
|
type Service struct {
|
|
conf *Config
|
|
Engine *fiber.App
|
|
healthCheck func(context.Context) error
|
|
readyCheck func(context.Context) error
|
|
}
|
|
|
|
var errTLSCertKeyRequired = errors.New("tls cert and key must be set")
|
|
|
|
func (svc *Service) listenerConfig() fiber.ListenConfig {
|
|
listenConfig := fiber.ListenConfig{
|
|
EnablePrintRoutes: true,
|
|
// DisableStartupMessage: true,
|
|
}
|
|
|
|
if svc.conf.TLS != nil {
|
|
if svc.conf.TLS.Cert == "" || svc.conf.TLS.Key == "" {
|
|
panic(errTLSCertKeyRequired)
|
|
}
|
|
listenConfig.CertFile = svc.conf.TLS.Cert
|
|
listenConfig.CertKeyFile = svc.conf.TLS.Key
|
|
}
|
|
container.AddCloseAble(func() {
|
|
svc.Engine.ShutdownWithTimeout(time.Second * 10)
|
|
})
|
|
|
|
return listenConfig
|
|
}
|
|
|
|
func (svc *Service) Listener(ln net.Listener) error {
|
|
return svc.Engine.Listener(ln, svc.listenerConfig())
|
|
}
|
|
|
|
func (svc *Service) Serve(ctx context.Context) error {
|
|
ln, err := net.Listen("tcp4", svc.conf.Address())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Run the server in a goroutine so we can listen for context cancellation
|
|
serverErr := make(chan error, 1)
|
|
go func() {
|
|
serverErr <- svc.Engine.Listener(ln, svc.listenerConfig())
|
|
}()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
// Shutdown the server gracefully
|
|
if shutdownErr := svc.Engine.Shutdown(); shutdownErr != nil {
|
|
return shutdownErr
|
|
}
|
|
// treat context cancellation as graceful shutdown
|
|
return nil
|
|
case err := <-serverErr:
|
|
return err
|
|
}
|
|
}
|
|
|
|
func Provide(opts ...opt.Option) error {
|
|
o := opt.New(opts...)
|
|
var config Config
|
|
if err := o.UnmarshalConfig(&config); err != nil {
|
|
return err
|
|
}
|
|
|
|
return container.Container.Provide(func(params struct {
|
|
dig.In
|
|
DB *sql.DB `optional:"true"`
|
|
Storage *storage.Storage `optional:"true"`
|
|
}) (*Service, error) {
|
|
engine := fiber.New(fiber.Config{
|
|
StrictRouting: true,
|
|
CaseSensitive: true,
|
|
BodyLimit: 10 * 1024 * 1024, // 10 MiB
|
|
ReadTimeout: 10 * time.Second,
|
|
WriteTimeout: 10 * time.Second,
|
|
IdleTimeout: 60 * time.Second,
|
|
ProxyHeader: fiber.HeaderXForwardedFor,
|
|
EnableIPValidation: true,
|
|
})
|
|
|
|
engine.Use(requestid.New())
|
|
|
|
engine.Use(recover.New(recover.Config{
|
|
EnableStackTrace: true,
|
|
StackTraceHandler: func(c fiber.Ctx, e any) {
|
|
rid := c.Get(fiber.HeaderXRequestID)
|
|
logrus.WithField("request_id", rid).Error(fmt.Sprintf("panic: %v\n%s\n", e, debug.Stack()))
|
|
},
|
|
}))
|
|
|
|
// basic security + compression
|
|
engine.Use(helmet.New())
|
|
engine.Use(compress.New(compress.Config{Level: compress.LevelDefault}))
|
|
|
|
// optional CORS based on config
|
|
if config.Cors != nil {
|
|
corsCfg := buildCORSConfig(config.Cors)
|
|
if corsCfg != nil {
|
|
engine.Use(cors.New(*corsCfg))
|
|
}
|
|
}
|
|
|
|
engine.Use(logger.New(logger.Config{
|
|
Format: `${time} [${ip}] ${method} ${status} ${path} ${latency} rid=${locals:requestid} "${ua}"\n`,
|
|
TimeFormat: time.RFC3339,
|
|
TimeZone: "Asia/Shanghai",
|
|
}))
|
|
|
|
// rate limit (by tenant code or IP)
|
|
if config.RateLimit != nil && config.RateLimit.Enabled {
|
|
limitMax := config.RateLimit.Max
|
|
if limitMax <= 0 {
|
|
limitMax = 120
|
|
}
|
|
windowSeconds := config.RateLimit.WindowSeconds
|
|
if windowSeconds <= 0 {
|
|
windowSeconds = 60
|
|
}
|
|
message := strings.TrimSpace(config.RateLimit.Message)
|
|
|
|
var limiterStorage fiber.Storage
|
|
if config.RateLimit.Redis != nil {
|
|
storage, err := newRedisLimiterStorage(config.RateLimit.Redis)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
limiterStorage = storage
|
|
container.AddCloseAble(func() { _ = storage.Close() })
|
|
}
|
|
|
|
skipPrefixes := append([]string{"/healthz", "/readyz"}, config.RateLimit.SkipPaths...)
|
|
engine.Use(limiter.New(limiter.Config{
|
|
Max: limitMax,
|
|
Expiration: time.Duration(windowSeconds) * time.Second,
|
|
Storage: limiterStorage,
|
|
LimitReached: func(c fiber.Ctx) error {
|
|
appErr := errorx.ErrRateLimitExceeded
|
|
if message != "" {
|
|
appErr = appErr.WithMsg(message)
|
|
}
|
|
|
|
return errorx.SendError(c, appErr)
|
|
},
|
|
Next: func(requestCtx fiber.Ctx) bool {
|
|
path := requestCtx.Path()
|
|
for _, prefix := range skipPrefixes {
|
|
if prefix == "" {
|
|
continue
|
|
}
|
|
if strings.HasPrefix(path, prefix) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
},
|
|
KeyGenerator: func(requestCtx fiber.Ctx) string {
|
|
if strings.HasPrefix(requestCtx.Path(), "/t/") {
|
|
if tenantCode := strings.TrimSpace(requestCtx.Params("tenantCode")); tenantCode != "" {
|
|
return "tenant:" + tenantCode
|
|
}
|
|
}
|
|
|
|
return requestCtx.IP()
|
|
},
|
|
}))
|
|
}
|
|
|
|
service := &Service{
|
|
Engine: engine,
|
|
conf: &config,
|
|
}
|
|
service.healthCheck = service.buildHealthCheck()
|
|
service.readyCheck = service.buildReadyCheck(params.DB, params.Storage)
|
|
engine.Get("/healthz", service.handleHealthz)
|
|
engine.Get("/readyz", service.handleReadyz)
|
|
|
|
engine.Hooks().OnPostShutdown(func(err error) error {
|
|
if err != nil {
|
|
logrus.Error("http server shutdown error: ", err)
|
|
}
|
|
logrus.Info("http server has shutdown success")
|
|
|
|
return nil
|
|
})
|
|
|
|
return service, nil
|
|
}, o.DiOptions()...)
|
|
}
|
|
|
|
// buildCORSConfig converts provider Cors config into fiber cors.Config
|
|
func (svc *Service) buildHealthCheck() func(context.Context) error {
|
|
return func(_ context.Context) error {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (svc *Service) buildReadyCheck(db *sql.DB, store *storage.Storage) func(context.Context) error {
|
|
var dbPing func(context.Context) error
|
|
if db != nil {
|
|
dbPing = func(ctx context.Context) error {
|
|
pingCtx, cancel := context.WithTimeout(ctx, 1500*time.Millisecond)
|
|
defer cancel()
|
|
|
|
return db.PingContext(pingCtx)
|
|
}
|
|
}
|
|
|
|
return newReadyCheck(dbPing, store)
|
|
}
|
|
|
|
func newReadyCheck(dbPing func(context.Context) error, store *storage.Storage) func(context.Context) error {
|
|
return func(ctx context.Context) error {
|
|
if dbPing != nil {
|
|
if err := dbPing(ctx); err != nil {
|
|
return errorx.ErrServiceUnavailable.WithCause(err).WithMsg("database not ready")
|
|
}
|
|
}
|
|
if store != nil && store.Config != nil && strings.EqualFold(strings.TrimSpace(store.Config.Type), "s3") && store.Config.CheckOnBoot {
|
|
if strings.TrimSpace(store.Config.Endpoint) == "" || strings.TrimSpace(store.Config.Bucket) == "" {
|
|
return errorx.ErrServiceUnavailable.WithMsg("storage not ready")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (svc *Service) handleHealthz(c fiber.Ctx) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
if svc.healthCheck != nil {
|
|
if err := svc.healthCheck(ctx); err != nil {
|
|
return errorx.SendError(c, err)
|
|
}
|
|
}
|
|
|
|
return c.SendStatus(fiber.StatusNoContent)
|
|
}
|
|
|
|
func (svc *Service) handleReadyz(c fiber.Ctx) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
if svc.readyCheck != nil {
|
|
if err := svc.readyCheck(ctx); err != nil {
|
|
return errorx.SendError(c, err)
|
|
}
|
|
}
|
|
|
|
return c.SendStatus(fiber.StatusNoContent)
|
|
}
|
|
|
|
func buildCORSConfig(c *Cors) *cors.Config {
|
|
if c == nil {
|
|
return nil
|
|
}
|
|
if c.Mode == "disabled" {
|
|
return nil
|
|
}
|
|
var (
|
|
origins []string
|
|
headers []string
|
|
methods []string
|
|
exposes []string
|
|
allowCreds bool
|
|
)
|
|
for _, whitelistItem := range c.Whitelist {
|
|
if whitelistItem.AllowOrigin != "" {
|
|
origins = append(origins, whitelistItem.AllowOrigin)
|
|
}
|
|
if whitelistItem.AllowHeaders != "" {
|
|
headers = append(headers, whitelistItem.AllowHeaders)
|
|
}
|
|
if whitelistItem.AllowMethods != "" {
|
|
methods = append(methods, whitelistItem.AllowMethods)
|
|
}
|
|
if whitelistItem.ExposeHeaders != "" {
|
|
exposes = append(exposes, whitelistItem.ExposeHeaders)
|
|
}
|
|
allowCreds = allowCreds || whitelistItem.AllowCredentials
|
|
}
|
|
|
|
cfg := cors.Config{
|
|
AllowOrigins: lo.Uniq(origins),
|
|
AllowHeaders: lo.Uniq(headers),
|
|
AllowMethods: lo.Uniq(methods),
|
|
ExposeHeaders: lo.Uniq(exposes),
|
|
AllowCredentials: allowCreds,
|
|
}
|
|
|
|
return &cfg
|
|
}
|