diff --git a/backend/common/consts/jwt.go b/backend/common/consts/jwt.go new file mode 100644 index 0000000..0d4cb1a --- /dev/null +++ b/backend/common/consts/jwt.go @@ -0,0 +1,6 @@ +package consts + +const ( + JwtToken = "__jwt_token:" + SessionUser = "__session_user:" +) diff --git a/backend/common/service/http/http.go b/backend/common/service/http/http.go index 8e14e81..48d83ee 100644 --- a/backend/common/service/http/http.go +++ b/backend/common/service/http/http.go @@ -1,6 +1,7 @@ package http import ( + "backend/modules/middlewares" "backend/modules/users" "backend/providers/app" "backend/providers/http" @@ -31,16 +32,19 @@ func Command() atom.Option { atom.Name("serve"), atom.Short("run http server"), atom.RunE(Serve), - atom.Providers(providers), + atom.Providers(providers.With( + middlewares.Provide, + )), ) } type Http struct { dig.In - Service *http.Service - Initials []contracts.Initial `group:"initials"` - Routes []contracts.HttpRoute `group:"routes"` + Service *http.Service + Initials []contracts.Initial `group:"initials"` + Routes []contracts.HttpRoute `group:"routes"` + Middlewares *middlewares.Middlewares } func Serve(cmd *cobra.Command, args []string) error { @@ -51,6 +55,11 @@ func Serve(cmd *cobra.Command, args []string) error { } } + mid := http.Middlewares + http.Service.Engine.Use(mid.Verify) + http.Service.Engine.Use(mid.AuthUserInfo) + http.Service.Engine.Use(mid.SilentAuth) + return http.Service.Serve() }) } diff --git a/backend/database/.transform.yaml b/backend/database/.transform.yaml index 6ec2842..f16420c 100755 --- a/backend/database/.transform.yaml +++ b/backend/database/.transform.yaml @@ -1,4 +1,4 @@ ignores: [] # ignore tables types: - kube_pods: # table name - labels: backend/pkg/pg.JsonMap # column type + users: # table name + oauth: backend/pkg/pg.UserOAuth diff --git a/backend/database/models/qvyun/public/model/users.go b/backend/database/models/qvyun/public/model/users.go index 81ade60..66f2f5c 100644 --- a/backend/database/models/qvyun/public/model/users.go +++ b/backend/database/models/qvyun/public/model/users.go @@ -8,15 +8,16 @@ package model import ( + "backend/pkg/pg" "time" ) type Users struct { - ID int64 `sql:"primary_key" json:"id"` - OpenID string `json:"open_id"` - UnionID *string `json:"union_id"` - OAuth *string `json:"oauth"` - ExpireIn time.Time `json:"expire_in"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID int64 `sql:"primary_key" json:"id"` + OpenID string `json:"open_id"` + UnionID *string `json:"union_id"` + OAuth pg.UserOAuth `json:"oauth"` + ExpireIn time.Time `json:"expire_in"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } diff --git a/backend/go.mod b/backend/go.mod index d3e0791..3682028 100755 --- a/backend/go.mod +++ b/backend/go.mod @@ -9,6 +9,7 @@ require ( github.com/gofrs/uuid v4.4.0+incompatible github.com/golang-jwt/jwt/v4 v4.5.1 github.com/imroc/req/v3 v3.48.0 + github.com/jinzhu/copier v0.4.0 github.com/juju/go4 v0.0.0-20160222163258-40d72ab9641a github.com/lib/pq v1.10.9 github.com/pkg/errors v0.9.1 @@ -18,6 +19,7 @@ require ( github.com/smartystreets/goconvey v1.6.4 github.com/speps/go-hashids/v2 v2.0.1 github.com/spf13/cobra v1.8.1 + github.com/spf13/viper v1.17.0 go.uber.org/dig v1.18.0 golang.org/x/net v0.31.0 golang.org/x/sync v0.9.0 @@ -68,7 +70,6 @@ require ( github.com/spf13/afero v1.10.0 // indirect github.com/spf13/cast v1.5.1 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/spf13/viper v1.17.0 // indirect github.com/stretchr/testify v1.9.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index a1fa37c..59d1e39 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -228,6 +228,8 @@ github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0f github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8= +github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= diff --git a/backend/common/service/http/middlewares.go b/backend/modules/middlewares/middlewares.go similarity index 55% rename from backend/common/service/http/middlewares.go rename to backend/modules/middlewares/middlewares.go index 39a5f0d..88b2177 100644 --- a/backend/common/service/http/middlewares.go +++ b/backend/modules/middlewares/middlewares.go @@ -1,21 +1,31 @@ -package http +package middlewares import ( + "context" + + "backend/common/consts" + "backend/modules/users" + "backend/pkg/pg" + "backend/providers/jwt" "backend/providers/wechat" "github.com/gofiber/fiber/v3" + "github.com/jinzhu/copier" "github.com/pkg/errors" log "github.com/sirupsen/logrus" ) +// @provider type Middlewares struct { - client *wechat.Client + client *wechat.Client + userSvc *users.Service + jwt *jwt.JWT + log *log.Entry `inject:"false"` } -func Init(client *wechat.Client) *Middlewares { - return &Middlewares{ - client: client, - } +func (f *Middlewares) Prepare() error { + f.log = log.WithField("module", "middleware") + return nil } func (f *Middlewares) Verify(c fiber.Ctx) error { @@ -29,13 +39,12 @@ func (f *Middlewares) Verify(c fiber.Ctx) error { return c.Next() } - log.Infof( - "begin verify signature, signature: %s, timestamp: %s, nonce: %s, echostr: %s", - signature, - timestamp, - nonce, - echostr, - ) + log.WithField("method", "Verify").WithFields(log.Fields{ + "signature": signature, + "timestamp": timestamp, + "nonce": nonce, + "echostr": echostr, + }).Debug("begin verify signature") // verify the signature if err := f.client.Verify(signature, timestamp, nonce); err != nil { @@ -47,9 +56,22 @@ func (f *Middlewares) Verify(c fiber.Ctx) error { func (f *Middlewares) SilentAuth(c fiber.Ctx) error { // if cookie not exists key "openid", then redirect to the wechat auth page - sid := c.Cookies("sid", "") - if sid != "" { - // TODO: verify sid + tokenCookie := c.Cookies("token", "") + if tokenCookie != "" { + claim, err := f.jwt.Parse(tokenCookie) + if err != nil { + return errors.Wrap(err, "failed to parse token") + } + + // query user + user, err := f.userSvc.GetByOpenID(c.Context(), claim.ID) + if err != nil { + return errors.Wrap(err, "failed to get user") + } + + c.SetUserContext(context.WithValue(c.UserContext(), consts.JwtToken, tokenCookie)) + c.SetUserContext(context.WithValue(c.UserContext(), consts.SessionUser, user)) + return c.Next() } @@ -88,12 +110,25 @@ func (f *Middlewares) AuthUserInfo(c fiber.Ctx) error { if err != nil { return errors.Wrap(err, "failed to get openid") } - // TODO: store the openid to the session + + var oauthInfo pg.UserOAuth + copier.Copy(&oauthInfo, token) + user, err := f.userSvc.GetOrNew(c.Context(), token.Openid, oauthInfo) + if err != nil { + return errors.Wrap(err, "failed to get user") + } + + claim := f.jwt.CreateClaims(jwt.BaseClaims{UID: uint64(user.ID)}) + claim.ID = user.OpenID + jwtToken, err := f.jwt.CreateToken(claim) + if err != nil { + return errors.Wrap(err, "failed to create token") + } // set the openid to the cookie c.Cookie(&fiber.Cookie{ Name: "sid", - Value: token.Openid, + Value: jwtToken, HTTPOnly: true, }) diff --git a/backend/modules/middlewares/provider.gen.go b/backend/modules/middlewares/provider.gen.go new file mode 100755 index 0000000..6fd349f --- /dev/null +++ b/backend/modules/middlewares/provider.gen.go @@ -0,0 +1,26 @@ +package middlewares + +import ( + "backend/providers/wechat" + + "git.ipao.vip/rogeecn/atom/container" + "git.ipao.vip/rogeecn/atom/utils/opt" +) + +func Provide(opts ...opt.Option) error { + if err := container.Container.Provide(func( + client *wechat.Client, + ) (*Middlewares, error) { + obj := &Middlewares{ + client: client, + } + if err := obj.Prepare(); err != nil { + return nil, err + } + return obj, nil + }); err != nil { + return err + } + + return nil +} diff --git a/backend/modules/users/controller.go b/backend/modules/users/controller.go index b59d0af..3dc85d5 100644 --- a/backend/modules/users/controller.go +++ b/backend/modules/users/controller.go @@ -9,10 +9,5 @@ type Controller struct { // List func (c *Controller) List(ctx fiber.Ctx) error { - resp, err := c.svc.List(ctx.Context()) - if err != nil { - return err - } - - return ctx.JSON(resp) + return ctx.JSON(nil) } diff --git a/backend/modules/users/service.go b/backend/modules/users/service.go index 4c4f32f..a593be3 100644 --- a/backend/modules/users/service.go +++ b/backend/modules/users/service.go @@ -6,7 +6,9 @@ import ( "backend/database/models/qvyun/public/model" "backend/database/models/qvyun/public/table" + "backend/pkg/pg" + . "github.com/go-jet/jet/v2/postgres" "github.com/pkg/errors" "github.com/sirupsen/logrus" ) @@ -22,15 +24,43 @@ func (svc *Service) Prepare() error { return nil } -// List -func (svc *Service) List(ctx context.Context) ([]model.Users, error) { +// GetByOpenID +func (svc *Service) GetByOpenID(ctx context.Context, openid string) (*model.Users, error) { tbl := table.Users - stmt := tbl.SELECT(tbl.AllColumns) - svc.log.WithField("method", "List").Debug(stmt.DebugSql()) + stmt := tbl. + SELECT(tbl.AllColumns). + WHERE( + tbl.OpenID.EQ(String(openid)), + ) + svc.log.WithField("method", "GetByOpenID").Debug(stmt.DebugSql()) - var items []model.Users - if err := stmt.QueryContext(ctx, svc.db, &items); err != nil { - return nil, errors.Wrap(err, "failed to query users") + var item model.Users + if err := stmt.QueryContext(ctx, svc.db, &item); err != nil { + return nil, errors.Wrap(err, "failed to query user by openid") } - return items, nil + return &item, nil +} + +// GetOrNew +func (svc *Service) GetOrNew(ctx context.Context, openid string, authInfo pg.UserOAuth) (*model.Users, error) { + user, err := svc.GetByOpenID(ctx, openid) + if err == nil { + return user, nil + } + + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + // user = &model.Users{ + // OpenID: openid, + // OAuth:,authInfo + // } + // if err := user.Insert(ctx, svc.db, table.Users); err != nil { + // return nil, errors.Wrap(err, "failed to insert user") + // } + } + return nil, errors.Wrap(err, "failed to get user by openid") + + } + + return nil, errors.New("unknown error") } diff --git a/backend/pkg/pg/users.go b/backend/pkg/pg/users.go new file mode 100644 index 0000000..69aa98a --- /dev/null +++ b/backend/pkg/pg/users.go @@ -0,0 +1,11 @@ +package pg + +type UserOAuth struct { + AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in"` + IsSnapshotuser int64 `json:"is_snapshotuser"` + Openid string `json:"openid"` + RefreshToken string `json:"refresh_token"` + Scope string `json:"scope"` + Unionid string `json:"unionid"` +} diff --git a/backend/providers/jwt/jwt.go b/backend/providers/jwt/jwt.go index 0b425d7..201ce3e 100644 --- a/backend/providers/jwt/jwt.go +++ b/backend/providers/jwt/jwt.go @@ -51,8 +51,9 @@ func Provide(opts ...opt.Option) error { } return container.Container.Provide(func() (*JWT, error) { return &JWT{ - config: &config, - SigningKey: []byte(config.SigningKey), + singleflight: &singleflight.Group{}, + config: &config, + SigningKey: []byte(config.SigningKey), }, nil }, o.DiOptions()...) } @@ -85,7 +86,7 @@ func (j *JWT) CreateTokenByOldToken(oldToken string, claims *Claims) (string, er } // 解析 token -func (j *JWT) ParseToken(tokenString string) (*Claims, error) { +func (j *JWT) Parse(tokenString string) (*Claims, error) { tokenString = strings.TrimPrefix(tokenString, TokenPrefix) token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (i interface{}, e error) { return j.SigningKey, nil