chore: stabilize lint and verify builds
This commit is contained in:
@@ -1,18 +1,24 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"go.ipao.vip/atom/container"
|
||||
"go.ipao.vip/atom/opt"
|
||||
)
|
||||
|
||||
func Provide(opts ...opt.Option) error {
|
||||
o := opt.New(opts...)
|
||||
options := opt.New(opts...)
|
||||
var config Config
|
||||
if err := o.UnmarshalConfig(&config); err != nil {
|
||||
return err
|
||||
if err := options.UnmarshalConfig(&config); err != nil {
|
||||
return fmt.Errorf("unmarshal app config: %w", err)
|
||||
}
|
||||
|
||||
return container.Container.Provide(func() (*Config, error) {
|
||||
if err := container.Container.Provide(func() (*Config, error) {
|
||||
return &config, nil
|
||||
}, o.DiOptions()...)
|
||||
}, options.DiOptions()...); err != nil {
|
||||
return fmt.Errorf("provide app config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ func DefaultProvider() container.ProviderContainer {
|
||||
|
||||
// swagger:enum AppMode
|
||||
// ENUM(development, release, test)
|
||||
type AppMode string
|
||||
type AppMode string //nolint:revive // keep enum name stable with generated helpers
|
||||
|
||||
type Config struct {
|
||||
Mode AppMode
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"quyun/v2/providers/grpc"
|
||||
"quyun/v2/providers/http"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
logrus "github.com/sirupsen/logrus"
|
||||
"github.com/soheilhy/cmux"
|
||||
"go.ipao.vip/atom/container"
|
||||
"go.ipao.vip/atom/opt"
|
||||
@@ -35,11 +35,12 @@ func (h *Config) Address() string {
|
||||
if h.Host == nil {
|
||||
return fmt.Sprintf(":%d", h.Port)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s:%d", *h.Host, h.Port)
|
||||
}
|
||||
|
||||
type CMux struct {
|
||||
Http *http.Service
|
||||
HTTP *http.Service
|
||||
Grpc *grpc.Grpc
|
||||
Mux cmux.CMux
|
||||
Base net.Listener
|
||||
@@ -54,7 +55,7 @@ func (c *CMux) Serve() error {
|
||||
if c.Base != nil && c.Base.Addr() != nil {
|
||||
addr = c.Base.Addr().String()
|
||||
}
|
||||
log.WithFields(log.Fields{
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"addr": addr,
|
||||
}).Info("cmux starting")
|
||||
|
||||
@@ -70,24 +71,27 @@ func (c *CMux) Serve() error {
|
||||
|
||||
var eg errgroup.Group
|
||||
eg.Go(func() error {
|
||||
log.WithField("addr", addr).Info("grpc serving via cmux")
|
||||
logrus.WithField("addr", addr).Info("grpc serving via cmux")
|
||||
|
||||
err := c.Grpc.ServeWithListener(grpcL)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("grpc server exited with error")
|
||||
logrus.WithError(err).Error("grpc server exited with error")
|
||||
} else {
|
||||
log.Info("grpc server exited")
|
||||
logrus.Info("grpc server exited")
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
eg.Go(func() error {
|
||||
log.WithField("addr", addr).Info("http serving via cmux")
|
||||
err := c.Http.Listener(httpL)
|
||||
logrus.WithField("addr", addr).Info("http serving via cmux")
|
||||
err := c.HTTP.Listener(httpL)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("http server exited with error")
|
||||
logrus.WithError(err).Error("http server exited with error")
|
||||
} else {
|
||||
log.Info("http server exited")
|
||||
logrus.Info("http server exited")
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
@@ -95,15 +99,17 @@ func (c *CMux) Serve() error {
|
||||
eg.Go(func() error {
|
||||
err := c.Mux.Serve()
|
||||
if err != nil {
|
||||
log.WithError(err).Error("cmux exited with error")
|
||||
logrus.WithError(err).Error("cmux exited with error")
|
||||
} else {
|
||||
log.Info("cmux exited")
|
||||
logrus.Info("cmux exited")
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
err := eg.Wait()
|
||||
if err == nil {
|
||||
log.Info("cmux and sub-servers exited cleanly")
|
||||
logrus.Info("cmux and sub-servers exited cleanly")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -17,20 +17,21 @@ func Provide(opts ...opt.Option) error {
|
||||
if err := o.UnmarshalConfig(&config); err != nil {
|
||||
return err
|
||||
}
|
||||
return container.Container.Provide(func(http *http.Service, grpc *grpc.Grpc) (*CMux, error) {
|
||||
l, err := net.Listen("tcp", config.Address())
|
||||
|
||||
return container.Container.Provide(func(httpSvc *http.Service, grpcSvc *grpc.Grpc) (*CMux, error) {
|
||||
listener, err := net.Listen("tcp", config.Address())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mux := &CMux{
|
||||
Http: http,
|
||||
Grpc: grpc,
|
||||
Mux: cmux.New(l),
|
||||
Base: l,
|
||||
HTTP: httpSvc,
|
||||
Grpc: grpcSvc,
|
||||
Mux: cmux.New(listener),
|
||||
Base: listener,
|
||||
}
|
||||
// Ensure cmux stops accepting new connections on shutdown
|
||||
container.AddCloseAble(func() { _ = l.Close() })
|
||||
container.AddCloseAble(func() { _ = listener.Close() })
|
||||
|
||||
return mux, nil
|
||||
}, o.DiOptions()...)
|
||||
|
||||
@@ -6,7 +6,7 @@ const (
|
||||
Go contracts.Channel = "go"
|
||||
Kafka contracts.Channel = "kafka"
|
||||
Redis contracts.Channel = "redis"
|
||||
Sql contracts.Channel = "sql"
|
||||
SQL contracts.Channel = "sql"
|
||||
)
|
||||
|
||||
type DefaultPublishTo struct{}
|
||||
|
||||
@@ -2,6 +2,7 @@ package event
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/ThreeDotsLabs/watermill"
|
||||
"github.com/ThreeDotsLabs/watermill/message"
|
||||
@@ -22,12 +23,12 @@ func DefaultProvider() container.ProviderContainer {
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Sql *ConfigSql
|
||||
SQL *ConfigSQL
|
||||
Kafka *ConfigKafka
|
||||
Redis *ConfigRedis
|
||||
}
|
||||
|
||||
type ConfigSql struct {
|
||||
type ConfigSQL struct {
|
||||
ConsumerGroup string
|
||||
}
|
||||
|
||||
@@ -50,24 +51,26 @@ type PubSub struct {
|
||||
|
||||
func (ps *PubSub) Serve(ctx context.Context) error {
|
||||
if err := ps.Router.Run(ctx); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("run event router: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// publish
|
||||
func (ps *PubSub) Publish(e contracts.EventPublisher) error {
|
||||
if e == nil {
|
||||
func (ps *PubSub) Publish(event contracts.EventPublisher) error {
|
||||
if event == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
payload, err := e.Marshal()
|
||||
payload, err := event.Marshal()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("marshal event payload: %w", err)
|
||||
}
|
||||
|
||||
msg := message.NewMessage(watermill.NewUUID(), payload)
|
||||
return ps.getPublisher(e.Channel()).Publish(e.Topic(), msg)
|
||||
|
||||
return ps.getPublisher(event.Channel()).Publish(event.Topic(), msg)
|
||||
}
|
||||
|
||||
// getPublisher returns the publisher for the specified channel.
|
||||
@@ -75,6 +78,7 @@ func (ps *PubSub) getPublisher(channel contracts.Channel) message.Publisher {
|
||||
if pub, ok := ps.publishers[channel]; ok {
|
||||
return pub
|
||||
}
|
||||
|
||||
return ps.publishers[Go]
|
||||
}
|
||||
|
||||
@@ -82,6 +86,7 @@ func (ps *PubSub) getSubscriber(channel contracts.Channel) message.Subscriber {
|
||||
if sub, ok := ps.subscribers[channel]; ok {
|
||||
return sub
|
||||
}
|
||||
|
||||
return ps.subscribers[Go]
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ package event
|
||||
|
||||
import (
|
||||
"github.com/ThreeDotsLabs/watermill"
|
||||
"github.com/sirupsen/logrus"
|
||||
logrus "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// LogrusLoggerAdapter is a watermill logger adapter for logrus.
|
||||
|
||||
@@ -78,7 +78,7 @@ func ProvideChannel(opts ...opt.Option) error {
|
||||
publishers[Redis] = redisPublisher
|
||||
}
|
||||
|
||||
if config.Sql == nil {
|
||||
if config.SQL == nil {
|
||||
var db *sqlDB.DB
|
||||
sqlPublisher, err := sql.NewPublisher(db, sql.PublisherConfig{
|
||||
SchemaAdapter: sql.DefaultPostgreSQLSchema{},
|
||||
@@ -87,16 +87,16 @@ func ProvideChannel(opts ...opt.Option) error {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
publishers[Sql] = sqlPublisher
|
||||
publishers[SQL] = sqlPublisher
|
||||
|
||||
sqlSubscriber, err := sql.NewSubscriber(db, sql.SubscriberConfig{
|
||||
SchemaAdapter: sql.DefaultPostgreSQLSchema{},
|
||||
ConsumerGroup: config.Sql.ConsumerGroup,
|
||||
ConsumerGroup: config.SQL.ConsumerGroup,
|
||||
}, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subscribers[Sql] = sqlSubscriber
|
||||
subscribers[SQL] = sqlSubscriber
|
||||
}
|
||||
|
||||
router, err := message.NewRouter(message.RouterConfig{}, logger)
|
||||
|
||||
@@ -36,15 +36,16 @@ type Config struct {
|
||||
ShutdownTimeoutSeconds uint
|
||||
}
|
||||
|
||||
func (h *Config) Address() string {
|
||||
if h.Port == 0 {
|
||||
h.Port = 8081
|
||||
func (cfg *Config) Address() string {
|
||||
if cfg.Port == 0 {
|
||||
cfg.Port = 8081
|
||||
}
|
||||
|
||||
if h.Host == nil {
|
||||
return fmt.Sprintf(":%d", h.Port)
|
||||
if cfg.Host == nil {
|
||||
return fmt.Sprintf(":%d", cfg.Port)
|
||||
}
|
||||
return fmt.Sprintf("%s:%d", *h.Host, h.Port)
|
||||
|
||||
return fmt.Sprintf("%s:%d", *cfg.Host, cfg.Port)
|
||||
}
|
||||
|
||||
type Grpc struct {
|
||||
@@ -56,45 +57,45 @@ type Grpc struct {
|
||||
streamInterceptors []grpc.StreamServerInterceptor
|
||||
}
|
||||
|
||||
func (g *Grpc) Init() error {
|
||||
func (grpcServer *Grpc) Init() error {
|
||||
// merge options and build interceptor chains if provided
|
||||
var srvOpts []grpc.ServerOption
|
||||
if len(g.unaryInterceptors) > 0 {
|
||||
srvOpts = append(srvOpts, grpc.ChainUnaryInterceptor(g.unaryInterceptors...))
|
||||
if len(grpcServer.unaryInterceptors) > 0 {
|
||||
srvOpts = append(srvOpts, grpc.ChainUnaryInterceptor(grpcServer.unaryInterceptors...))
|
||||
}
|
||||
if len(g.streamInterceptors) > 0 {
|
||||
srvOpts = append(srvOpts, grpc.ChainStreamInterceptor(g.streamInterceptors...))
|
||||
if len(grpcServer.streamInterceptors) > 0 {
|
||||
srvOpts = append(srvOpts, grpc.ChainStreamInterceptor(grpcServer.streamInterceptors...))
|
||||
}
|
||||
srvOpts = append(srvOpts, g.options...)
|
||||
srvOpts = append(srvOpts, grpcServer.options...)
|
||||
|
||||
g.Server = grpc.NewServer(srvOpts...)
|
||||
grpcServer.Server = grpc.NewServer(srvOpts...)
|
||||
|
||||
// optional reflection and health
|
||||
if g.config.EnableReflection != nil && *g.config.EnableReflection {
|
||||
reflection.Register(g.Server)
|
||||
if grpcServer.config.EnableReflection != nil && *grpcServer.config.EnableReflection {
|
||||
reflection.Register(grpcServer.Server)
|
||||
}
|
||||
if g.config.EnableHealth != nil && *g.config.EnableHealth {
|
||||
if grpcServer.config.EnableHealth != nil && *grpcServer.config.EnableHealth {
|
||||
hs := health.NewServer()
|
||||
grpc_health_v1.RegisterHealthServer(g.Server, hs)
|
||||
grpc_health_v1.RegisterHealthServer(grpcServer.Server, hs)
|
||||
}
|
||||
|
||||
// graceful stop with timeout fallback to Stop()
|
||||
container.AddCloseAble(func() {
|
||||
timeout := g.config.ShutdownTimeoutSeconds
|
||||
timeout := grpcServer.config.ShutdownTimeoutSeconds
|
||||
if timeout == 0 {
|
||||
timeout = 10
|
||||
}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
g.Server.GracefulStop()
|
||||
grpcServer.Server.GracefulStop()
|
||||
close(done)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
// graceful stop finished
|
||||
case <-time.After(time.Duration(timeout) * time.Second):
|
||||
case <-time.After(time.Duration(int64(timeout)) * time.Second):
|
||||
// timeout, force stop
|
||||
g.Server.Stop()
|
||||
grpcServer.Server.Stop()
|
||||
}
|
||||
})
|
||||
|
||||
@@ -102,44 +103,52 @@ func (g *Grpc) Init() error {
|
||||
}
|
||||
|
||||
// Serve
|
||||
func (g *Grpc) Serve() error {
|
||||
if g.Server == nil {
|
||||
if err := g.Init(); err != nil {
|
||||
func (grpcServer *Grpc) Serve() error {
|
||||
if grpcServer.Server == nil {
|
||||
if err := grpcServer.Init(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
l, err := net.Listen("tcp", g.config.Address())
|
||||
listener, err := net.Listen("tcp", grpcServer.config.Address())
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("listen grpc address: %w", err)
|
||||
}
|
||||
|
||||
return g.Server.Serve(l)
|
||||
if err := grpcServer.Server.Serve(listener); err != nil {
|
||||
return fmt.Errorf("serve grpc listener: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Grpc) ServeWithListener(ln net.Listener) error {
|
||||
return g.Server.Serve(ln)
|
||||
func (grpcServer *Grpc) ServeWithListener(listener net.Listener) error {
|
||||
if err := grpcServer.Server.Serve(listener); err != nil {
|
||||
return fmt.Errorf("serve grpc with listener: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UseOptions appends gRPC ServerOptions to be applied when constructing the server.
|
||||
func (g *Grpc) UseOptions(opts ...grpc.ServerOption) {
|
||||
g.options = append(g.options, opts...)
|
||||
func (grpcServer *Grpc) UseOptions(opts ...grpc.ServerOption) {
|
||||
grpcServer.options = append(grpcServer.options, opts...)
|
||||
}
|
||||
|
||||
// UseUnaryInterceptors appends unary interceptors to be chained.
|
||||
func (g *Grpc) UseUnaryInterceptors(inters ...grpc.UnaryServerInterceptor) {
|
||||
g.unaryInterceptors = append(g.unaryInterceptors, inters...)
|
||||
func (grpcServer *Grpc) UseUnaryInterceptors(inters ...grpc.UnaryServerInterceptor) {
|
||||
grpcServer.unaryInterceptors = append(grpcServer.unaryInterceptors, inters...)
|
||||
}
|
||||
|
||||
// UseStreamInterceptors appends stream interceptors to be chained.
|
||||
func (g *Grpc) UseStreamInterceptors(inters ...grpc.StreamServerInterceptor) {
|
||||
g.streamInterceptors = append(g.streamInterceptors, inters...)
|
||||
func (grpcServer *Grpc) UseStreamInterceptors(inters ...grpc.StreamServerInterceptor) {
|
||||
grpcServer.streamInterceptors = append(grpcServer.streamInterceptors, inters...)
|
||||
}
|
||||
|
||||
// Reset clears all configured options and interceptors.
|
||||
// Useful in tests to ensure isolation.
|
||||
func (g *Grpc) Reset() {
|
||||
g.options = nil
|
||||
g.unaryInterceptors = nil
|
||||
g.streamInterceptors = nil
|
||||
func (grpcServer *Grpc) Reset() {
|
||||
grpcServer.options = nil
|
||||
grpcServer.unaryInterceptors = nil
|
||||
grpcServer.streamInterceptors = nil
|
||||
}
|
||||
|
||||
@@ -1,18 +1,24 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"go.ipao.vip/atom/container"
|
||||
"go.ipao.vip/atom/opt"
|
||||
)
|
||||
|
||||
func Provide(opts ...opt.Option) error {
|
||||
o := opt.New(opts...)
|
||||
options := opt.New(opts...)
|
||||
var config Config
|
||||
if err := o.UnmarshalConfig(&config); err != nil {
|
||||
return err
|
||||
if err := options.UnmarshalConfig(&config); err != nil {
|
||||
return fmt.Errorf("unmarshal grpc config: %w", err)
|
||||
}
|
||||
|
||||
return container.Container.Provide(func() (*Grpc, error) {
|
||||
if err := container.Container.Provide(func() (*Grpc, error) {
|
||||
return &Grpc{config: &config}, nil
|
||||
}, o.DiOptions()...)
|
||||
}, options.DiOptions()...); err != nil {
|
||||
return fmt.Errorf("provide grpc: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -13,12 +13,12 @@ type Config struct {
|
||||
StaticPath *string
|
||||
StaticRoute *string
|
||||
BaseURI *string
|
||||
Tls *Tls
|
||||
TLS *TLS
|
||||
Cors *Cors
|
||||
RateLimit *RateLimit
|
||||
}
|
||||
|
||||
type Tls struct {
|
||||
type TLS struct {
|
||||
Cert string
|
||||
Key string
|
||||
}
|
||||
@@ -57,5 +57,6 @@ func (h *Config) Address() string {
|
||||
if h.Host == "" {
|
||||
return fmt.Sprintf("0.0.0.0:%d", h.Port)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s:%d", h.Host, h.Port)
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
logrus "github.com/sirupsen/logrus"
|
||||
"go.ipao.vip/atom/container"
|
||||
"go.ipao.vip/atom/opt"
|
||||
|
||||
@@ -40,22 +40,25 @@ type Service struct {
|
||||
Engine *fiber.App
|
||||
}
|
||||
|
||||
var errTLSCertKeyRequired = errors.New("tls cert and key must be set")
|
||||
|
||||
func (svc *Service) listenerConfig() fiber.ListenConfig {
|
||||
listenConfig := fiber.ListenConfig{
|
||||
EnablePrintRoutes: true,
|
||||
// DisableStartupMessage: true,
|
||||
}
|
||||
|
||||
if svc.conf.Tls != nil {
|
||||
if svc.conf.Tls.Cert == "" || svc.conf.Tls.Key == "" {
|
||||
panic(errors.New("tls cert and key must be set"))
|
||||
if svc.conf.TLS != nil {
|
||||
if svc.conf.TLS.Cert == "" || svc.conf.TLS.Key == "" {
|
||||
panic(errTLSCertKeyRequired)
|
||||
}
|
||||
listenConfig.CertFile = svc.conf.Tls.Cert
|
||||
listenConfig.CertKeyFile = svc.conf.Tls.Key
|
||||
listenConfig.CertFile = svc.conf.TLS.Cert
|
||||
listenConfig.CertKeyFile = svc.conf.TLS.Key
|
||||
}
|
||||
container.AddCloseAble(func() {
|
||||
svc.Engine.ShutdownWithTimeout(time.Second * 10)
|
||||
})
|
||||
|
||||
return listenConfig
|
||||
}
|
||||
|
||||
@@ -64,8 +67,6 @@ func (svc *Service) Listener(ln net.Listener) error {
|
||||
}
|
||||
|
||||
func (svc *Service) Serve(ctx context.Context) error {
|
||||
// log.WithField("http_address", svc.conf.Address()).Info("http config address")
|
||||
|
||||
ln, err := net.Listen("tcp4", svc.conf.Address())
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -109,15 +110,13 @@ func Provide(opts ...opt.Option) error {
|
||||
EnableIPValidation: true,
|
||||
})
|
||||
|
||||
// request id first for correlation
|
||||
engine.Use(requestid.New())
|
||||
|
||||
// recover with stack + request id
|
||||
engine.Use(recover.New(recover.Config{
|
||||
EnableStackTrace: true,
|
||||
StackTraceHandler: func(c fiber.Ctx, e any) {
|
||||
rid := c.Get(fiber.HeaderXRequestID)
|
||||
log.WithField("request_id", rid).Error(fmt.Sprintf("panic: %v\n%s\n", e, debug.Stack()))
|
||||
logrus.WithField("request_id", rid).Error(fmt.Sprintf("panic: %v\n%s\n", e, debug.Stack()))
|
||||
},
|
||||
}))
|
||||
|
||||
@@ -133,9 +132,7 @@ func Provide(opts ...opt.Option) error {
|
||||
}
|
||||
}
|
||||
|
||||
// logging with request id and latency
|
||||
engine.Use(logger.New(logger.Config{
|
||||
// requestid middleware stores ctx.Locals("requestid")
|
||||
Format: `${time} [${ip}] ${method} ${status} ${path} ${latency} rid=${locals:requestid} "${ua}"\n`,
|
||||
TimeFormat: time.RFC3339,
|
||||
TimeZone: "Asia/Shanghai",
|
||||
@@ -143,9 +140,9 @@ func Provide(opts ...opt.Option) error {
|
||||
|
||||
// rate limit (by tenant code or IP)
|
||||
if config.RateLimit != nil && config.RateLimit.Enabled {
|
||||
max := config.RateLimit.Max
|
||||
if max <= 0 {
|
||||
max = 120
|
||||
limitMax := config.RateLimit.Max
|
||||
if limitMax <= 0 {
|
||||
limitMax = 120
|
||||
}
|
||||
windowSeconds := config.RateLimit.WindowSeconds
|
||||
if windowSeconds <= 0 {
|
||||
@@ -165,7 +162,7 @@ func Provide(opts ...opt.Option) error {
|
||||
|
||||
skipPrefixes := append([]string{"/healthz", "/readyz"}, config.RateLimit.SkipPaths...)
|
||||
engine.Use(limiter.New(limiter.Config{
|
||||
Max: max,
|
||||
Max: limitMax,
|
||||
Expiration: time.Duration(windowSeconds) * time.Second,
|
||||
Storage: limiterStorage,
|
||||
LimitReached: func(c fiber.Ctx) error {
|
||||
@@ -173,10 +170,11 @@ func Provide(opts ...opt.Option) error {
|
||||
if message != "" {
|
||||
appErr = appErr.WithMsg(message)
|
||||
}
|
||||
|
||||
return errorx.SendError(c, appErr)
|
||||
},
|
||||
Next: func(c fiber.Ctx) bool {
|
||||
path := c.Path()
|
||||
Next: func(requestCtx fiber.Ctx) bool {
|
||||
path := requestCtx.Path()
|
||||
for _, prefix := range skipPrefixes {
|
||||
if prefix == "" {
|
||||
continue
|
||||
@@ -185,31 +183,30 @@ func Provide(opts ...opt.Option) error {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
},
|
||||
KeyGenerator: func(c fiber.Ctx) string {
|
||||
if strings.HasPrefix(c.Path(), "/t/") {
|
||||
if tenantCode := strings.TrimSpace(c.Params("tenantCode")); tenantCode != "" {
|
||||
KeyGenerator: func(requestCtx fiber.Ctx) string {
|
||||
if strings.HasPrefix(requestCtx.Path(), "/t/") {
|
||||
if tenantCode := strings.TrimSpace(requestCtx.Params("tenantCode")); tenantCode != "" {
|
||||
return "tenant:" + tenantCode
|
||||
}
|
||||
}
|
||||
return c.IP()
|
||||
|
||||
return requestCtx.IP()
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
// static files (Fiber v3 Static helper moved; enable via filesystem middleware later)
|
||||
// if config.StaticRoute != nil && config.StaticPath != nil { ... }
|
||||
|
||||
// health endpoints
|
||||
engine.Get("/healthz", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusNoContent) })
|
||||
engine.Get("/readyz", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusNoContent) })
|
||||
|
||||
engine.Hooks().OnPostShutdown(func(err error) error {
|
||||
if err != nil {
|
||||
log.Error("http server shutdown error: ", err)
|
||||
logrus.Error("http server shutdown error: ", err)
|
||||
}
|
||||
log.Info("http server has shutdown success")
|
||||
logrus.Info("http server has shutdown success")
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -235,20 +232,20 @@ func buildCORSConfig(c *Cors) *cors.Config {
|
||||
exposes []string
|
||||
allowCreds bool
|
||||
)
|
||||
for _, w := range c.Whitelist {
|
||||
if w.AllowOrigin != "" {
|
||||
origins = append(origins, w.AllowOrigin)
|
||||
for _, whitelistItem := range c.Whitelist {
|
||||
if whitelistItem.AllowOrigin != "" {
|
||||
origins = append(origins, whitelistItem.AllowOrigin)
|
||||
}
|
||||
if w.AllowHeaders != "" {
|
||||
headers = append(headers, w.AllowHeaders)
|
||||
if whitelistItem.AllowHeaders != "" {
|
||||
headers = append(headers, whitelistItem.AllowHeaders)
|
||||
}
|
||||
if w.AllowMethods != "" {
|
||||
methods = append(methods, w.AllowMethods)
|
||||
if whitelistItem.AllowMethods != "" {
|
||||
methods = append(methods, whitelistItem.AllowMethods)
|
||||
}
|
||||
if w.ExposeHeaders != "" {
|
||||
exposes = append(exposes, w.ExposeHeaders)
|
||||
if whitelistItem.ExposeHeaders != "" {
|
||||
exposes = append(exposes, whitelistItem.ExposeHeaders)
|
||||
}
|
||||
allowCreds = allowCreds || w.AllowCredentials
|
||||
allowCreds = allowCreds || whitelistItem.AllowCredentials
|
||||
}
|
||||
|
||||
cfg := cors.Config{
|
||||
@@ -258,5 +255,6 @@ func buildCORSConfig(c *Cors) *cors.Config {
|
||||
ExposeHeaders: lo.Uniq(exposes),
|
||||
AllowCredentials: allowCreds,
|
||||
}
|
||||
|
||||
return &cfg
|
||||
}
|
||||
|
||||
@@ -15,12 +15,17 @@ type redisLimiterStorage struct {
|
||||
prefix string
|
||||
}
|
||||
|
||||
var (
|
||||
errRateLimitRedisConfigNil = errors.New("rate limit redis config is nil")
|
||||
errRateLimitRedisAddrsEmpty = errors.New("rate limit redis addrs is empty")
|
||||
)
|
||||
|
||||
func newRedisLimiterStorage(config *RateLimitRedis) (fiber.Storage, error) {
|
||||
if config == nil {
|
||||
return nil, errors.New("rate limit redis config is nil")
|
||||
return nil, errRateLimitRedisConfigNil
|
||||
}
|
||||
if len(config.Addrs) == 0 {
|
||||
return nil, errors.New("rate limit redis addrs is empty")
|
||||
return nil, errRateLimitRedisAddrsEmpty
|
||||
}
|
||||
|
||||
client := redis.NewUniversalClient(&redis.UniversalOptions{
|
||||
@@ -34,10 +39,12 @@ func newRedisLimiterStorage(config *RateLimitRedis) (fiber.Storage, error) {
|
||||
defer cancel()
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
_ = client.Close()
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
prefix := strings.TrimSpace(config.Prefix)
|
||||
|
||||
return &redisLimiterStorage{
|
||||
client: client,
|
||||
prefix: prefix,
|
||||
@@ -52,6 +59,7 @@ func (s *redisLimiterStorage) GetWithContext(ctx context.Context, key string) ([
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return val, err
|
||||
}
|
||||
|
||||
@@ -63,6 +71,7 @@ func (s *redisLimiterStorage) SetWithContext(ctx context.Context, key string, va
|
||||
if s == nil || key == "" || len(val) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.client.Set(ctx, s.key(key), val, exp).Err()
|
||||
}
|
||||
|
||||
@@ -74,6 +83,7 @@ func (s *redisLimiterStorage) DeleteWithContext(ctx context.Context, key string)
|
||||
if s == nil || key == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.client.Del(ctx, s.key(key)).Err()
|
||||
}
|
||||
|
||||
@@ -85,6 +95,7 @@ func (s *redisLimiterStorage) ResetWithContext(ctx context.Context) error {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.client.FlushDB(ctx).Err()
|
||||
}
|
||||
|
||||
@@ -96,6 +107,7 @@ func (s *redisLimiterStorage) Close() error {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.client.Close()
|
||||
}
|
||||
|
||||
@@ -103,5 +115,6 @@ func (s *redisLimiterStorage) key(raw string) string {
|
||||
if s.prefix == "" {
|
||||
return raw
|
||||
}
|
||||
|
||||
return s.prefix + raw
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ type Config struct {
|
||||
|
||||
// Controls the display of operationId in operations list.
|
||||
// default: false
|
||||
DisplayOperationId bool `json:"displayOperationId,omitempty"`
|
||||
DisplayOperationID bool `json:"displayOperationId,omitempty"`
|
||||
|
||||
// The default expansion depth for models (set to -1 completely hide the models).
|
||||
// default: 1
|
||||
@@ -109,7 +109,7 @@ type Config struct {
|
||||
|
||||
// OAuth redirect URL.
|
||||
// default: ""
|
||||
OAuth2RedirectUrl string `json:"oauth2RedirectUrl,omitempty"`
|
||||
OAuth2RedirectURL string `json:"oauth2RedirectUrl,omitempty"`
|
||||
|
||||
// MUST be a function. Function to intercept remote definition, "Try it out", and OAuth 2.0 requests.
|
||||
// Accepts one argument requestInterceptor(request) and must return the modified request, or a Promise that resolves to the modified request.
|
||||
@@ -141,7 +141,7 @@ type Config struct {
|
||||
// For example for locally deployed validators (https://github.com/swagger-api/validator-badge).
|
||||
// Setting it to either none, 127.0.0.1 or localhost will disable validation.
|
||||
// default: ""
|
||||
ValidatorUrl string `json:"validatorUrl,omitempty"`
|
||||
ValidatorURL string `json:"validatorUrl,omitempty"`
|
||||
|
||||
// If set to true, enables passing credentials, as defined in the Fetch standard, in CORS requests that are sent by the browser.
|
||||
// Note that Swagger UI cannot currently set cookies cross-domain (see https://github.com/swagger-api/swagger-js/issues/1163).
|
||||
@@ -174,7 +174,7 @@ type Config struct {
|
||||
// Programmatically set values for an API key or Bearer authorization scheme.
|
||||
// In case of OpenAPI 3.0 Bearer scheme, apiKeyValue must contain just the token itself without the Bearer prefix.
|
||||
// default: ""
|
||||
PreauthorizeApiKey template.JS `json:"-"`
|
||||
PreauthorizeAPIKey template.JS `json:"-"`
|
||||
|
||||
// Applies custom CSS styles.
|
||||
// default: ""
|
||||
@@ -194,6 +194,7 @@ func (fc FilterConfig) Value() interface{} {
|
||||
if fc.Expression != "" {
|
||||
return fc.Expression
|
||||
}
|
||||
|
||||
return fc.Enabled
|
||||
}
|
||||
|
||||
@@ -211,13 +212,14 @@ func (shc SyntaxHighlightConfig) Value() interface{} {
|
||||
if shc.Activate {
|
||||
return shc
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
type OAuthConfig struct {
|
||||
// ID of the client sent to the OAuth2 provider.
|
||||
// default: ""
|
||||
ClientId string `json:"clientId,omitempty"`
|
||||
ClientID string `json:"clientId,omitempty"`
|
||||
|
||||
// Never use this parameter in your production environment.
|
||||
// It exposes cruicial security information. This feature is intended for dev/test environments only.
|
||||
|
||||
@@ -35,13 +35,13 @@ func New(config ...Config) fiber.Handler {
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
return func(c fiber.Ctx) error {
|
||||
return func(ctx fiber.Ctx) error {
|
||||
// Set prefix
|
||||
once.Do(
|
||||
func() {
|
||||
prefix = strings.ReplaceAll(c.Route().Path, "*", "")
|
||||
prefix = strings.ReplaceAll(ctx.Route().Path, "*", "")
|
||||
|
||||
forwardedPrefix := getForwardedPrefix(c)
|
||||
forwardedPrefix := getForwardedPrefix(ctx)
|
||||
if forwardedPrefix != "" {
|
||||
prefix = forwardedPrefix + prefix
|
||||
}
|
||||
@@ -53,26 +53,28 @@ func New(config ...Config) fiber.Handler {
|
||||
},
|
||||
)
|
||||
|
||||
p := c.Path(utils.CopyString(c.Params("*")))
|
||||
p := ctx.Path(utils.CopyString(ctx.Params("*")))
|
||||
|
||||
switch p {
|
||||
case defaultIndex:
|
||||
c.Type("html")
|
||||
return index.Execute(c, cfg)
|
||||
ctx.Type("html")
|
||||
|
||||
return index.Execute(ctx, cfg)
|
||||
case defaultDocURL:
|
||||
var doc string
|
||||
if doc, err = swag.ReadDoc(cfg.InstanceName); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("read swagger doc: %w", err)
|
||||
}
|
||||
return c.Type("json").SendString(doc)
|
||||
|
||||
return ctx.Type("json").SendString(doc)
|
||||
case "", "/":
|
||||
return c.Redirect().To(path.Join(prefix, defaultIndex))
|
||||
return ctx.Redirect().To(path.Join(prefix, defaultIndex))
|
||||
default:
|
||||
// return fs(c)
|
||||
return static.New("/swagger", static.Config{
|
||||
FS: swaggerFiles.FS,
|
||||
Browse: true,
|
||||
})(c)
|
||||
})(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,8 +95,8 @@ const indexTmpl string = `
|
||||
{{if .PreauthorizeBasic}}
|
||||
ui.preauthorizeBasic({{.PreauthorizeBasic}});
|
||||
{{end}}
|
||||
{{if .PreauthorizeApiKey}}
|
||||
ui.preauthorizeApiKey({{.PreauthorizeApiKey}});
|
||||
{{if .PreauthorizeAPIKey}}
|
||||
ui.preauthorizeApiKey({{.PreauthorizeAPIKey}});
|
||||
{{end}}
|
||||
|
||||
window.ui = ui
|
||||
|
||||
@@ -45,15 +45,16 @@ const (
|
||||
)
|
||||
|
||||
// queueConfig returns a river.QueueConfig map built from QueueWorkers or defaults.
|
||||
func (c *Config) queueConfig() map[string]river.QueueConfig {
|
||||
func (config *Config) queueConfig() map[string]river.QueueConfig {
|
||||
cfg := map[string]river.QueueConfig{}
|
||||
if c == nil || len(c.QueueWorkers) == 0 {
|
||||
if config == nil || len(config.QueueWorkers) == 0 {
|
||||
cfg[QueueHigh] = river.QueueConfig{MaxWorkers: 10}
|
||||
cfg[QueueDefault] = river.QueueConfig{MaxWorkers: 10}
|
||||
cfg[QueueLow] = river.QueueConfig{MaxWorkers: 10}
|
||||
|
||||
return cfg
|
||||
}
|
||||
for name, n := range c.QueueWorkers {
|
||||
for name, n := range config.QueueWorkers {
|
||||
if n <= 0 {
|
||||
n = 1
|
||||
}
|
||||
@@ -63,5 +64,6 @@ func (c *Config) queueConfig() map[string]river.QueueConfig {
|
||||
if _, ok := cfg[QueueDefault]; !ok {
|
||||
cfg[QueueDefault] = river.QueueConfig{MaxWorkers: 10}
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"github.com/riverqueue/river/riverdriver/riverpgxv5"
|
||||
"github.com/riverqueue/river/rivertype"
|
||||
"github.com/samber/lo"
|
||||
log "github.com/sirupsen/logrus"
|
||||
logrus "github.com/sirupsen/logrus"
|
||||
"go.ipao.vip/atom/container"
|
||||
"go.ipao.vip/atom/contracts"
|
||||
"go.ipao.vip/atom/opt"
|
||||
@@ -27,6 +27,7 @@ func Provide(opts ...opt.Option) error {
|
||||
if err := o.UnmarshalConfig(&config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return container.Container.Provide(func(ctx context.Context, dbConf *postgres.Config) (*Job, error) {
|
||||
workers := river.NewWorkers()
|
||||
|
||||
@@ -79,7 +80,7 @@ func (q *Job) Close() {
|
||||
}
|
||||
|
||||
if err := q.client.StopAndCancel(q.ctx); err != nil {
|
||||
log.Errorf("Failed to stop and cancel client: %s", err)
|
||||
logrus.Errorf("Failed to stop and cancel client: %s", err)
|
||||
}
|
||||
// clear references
|
||||
q.l.Lock()
|
||||
@@ -87,22 +88,22 @@ func (q *Job) Close() {
|
||||
q.l.Unlock()
|
||||
}
|
||||
|
||||
func (q *Job) Client() (*river.Client[pgx.Tx], error) {
|
||||
q.l.Lock()
|
||||
defer q.l.Unlock()
|
||||
func (jobProvider *Job) Client() (*river.Client[pgx.Tx], error) {
|
||||
jobProvider.l.Lock()
|
||||
defer jobProvider.l.Unlock()
|
||||
|
||||
if q.client == nil {
|
||||
if jobProvider.client == nil {
|
||||
var err error
|
||||
q.client, err = river.NewClient(q.driver, &river.Config{
|
||||
Workers: q.Workers,
|
||||
Queues: q.conf.queueConfig(),
|
||||
jobProvider.client, err = river.NewClient(jobProvider.driver, &river.Config{
|
||||
Workers: jobProvider.Workers,
|
||||
Queues: jobProvider.conf.queueConfig(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return q.client, nil
|
||||
return jobProvider.client, nil
|
||||
}
|
||||
|
||||
func (q *Job) Start(ctx context.Context) error {
|
||||
@@ -136,6 +137,7 @@ func (q *Job) AddPeriodicJobs(job contracts.CronJob) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -160,19 +162,20 @@ func (q *Job) AddPeriodicJob(job contracts.CronJobArg) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *Job) Cancel(id string) error {
|
||||
client, err := q.Client()
|
||||
func (jobProvider *Job) Cancel(id string) error {
|
||||
client, err := jobProvider.Client()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q.l.Lock()
|
||||
defer q.l.Unlock()
|
||||
jobProvider.l.Lock()
|
||||
defer jobProvider.l.Unlock()
|
||||
|
||||
if h, ok := q.periodicJobs[id]; ok {
|
||||
client.PeriodicJobs().Remove(h)
|
||||
delete(q.periodicJobs, id)
|
||||
if handle, ok := jobProvider.periodicJobs[id]; ok {
|
||||
client.PeriodicJobs().Remove(handle)
|
||||
delete(jobProvider.periodicJobs, id)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -187,21 +190,23 @@ func (q *Job) CancelContext(ctx context.Context, id string) error {
|
||||
if h, ok := q.periodicJobs[id]; ok {
|
||||
client.PeriodicJobs().Remove(h)
|
||||
delete(q.periodicJobs, id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *Job) Add(job contracts.JobArgs) error {
|
||||
client, err := q.Client()
|
||||
func (jobProvider *Job) Add(job contracts.JobArgs) error {
|
||||
client, err := jobProvider.Client()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q.l.Lock()
|
||||
defer q.l.Unlock()
|
||||
jobProvider.l.Lock()
|
||||
defer jobProvider.l.Unlock()
|
||||
|
||||
_, err = client.Insert(jobProvider.ctx, job, lo.ToPtr(job.InsertOpts()))
|
||||
|
||||
_, err = client.Insert(q.ctx, job, lo.ToPtr(job.InsertOpts()))
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package jwt
|
||||
import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
logrus "github.com/sirupsen/logrus"
|
||||
|
||||
"go.ipao.vip/atom/container"
|
||||
"go.ipao.vip/atom/opt"
|
||||
@@ -29,7 +29,8 @@ type Config struct {
|
||||
func (c *Config) ExpiresTimeDuration() time.Duration {
|
||||
d, err := time.ParseDuration(c.ExpiresTime)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package jwt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -14,9 +15,11 @@ import (
|
||||
|
||||
const (
|
||||
CtxKey = "claims"
|
||||
HttpHeader = "Authorization"
|
||||
HTTPHeader = "Authorization"
|
||||
)
|
||||
|
||||
var ErrTokenInvalidType = errors.New("token cache returned non-string value")
|
||||
|
||||
type BaseClaims struct {
|
||||
OpenID string `json:"open_id,omitempty"`
|
||||
Tenant string `json:"tenant,omitempty"`
|
||||
@@ -39,80 +42,104 @@ type JWT struct {
|
||||
}
|
||||
|
||||
var (
|
||||
TokenExpired = errors.New("Token is expired")
|
||||
TokenNotValidYet = errors.New("Token not active yet")
|
||||
TokenMalformed = errors.New("That's not even a token")
|
||||
TokenInvalid = errors.New("Couldn't handle this token:")
|
||||
ErrTokenExpired = errors.New("Token is expired")
|
||||
ErrTokenNotValidYet = errors.New("Token not active yet")
|
||||
ErrTokenMalformed = errors.New("That's not even a token")
|
||||
ErrTokenInvalid = errors.New("Couldn't handle this token")
|
||||
)
|
||||
|
||||
func Provide(opts ...opt.Option) error {
|
||||
o := opt.New(opts...)
|
||||
options := opt.New(opts...)
|
||||
var config Config
|
||||
if err := o.UnmarshalConfig(&config); err != nil {
|
||||
return err
|
||||
if err := options.UnmarshalConfig(&config); err != nil {
|
||||
return fmt.Errorf("unmarshal jwt config: %w", err)
|
||||
}
|
||||
return container.Container.Provide(func() (*JWT, error) {
|
||||
if err := container.Container.Provide(func() (*JWT, error) {
|
||||
return &JWT{
|
||||
singleflight: &singleflight.Group{},
|
||||
config: &config,
|
||||
SigningKey: []byte(config.SigningKey),
|
||||
}, nil
|
||||
}, o.DiOptions()...)
|
||||
}, options.DiOptions()...); err != nil {
|
||||
return fmt.Errorf("provide jwt: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (j *JWT) CreateClaims(baseClaims BaseClaims) *Claims {
|
||||
ep, _ := time.ParseDuration(j.config.ExpiresTime)
|
||||
func (jwtProvider *JWT) CreateClaims(baseClaims BaseClaims) *Claims {
|
||||
expiresDuration, _ := time.ParseDuration(jwtProvider.config.ExpiresTime)
|
||||
|
||||
claims := Claims{
|
||||
BaseClaims: baseClaims,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
NotBefore: jwt.NewNumericDate(time.Now().Add(-time.Second * 10)), // 签名生效时间
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(ep)), // 过期时间 7天 配置文件
|
||||
Issuer: j.config.Issuer, // 签名的发行者
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiresDuration)), // 过期时间 7天 配置文件
|
||||
Issuer: jwtProvider.config.Issuer, // 签名的发行者
|
||||
},
|
||||
}
|
||||
|
||||
return &claims
|
||||
}
|
||||
|
||||
// 创建一个token
|
||||
func (j *JWT) CreateToken(claims *Claims) (string, error) {
|
||||
func (jwtProvider *JWT) CreateToken(claims *Claims) (string, error) {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString(j.SigningKey)
|
||||
|
||||
return token.SignedString(jwtProvider.SigningKey)
|
||||
}
|
||||
|
||||
// CreateTokenByOldToken 旧token 换新token 使用归并回源避免并发问题
|
||||
func (j *JWT) CreateTokenByOldToken(oldToken string, claims *Claims) (string, error) {
|
||||
v, err, _ := j.singleflight.Do("JWT:"+oldToken, func() (interface{}, error) {
|
||||
return j.CreateToken(claims)
|
||||
func (jwtProvider *JWT) CreateTokenByOldToken(oldToken string, claims *Claims) (string, error) {
|
||||
value, err, _ := jwtProvider.singleflight.Do("JWT:"+oldToken, func() (interface{}, error) {
|
||||
return jwtProvider.CreateToken(claims)
|
||||
})
|
||||
return v.(string), err
|
||||
|
||||
tokenString, ok := value.(string)
|
||||
if !ok {
|
||||
return "", ErrTokenInvalidType
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create token by old token: %w", err)
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
// 解析 token
|
||||
func (j *JWT) Parse(tokenString string) (*Claims, error) {
|
||||
func (jwtProvider *JWT) Parse(tokenString string) (*Claims, error) {
|
||||
tokenString = strings.TrimPrefix(tokenString, TokenPrefix)
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (i interface{}, e error) {
|
||||
return j.SigningKey, nil
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(_ *jwt.Token) (interface{}, error) {
|
||||
return jwtProvider.SigningKey, nil
|
||||
})
|
||||
if err != nil {
|
||||
if ve, ok := err.(*jwt.ValidationError); ok {
|
||||
if ve.Errors&jwt.ValidationErrorMalformed != 0 {
|
||||
return nil, TokenMalformed
|
||||
} else if ve.Errors&jwt.ValidationErrorExpired != 0 {
|
||||
// Token is expired
|
||||
return nil, TokenExpired
|
||||
} else if ve.Errors&jwt.ValidationErrorNotValidYet != 0 {
|
||||
return nil, TokenNotValidYet
|
||||
} else {
|
||||
return nil, TokenInvalid
|
||||
var validationErr *jwt.ValidationError
|
||||
if errors.As(err, &validationErr) {
|
||||
if validationErr.Errors&jwt.ValidationErrorMalformed != 0 {
|
||||
return nil, ErrTokenMalformed
|
||||
}
|
||||
|
||||
if validationErr.Errors&jwt.ValidationErrorExpired != 0 {
|
||||
// Token is expired
|
||||
return nil, ErrTokenExpired
|
||||
}
|
||||
|
||||
if validationErr.Errors&jwt.ValidationErrorNotValidYet != 0 {
|
||||
return nil, ErrTokenNotValidYet
|
||||
}
|
||||
|
||||
return nil, ErrTokenInvalid
|
||||
}
|
||||
}
|
||||
|
||||
if token != nil {
|
||||
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
return nil, TokenInvalid
|
||||
} else {
|
||||
return nil, TokenInvalid
|
||||
|
||||
return nil, ErrTokenInvalid
|
||||
}
|
||||
|
||||
return nil, ErrTokenInvalid
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package postgres
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -50,15 +51,19 @@ type Config struct {
|
||||
ApplicationName string // application_name
|
||||
}
|
||||
|
||||
func (m Config) GormSlowThreshold() time.Duration {
|
||||
if m.SlowThresholdMs == 0 {
|
||||
func (config Config) GormSlowThreshold() time.Duration {
|
||||
if config.SlowThresholdMs == 0 {
|
||||
return 200 * time.Millisecond // 默认200ms
|
||||
}
|
||||
return time.Duration(m.SlowThresholdMs) * time.Millisecond
|
||||
if config.SlowThresholdMs > math.MaxInt64/uint(time.Millisecond) {
|
||||
return time.Duration(math.MaxInt64)
|
||||
}
|
||||
|
||||
return time.Duration(config.SlowThresholdMs) * time.Millisecond
|
||||
}
|
||||
|
||||
func (m Config) GormLogLevel() logger.LogLevel {
|
||||
switch m.LogLevel {
|
||||
func (config Config) GormLogLevel() logger.LogLevel {
|
||||
switch config.LogLevel {
|
||||
case "silent":
|
||||
return logger.Silent
|
||||
case "error":
|
||||
@@ -72,65 +77,67 @@ func (m Config) GormLogLevel() logger.LogLevel {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Config) checkDefault() {
|
||||
if m.MaxIdleConns == 0 {
|
||||
m.MaxIdleConns = 10
|
||||
func (config *Config) checkDefault() {
|
||||
if config.MaxIdleConns == 0 {
|
||||
config.MaxIdleConns = 10
|
||||
}
|
||||
|
||||
if m.MaxOpenConns == 0 {
|
||||
m.MaxOpenConns = 100
|
||||
if config.MaxOpenConns == 0 {
|
||||
config.MaxOpenConns = 100
|
||||
}
|
||||
|
||||
if m.Username == "" {
|
||||
m.Username = "postgres"
|
||||
if config.Username == "" {
|
||||
config.Username = "postgres"
|
||||
}
|
||||
|
||||
if m.SslMode == "" {
|
||||
m.SslMode = "disable"
|
||||
if config.SslMode == "" {
|
||||
config.SslMode = "disable"
|
||||
}
|
||||
|
||||
if m.TimeZone == "" {
|
||||
m.TimeZone = "Asia/Shanghai"
|
||||
if config.TimeZone == "" {
|
||||
config.TimeZone = "Asia/Shanghai"
|
||||
}
|
||||
|
||||
if m.Port == 0 {
|
||||
m.Port = 5432
|
||||
if config.Port == 0 {
|
||||
config.Port = 5432
|
||||
}
|
||||
|
||||
if m.Schema == "" {
|
||||
m.Schema = "public"
|
||||
if config.Schema == "" {
|
||||
config.Schema = "public"
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Config) EmptyDsn() string {
|
||||
func (config *Config) EmptyDsn() string {
|
||||
// 基本 DSN
|
||||
dsnTpl := "host=%s user=%s password=%s port=%d dbname=%s sslmode=%s TimeZone=%s"
|
||||
m.checkDefault()
|
||||
base := fmt.Sprintf(dsnTpl, m.Host, m.Username, m.Password, m.Port, m.Database, m.SslMode, m.TimeZone)
|
||||
config.checkDefault()
|
||||
base := fmt.Sprintf(dsnTpl, config.Host, config.Username, config.Password, config.Port, config.Database, config.SslMode, config.TimeZone)
|
||||
// 附加可选参数
|
||||
extras := ""
|
||||
if m.UseSearchPath && m.Schema != "" {
|
||||
extras += " search_path=" + m.Schema
|
||||
if config.UseSearchPath && config.Schema != "" {
|
||||
extras += " search_path=" + config.Schema
|
||||
}
|
||||
if m.ApplicationName != "" {
|
||||
extras += " application_name=" + strconv.Quote(m.ApplicationName)
|
||||
if config.ApplicationName != "" {
|
||||
extras += " application_name=" + strconv.Quote(config.ApplicationName)
|
||||
}
|
||||
|
||||
return base + extras
|
||||
}
|
||||
|
||||
// DSN connection dsn
|
||||
func (m *Config) DSN() string {
|
||||
func (config *Config) DSN() string {
|
||||
// 基本 DSN
|
||||
dsnTpl := "host=%s user=%s password=%s dbname=%s port=%d sslmode=%s TimeZone=%s"
|
||||
m.checkDefault()
|
||||
base := fmt.Sprintf(dsnTpl, m.Host, m.Username, m.Password, m.Database, m.Port, m.SslMode, m.TimeZone)
|
||||
config.checkDefault()
|
||||
base := fmt.Sprintf(dsnTpl, config.Host, config.Username, config.Password, config.Database, config.Port, config.SslMode, config.TimeZone)
|
||||
// 附加可选参数
|
||||
extras := ""
|
||||
if m.UseSearchPath && m.Schema != "" {
|
||||
extras += " search_path=" + m.Schema
|
||||
if config.UseSearchPath && config.Schema != "" {
|
||||
extras += " search_path=" + config.Schema
|
||||
}
|
||||
if m.ApplicationName != "" {
|
||||
extras += " application_name=" + strconv.Quote(m.ApplicationName)
|
||||
if config.ApplicationName != "" {
|
||||
extras += " application_name=" + strconv.Quote(config.ApplicationName)
|
||||
}
|
||||
|
||||
return base + extras
|
||||
}
|
||||
|
||||
@@ -3,9 +3,10 @@ package postgres
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
logrus "github.com/sirupsen/logrus"
|
||||
"go.ipao.vip/atom/container"
|
||||
"go.ipao.vip/atom/opt"
|
||||
"gorm.io/driver/postgres"
|
||||
@@ -70,10 +71,18 @@ func Provide(opts ...opt.Option) error {
|
||||
sqlDB.SetMaxIdleConns(conf.MaxIdleConns)
|
||||
sqlDB.SetMaxOpenConns(conf.MaxOpenConns)
|
||||
if conf.ConnMaxLifetimeSeconds > 0 {
|
||||
sqlDB.SetConnMaxLifetime(time.Duration(conf.ConnMaxLifetimeSeconds) * time.Second)
|
||||
if conf.ConnMaxLifetimeSeconds > math.MaxInt64/uint(time.Second) {
|
||||
sqlDB.SetConnMaxLifetime(time.Duration(math.MaxInt64))
|
||||
} else {
|
||||
sqlDB.SetConnMaxLifetime(time.Duration(conf.ConnMaxLifetimeSeconds) * time.Second)
|
||||
}
|
||||
}
|
||||
if conf.ConnMaxIdleTimeSeconds > 0 {
|
||||
sqlDB.SetConnMaxIdleTime(time.Duration(conf.ConnMaxIdleTimeSeconds) * time.Second)
|
||||
if conf.ConnMaxIdleTimeSeconds > math.MaxInt64/uint(time.Second) {
|
||||
sqlDB.SetConnMaxIdleTime(time.Duration(math.MaxInt64))
|
||||
} else {
|
||||
sqlDB.SetConnMaxIdleTime(time.Duration(conf.ConnMaxIdleTimeSeconds) * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
// Ping 校验
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
@@ -22,6 +23,20 @@ import (
|
||||
|
||||
const DefaultPrefix = "Storage"
|
||||
|
||||
var (
|
||||
errStorageBucketNotFound = errors.New("storage bucket not found")
|
||||
errStorageBucketCheckFailed = errors.New("storage bucket check failed")
|
||||
errStorageEndpointRequired = errors.New("storage endpoint is required")
|
||||
errStorageAccessKeyRequired = errors.New("storage access key or secret key is required")
|
||||
errStorageBucketRequired = errors.New("storage bucket is required")
|
||||
errStorageInvalidEndpoint = errors.New("storage endpoint is invalid")
|
||||
errStorageUnsupportedMethod = errors.New("unsupported method")
|
||||
errStorageSignedURLUnsupported = errors.New("s3 storage does not use signed local urls")
|
||||
errStorageInvalidExpiry = errors.New("invalid expiry")
|
||||
errStorageExpired = errors.New("expired")
|
||||
errStorageInvalidSignature = errors.New("invalid signature")
|
||||
)
|
||||
|
||||
func DefaultProvider() container.ProviderContainer {
|
||||
return container.ProviderContainer{
|
||||
Provider: Provide,
|
||||
@@ -37,6 +52,7 @@ func Provide(opts ...opt.Option) error {
|
||||
if err := o.UnmarshalConfig(&config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return container.Container.Provide(func() (*Storage, error) {
|
||||
store := &Storage{Config: &config}
|
||||
if store.storageType() == "s3" {
|
||||
@@ -48,13 +64,14 @@ func Provide(opts ...opt.Option) error {
|
||||
// 启动时可选检查 bucket 是否可用,便于尽早暴露配置问题。
|
||||
exists, err := client.BucketExists(context.Background(), store.Config.Bucket)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("storage bucket check failed: %w", err)
|
||||
return nil, fmt.Errorf("%w: %w", errStorageBucketCheckFailed, err)
|
||||
}
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("storage bucket not found: %s", store.Config.Bucket)
|
||||
return nil, fmt.Errorf("%w: %s", errStorageBucketNotFound, store.Config.Bucket)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return store, nil
|
||||
}, o.DiOptions()...)
|
||||
}
|
||||
@@ -72,23 +89,24 @@ func (s *Storage) Download(ctx context.Context, key, filePath string) error {
|
||||
}
|
||||
srcPath := filepath.Join(localPath, key)
|
||||
if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("create download dir: %w", err)
|
||||
}
|
||||
src, err := os.Open(srcPath)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("open source file: %w", err)
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
dst, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("create destination file: %w", err)
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
if _, err := io.Copy(dst, src); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("copy file content: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -97,9 +115,13 @@ func (s *Storage) Download(ctx context.Context, key, filePath string) error {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("create download dir: %w", err)
|
||||
}
|
||||
return client.FGetObject(ctx, s.Config.Bucket, key, filePath, minio.GetObjectOptions{})
|
||||
if err := client.FGetObject(ctx, s.Config.Bucket, key, filePath, minio.GetObjectOptions{}); err != nil {
|
||||
return fmt.Errorf("download object: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Storage) Delete(key string) error {
|
||||
@@ -108,14 +130,22 @@ func (s *Storage) Delete(key string) error {
|
||||
if localPath == "" {
|
||||
localPath = "./storage"
|
||||
}
|
||||
path := filepath.Join(localPath, key)
|
||||
return os.Remove(path)
|
||||
filePath := filepath.Join(localPath, key)
|
||||
if err := os.Remove(filePath); err != nil {
|
||||
return fmt.Errorf("remove local object: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
client, err := s.s3ClientForUse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return client.RemoveObject(context.Background(), s.Config.Bucket, key, minio.RemoveObjectOptions{})
|
||||
if err := client.RemoveObject(context.Background(), s.Config.Bucket, key, minio.RemoveObjectOptions{}); err != nil {
|
||||
return fmt.Errorf("remove s3 object: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Storage) SignURL(method, key string, expires time.Duration) (string, error) {
|
||||
@@ -128,17 +158,19 @@ func (s *Storage) SignURL(method, key string, expires time.Duration) (string, er
|
||||
case "GET":
|
||||
u, err := client.PresignedGetObject(context.Background(), s.Config.Bucket, key, expires, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("presign get object: %w", err)
|
||||
}
|
||||
|
||||
return u.String(), nil
|
||||
case "PUT":
|
||||
u, err := client.PresignedPutObject(context.Background(), s.Config.Bucket, key, expires)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("presign put object: %w", err)
|
||||
}
|
||||
|
||||
return u.String(), nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported method")
|
||||
return "", errStorageUnsupportedMethod
|
||||
}
|
||||
}
|
||||
|
||||
@@ -146,13 +178,10 @@ func (s *Storage) SignURL(method, key string, expires time.Duration) (string, er
|
||||
sign := s.signature(method, key, exp)
|
||||
|
||||
baseURL := strings.TrimRight(s.Config.BaseURL, "/")
|
||||
// Ensure BaseURL doesn't end with slash if we add one
|
||||
// Simplified: assume standard /v1/storage prefix in BaseURL or append it
|
||||
// We'll append /<key>
|
||||
|
||||
u, err := url.Parse(baseURL + "/" + key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("parse base url: %w", err)
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
@@ -165,20 +194,21 @@ func (s *Storage) SignURL(method, key string, expires time.Duration) (string, er
|
||||
|
||||
func (s *Storage) Verify(method, key, expStr, sign string) error {
|
||||
if s.storageType() == "s3" {
|
||||
return fmt.Errorf("s3 storage does not use signed local urls")
|
||||
return errStorageSignedURLUnsupported
|
||||
}
|
||||
exp, err := strconv.ParseInt(expStr, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid expiry")
|
||||
return errStorageInvalidExpiry
|
||||
}
|
||||
if time.Now().Unix() > exp {
|
||||
return fmt.Errorf("expired")
|
||||
return errStorageExpired
|
||||
}
|
||||
|
||||
expected := s.signature(method, key, exp)
|
||||
if !hmac.Equal([]byte(expected), []byte(sign)) {
|
||||
return fmt.Errorf("invalid signature")
|
||||
return errStorageInvalidSignature
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -186,6 +216,7 @@ func (s *Storage) signature(method, key string, exp int64) string {
|
||||
str := fmt.Sprintf("%s\n%s\n%d", method, key, exp)
|
||||
h := hmac.New(sha256.New, []byte(s.Config.Secret))
|
||||
h.Write([]byte(str))
|
||||
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
@@ -197,9 +228,13 @@ func (s *Storage) PutObject(ctx context.Context, key, filePath, contentType stri
|
||||
}
|
||||
dstPath := filepath.Join(localPath, key)
|
||||
if err := os.MkdirAll(filepath.Dir(dstPath), 0o755); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("create object dir: %w", err)
|
||||
}
|
||||
return os.Rename(filePath, dstPath)
|
||||
if err := os.Rename(filePath, dstPath); err != nil {
|
||||
return fmt.Errorf("move object file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
client, err := s.s3ClientForUse()
|
||||
@@ -210,14 +245,18 @@ func (s *Storage) PutObject(ctx context.Context, key, filePath, contentType stri
|
||||
if contentType != "" {
|
||||
opts.ContentType = contentType
|
||||
}
|
||||
_, err = client.FPutObject(ctx, s.Config.Bucket, key, filePath, opts)
|
||||
return err
|
||||
if _, err := client.FPutObject(ctx, s.Config.Bucket, key, filePath, opts); err != nil {
|
||||
return fmt.Errorf("upload object: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Storage) Provider() string {
|
||||
if s.storageType() == "s3" {
|
||||
return "s3"
|
||||
}
|
||||
|
||||
return "local"
|
||||
}
|
||||
|
||||
@@ -225,6 +264,7 @@ func (s *Storage) Bucket() string {
|
||||
if s.storageType() == "s3" && s.Config.Bucket != "" {
|
||||
return s.Config.Bucket
|
||||
}
|
||||
|
||||
return "default"
|
||||
}
|
||||
|
||||
@@ -236,6 +276,7 @@ func (s *Storage) storageType() string {
|
||||
if typ == "" {
|
||||
return "local"
|
||||
}
|
||||
|
||||
return typ
|
||||
}
|
||||
|
||||
@@ -244,13 +285,13 @@ func (s *Storage) s3ClientForUse() (*minio.Client, error) {
|
||||
return s.s3Client, nil
|
||||
}
|
||||
if strings.TrimSpace(s.Config.Endpoint) == "" {
|
||||
return nil, fmt.Errorf("storage endpoint is required")
|
||||
return nil, errStorageEndpointRequired
|
||||
}
|
||||
if strings.TrimSpace(s.Config.AccessKey) == "" || strings.TrimSpace(s.Config.SecretKey) == "" {
|
||||
return nil, fmt.Errorf("storage access key or secret key is required")
|
||||
return nil, errStorageAccessKeyRequired
|
||||
}
|
||||
if strings.TrimSpace(s.Config.Bucket) == "" {
|
||||
return nil, fmt.Errorf("storage bucket is required")
|
||||
return nil, errStorageBucketRequired
|
||||
}
|
||||
|
||||
endpoint, secure, err := parseEndpoint(s.Config.Endpoint)
|
||||
@@ -274,6 +315,7 @@ func (s *Storage) s3ClientForUse() (*minio.Client, error) {
|
||||
return nil, err
|
||||
}
|
||||
s.s3Client = client
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
@@ -281,12 +323,14 @@ func parseEndpoint(endpoint string) (string, bool, error) {
|
||||
if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") {
|
||||
u, err := url.Parse(endpoint)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
return "", false, fmt.Errorf("parse endpoint: %w", err)
|
||||
}
|
||||
if u.Host == "" {
|
||||
return "", false, fmt.Errorf("invalid endpoint")
|
||||
return "", false, errStorageInvalidEndpoint
|
||||
}
|
||||
|
||||
return u.Host, u.Scheme == "https", nil
|
||||
}
|
||||
|
||||
return endpoint, false, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user