package jwt import ( "errors" "fmt" "strings" "time" "go.ipao.vip/atom/container" "go.ipao.vip/atom/opt" jwt "github.com/golang-jwt/jwt/v4" "golang.org/x/sync/singleflight" ) const ( CtxKey = "claims" HTTPHeader = "Authorization" ) var ErrTokenInvalidType = errors.New("token cache returned non-string value") type BaseClaims struct { OpenID string `json:"open_id,omitempty"` Tenant string `json:"tenant,omitempty"` UserID int64 `json:"user_id,omitempty"` TenantID int64 `json:"tenant_id,omitempty"` } // Custom claims structure type Claims struct { BaseClaims jwt.RegisteredClaims } const TokenPrefix = "Bearer " type JWT struct { singleflight *singleflight.Group config *Config SigningKey []byte } var ( ErrTokenExpired = errors.New("Token is expired") ErrTokenNotValidYet = errors.New("Token not active yet") ErrTokenMalformed = errors.New("That's not even a token") ErrTokenInvalid = errors.New("Couldn't handle this token") ) func Provide(opts ...opt.Option) error { options := opt.New(opts...) var config Config if err := options.UnmarshalConfig(&config); err != nil { return fmt.Errorf("unmarshal jwt config: %w", err) } if err := container.Container.Provide(func() (*JWT, error) { return &JWT{ singleflight: &singleflight.Group{}, config: &config, SigningKey: []byte(config.SigningKey), }, nil }, options.DiOptions()...); err != nil { return fmt.Errorf("provide jwt: %w", err) } return nil } func (jwtProvider *JWT) CreateClaims(baseClaims BaseClaims) *Claims { expiresDuration, _ := time.ParseDuration(jwtProvider.config.ExpiresTime) claims := Claims{ BaseClaims: baseClaims, RegisteredClaims: jwt.RegisteredClaims{ NotBefore: jwt.NewNumericDate(time.Now().Add(-time.Second * 10)), // 签名生效时间 ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiresDuration)), // 过期时间 7天 配置文件 Issuer: jwtProvider.config.Issuer, // 签名的发行者 }, } return &claims } // 创建一个token func (jwtProvider *JWT) CreateToken(claims *Claims) (string, error) { token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) return token.SignedString(jwtProvider.SigningKey) } // CreateTokenByOldToken 旧token 换新token 使用归并回源避免并发问题 func (jwtProvider *JWT) CreateTokenByOldToken(oldToken string, claims *Claims) (string, error) { value, err, _ := jwtProvider.singleflight.Do("JWT:"+oldToken, func() (interface{}, error) { return jwtProvider.CreateToken(claims) }) tokenString, ok := value.(string) if !ok { return "", ErrTokenInvalidType } if err != nil { return "", fmt.Errorf("create token by old token: %w", err) } return tokenString, nil } // 解析 token func (jwtProvider *JWT) Parse(tokenString string) (*Claims, error) { tokenString = strings.TrimPrefix(tokenString, TokenPrefix) token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(_ *jwt.Token) (interface{}, error) { return jwtProvider.SigningKey, nil }) if err != nil { var validationErr *jwt.ValidationError if errors.As(err, &validationErr) { if validationErr.Errors&jwt.ValidationErrorMalformed != 0 { return nil, ErrTokenMalformed } if validationErr.Errors&jwt.ValidationErrorExpired != 0 { // Token is expired return nil, ErrTokenExpired } if validationErr.Errors&jwt.ValidationErrorNotValidYet != 0 { return nil, ErrTokenNotValidYet } return nil, ErrTokenInvalid } } if token != nil { if claims, ok := token.Claims.(*Claims); ok && token.Valid { return claims, nil } return nil, ErrTokenInvalid } return nil, ErrTokenInvalid }