add logics
This commit is contained in:
6
backend/common/consts/jwt.go
Normal file
6
backend/common/consts/jwt.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package consts
|
||||
|
||||
const (
|
||||
JwtToken = "__jwt_token:"
|
||||
SessionUser = "__session_user:"
|
||||
)
|
||||
@@ -1,6 +1,7 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"backend/modules/middlewares"
|
||||
"backend/modules/users"
|
||||
"backend/providers/app"
|
||||
"backend/providers/http"
|
||||
@@ -31,7 +32,9 @@ func Command() atom.Option {
|
||||
atom.Name("serve"),
|
||||
atom.Short("run http server"),
|
||||
atom.RunE(Serve),
|
||||
atom.Providers(providers),
|
||||
atom.Providers(providers.With(
|
||||
middlewares.Provide,
|
||||
)),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -41,6 +44,7 @@ type Http struct {
|
||||
Service *http.Service
|
||||
Initials []contracts.Initial `group:"initials"`
|
||||
Routes []contracts.HttpRoute `group:"routes"`
|
||||
Middlewares *middlewares.Middlewares
|
||||
}
|
||||
|
||||
func Serve(cmd *cobra.Command, args []string) error {
|
||||
@@ -51,6 +55,11 @@ func Serve(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
mid := http.Middlewares
|
||||
http.Service.Engine.Use(mid.Verify)
|
||||
http.Service.Engine.Use(mid.AuthUserInfo)
|
||||
http.Service.Engine.Use(mid.SilentAuth)
|
||||
|
||||
return http.Service.Serve()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
ignores: [] # ignore tables
|
||||
types:
|
||||
kube_pods: # table name
|
||||
labels: backend/pkg/pg.JsonMap # column type
|
||||
users: # table name
|
||||
oauth: backend/pkg/pg.UserOAuth
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"backend/pkg/pg"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -15,7 +16,7 @@ type Users struct {
|
||||
ID int64 `sql:"primary_key" json:"id"`
|
||||
OpenID string `json:"open_id"`
|
||||
UnionID *string `json:"union_id"`
|
||||
OAuth *string `json:"oauth"`
|
||||
OAuth pg.UserOAuth `json:"oauth"`
|
||||
ExpireIn time.Time `json:"expire_in"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
@@ -9,6 +9,7 @@ require (
|
||||
github.com/gofrs/uuid v4.4.0+incompatible
|
||||
github.com/golang-jwt/jwt/v4 v4.5.1
|
||||
github.com/imroc/req/v3 v3.48.0
|
||||
github.com/jinzhu/copier v0.4.0
|
||||
github.com/juju/go4 v0.0.0-20160222163258-40d72ab9641a
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/pkg/errors v0.9.1
|
||||
@@ -18,6 +19,7 @@ require (
|
||||
github.com/smartystreets/goconvey v1.6.4
|
||||
github.com/speps/go-hashids/v2 v2.0.1
|
||||
github.com/spf13/cobra v1.8.1
|
||||
github.com/spf13/viper v1.17.0
|
||||
go.uber.org/dig v1.18.0
|
||||
golang.org/x/net v0.31.0
|
||||
golang.org/x/sync v0.9.0
|
||||
@@ -68,7 +70,6 @@ require (
|
||||
github.com/spf13/afero v1.10.0 // indirect
|
||||
github.com/spf13/cast v1.5.1 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/spf13/viper v1.17.0 // indirect
|
||||
github.com/stretchr/testify v1.9.0 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
|
||||
@@ -228,6 +228,8 @@ github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0f
|
||||
github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
|
||||
github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
|
||||
github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
|
||||
github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
|
||||
github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
|
||||
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
|
||||
github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk=
|
||||
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
|
||||
|
||||
@@ -1,21 +1,31 @@
|
||||
package http
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"backend/common/consts"
|
||||
"backend/modules/users"
|
||||
"backend/pkg/pg"
|
||||
"backend/providers/jwt"
|
||||
"backend/providers/wechat"
|
||||
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/jinzhu/copier"
|
||||
"github.com/pkg/errors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// @provider
|
||||
type Middlewares struct {
|
||||
client *wechat.Client
|
||||
userSvc *users.Service
|
||||
jwt *jwt.JWT
|
||||
log *log.Entry `inject:"false"`
|
||||
}
|
||||
|
||||
func Init(client *wechat.Client) *Middlewares {
|
||||
return &Middlewares{
|
||||
client: client,
|
||||
}
|
||||
func (f *Middlewares) Prepare() error {
|
||||
f.log = log.WithField("module", "middleware")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Middlewares) Verify(c fiber.Ctx) error {
|
||||
@@ -29,13 +39,12 @@ func (f *Middlewares) Verify(c fiber.Ctx) error {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
log.Infof(
|
||||
"begin verify signature, signature: %s, timestamp: %s, nonce: %s, echostr: %s",
|
||||
signature,
|
||||
timestamp,
|
||||
nonce,
|
||||
echostr,
|
||||
)
|
||||
log.WithField("method", "Verify").WithFields(log.Fields{
|
||||
"signature": signature,
|
||||
"timestamp": timestamp,
|
||||
"nonce": nonce,
|
||||
"echostr": echostr,
|
||||
}).Debug("begin verify signature")
|
||||
|
||||
// verify the signature
|
||||
if err := f.client.Verify(signature, timestamp, nonce); err != nil {
|
||||
@@ -47,9 +56,22 @@ func (f *Middlewares) Verify(c fiber.Ctx) error {
|
||||
|
||||
func (f *Middlewares) SilentAuth(c fiber.Ctx) error {
|
||||
// if cookie not exists key "openid", then redirect to the wechat auth page
|
||||
sid := c.Cookies("sid", "")
|
||||
if sid != "" {
|
||||
// TODO: verify sid
|
||||
tokenCookie := c.Cookies("token", "")
|
||||
if tokenCookie != "" {
|
||||
claim, err := f.jwt.Parse(tokenCookie)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to parse token")
|
||||
}
|
||||
|
||||
// query user
|
||||
user, err := f.userSvc.GetByOpenID(c.Context(), claim.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get user")
|
||||
}
|
||||
|
||||
c.SetUserContext(context.WithValue(c.UserContext(), consts.JwtToken, tokenCookie))
|
||||
c.SetUserContext(context.WithValue(c.UserContext(), consts.SessionUser, user))
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
@@ -88,12 +110,25 @@ func (f *Middlewares) AuthUserInfo(c fiber.Ctx) error {
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get openid")
|
||||
}
|
||||
// TODO: store the openid to the session
|
||||
|
||||
var oauthInfo pg.UserOAuth
|
||||
copier.Copy(&oauthInfo, token)
|
||||
user, err := f.userSvc.GetOrNew(c.Context(), token.Openid, oauthInfo)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get user")
|
||||
}
|
||||
|
||||
claim := f.jwt.CreateClaims(jwt.BaseClaims{UID: uint64(user.ID)})
|
||||
claim.ID = user.OpenID
|
||||
jwtToken, err := f.jwt.CreateToken(claim)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create token")
|
||||
}
|
||||
|
||||
// set the openid to the cookie
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: "sid",
|
||||
Value: token.Openid,
|
||||
Value: jwtToken,
|
||||
HTTPOnly: true,
|
||||
})
|
||||
|
||||
26
backend/modules/middlewares/provider.gen.go
Executable file
26
backend/modules/middlewares/provider.gen.go
Executable file
@@ -0,0 +1,26 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"backend/providers/wechat"
|
||||
|
||||
"git.ipao.vip/rogeecn/atom/container"
|
||||
"git.ipao.vip/rogeecn/atom/utils/opt"
|
||||
)
|
||||
|
||||
func Provide(opts ...opt.Option) error {
|
||||
if err := container.Container.Provide(func(
|
||||
client *wechat.Client,
|
||||
) (*Middlewares, error) {
|
||||
obj := &Middlewares{
|
||||
client: client,
|
||||
}
|
||||
if err := obj.Prepare(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return obj, nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -9,10 +9,5 @@ type Controller struct {
|
||||
|
||||
// List
|
||||
func (c *Controller) List(ctx fiber.Ctx) error {
|
||||
resp, err := c.svc.List(ctx.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ctx.JSON(resp)
|
||||
return ctx.JSON(nil)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
|
||||
"backend/database/models/qvyun/public/model"
|
||||
"backend/database/models/qvyun/public/table"
|
||||
"backend/pkg/pg"
|
||||
|
||||
. "github.com/go-jet/jet/v2/postgres"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
@@ -22,15 +24,43 @@ func (svc *Service) Prepare() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// List
|
||||
func (svc *Service) List(ctx context.Context) ([]model.Users, error) {
|
||||
// GetByOpenID
|
||||
func (svc *Service) GetByOpenID(ctx context.Context, openid string) (*model.Users, error) {
|
||||
tbl := table.Users
|
||||
stmt := tbl.SELECT(tbl.AllColumns)
|
||||
svc.log.WithField("method", "List").Debug(stmt.DebugSql())
|
||||
stmt := tbl.
|
||||
SELECT(tbl.AllColumns).
|
||||
WHERE(
|
||||
tbl.OpenID.EQ(String(openid)),
|
||||
)
|
||||
svc.log.WithField("method", "GetByOpenID").Debug(stmt.DebugSql())
|
||||
|
||||
var items []model.Users
|
||||
if err := stmt.QueryContext(ctx, svc.db, &items); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to query users")
|
||||
var item model.Users
|
||||
if err := stmt.QueryContext(ctx, svc.db, &item); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to query user by openid")
|
||||
}
|
||||
return items, nil
|
||||
return &item, nil
|
||||
}
|
||||
|
||||
// GetOrNew
|
||||
func (svc *Service) GetOrNew(ctx context.Context, openid string, authInfo pg.UserOAuth) (*model.Users, error) {
|
||||
user, err := svc.GetByOpenID(ctx, openid)
|
||||
if err == nil {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
// user = &model.Users{
|
||||
// OpenID: openid,
|
||||
// OAuth:,authInfo
|
||||
// }
|
||||
// if err := user.Insert(ctx, svc.db, table.Users); err != nil {
|
||||
// return nil, errors.Wrap(err, "failed to insert user")
|
||||
// }
|
||||
}
|
||||
return nil, errors.Wrap(err, "failed to get user by openid")
|
||||
|
||||
}
|
||||
|
||||
return nil, errors.New("unknown error")
|
||||
}
|
||||
|
||||
11
backend/pkg/pg/users.go
Normal file
11
backend/pkg/pg/users.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package pg
|
||||
|
||||
type UserOAuth struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
IsSnapshotuser int64 `json:"is_snapshotuser"`
|
||||
Openid string `json:"openid"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Scope string `json:"scope"`
|
||||
Unionid string `json:"unionid"`
|
||||
}
|
||||
@@ -51,6 +51,7 @@ func Provide(opts ...opt.Option) error {
|
||||
}
|
||||
return container.Container.Provide(func() (*JWT, error) {
|
||||
return &JWT{
|
||||
singleflight: &singleflight.Group{},
|
||||
config: &config,
|
||||
SigningKey: []byte(config.SigningKey),
|
||||
}, nil
|
||||
@@ -85,7 +86,7 @@ func (j *JWT) CreateTokenByOldToken(oldToken string, claims *Claims) (string, er
|
||||
}
|
||||
|
||||
// 解析 token
|
||||
func (j *JWT) ParseToken(tokenString string) (*Claims, error) {
|
||||
func (j *JWT) Parse(tokenString string) (*Claims, error) {
|
||||
tokenString = strings.TrimPrefix(tokenString, TokenPrefix)
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (i interface{}, e error) {
|
||||
return j.SigningKey, nil
|
||||
|
||||
Reference in New Issue
Block a user