feat: add backend_v1 migration
Some checks failed
build quyun / Build (push) Has been cancelled

This commit is contained in:
2025-12-19 14:46:58 +08:00
parent 218eb4689c
commit 24bd161df9
119 changed files with 12259 additions and 0 deletions

View File

@@ -0,0 +1,18 @@
package app
import (
"go.ipao.vip/atom/container"
"go.ipao.vip/atom/opt"
)
func Provide(opts ...opt.Option) error {
o := opt.New(opts...)
var config Config
if err := o.UnmarshalConfig(&config); err != nil {
return err
}
return container.Container.Provide(func() (*Config, error) {
return &config, nil
}, o.DiOptions()...)
}

View File

@@ -0,0 +1,179 @@
// Code generated by go-enum DO NOT EDIT.
// Version: -
// Revision: -
// Build Date: -
// Built By: -
package app
import (
"database/sql/driver"
"errors"
"fmt"
"strings"
)
const (
// AppModeDevelopment is a AppMode of type development.
AppModeDevelopment AppMode = "development"
// AppModeRelease is a AppMode of type release.
AppModeRelease AppMode = "release"
// AppModeTest is a AppMode of type test.
AppModeTest AppMode = "test"
)
var ErrInvalidAppMode = fmt.Errorf("not a valid AppMode, try [%s]", strings.Join(_AppModeNames, ", "))
var _AppModeNames = []string{
string(AppModeDevelopment),
string(AppModeRelease),
string(AppModeTest),
}
// AppModeNames returns a list of possible string values of AppMode.
func AppModeNames() []string {
tmp := make([]string, len(_AppModeNames))
copy(tmp, _AppModeNames)
return tmp
}
// AppModeValues returns a list of the values for AppMode
func AppModeValues() []AppMode {
return []AppMode{
AppModeDevelopment,
AppModeRelease,
AppModeTest,
}
}
// String implements the Stringer interface.
func (x AppMode) String() string {
return string(x)
}
// IsValid provides a quick way to determine if the typed value is
// part of the allowed enumerated values
func (x AppMode) IsValid() bool {
_, err := ParseAppMode(string(x))
return err == nil
}
var _AppModeValue = map[string]AppMode{
"development": AppModeDevelopment,
"release": AppModeRelease,
"test": AppModeTest,
}
// ParseAppMode attempts to convert a string to a AppMode.
func ParseAppMode(name string) (AppMode, error) {
if x, ok := _AppModeValue[name]; ok {
return x, nil
}
return AppMode(""), fmt.Errorf("%s is %w", name, ErrInvalidAppMode)
}
var errAppModeNilPtr = errors.New("value pointer is nil") // one per type for package clashes
// Scan implements the Scanner interface.
func (x *AppMode) Scan(value interface{}) (err error) {
if value == nil {
*x = AppMode("")
return
}
// A wider range of scannable types.
// driver.Value values at the top of the list for expediency
switch v := value.(type) {
case string:
*x, err = ParseAppMode(v)
case []byte:
*x, err = ParseAppMode(string(v))
case AppMode:
*x = v
case *AppMode:
if v == nil {
return errAppModeNilPtr
}
*x = *v
case *string:
if v == nil {
return errAppModeNilPtr
}
*x, err = ParseAppMode(*v)
default:
return errors.New("invalid type for AppMode")
}
return
}
// Value implements the driver Valuer interface.
func (x AppMode) Value() (driver.Value, error) {
return x.String(), nil
}
// Set implements the Golang flag.Value interface func.
func (x *AppMode) Set(val string) error {
v, err := ParseAppMode(val)
*x = v
return err
}
// Get implements the Golang flag.Getter interface func.
func (x *AppMode) Get() interface{} {
return *x
}
// Type implements the github.com/spf13/pFlag Value interface.
func (x *AppMode) Type() string {
return "AppMode"
}
type NullAppMode struct {
AppMode AppMode
Valid bool
}
func NewNullAppMode(val interface{}) (x NullAppMode) {
err := x.Scan(val) // yes, we ignore this error, it will just be an invalid value.
_ = err // make any errcheck linters happy
return
}
// Scan implements the Scanner interface.
func (x *NullAppMode) Scan(value interface{}) (err error) {
if value == nil {
x.AppMode, x.Valid = AppMode(""), false
return
}
err = x.AppMode.Scan(value)
x.Valid = (err == nil)
return
}
// Value implements the driver Valuer interface.
func (x NullAppMode) Value() (driver.Value, error) {
if !x.Valid {
return nil, nil
}
// driver.Value accepts int64 for int values.
return string(x.AppMode), nil
}
type NullAppModeStr struct {
NullAppMode
}
func NewNullAppModeStr(val interface{}) (x NullAppModeStr) {
x.Scan(val) // yes, we ignore this error, it will just be an invalid value.
return
}
// Value implements the driver Valuer interface.
func (x NullAppModeStr) Value() (driver.Value, error) {
if !x.Valid {
return nil, nil
}
return x.AppMode.String(), nil
}

View File

@@ -0,0 +1,45 @@
package app
import (
"go.ipao.vip/atom/container"
"go.ipao.vip/atom/opt"
)
const DefaultPrefix = "App"
func DefaultProvider() container.ProviderContainer {
return container.ProviderContainer{
Provider: Provide,
Options: []opt.Option{
opt.Prefix(DefaultPrefix),
},
}
}
// swagger:enum AppMode
// ENUM(development, release, test)
type AppMode string
type Config struct {
Mode AppMode
Cert *Cert
BaseURI *string
}
func (c *Config) IsDevMode() bool {
return c.Mode == AppModeDevelopment
}
func (c *Config) IsReleaseMode() bool {
return c.Mode == AppModeRelease
}
func (c *Config) IsTestMode() bool {
return c.Mode == AppModeTest
}
type Cert struct {
CA string
Cert string
Key string
}

View File

@@ -0,0 +1,109 @@
package cmux
import (
"fmt"
"net"
"time"
"quyun/v2/providers/grpc"
"quyun/v2/providers/http"
log "github.com/sirupsen/logrus"
"github.com/soheilhy/cmux"
"go.ipao.vip/atom/container"
"go.ipao.vip/atom/opt"
"golang.org/x/sync/errgroup"
)
const DefaultPrefix = "Cmux"
func DefaultProvider() container.ProviderContainer {
return container.ProviderContainer{
Provider: Provide,
Options: []opt.Option{
opt.Prefix(DefaultPrefix),
},
}
}
type Config struct {
Host *string
Port uint
}
func (h *Config) Address() string {
if h.Host == nil {
return fmt.Sprintf(":%d", h.Port)
}
return fmt.Sprintf("%s:%d", *h.Host, h.Port)
}
type CMux struct {
Http *http.Service
Grpc *grpc.Grpc
Mux cmux.CMux
Base net.Listener
}
func (c *CMux) Serve() error {
// Protect against slowloris connections when sniffing protocol
// Safe even if SetReadTimeout is a no-op in the cmux version in use
c.Mux.SetReadTimeout(1 * time.Second)
addr := ""
if c.Base != nil && c.Base.Addr() != nil {
addr = c.Base.Addr().String()
}
log.WithFields(log.Fields{
"addr": addr,
}).Info("cmux starting")
// Route classic HTTP/1.x traffic to the HTTP service
httpL := c.Mux.Match(cmux.HTTP1Fast())
// Route gRPC (HTTP/2 with content-type application/grpc) to the gRPC service.
// Additionally, send other HTTP/2 traffic to gRPC since Fiber (HTTP) does not serve HTTP/2.
grpcL := c.Mux.Match(
cmux.HTTP2HeaderField("content-type", "application/grpc"),
cmux.HTTP2(),
)
var eg errgroup.Group
eg.Go(func() error {
log.WithField("addr", addr).Info("grpc serving via cmux")
err := c.Grpc.ServeWithListener(grpcL)
if err != nil {
log.WithError(err).Error("grpc server exited with error")
} else {
log.Info("grpc server exited")
}
return err
})
eg.Go(func() error {
log.WithField("addr", addr).Info("http serving via cmux")
err := c.Http.Listener(httpL)
if err != nil {
log.WithError(err).Error("http server exited with error")
} else {
log.Info("http server exited")
}
return err
})
// Run cmux dispatcher; wait for the first error from any goroutine
eg.Go(func() error {
err := c.Mux.Serve()
if err != nil {
log.WithError(err).Error("cmux exited with error")
} else {
log.Info("cmux exited")
}
return err
})
err := eg.Wait()
if err == nil {
log.Info("cmux and sub-servers exited cleanly")
}
return err
}

View File

@@ -0,0 +1,37 @@
package cmux
import (
"net"
"quyun/v2/providers/grpc"
"quyun/v2/providers/http"
"github.com/soheilhy/cmux"
"go.ipao.vip/atom/container"
"go.ipao.vip/atom/opt"
)
func Provide(opts ...opt.Option) error {
o := opt.New(opts...)
var config Config
if err := o.UnmarshalConfig(&config); err != nil {
return err
}
return container.Container.Provide(func(http *http.Service, grpc *grpc.Grpc) (*CMux, error) {
l, err := net.Listen("tcp", config.Address())
if err != nil {
return nil, err
}
mux := &CMux{
Http: http,
Grpc: grpc,
Mux: cmux.New(l),
Base: l,
}
// Ensure cmux stops accepting new connections on shutdown
container.AddCloseAble(func() { _ = l.Close() })
return mux, nil
}, o.DiOptions()...)
}

View File

@@ -0,0 +1,30 @@
package event
import "go.ipao.vip/atom/contracts"
const (
Go contracts.Channel = "go"
Kafka contracts.Channel = "kafka"
Redis contracts.Channel = "redis"
Sql contracts.Channel = "sql"
)
type DefaultPublishTo struct{}
func (d *DefaultPublishTo) PublishTo() (contracts.Channel, string) {
return Go, "event:processed"
}
type DefaultChannel struct{}
func (d *DefaultChannel) Channel() contracts.Channel { return Go }
// kafka
type KafkaChannel struct{}
func (k *KafkaChannel) Channel() contracts.Channel { return Kafka }
// kafka
type RedisChannel struct{}
func (k *RedisChannel) Channel() contracts.Channel { return Redis }

View File

@@ -0,0 +1,99 @@
package event
import (
"context"
"github.com/ThreeDotsLabs/watermill"
"github.com/ThreeDotsLabs/watermill/message"
"go.ipao.vip/atom/container"
"go.ipao.vip/atom/contracts"
"go.ipao.vip/atom/opt"
)
const DefaultPrefix = "Events"
func DefaultProvider() container.ProviderContainer {
return container.ProviderContainer{
Provider: ProvideChannel,
Options: []opt.Option{
opt.Prefix(DefaultPrefix),
},
}
}
type Config struct {
Sql *ConfigSql
Kafka *ConfigKafka
Redis *ConfigRedis
}
type ConfigSql struct {
ConsumerGroup string
}
type ConfigRedis struct {
ConsumerGroup string
Streams []string
}
type ConfigKafka struct {
ConsumerGroup string
Brokers []string
}
type PubSub struct {
Router *message.Router
publishers map[contracts.Channel]message.Publisher
subscribers map[contracts.Channel]message.Subscriber
}
func (ps *PubSub) Serve(ctx context.Context) error {
if err := ps.Router.Run(ctx); err != nil {
return err
}
return nil
}
// publish
func (ps *PubSub) Publish(e contracts.EventPublisher) error {
if e == nil {
return nil
}
payload, err := e.Marshal()
if err != nil {
return err
}
msg := message.NewMessage(watermill.NewUUID(), payload)
return ps.getPublisher(e.Channel()).Publish(e.Topic(), msg)
}
// getPublisher returns the publisher for the specified channel.
func (ps *PubSub) getPublisher(channel contracts.Channel) message.Publisher {
if pub, ok := ps.publishers[channel]; ok {
return pub
}
return ps.publishers[Go]
}
func (ps *PubSub) getSubscriber(channel contracts.Channel) message.Subscriber {
if sub, ok := ps.subscribers[channel]; ok {
return sub
}
return ps.subscribers[Go]
}
func (ps *PubSub) Handle(handlerName string, sub contracts.EventHandler) {
publishToCh, publishToTopic := sub.PublishTo()
ps.Router.AddHandler(
handlerName,
sub.Topic(),
ps.getSubscriber(sub.Channel()),
publishToTopic,
ps.getPublisher(publishToCh),
sub.Handler,
)
}

View File

@@ -0,0 +1,60 @@
package event
import (
"github.com/ThreeDotsLabs/watermill"
"github.com/sirupsen/logrus"
)
// LogrusLoggerAdapter is a watermill logger adapter for logrus.
type LogrusLoggerAdapter struct {
log *logrus.Logger
fields watermill.LogFields
}
// NewLogrusLogger returns a LogrusLoggerAdapter that sends all logs to
// the passed logrus instance.
func LogrusAdapter() watermill.LoggerAdapter {
return &LogrusLoggerAdapter{log: logrus.StandardLogger()}
}
// Error logs on level error with err as field and optional fields.
func (l *LogrusLoggerAdapter) Error(msg string, err error, fields watermill.LogFields) {
l.createEntry(fields.Add(watermill.LogFields{"err": err})).Error(msg)
}
// Info logs on level info with optional fields.
func (l *LogrusLoggerAdapter) Info(msg string, fields watermill.LogFields) {
l.createEntry(fields).Info(msg)
}
// Debug logs on level debug with optional fields.
func (l *LogrusLoggerAdapter) Debug(msg string, fields watermill.LogFields) {
l.createEntry(fields).Debug(msg)
}
// Trace logs on level trace with optional fields.
func (l *LogrusLoggerAdapter) Trace(msg string, fields watermill.LogFields) {
l.createEntry(fields).Trace(msg)
}
// With returns a new LogrusLoggerAdapter that includes fields
// to be re-used between logging statements.
func (l *LogrusLoggerAdapter) With(fields watermill.LogFields) watermill.LoggerAdapter {
return &LogrusLoggerAdapter{
log: l.log,
fields: l.fields.Add(fields),
}
}
// createEntry is a helper to add fields to a logrus entry if necessary.
func (l *LogrusLoggerAdapter) createEntry(fields watermill.LogFields) *logrus.Entry {
entry := logrus.NewEntry(l.log)
allFields := fields.Add(l.fields)
if len(allFields) > 0 {
entry = entry.WithFields(logrus.Fields(allFields))
}
return entry
}

View File

@@ -0,0 +1,109 @@
package event
import (
sqlDB "database/sql"
"go.ipao.vip/atom/container"
"go.ipao.vip/atom/contracts"
"go.ipao.vip/atom/opt"
"github.com/ThreeDotsLabs/watermill-kafka/v3/pkg/kafka"
"github.com/ThreeDotsLabs/watermill-redisstream/pkg/redisstream"
"github.com/ThreeDotsLabs/watermill-sql/v3/pkg/sql"
"github.com/ThreeDotsLabs/watermill/message"
"github.com/ThreeDotsLabs/watermill/pubsub/gochannel"
"github.com/redis/go-redis/v9"
)
func ProvideChannel(opts ...opt.Option) error {
o := opt.New(opts...)
var config Config
if err := o.UnmarshalConfig(&config); err != nil {
return err
}
return container.Container.Provide(func() (*PubSub, error) {
logger := LogrusAdapter()
publishers := make(map[contracts.Channel]message.Publisher)
subscribers := make(map[contracts.Channel]message.Subscriber)
// gochannel
client := gochannel.NewGoChannel(gochannel.Config{}, logger)
publishers[Go] = client
subscribers[Go] = client
// kafka
if config.Kafka != nil {
kafkaPublisher, err := kafka.NewPublisher(kafka.PublisherConfig{
Brokers: config.Kafka.Brokers,
Marshaler: kafka.DefaultMarshaler{},
}, logger)
if err != nil {
return nil, err
}
publishers[Kafka] = kafkaPublisher
kafkaSubscriber, err := kafka.NewSubscriber(kafka.SubscriberConfig{
Brokers: config.Kafka.Brokers,
Unmarshaler: kafka.DefaultMarshaler{},
ConsumerGroup: config.Kafka.ConsumerGroup,
}, logger)
if err != nil {
return nil, err
}
subscribers[Kafka] = kafkaSubscriber
}
// redis
if config.Redis != nil {
var rdb redis.UniversalClient
redisSubscriber, err := redisstream.NewSubscriber(redisstream.SubscriberConfig{
Client: rdb,
Unmarshaller: redisstream.DefaultMarshallerUnmarshaller{},
ConsumerGroup: config.Redis.ConsumerGroup,
}, logger)
if err != nil {
return nil, err
}
subscribers[Redis] = redisSubscriber
redisPublisher, err := redisstream.NewPublisher(redisstream.PublisherConfig{
Client: rdb,
Marshaller: redisstream.DefaultMarshallerUnmarshaller{},
}, logger)
if err != nil {
return nil, err
}
publishers[Redis] = redisPublisher
}
if config.Sql == nil {
var db *sqlDB.DB
sqlPublisher, err := sql.NewPublisher(db, sql.PublisherConfig{
SchemaAdapter: sql.DefaultPostgreSQLSchema{},
AutoInitializeSchema: false,
}, logger)
if err != nil {
return nil, err
}
publishers[Sql] = sqlPublisher
sqlSubscriber, err := sql.NewSubscriber(db, sql.SubscriberConfig{
SchemaAdapter: sql.DefaultPostgreSQLSchema{},
ConsumerGroup: config.Sql.ConsumerGroup,
}, logger)
if err != nil {
return nil, err
}
subscribers[Sql] = sqlSubscriber
}
router, err := message.NewRouter(message.RouterConfig{}, logger)
if err != nil {
return nil, err
}
return &PubSub{Router: router, publishers: publishers, subscribers: subscribers}, nil
}, o.DiOptions()...)
}

View File

@@ -0,0 +1,145 @@
package grpc
import (
"fmt"
"net"
"time"
"go.ipao.vip/atom/container"
"go.ipao.vip/atom/opt"
"google.golang.org/grpc"
"google.golang.org/grpc/health"
grpc_health_v1 "google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/reflection"
)
const DefaultPrefix = "Grpc"
func DefaultProvider() container.ProviderContainer {
return container.ProviderContainer{
Provider: Provide,
Options: []opt.Option{
opt.Prefix(DefaultPrefix),
},
}
}
type Config struct {
Host *string
Port uint
// EnableReflection enables grpc/reflection registration when true
EnableReflection *bool
// EnableHealth enables gRPC health service registration when true
EnableHealth *bool
// ShutdownTimeoutSeconds controls graceful stop timeout; 0 uses default
ShutdownTimeoutSeconds uint
}
func (h *Config) Address() string {
if h.Port == 0 {
h.Port = 8081
}
if h.Host == nil {
return fmt.Sprintf(":%d", h.Port)
}
return fmt.Sprintf("%s:%d", *h.Host, h.Port)
}
type Grpc struct {
Server *grpc.Server
config *Config
options []grpc.ServerOption
unaryInterceptors []grpc.UnaryServerInterceptor
streamInterceptors []grpc.StreamServerInterceptor
}
func (g *Grpc) Init() error {
// merge options and build interceptor chains if provided
var srvOpts []grpc.ServerOption
if len(g.unaryInterceptors) > 0 {
srvOpts = append(srvOpts, grpc.ChainUnaryInterceptor(g.unaryInterceptors...))
}
if len(g.streamInterceptors) > 0 {
srvOpts = append(srvOpts, grpc.ChainStreamInterceptor(g.streamInterceptors...))
}
srvOpts = append(srvOpts, g.options...)
g.Server = grpc.NewServer(srvOpts...)
// optional reflection and health
if g.config.EnableReflection != nil && *g.config.EnableReflection {
reflection.Register(g.Server)
}
if g.config.EnableHealth != nil && *g.config.EnableHealth {
hs := health.NewServer()
grpc_health_v1.RegisterHealthServer(g.Server, hs)
}
// graceful stop with timeout fallback to Stop()
container.AddCloseAble(func() {
timeout := g.config.ShutdownTimeoutSeconds
if timeout == 0 {
timeout = 10
}
done := make(chan struct{})
go func() {
g.Server.GracefulStop()
close(done)
}()
select {
case <-done:
// graceful stop finished
case <-time.After(time.Duration(timeout) * time.Second):
// timeout, force stop
g.Server.Stop()
}
})
return nil
}
// Serve
func (g *Grpc) Serve() error {
if g.Server == nil {
if err := g.Init(); err != nil {
return err
}
}
l, err := net.Listen("tcp", g.config.Address())
if err != nil {
return err
}
return g.Server.Serve(l)
}
func (g *Grpc) ServeWithListener(ln net.Listener) error {
return g.Server.Serve(ln)
}
// UseOptions appends gRPC ServerOptions to be applied when constructing the server.
func (g *Grpc) UseOptions(opts ...grpc.ServerOption) {
g.options = append(g.options, opts...)
}
// UseUnaryInterceptors appends unary interceptors to be chained.
func (g *Grpc) UseUnaryInterceptors(inters ...grpc.UnaryServerInterceptor) {
g.unaryInterceptors = append(g.unaryInterceptors, inters...)
}
// UseStreamInterceptors appends stream interceptors to be chained.
func (g *Grpc) UseStreamInterceptors(inters ...grpc.StreamServerInterceptor) {
g.streamInterceptors = append(g.streamInterceptors, inters...)
}
// Reset clears all configured options and interceptors.
// Useful in tests to ensure isolation.
func (g *Grpc) Reset() {
g.options = nil
g.unaryInterceptors = nil
g.streamInterceptors = nil
}

View File

@@ -0,0 +1,513 @@
# gRPC Server Options & Interceptors Examples
本文件给出一些可直接拷贝使用的示例,配合本包提供的注册函数:
- `UseOptions(opts ...grpc.ServerOption)`
- `UseUnaryInterceptors(inters ...grpc.UnaryServerInterceptor)`
- `UseStreamInterceptors(inters ...grpc.StreamServerInterceptor)`
建议在应用启动或 Provider 初始化阶段调用(在 gRPC 服务构造前)。
> 导入建议:
>
> ```go
> import (
> pgrpc "test/providers/grpc" // 本包
> grpc "google.golang.org/grpc" // 避免命名冲突
> )
> ```
## ServerOption 示例
最大消息大小限制:
```go
pgrpc.UseOptions(
grpc.MaxRecvMsgSize(32<<20), // 32 MiB
grpc.MaxSendMsgSize(32<<20), // 32 MiB
)
```
限制最大并发流(对 HTTP/2 流并发施加上限):
```go
pgrpc.UseOptions(
grpc.MaxConcurrentStreams(1024),
)
```
Keepalive 参数(需要 keepalive 包):
```go
import (
"time"
"google.golang.org/grpc/keepalive"
)
pgrpc.UseOptions(
grpc.KeepaliveParams(keepalive.ServerParameters{
MaxConnectionIdle: 5 * time.Minute,
MaxConnectionAge: 30 * time.Minute,
MaxConnectionAgeGrace: 5 * time.Minute,
Time: 2 * time.Minute, // ping 间隔
Timeout: 20 * time.Second,
}),
grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
MinTime: 1 * time.Minute, // 客户端 ping 最小间隔
PermitWithoutStream: true,
}),
)
```
## UnaryServerInterceptor 示例
简单日志拦截器logrus
```go
import (
"context"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
)
func LoggingUnaryInterceptor(
ctx context.Context,
req any,
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (any, error) {
start := time.Now()
resp, err := handler(ctx, req)
dur := time.Since(start)
entry := log.WithFields(log.Fields{
"grpc.method": info.FullMethod,
"grpc.duration_ms": dur.Milliseconds(),
})
if err != nil {
entry.WithError(err).Warn("grpc unary request failed")
} else {
entry.Info("grpc unary request finished")
}
return resp, err
}
// 注册
pgrpc.UseUnaryInterceptors(LoggingUnaryInterceptor)
```
恢复拦截器panic 捕获):
```go
import (
"context"
"fmt"
"runtime/debug"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/status"
"google.golang.org/grpc/codes"
)
func RecoveryUnaryInterceptor(
ctx context.Context,
req any,
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (any, error) {
defer func() {
if r := recover() ; r != nil {
log.WithField("grpc.method", info.FullMethod).Errorf("panic: %v\n%s", r, debug.Stack())
}
}()
return handler(ctx, req)
}
// 或者向客户端返回内部错误:
func RecoveryUnaryInterceptorWithError(
ctx context.Context,
req any,
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (any, error) {
defer func() {
if r := recover() ; r != nil {
log.WithField("grpc.method", info.FullMethod).Errorf("panic: %v\n%s", r, debug.Stack())
}
}()
resp, err := handler(ctx, req)
if rec := recover() ; rec != nil {
return nil, status.Error(codes.Internal, fmt.Sprint(rec))
}
return resp, err
}
pgrpc.UseUnaryInterceptors(RecoveryUnaryInterceptor)
```
链式调用(与其它拦截器共同使用):
```go
pgrpc.UseUnaryInterceptors(LoggingUnaryInterceptor, RecoveryUnaryInterceptor)
```
## StreamServerInterceptor 示例
简单日志拦截器:
```go
import (
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
)
func LoggingStreamInterceptor(
srv any,
ss grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
start := time.Now()
err := handler(srv, ss)
dur := time.Since(start)
entry := log.WithFields(log.Fields{
"grpc.method": info.FullMethod,
"grpc.is_client_stream": info.IsClientStream,
"grpc.is_server_stream": info.IsServerStream,
"grpc.duration_ms": dur.Milliseconds(),
})
if err != nil {
entry.WithError(err).Warn("grpc stream request failed")
} else {
entry.Info("grpc stream request finished")
}
return err
}
pgrpc.UseStreamInterceptors(LoggingStreamInterceptor)
```
恢复拦截器panic 捕获):
```go
import (
"runtime/debug"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
)
func RecoveryStreamInterceptor(
srv any,
ss grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) (err error) {
defer func() {
if r := recover() ; r != nil {
log.WithField("grpc.method", info.FullMethod).Errorf("panic: %v\n%s", r, debug.Stack())
}
}()
return handler(srv, ss)
}
pgrpc.UseStreamInterceptors(RecoveryStreamInterceptor)
```
## 组合与测试小贴士
- 可以多次调用 `UseOptions/UseUnaryInterceptors/UseStreamInterceptors`,最终会在服务构造时链式生效。
- 单元测试中如需隔离,建议使用 `pgrpc.Reset()` 清理已注册的选项和拦截器。
- 若要启用健康检查或反射,请在配置中设置:
- `EnableHealth = true`
- `EnableReflection = true`
## 更多 ServerOption 示例
TLS服务端或 mTLS
```go
import (
"crypto/tls"
grpcCredentials "google.golang.org/grpc/credentials"
)
// 使用自定义 tls.Config可配置 mTLS
var tlsConfig *tls.Config = &tls.Config{ /* ... */ }
pgrpc.UseOptions(
grpc.Creds(grpcCredentials.NewTLS(tlsConfig)),
)
// 或者从证书文件加载(仅服务端 TLS
// pgrpc.UseOptions(grpc.Creds(grpcCredentials.NewServerTLSFromFile(certFile, keyFile)))
```
OpenTelemetry 统计/追踪StatsHandler
```go
import (
otelgrpc "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
)
pgrpc.UseOptions(
grpc.StatsHandler(otelgrpc.NewServerHandler()),
)
```
流控/缓冲区调优:
```go
pgrpc.UseOptions(
grpc.InitialWindowSize(1<<20), // 每个流初始窗口(字节)
grpc.InitialConnWindowSize(1<<21), // 连接级窗口
grpc.ReadBufferSize(64<<10), // 读缓冲 64 KiB
grpc.WriteBufferSize(64<<10), // 写缓冲 64 KiB
)
```
连接超时与 Tap Handle早期拦截
```go
import (
"context"
"time"
"google.golang.org/grpc/tap"
)
pgrpc.UseOptions(
grpc.ConnectionTimeout(5 * time.Second),
grpc.InTapHandle(func(ctx context.Context, info *tap.Info) (context.Context, error) {
// 在真正的 RPC 处理前进行快速拒绝如黑名单、IP 检查等)
return ctx, nil
}),
)
```
未知服务处理与工作池:
```go
pgrpc.UseOptions(
grpc.UnknownServiceHandler(func(srv any, stream grpc.ServerStream) error {
// 统一记录未注册方法,或返回自定义错误
return status.Error(codes.Unimplemented, "unknown service/method")
}),
grpc.NumStreamWorkers(8), // 针对 CPU 密集流处理的工作池
)
```
## 更多 Unary 拦截器示例
基于 Metadata 的鉴权:
```go
import (
"context"
"strings"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/grpc/codes"
)
func AuthUnaryInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
md, _ := metadata.FromIncomingContext(ctx)
token := ""
if vals := md.Get("authorization"); len(vals) > 0 {
token = vals[0]
}
if token == "" || !strings.HasPrefix(strings.ToLower(token), "bearer ") {
return nil, status.Error(codes.Unauthenticated, "missing or invalid token")
}
// TODO: 验证 JWT / API-Key
return handler(ctx, req)
}
pgrpc.UseUnaryInterceptors(AuthUnaryInterceptor)
```
方法粒度速率限制x/time/rate
```go
import (
"context"
"sync"
"golang.org/x/time/rate"
"google.golang.org/grpc/status"
"google.golang.org/grpc/codes"
)
var (
rlmu sync.RWMutex
rlm = map[string]*rate.Limiter{}
)
func limitFor(method string) *rate.Limiter {
rlmu.RLock() ; l := rlm[method]; rlmu.RUnlock()
if l != nil { return l }
rlmu.Lock() ; defer rlmu.Unlock()
if rlm[method] == nil { rlm[method] = rate.NewLimiter(100, 200) } // 100 rps, burst 200
return rlm[method]
}
func RateLimitUnaryInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
l := limitFor(info.FullMethod)
if !l.Allow() {
return nil, status.Error(codes.ResourceExhausted, "rate limited")
}
return handler(ctx, req)
}
pgrpc.UseUnaryInterceptors(RateLimitUnaryInterceptor)
```
Request-ID 注入与日志关联:
```go
import (
"context"
"github.com/google/uuid"
"google.golang.org/grpc/metadata"
)
type ctxKey string
const requestIDKey ctxKey = "request_id"
func RequestIDUnaryInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
md, _ := metadata.FromIncomingContext(ctx)
var rid string
if v := md.Get("x-request-id"); len(v) > 0 { rid = v[0] }
if rid == "" { rid = uuid.New().String() }
ctx = context.WithValue(ctx, requestIDKey, rid)
return handler(ctx, req)
}
pgrpc.UseUnaryInterceptors(RequestIDUnaryInterceptor)
```
无超时/超长请求治理(默认超时/拒绝超长):
```go
import (
"context"
"time"
"google.golang.org/grpc/status"
"google.golang.org/grpc/codes"
)
func DeadlineUnaryInterceptor(max time.Duration) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
if _, ok := ctx.Deadline() ; !ok { // 未设置超时
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, max)
defer cancel()
}
resp, err := handler(ctx, req)
if err != nil && ctx.Err() == context.DeadlineExceeded {
return nil, status.Error(codes.DeadlineExceeded, "deadline exceeded")
}
return resp, err
}
}
pgrpc.UseUnaryInterceptors(DeadlineUnaryInterceptor(5*time.Second))
```
## 更多 Stream 拦截器示例
基于 Metadata 的鉴权(流):
```go
import (
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/grpc/codes"
)
func AuthStreamInterceptor(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
md, _ := metadata.FromIncomingContext(ss.Context())
if len(md.Get("authorization")) == 0 {
return status.Error(codes.Unauthenticated, "missing token")
}
return handler(srv, ss)
}
pgrpc.UseStreamInterceptors(AuthStreamInterceptor)
```
流级限流(示例:简单 Allow 检查):
```go
func RateLimitStreamInterceptor(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
l := limitFor(info.FullMethod)
if !l.Allow() {
return status.Error(codes.ResourceExhausted, "rate limited")
}
return handler(srv, ss)
}
pgrpc.UseStreamInterceptors(RateLimitStreamInterceptor)
```
## 压缩与编码
注册 gzip 压缩器后,客户端可按需协商使用(新版本通过 encoding 注册):
```go
import (
_ "google.golang.org/grpc/encoding/gzip" // 注册 gzip 编解码器
)
// 仅需 import 即可,无额外 ServerOption
```
## OpenTelemetry 集成(推荐)
使用 StatsHandler推荐不与拦截器同时使用避免重复埋点
```go
import (
otelgrpc "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
)
// 基本接入:使用全局 Tracer/Meter由 OTEL Provider 初始化)
handler := otelgrpc.NewServerHandler(
otelgrpc.WithTraceEvents(), // 在 span 中记录消息事件
)
pgrpc.UseOptions(grpc.StatsHandler(handler))
// 忽略某些方法(如健康检查),避免噪声:
handler = otelgrpc.NewServerHandler(
otelgrpc.WithFilter(func(ctx context.Context, fullMethod string) bool {
return fullMethod != "/grpc.health.v1.Health/Check"
}),
)
pgrpc.UseOptions(grpc.StatsHandler(handler))
```
使用拦截器版本(如你更偏好 Interceptor 方案;与 StatsHandler 二选一):
```go
import (
otelgrpc "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
)
pgrpc.UseUnaryInterceptors(otelgrpc.UnaryServerInterceptor())
pgrpc.UseStreamInterceptors(otelgrpc.StreamServerInterceptor())
```
> 注意:不要同时启用 StatsHandler 和拦截器,否则会重复生成 span/metrics。
## OpenTracingJaeger集成
当使用 Tracing ProviderJaeger + OpenTracing可使用 opentracing 的 gRPC 拦截器:
```go
import (
opentracing "github.com/opentracing/opentracing-go"
otgrpc "github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc"
)
pgrpc.UseUnaryInterceptors(otgrpc.OpenTracingServerInterceptor(opentracing.GlobalTracer()))
pgrpc.UseStreamInterceptors(otgrpc.OpenTracingStreamServerInterceptor(opentracing.GlobalTracer()))
```
> 与 OTEL 方案互斥:如果已启用 OTEL请不要再开启 OpenTracing 拦截器,以免重复埋点。

View File

@@ -0,0 +1,18 @@
package grpc
import (
"go.ipao.vip/atom/container"
"go.ipao.vip/atom/opt"
)
func Provide(opts ...opt.Option) error {
o := opt.New(opts...)
var config Config
if err := o.UnmarshalConfig(&config); err != nil {
return err
}
return container.Container.Provide(func() (*Grpc, error) {
return &Grpc{config: &config}, nil
}, o.DiOptions()...)
}

View File

@@ -0,0 +1,38 @@
package http
import (
"fmt"
)
const DefaultPrefix = "Http"
type Config struct {
StaticPath *string
StaticRoute *string
BaseURI *string
Port uint
Tls *Tls
Cors *Cors
}
type Tls struct {
Cert string
Key string
}
type Cors struct {
Mode string
Whitelist []Whitelist
}
type Whitelist struct {
AllowOrigin string
AllowHeaders string
AllowMethods string
ExposeHeaders string
AllowCredentials bool
}
func (h *Config) Address() string {
return fmt.Sprintf(":%d", h.Port)
}

View File

@@ -0,0 +1,203 @@
package http
import (
"context"
"errors"
"fmt"
"net"
"runtime/debug"
"time"
log "github.com/sirupsen/logrus"
"go.ipao.vip/atom/container"
"go.ipao.vip/atom/opt"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/compress"
"github.com/gofiber/fiber/v3/middleware/cors"
"github.com/gofiber/fiber/v3/middleware/helmet"
"github.com/gofiber/fiber/v3/middleware/limiter"
"github.com/gofiber/fiber/v3/middleware/logger"
"github.com/gofiber/fiber/v3/middleware/recover"
"github.com/gofiber/fiber/v3/middleware/requestid"
"github.com/samber/lo"
)
func DefaultProvider() container.ProviderContainer {
return container.ProviderContainer{
Provider: Provide,
Options: []opt.Option{
opt.Prefix(DefaultPrefix),
},
}
}
type Service struct {
conf *Config
Engine *fiber.App
}
func (svc *Service) listenerConfig() fiber.ListenConfig {
listenConfig := fiber.ListenConfig{
EnablePrintRoutes: true,
// DisableStartupMessage: true,
}
if svc.conf.Tls != nil {
if svc.conf.Tls.Cert == "" || svc.conf.Tls.Key == "" {
panic(errors.New("tls cert and key must be set"))
}
listenConfig.CertFile = svc.conf.Tls.Cert
listenConfig.CertKeyFile = svc.conf.Tls.Key
}
container.AddCloseAble(func() {
svc.Engine.ShutdownWithTimeout(time.Second * 10)
})
return listenConfig
}
func (svc *Service) Listener(ln net.Listener) error {
return svc.Engine.Listener(ln, svc.listenerConfig())
}
func (svc *Service) Serve(ctx context.Context) error {
ln, err := net.Listen("tcp", svc.conf.Address())
if err != nil {
return err
}
// Run the server in a goroutine so we can listen for context cancellation
serverErr := make(chan error, 1)
go func() {
serverErr <- svc.Engine.Listener(ln, svc.listenerConfig())
}()
select {
case <-ctx.Done():
// Shutdown the server gracefully
if shutdownErr := svc.Engine.Shutdown(); shutdownErr != nil {
return shutdownErr
}
// treat context cancellation as graceful shutdown
return nil
case err := <-serverErr:
return err
}
}
func Provide(opts ...opt.Option) error {
o := opt.New(opts...)
var config Config
if err := o.UnmarshalConfig(&config); err != nil {
return err
}
return container.Container.Provide(func() (*Service, error) {
engine := fiber.New(fiber.Config{
StrictRouting: true,
CaseSensitive: true,
BodyLimit: 10 * 1024 * 1024, // 10 MiB
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 60 * time.Second,
ProxyHeader: fiber.HeaderXForwardedFor,
EnableIPValidation: true,
})
// request id first for correlation
engine.Use(requestid.New())
// recover with stack + request id
engine.Use(recover.New(recover.Config{
EnableStackTrace: true,
StackTraceHandler: func(c fiber.Ctx, e any) {
rid := c.Get(fiber.HeaderXRequestID)
log.WithField("request_id", rid).Error(fmt.Sprintf("panic: %v\n%s\n", e, debug.Stack()))
},
}))
// basic security + compression
engine.Use(helmet.New())
engine.Use(compress.New(compress.Config{Level: compress.LevelDefault}))
// optional CORS based on config
if config.Cors != nil {
corsCfg := buildCORSConfig(config.Cors)
if corsCfg != nil {
engine.Use(cors.New(*corsCfg))
}
}
// logging with request id and latency
engine.Use(logger.New(logger.Config{
// requestid middleware stores ctx.Locals("requestid")
Format: `${time} [${ip}] ${method} ${status} ${path} ${latency} rid=${locals:requestid} "${ua}"\n`,
TimeFormat: time.RFC3339,
TimeZone: "Asia/Shanghai",
}))
// rate limit (enable standard headers; adjust Max via config if needed)
engine.Use(limiter.New(limiter.Config{Max: 0}))
// static files (Fiber v3 Static helper moved; enable via filesystem middleware later)
// if config.StaticRoute != nil && config.StaticPath != nil { ... }
// health endpoints
engine.Get("/healthz", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusNoContent) })
engine.Get("/readyz", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusNoContent) })
engine.Hooks().OnPostShutdown(func(err error) error {
if err != nil {
log.Error("http server shutdown error: ", err)
}
log.Info("http server has shutdown success")
return nil
})
return &Service{
Engine: engine,
conf: &config,
}, nil
}, o.DiOptions()...)
}
// buildCORSConfig converts provider Cors config into fiber cors.Config
func buildCORSConfig(c *Cors) *cors.Config {
if c == nil {
return nil
}
if c.Mode == "disabled" {
return nil
}
var (
origins []string
headers []string
methods []string
exposes []string
allowCreds bool
)
for _, w := range c.Whitelist {
if w.AllowOrigin != "" {
origins = append(origins, w.AllowOrigin)
}
if w.AllowHeaders != "" {
headers = append(headers, w.AllowHeaders)
}
if w.AllowMethods != "" {
methods = append(methods, w.AllowMethods)
}
if w.ExposeHeaders != "" {
exposes = append(exposes, w.ExposeHeaders)
}
allowCreds = allowCreds || w.AllowCredentials
}
cfg := cors.Config{
AllowOrigins: lo.Uniq(origins),
AllowHeaders: lo.Uniq(headers),
AllowMethods: lo.Uniq(methods),
ExposeHeaders: lo.Uniq(exposes),
AllowCredentials: allowCreds,
}
return &cfg
}

View File

@@ -0,0 +1,317 @@
package swagger
import (
"html/template"
)
// Config stores SwaggerUI configuration variables
type Config struct {
// This parameter can be used to name different swagger document instances.
// default: ""
InstanceName string `json:"-"`
// Title pointing to title of HTML page.
// default: "Swagger UI"
Title string `json:"-"`
// URL to fetch external configuration document from.
// default: ""
ConfigURL string `json:"configUrl,omitempty"`
// The URL pointing to API definition (normally swagger.json or swagger.yaml).
// default: "doc.json"
URL string `json:"url,omitempty"`
// Enables overriding configuration parameters via URL search params.
// default: false
QueryConfigEnabled bool `json:"queryConfigEnabled,omitempty"`
// The name of a component available via the plugin system to use as the top-level layout for Swagger UI.
// default: "StandaloneLayout"
Layout string `json:"layout,omitempty"`
// An array of plugin functions to use in Swagger UI.
// default: [SwaggerUIBundle.plugins.DownloadUrl]
Plugins []template.JS `json:"-"`
// An array of presets to use in Swagger UI. Usually, you'll want to include ApisPreset if you use this option.
// default: [SwaggerUIBundle.presets.apis, SwaggerUIStandalonePreset]
Presets []template.JS `json:"-"`
// If set to true, enables deep linking for tags and operations.
// default: true
DeepLinking bool `json:"deepLinking"`
// Controls the display of operationId in operations list.
// default: false
DisplayOperationId bool `json:"displayOperationId,omitempty"`
// The default expansion depth for models (set to -1 completely hide the models).
// default: 1
DefaultModelsExpandDepth int `json:"defaultModelsExpandDepth,omitempty"`
// The default expansion depth for the model on the model-example section.
// default: 1
DefaultModelExpandDepth int `json:"defaultModelExpandDepth,omitempty"`
// Controls how the model is shown when the API is first rendered.
// The user can always switch the rendering for a given model by clicking the 'Model' and 'Example Value' links.
// default: "example"
DefaultModelRendering string `json:"defaultModelRendering,omitempty"`
// Controls the display of the request duration (in milliseconds) for "Try it out" requests.
// default: false
DisplayRequestDuration bool `json:"displayRequestDuration,omitempty"`
// Controls the default expansion setting for the operations and tags.
// 'list' (default, expands only the tags),
// 'full' (expands the tags and operations),
// 'none' (expands nothing)
DocExpansion string `json:"docExpansion,omitempty"`
// If set, enables filtering. The top bar will show an edit box that you can use to filter the tagged operations that are shown.
// Can be Boolean to enable or disable, or a string, in which case filtering will be enabled using that string as the filter expression.
// Filtering is case sensitive matching the filter expression anywhere inside the tag.
// default: false
Filter FilterConfig `json:"-"`
// If set, limits the number of tagged operations displayed to at most this many. The default is to show all operations.
// default: 0
MaxDisplayedTags int `json:"maxDisplayedTags,omitempty"`
// Controls the display of vendor extension (x-) fields and values for Operations, Parameters, Responses, and Schema.
// default: false
ShowExtensions bool `json:"showExtensions,omitempty"`
// Controls the display of extensions (pattern, maxLength, minLength, maximum, minimum) fields and values for Parameters.
// default: false
ShowCommonExtensions bool `json:"showCommonExtensions,omitempty"`
// Apply a sort to the tag list of each API. It can be 'alpha' (sort by paths alphanumerically) or a function (see Array.prototype.sort().
// to learn how to write a sort function). Two tag name strings are passed to the sorter for each pass.
// default: "" -> Default is the order determined by Swagger UI.
TagsSorter template.JS `json:"-"`
// Provides a mechanism to be notified when Swagger UI has finished rendering a newly provided definition.
// default: "" -> Function=NOOP
OnComplete template.JS `json:"-"`
// An object with the activate and theme properties.
SyntaxHighlight *SyntaxHighlightConfig `json:"-"`
// Controls whether the "Try it out" section should be enabled by default.
// default: false
TryItOutEnabled bool `json:"tryItOutEnabled,omitempty"`
// Enables the request snippet section. When disabled, the legacy curl snippet will be used.
// default: false
RequestSnippetsEnabled bool `json:"requestSnippetsEnabled,omitempty"`
// OAuth redirect URL.
// default: ""
OAuth2RedirectUrl string `json:"oauth2RedirectUrl,omitempty"`
// MUST be a function. Function to intercept remote definition, "Try it out", and OAuth 2.0 requests.
// Accepts one argument requestInterceptor(request) and must return the modified request, or a Promise that resolves to the modified request.
// default: ""
RequestInterceptor template.JS `json:"-"`
// If set, MUST be an array of command line options available to the curl command. This can be set on the mutated request in the requestInterceptor function.
// For example request.curlOptions = ["-g", "--limit-rate 20k"]
// default: nil
RequestCurlOptions []string `json:"request.curlOptions,omitempty"`
// MUST be a function. Function to intercept remote definition, "Try it out", and OAuth 2.0 responses.
// Accepts one argument responseInterceptor(response) and must return the modified response, or a Promise that resolves to the modified response.
// default: ""
ResponseInterceptor template.JS `json:"-"`
// If set to true, uses the mutated request returned from a requestInterceptor to produce the curl command in the UI,
// otherwise the request before the requestInterceptor was applied is used.
// default: true
ShowMutatedRequest bool `json:"showMutatedRequest"`
// List of HTTP methods that have the "Try it out" feature enabled. An empty array disables "Try it out" for all operations.
// This does not filter the operations from the display.
// Possible values are ["get", "put", "post", "delete", "options", "head", "patch", "trace"]
// default: nil
SupportedSubmitMethods []string `json:"supportedSubmitMethods,omitempty"`
// By default, Swagger UI attempts to validate specs against swagger.io's online validator. You can use this parameter to set a different validator URL.
// For example for locally deployed validators (https://github.com/swagger-api/validator-badge).
// Setting it to either none, 127.0.0.1 or localhost will disable validation.
// default: ""
ValidatorUrl string `json:"validatorUrl,omitempty"`
// If set to true, enables passing credentials, as defined in the Fetch standard, in CORS requests that are sent by the browser.
// Note that Swagger UI cannot currently set cookies cross-domain (see https://github.com/swagger-api/swagger-js/issues/1163).
// as a result, you will have to rely on browser-supplied cookies (which this setting enables sending) that Swagger UI cannot control.
// default: false
WithCredentials bool `json:"withCredentials,omitempty"`
// Function to set default values to each property in model. Accepts one argument modelPropertyMacro(property), property is immutable.
// default: ""
ModelPropertyMacro template.JS `json:"-"`
// Function to set default value to parameters. Accepts two arguments parameterMacro(operation, parameter).
// Operation and parameter are objects passed for context, both remain immutable.
// default: ""
ParameterMacro template.JS `json:"-"`
// If set to true, it persists authorization data and it would not be lost on browser close/refresh.
// default: false
PersistAuthorization bool `json:"persistAuthorization,omitempty"`
// Configuration information for OAuth2, optional if using OAuth2
OAuth *OAuthConfig `json:"-"`
// (authDefinitionKey, username, password) => action
// Programmatically set values for a Basic authorization scheme.
// default: ""
PreauthorizeBasic template.JS `json:"-"`
// (authDefinitionKey, apiKeyValue) => action
// Programmatically set values for an API key or Bearer authorization scheme.
// In case of OpenAPI 3.0 Bearer scheme, apiKeyValue must contain just the token itself without the Bearer prefix.
// default: ""
PreauthorizeApiKey template.JS `json:"-"`
// Applies custom CSS styles.
// default: ""
CustomStyle template.CSS `json:"-"`
// Applies custom JavaScript scripts.
// default ""
CustomScript template.JS `json:"-"`
}
type FilterConfig struct {
Enabled bool
Expression string
}
func (fc FilterConfig) Value() interface{} {
if fc.Expression != "" {
return fc.Expression
}
return fc.Enabled
}
type SyntaxHighlightConfig struct {
// Whether syntax highlighting should be activated or not.
// default: true
Activate bool `json:"activate"`
// Highlight.js syntax coloring theme to use.
// Possible values are ["agate", "arta", "monokai", "nord", "obsidian", "tomorrow-night"]
// default: "agate"
Theme string `json:"theme,omitempty"`
}
func (shc SyntaxHighlightConfig) Value() interface{} {
if shc.Activate {
return shc
}
return false
}
type OAuthConfig struct {
// ID of the client sent to the OAuth2 provider.
// default: ""
ClientId string `json:"clientId,omitempty"`
// Never use this parameter in your production environment.
// It exposes cruicial security information. This feature is intended for dev/test environments only.
// Secret of the client sent to the OAuth2 provider.
// default: ""
ClientSecret string `json:"clientSecret,omitempty"`
// Application name, displayed in authorization popup.
// default: ""
AppName string `json:"appName,omitempty"`
// Realm query parameter (for oauth1) added to authorizationUrl and tokenUrl.
// default: ""
Realm string `json:"realm,omitempty"`
// String array of initially selected oauth scopes
// default: nil
Scopes []string `json:"scopes,omitempty"`
// Additional query parameters added to authorizationUrl and tokenUrl.
// default: nil
AdditionalQueryStringParams map[string]string `json:"additionalQueryStringParams,omitempty"`
// Unavailable Only activated for the accessCode flow.
// During the authorization_code request to the tokenUrl, pass the Client Password using the HTTP Basic Authentication scheme
// (Authorization header with Basic base64encode(client_id + client_secret)).
// default: false
UseBasicAuthenticationWithAccessCodeGrant bool `json:"useBasicAuthenticationWithAccessCodeGrant,omitempty"`
// Only applies to authorizatonCode flows.
// Proof Key for Code Exchange brings enhanced security for OAuth public clients.
// default: false
UsePkceWithAuthorizationCodeGrant bool `json:"usePkceWithAuthorizationCodeGrant,omitempty"`
}
var ConfigDefault = Config{
Title: "Swagger UI",
Layout: "StandaloneLayout",
Plugins: []template.JS{
template.JS("SwaggerUIBundle.plugins.DownloadUrl"),
},
Presets: []template.JS{
template.JS("SwaggerUIBundle.presets.apis"),
template.JS("SwaggerUIStandalonePreset"),
},
DeepLinking: true,
DefaultModelsExpandDepth: 1,
DefaultModelExpandDepth: 1,
DefaultModelRendering: "example",
DocExpansion: "list",
SyntaxHighlight: &SyntaxHighlightConfig{
Activate: true,
Theme: "agate",
},
ShowMutatedRequest: true,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
if cfg.Title == "" {
cfg.Title = ConfigDefault.Title
}
if cfg.Layout == "" {
cfg.Layout = ConfigDefault.Layout
}
if cfg.DefaultModelRendering == "" {
cfg.DefaultModelRendering = ConfigDefault.DefaultModelRendering
}
if cfg.DocExpansion == "" {
cfg.DocExpansion = ConfigDefault.DocExpansion
}
if cfg.Plugins == nil {
cfg.Plugins = ConfigDefault.Plugins
}
if cfg.Presets == nil {
cfg.Presets = ConfigDefault.Presets
}
if cfg.SyntaxHighlight == nil {
cfg.SyntaxHighlight = ConfigDefault.SyntaxHighlight
}
return cfg
}

View File

@@ -0,0 +1,103 @@
package swagger
import (
"fmt"
"html/template"
"path"
"strings"
"sync"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/static"
"github.com/gofiber/utils/v2"
"github.com/rogeecn/swag"
swaggerFiles "github.com/swaggo/files/v2"
)
const (
defaultDocURL = "doc.json"
defaultIndex = "index.html"
)
var HandlerDefault = New()
// New returns custom handler
func New(config ...Config) fiber.Handler {
cfg := configDefault(config...)
index, err := template.New("swagger_index.html").Parse(indexTmpl)
if err != nil {
panic(fmt.Errorf("fiber: swagger middleware error -> %w", err))
}
var (
prefix string
once sync.Once
)
return func(c fiber.Ctx) error {
// Set prefix
once.Do(
func() {
prefix = strings.ReplaceAll(c.Route().Path, "*", "")
forwardedPrefix := getForwardedPrefix(c)
if forwardedPrefix != "" {
prefix = forwardedPrefix + prefix
}
// Set doc url
if len(cfg.URL) == 0 {
cfg.URL = path.Join(prefix, defaultDocURL)
}
},
)
p := c.Path(utils.CopyString(c.Params("*")))
switch p {
case defaultIndex:
c.Type("html")
return index.Execute(c, cfg)
case defaultDocURL:
var doc string
if doc, err = swag.ReadDoc(cfg.InstanceName); err != nil {
return err
}
return c.Type("json").SendString(doc)
case "", "/":
return c.Redirect().To(path.Join(prefix, defaultIndex))
default:
// return fs(c)
return static.New("/swagger", static.Config{
FS: swaggerFiles.FS,
Browse: true,
})(c)
}
}
}
func getForwardedPrefix(c fiber.Ctx) string {
header := c.GetReqHeaders()["X-Forwarded-Prefix"]
if len(header) == 0 {
return ""
}
prefix := ""
for _, rawPrefix := range header {
endIndex := len(rawPrefix)
for endIndex > 1 && rawPrefix[endIndex-1] == '/' {
endIndex--
}
if endIndex != len(rawPrefix) {
prefix += rawPrefix[:endIndex]
} else {
prefix += rawPrefix
}
}
return prefix
}

View File

@@ -0,0 +1,107 @@
package swagger
const indexTmpl string = `
<!-- HTML for static distribution bundle build -->
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>{{.Title}}</title>
<link href="https://fonts.googleapis.com/css?family=Open+Sans:400,700|Source+Code+Pro:300,600|Titillium+Web:400,600,700" rel="stylesheet">
<link rel="stylesheet" type="text/css" href="./swagger-ui.css" >
<link rel="icon" type="image/png" href="./favicon-32x32.png" sizes="32x32" />
<link rel="icon" type="image/png" href="./favicon-16x16.png" sizes="16x16" />
{{- if .CustomStyle}}
<style>
body { margin: 0; }
{{.CustomStyle}}
</style>
{{- end}}
{{- if .CustomScript}}
<script>
{{.CustomScript}}
</script>
{{- end}}
</head>
<body>
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" style="position:absolute;width:0;height:0">
<defs>
<symbol viewBox="0 0 20 20" id="unlocked">
<path d="M15.8 8H14V5.6C14 2.703 12.665 1 10 1 7.334 1 6 2.703 6 5.6V6h2v-.801C8 3.754 8.797 3 10 3c1.203 0 2 .754 2 2.199V8H4c-.553 0-1 .646-1 1.199V17c0 .549.428 1.139.951 1.307l1.197.387C5.672 18.861 6.55 19 7.1 19h5.8c.549 0 1.428-.139 1.951-.307l1.196-.387c.524-.167.953-.757.953-1.306V9.199C17 8.646 16.352 8 15.8 8z"></path>
</symbol>
<symbol viewBox="0 0 20 20" id="locked">
<path d="M15.8 8H14V5.6C14 2.703 12.665 1 10 1 7.334 1 6 2.703 6 5.6V8H4c-.553 0-1 .646-1 1.199V17c0 .549.428 1.139.951 1.307l1.197.387C5.672 18.861 6.55 19 7.1 19h5.8c.549 0 1.428-.139 1.951-.307l1.196-.387c.524-.167.953-.757.953-1.306V9.199C17 8.646 16.352 8 15.8 8zM12 8H8V5.199C8 3.754 8.797 3 10 3c1.203 0 2 .754 2 2.199V8z"/>
</symbol>
<symbol viewBox="0 0 20 20" id="close">
<path d="M14.348 14.849c-.469.469-1.229.469-1.697 0L10 11.819l-2.651 3.029c-.469.469-1.229.469-1.697 0-.469-.469-.469-1.229 0-1.697l2.758-3.15-2.759-3.152c-.469-.469-.469-1.228 0-1.697.469-.469 1.228-.469 1.697 0L10 8.183l2.651-3.031c.469-.469 1.228-.469 1.697 0 .469.469.469 1.229 0 1.697l-2.758 3.152 2.758 3.15c.469.469.469 1.229 0 1.698z"/>
</symbol>
<symbol viewBox="0 0 20 20" id="large-arrow">
<path d="M13.25 10L6.109 2.58c-.268-.27-.268-.707 0-.979.268-.27.701-.27.969 0l7.83 7.908c.268.271.268.709 0 .979l-7.83 7.908c-.268.271-.701.27-.969 0-.268-.269-.268-.707 0-.979L13.25 10z"/>
</symbol>
<symbol viewBox="0 0 20 20" id="large-arrow-down">
<path d="M17.418 6.109c.272-.268.709-.268.979 0s.271.701 0 .969l-7.908 7.83c-.27.268-.707.268-.979 0l-7.908-7.83c-.27-.268-.27-.701 0-.969.271-.268.709-.268.979 0L10 13.25l7.418-7.141z"/>
</symbol>
<symbol viewBox="0 0 24 24" id="jump-to">
<path d="M19 7v4H5.83l3.58-3.59L8 6l-6 6 6 6 1.41-1.41L5.83 13H21V7z"/>
</symbol>
<symbol viewBox="0 0 24 24" id="expand">
<path d="M10 18h4v-2h-4v2zM3 6v2h18V6H3zm3 7h12v-2H6v2z"/>
</symbol>
</defs>
</svg>
<div id="swagger-ui"></div>
<script src="./swagger-ui-bundle.js"> </script>
<script src="./swagger-ui-standalone-preset.js"> </script>
<script>
window.onload = function() {
config = {{.}};
config.dom_id = '#swagger-ui';
config.plugins = [
{{- range $plugin := .Plugins }}
{{$plugin}},
{{- end}}
];
config.presets = [
{{- range $preset := .Presets }}
{{$preset}},
{{- end}}
];
config.filter = {{.Filter.Value}}
config.syntaxHighlight = {{.SyntaxHighlight.Value}}
{{if .TagsSorter}}
config.tagsSorter = {{.TagsSorter}}
{{end}}
{{if .OnComplete}}
config.onComplete = {{.OnComplete}}
{{end}}
{{if .RequestInterceptor}}
config.requestInterceptor = {{.RequestInterceptor}}
{{end}}
{{if .ResponseInterceptor}}
config.responseInterceptor = {{.ResponseInterceptor}}
{{end}}
{{if .ModelPropertyMacro}}
config.modelPropertyMacro = {{.ModelPropertyMacro}}
{{end}}
{{if .ParameterMacro}}
config.parameterMacro = {{.ParameterMacro}}
{{end}}
const ui = SwaggerUIBundle(config);
{{if .OAuth}}
ui.initOAuth({{.OAuth}});
{{end}}
{{if .PreauthorizeBasic}}
ui.preauthorizeBasic({{.PreauthorizeBasic}});
{{end}}
{{if .PreauthorizeApiKey}}
ui.preauthorizeApiKey({{.PreauthorizeApiKey}});
{{end}}
window.ui = ui
}
</script>
</body>
</html>
`

View File

@@ -0,0 +1,67 @@
package job
import (
"github.com/riverqueue/river"
"go.ipao.vip/atom/container"
"go.ipao.vip/atom/opt"
)
const DefaultPrefix = "Job"
func DefaultProvider() container.ProviderContainer {
return container.ProviderContainer{
Provider: Provide,
Options: []opt.Option{
opt.Prefix(DefaultPrefix),
},
}
}
type Config struct {
// Optional per-queue worker concurrency. If empty, defaults apply.
QueueWorkers QueueWorkersConfig
}
// QueueWorkers allows configuring worker concurrency per queue.
// Key is the queue name, value is MaxWorkers. If empty, defaults are used.
// Example TOML:
//
// [Job]
// # high=20, default=10, low=5
// # QueueWorkers = { high = 20, default = 10, low = 5 }
type QueueWorkersConfig map[string]int
const (
PriorityDefault = river.PriorityDefault
PriorityLow = 2
PriorityMiddle = 3
PriorityHigh = 3
)
const (
QueueHigh = "high"
QueueDefault = river.QueueDefault
QueueLow = "low"
)
// queueConfig returns a river.QueueConfig map built from QueueWorkers or defaults.
func (c *Config) queueConfig() map[string]river.QueueConfig {
cfg := map[string]river.QueueConfig{}
if c == nil || len(c.QueueWorkers) == 0 {
cfg[QueueHigh] = river.QueueConfig{MaxWorkers: 10}
cfg[QueueDefault] = river.QueueConfig{MaxWorkers: 10}
cfg[QueueLow] = river.QueueConfig{MaxWorkers: 10}
return cfg
}
for name, n := range c.QueueWorkers {
if n <= 0 {
n = 1
}
cfg[name] = river.QueueConfig{MaxWorkers: n}
}
if _, ok := cfg[QueueDefault]; !ok {
cfg[QueueDefault] = river.QueueConfig{MaxWorkers: 10}
}
return cfg
}

View File

@@ -0,0 +1,207 @@
package job
import (
"context"
"fmt"
"sync"
"time"
"quyun/v2/providers/postgres"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/pkg/errors"
"github.com/riverqueue/river"
"github.com/riverqueue/river/riverdriver/riverpgxv5"
"github.com/riverqueue/river/rivertype"
"github.com/samber/lo"
log "github.com/sirupsen/logrus"
"go.ipao.vip/atom/container"
"go.ipao.vip/atom/contracts"
"go.ipao.vip/atom/opt"
)
func Provide(opts ...opt.Option) error {
o := opt.New(opts...)
var config Config
if err := o.UnmarshalConfig(&config); err != nil {
return err
}
return container.Container.Provide(func(ctx context.Context, dbConf *postgres.Config) (*Job, error) {
workers := river.NewWorkers()
dbPoolConfig, err := pgxpool.ParseConfig(dbConf.DSN())
if err != nil {
return nil, err
}
dbPool, err := pgxpool.NewWithConfig(ctx, dbPoolConfig)
if err != nil {
return nil, err
}
// health check ping with timeout
pingCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
if err := dbPool.Ping(pingCtx); err != nil {
return nil, fmt.Errorf("job provider: db ping failed: %w", err)
}
container.AddCloseAble(dbPool.Close)
pool := riverpgxv5.New(dbPool)
queue := &Job{
Workers: workers,
driver: pool,
ctx: ctx,
conf: &config,
periodicJobs: make(map[string]rivertype.PeriodicJobHandle),
}
container.AddCloseAble(queue.Close)
return queue, nil
}, o.DiOptions()...)
}
type Job struct {
ctx context.Context
conf *Config
Workers *river.Workers
driver *riverpgxv5.Driver
l sync.Mutex
client *river.Client[pgx.Tx]
periodicJobs map[string]rivertype.PeriodicJobHandle
}
func (q *Job) Close() {
if q.client == nil {
return
}
if err := q.client.StopAndCancel(q.ctx); err != nil {
log.Errorf("Failed to stop and cancel client: %s", err)
}
// clear references
q.l.Lock()
q.periodicJobs = map[string]rivertype.PeriodicJobHandle{}
q.l.Unlock()
}
func (q *Job) Client() (*river.Client[pgx.Tx], error) {
q.l.Lock()
defer q.l.Unlock()
if q.client == nil {
var err error
q.client, err = river.NewClient(q.driver, &river.Config{
Workers: q.Workers,
Queues: q.conf.queueConfig(),
})
if err != nil {
return nil, err
}
}
return q.client, nil
}
func (q *Job) Start(ctx context.Context) error {
client, err := q.Client()
if err != nil {
return errors.Wrap(err, "get client failed")
}
if err := client.Start(ctx); err != nil {
return err
}
defer client.StopAndCancel(ctx)
<-ctx.Done()
return nil
}
func (q *Job) StopAndCancel(ctx context.Context) error {
client, err := q.Client()
if err != nil {
return errors.Wrap(err, "get client failed")
}
return client.StopAndCancel(ctx)
}
func (q *Job) AddPeriodicJobs(job contracts.CronJob) error {
for _, job := range job.Args() {
if err := q.AddPeriodicJob(job); err != nil {
return err
}
}
return nil
}
func (q *Job) AddPeriodicJob(job contracts.CronJobArg) error {
client, err := q.Client()
if err != nil {
return err
}
q.l.Lock()
defer q.l.Unlock()
q.periodicJobs[job.Arg.UniqueID()] = client.PeriodicJobs().Add(river.NewPeriodicJob(
job.PeriodicInterval,
func() (river.JobArgs, *river.InsertOpts) {
return job.Arg, lo.ToPtr(job.Arg.InsertOpts())
},
&river.PeriodicJobOpts{
RunOnStart: job.RunOnStart,
},
))
return nil
}
func (q *Job) Cancel(id string) error {
client, err := q.Client()
if err != nil {
return err
}
q.l.Lock()
defer q.l.Unlock()
if h, ok := q.periodicJobs[id]; ok {
client.PeriodicJobs().Remove(h)
delete(q.periodicJobs, id)
}
return nil
}
// CancelContext is like Cancel but allows passing a context.
func (q *Job) CancelContext(ctx context.Context, id string) error {
client, err := q.Client()
if err != nil {
return err
}
q.l.Lock()
defer q.l.Unlock()
if h, ok := q.periodicJobs[id]; ok {
client.PeriodicJobs().Remove(h)
delete(q.periodicJobs, id)
return nil
}
return nil
}
func (q *Job) Add(job contracts.JobArgs) error {
client, err := q.Client()
if err != nil {
return err
}
q.l.Lock()
defer q.l.Unlock()
_, err = client.Insert(q.ctx, job, lo.ToPtr(job.InsertOpts()))
return err
}

View File

@@ -0,0 +1,35 @@
package jwt
import (
"time"
log "github.com/sirupsen/logrus"
"go.ipao.vip/atom/container"
"go.ipao.vip/atom/opt"
)
const DefaultPrefix = "JWT"
func DefaultProvider() container.ProviderContainer {
return container.ProviderContainer{
Provider: Provide,
Options: []opt.Option{
opt.Prefix(DefaultPrefix),
},
}
}
type Config struct {
SigningKey string // jwt签名
ExpiresTime string // 过期时间
Issuer string // 签发者
}
func (c *Config) ExpiresTimeDuration() time.Duration {
d, err := time.ParseDuration(c.ExpiresTime)
if err != nil {
log.Fatal(err)
}
return d
}

View File

@@ -0,0 +1,118 @@
package jwt
import (
"errors"
"strings"
"time"
"go.ipao.vip/atom/container"
"go.ipao.vip/atom/opt"
jwt "github.com/golang-jwt/jwt/v4"
"golang.org/x/sync/singleflight"
)
const (
CtxKey = "claims"
HttpHeader = "Authorization"
)
type BaseClaims struct {
OpenID string `json:"open_id,omitempty"`
Tenant string `json:"tenant,omitempty"`
UserID int64 `json:"user_id,omitempty"`
TenantID int64 `json:"tenant_id,omitempty"`
}
// Custom claims structure
type Claims struct {
BaseClaims
jwt.RegisteredClaims
}
const TokenPrefix = "Bearer "
type JWT struct {
singleflight *singleflight.Group
config *Config
SigningKey []byte
}
var (
TokenExpired = errors.New("Token is expired")
TokenNotValidYet = errors.New("Token not active yet")
TokenMalformed = errors.New("That's not even a token")
TokenInvalid = errors.New("Couldn't handle this token:")
)
func Provide(opts ...opt.Option) error {
o := opt.New(opts...)
var config Config
if err := o.UnmarshalConfig(&config); err != nil {
return err
}
return container.Container.Provide(func() (*JWT, error) {
return &JWT{
singleflight: &singleflight.Group{},
config: &config,
SigningKey: []byte(config.SigningKey),
}, nil
}, o.DiOptions()...)
}
func (j *JWT) CreateClaims(baseClaims BaseClaims) *Claims {
ep, _ := time.ParseDuration(j.config.ExpiresTime)
claims := Claims{
BaseClaims: baseClaims,
RegisteredClaims: jwt.RegisteredClaims{
NotBefore: jwt.NewNumericDate(time.Now().Add(-time.Second * 10)), // 签名生效时间
ExpiresAt: jwt.NewNumericDate(time.Now().Add(ep)), // 过期时间 7天 配置文件
Issuer: j.config.Issuer, // 签名的发行者
},
}
return &claims
}
// 创建一个token
func (j *JWT) CreateToken(claims *Claims) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(j.SigningKey)
}
// CreateTokenByOldToken 旧token 换新token 使用归并回源避免并发问题
func (j *JWT) CreateTokenByOldToken(oldToken string, claims *Claims) (string, error) {
v, err, _ := j.singleflight.Do("JWT:"+oldToken, func() (interface{}, error) {
return j.CreateToken(claims)
})
return v.(string), err
}
// 解析 token
func (j *JWT) Parse(tokenString string) (*Claims, error) {
tokenString = strings.TrimPrefix(tokenString, TokenPrefix)
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (i interface{}, e error) {
return j.SigningKey, nil
})
if err != nil {
if ve, ok := err.(*jwt.ValidationError); ok {
if ve.Errors&jwt.ValidationErrorMalformed != 0 {
return nil, TokenMalformed
} else if ve.Errors&jwt.ValidationErrorExpired != 0 {
// Token is expired
return nil, TokenExpired
} else if ve.Errors&jwt.ValidationErrorNotValidYet != 0 {
return nil, TokenNotValidYet
} else {
return nil, TokenInvalid
}
}
}
if token != nil {
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
return claims, nil
}
return nil, TokenInvalid
} else {
return nil, TokenInvalid
}
}

View File

@@ -0,0 +1,136 @@
package postgres
import (
"fmt"
"strconv"
"time"
"go.ipao.vip/atom/container"
"go.ipao.vip/atom/opt"
"gorm.io/gorm/logger"
)
const DefaultPrefix = "Database"
func DefaultProvider() container.ProviderContainer {
return container.ProviderContainer{
Provider: Provide,
Options: []opt.Option{
opt.Prefix(DefaultPrefix),
},
}
}
type Config struct {
Username string
Password string
Database string
Schema string
Host string
Port uint
SslMode string
TimeZone string
Prefix string // 表前缀
Singular bool // 是否开启全局禁用复数true表示开启
MaxIdleConns int // 空闲中的最大连接数
MaxOpenConns int // 打开到数据库的最大连接数
// 可选连接生命周期配置0 表示不设置)
ConnMaxLifetimeSeconds uint
ConnMaxIdleTimeSeconds uint
// 可选GORM 日志与行为配置
LogLevel string // silent|error|warn|info默认info
SlowThresholdMs uint // 慢查询阈值毫秒默认200
ParameterizedQueries bool // 占位符输出,便于日志安全与查询归并
PrepareStmt bool // 预编译语句缓存
SkipDefaultTransaction bool // 跳过默认事务
// 可选DSN 增强
UseSearchPath bool // 在 DSN 中附带 search_path
ApplicationName string // application_name
}
func (m Config) GormSlowThreshold() time.Duration {
if m.SlowThresholdMs == 0 {
return 200 * time.Millisecond // 默认200ms
}
return time.Duration(m.SlowThresholdMs) * time.Millisecond
}
func (m Config) GormLogLevel() logger.LogLevel {
switch m.LogLevel {
case "silent":
return logger.Silent
case "error":
return logger.Error
case "warn":
return logger.Warn
case "info", "":
return logger.Info
default:
return logger.Info
}
}
func (m *Config) checkDefault() {
if m.MaxIdleConns == 0 {
m.MaxIdleConns = 10
}
if m.MaxOpenConns == 0 {
m.MaxOpenConns = 100
}
if m.Username == "" {
m.Username = "postgres"
}
if m.SslMode == "" {
m.SslMode = "disable"
}
if m.TimeZone == "" {
m.TimeZone = "Asia/Shanghai"
}
if m.Port == 0 {
m.Port = 5432
}
if m.Schema == "" {
m.Schema = "public"
}
}
func (m *Config) EmptyDsn() string {
// 基本 DSN
dsnTpl := "host=%s user=%s password=%s port=%d dbname=%s sslmode=%s TimeZone=%s"
m.checkDefault()
base := fmt.Sprintf(dsnTpl, m.Host, m.Username, m.Password, m.Port, m.Database, m.SslMode, m.TimeZone)
// 附加可选参数
extras := ""
if m.UseSearchPath && m.Schema != "" {
extras += " search_path=" + m.Schema
}
if m.ApplicationName != "" {
extras += " application_name=" + strconv.Quote(m.ApplicationName)
}
return base + extras
}
// DSN connection dsn
func (m *Config) DSN() string {
// 基本 DSN
dsnTpl := "host=%s user=%s password=%s dbname=%s port=%d sslmode=%s TimeZone=%s"
m.checkDefault()
base := fmt.Sprintf(dsnTpl, m.Host, m.Username, m.Password, m.Database, m.Port, m.SslMode, m.TimeZone)
// 附加可选参数
extras := ""
if m.UseSearchPath && m.Schema != "" {
extras += " search_path=" + m.Schema
}
if m.ApplicationName != "" {
extras += " application_name=" + strconv.Quote(m.ApplicationName)
}
return base + extras
}

View File

@@ -0,0 +1,91 @@
package postgres
import (
"context"
"database/sql"
"time"
"github.com/sirupsen/logrus"
"go.ipao.vip/atom/container"
"go.ipao.vip/atom/opt"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
)
func Provide(opts ...opt.Option) error {
o := opt.New(opts...)
var conf Config
if err := o.UnmarshalConfig(&conf); err != nil {
return err
}
return container.Container.Provide(func() (*gorm.DB, *sql.DB, *Config, error) {
dbConfig := postgres.Config{DSN: conf.DSN()}
// 安全日志:不打印密码,仅输出关键连接信息
logrus.
WithFields(
logrus.Fields{
"host": conf.Host,
"port": conf.Port,
"db": conf.Database,
"schema": conf.Schema,
"ssl": conf.SslMode,
},
).
Info("opening PostgreSQL connection")
// 映射日志等级
lvl := conf.GormLogLevel()
slow := conf.GormSlowThreshold()
gormConfig := gorm.Config{
NamingStrategy: schema.NamingStrategy{
TablePrefix: conf.Prefix,
SingularTable: conf.Singular,
},
DisableForeignKeyConstraintWhenMigrating: true,
PrepareStmt: conf.PrepareStmt,
SkipDefaultTransaction: conf.SkipDefaultTransaction,
Logger: logger.New(logrus.StandardLogger(), logger.Config{
SlowThreshold: slow,
LogLevel: lvl,
IgnoreRecordNotFoundError: true,
Colorful: false,
ParameterizedQueries: conf.ParameterizedQueries,
}),
}
db, err := gorm.Open(postgres.New(dbConfig), &gormConfig)
if err != nil {
return nil, nil, nil, err
}
sqlDB, err := db.DB()
if err != nil {
return nil, sqlDB, nil, err
}
sqlDB.SetMaxIdleConns(conf.MaxIdleConns)
sqlDB.SetMaxOpenConns(conf.MaxOpenConns)
if conf.ConnMaxLifetimeSeconds > 0 {
sqlDB.SetConnMaxLifetime(time.Duration(conf.ConnMaxLifetimeSeconds) * time.Second)
}
if conf.ConnMaxIdleTimeSeconds > 0 {
sqlDB.SetConnMaxIdleTime(time.Duration(conf.ConnMaxIdleTimeSeconds) * time.Second)
}
// Ping 校验
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := sqlDB.PingContext(ctx); err != nil {
return nil, sqlDB, nil, err
}
// 关闭钩子
container.AddCloseAble(func() { _ = sqlDB.Close() })
return db, sqlDB, &conf, nil
}, o.DiOptions()...)
}