From d1a2a808591c3f1b3dc0065d6229fd97b6cec5fd Mon Sep 17 00:00:00 2001 From: Rogee Date: Fri, 29 Nov 2024 20:28:37 +0800 Subject: [PATCH] feat: add features --- backend/common/consts/ctx.gen.go | 179 +++++++++++++++++++++ backend/common/consts/ctx.go | 9 ++ backend/common/consts/jwt.go | 6 - backend/common/service/model/gen.go | 2 + backend/main_test.go | 18 +++ backend/modules/middlewares/middlewares.go | 8 +- backend/modules/users/service.go | 120 ++++++++++++-- backend/pkg/db/db.go | 17 ++ 8 files changed, 337 insertions(+), 22 deletions(-) create mode 100644 backend/common/consts/ctx.gen.go create mode 100644 backend/common/consts/ctx.go delete mode 100644 backend/common/consts/jwt.go create mode 100644 backend/pkg/db/db.go diff --git a/backend/common/consts/ctx.gen.go b/backend/common/consts/ctx.gen.go new file mode 100644 index 0000000..b131d39 --- /dev/null +++ b/backend/common/consts/ctx.gen.go @@ -0,0 +1,179 @@ +// Code generated by go-enum DO NOT EDIT. +// Version: - +// Revision: - +// Build Date: - +// Built By: - + +package consts + +import ( + "database/sql/driver" + "errors" + "fmt" + "strings" +) + +const ( + // CtxKeyTx is a CtxKey of type Tx. + CtxKeyTx CtxKey = "__ctx_db:" + // CtxKeyJwt is a CtxKey of type Jwt. + CtxKeyJwt CtxKey = "__jwt_token:" + // CtxKeySession is a CtxKey of type Session. + CtxKeySession CtxKey = "__session_user:" +) + +var ErrInvalidCtxKey = fmt.Errorf("not a valid CtxKey, try [%s]", strings.Join(_CtxKeyNames, ", ")) + +var _CtxKeyNames = []string{ + string(CtxKeyTx), + string(CtxKeyJwt), + string(CtxKeySession), +} + +// CtxKeyNames returns a list of possible string values of CtxKey. +func CtxKeyNames() []string { + tmp := make([]string, len(_CtxKeyNames)) + copy(tmp, _CtxKeyNames) + return tmp +} + +// CtxKeyValues returns a list of the values for CtxKey +func CtxKeyValues() []CtxKey { + return []CtxKey{ + CtxKeyTx, + CtxKeyJwt, + CtxKeySession, + } +} + +// String implements the Stringer interface. +func (x CtxKey) String() string { + return string(x) +} + +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x CtxKey) IsValid() bool { + _, err := ParseCtxKey(string(x)) + return err == nil +} + +var _CtxKeyValue = map[string]CtxKey{ + "__ctx_db:": CtxKeyTx, + "__jwt_token:": CtxKeyJwt, + "__session_user:": CtxKeySession, +} + +// ParseCtxKey attempts to convert a string to a CtxKey. +func ParseCtxKey(name string) (CtxKey, error) { + if x, ok := _CtxKeyValue[name]; ok { + return x, nil + } + return CtxKey(""), fmt.Errorf("%s is %w", name, ErrInvalidCtxKey) +} + +var errCtxKeyNilPtr = errors.New("value pointer is nil") // one per type for package clashes + +// Scan implements the Scanner interface. +func (x *CtxKey) Scan(value interface{}) (err error) { + if value == nil { + *x = CtxKey("") + return + } + + // A wider range of scannable types. + // driver.Value values at the top of the list for expediency + switch v := value.(type) { + case string: + *x, err = ParseCtxKey(v) + case []byte: + *x, err = ParseCtxKey(string(v)) + case CtxKey: + *x = v + case *CtxKey: + if v == nil { + return errCtxKeyNilPtr + } + *x = *v + case *string: + if v == nil { + return errCtxKeyNilPtr + } + *x, err = ParseCtxKey(*v) + default: + return errors.New("invalid type for CtxKey") + } + + return +} + +// Value implements the driver Valuer interface. +func (x CtxKey) Value() (driver.Value, error) { + return x.String(), nil +} + +// Set implements the Golang flag.Value interface func. +func (x *CtxKey) Set(val string) error { + v, err := ParseCtxKey(val) + *x = v + return err +} + +// Get implements the Golang flag.Getter interface func. +func (x *CtxKey) Get() interface{} { + return *x +} + +// Type implements the github.com/spf13/pFlag Value interface. +func (x *CtxKey) Type() string { + return "CtxKey" +} + +type NullCtxKey struct { + CtxKey CtxKey + Valid bool +} + +func NewNullCtxKey(val interface{}) (x NullCtxKey) { + err := x.Scan(val) // yes, we ignore this error, it will just be an invalid value. + _ = err // make any errcheck linters happy + return +} + +// Scan implements the Scanner interface. +func (x *NullCtxKey) Scan(value interface{}) (err error) { + if value == nil { + x.CtxKey, x.Valid = CtxKey(""), false + return + } + + err = x.CtxKey.Scan(value) + x.Valid = (err == nil) + return +} + +// Value implements the driver Valuer interface. +func (x NullCtxKey) Value() (driver.Value, error) { + if !x.Valid { + return nil, nil + } + // driver.Value accepts int64 for int values. + return string(x.CtxKey), nil +} + +type NullCtxKeyStr struct { + NullCtxKey +} + +func NewNullCtxKeyStr(val interface{}) (x NullCtxKeyStr) { + x.Scan(val) // yes, we ignore this error, it will just be an invalid value. + return +} + +// Value implements the driver Valuer interface. +func (x NullCtxKeyStr) Value() (driver.Value, error) { + if !x.Valid { + return nil, nil + } + return x.CtxKey.String(), nil +} diff --git a/backend/common/consts/ctx.go b/backend/common/consts/ctx.go new file mode 100644 index 0000000..26cec86 --- /dev/null +++ b/backend/common/consts/ctx.go @@ -0,0 +1,9 @@ +package consts + +// swagger:enum CacheKey +// ENUM( +// Tx = "__ctx_db:", +// Jwt = "__jwt_token:", +// Session = "__session_user:", +// ) +type CtxKey string diff --git a/backend/common/consts/jwt.go b/backend/common/consts/jwt.go deleted file mode 100644 index 0d4cb1a..0000000 --- a/backend/common/consts/jwt.go +++ /dev/null @@ -1,6 +0,0 @@ -package consts - -const ( - JwtToken = "__jwt_token:" - SessionUser = "__session_user:" -) diff --git a/backend/common/service/model/gen.go b/backend/common/service/model/gen.go index 1c95477..0bf449a 100644 --- a/backend/common/service/model/gen.go +++ b/backend/common/service/model/gen.go @@ -13,6 +13,7 @@ import ( "github.com/go-jet/jet/v2/generator/postgres" "github.com/go-jet/jet/v2/generator/template" pg "github.com/go-jet/jet/v2/postgres" + "github.com/gofiber/fiber/v3/log" _ "github.com/lib/pq" "github.com/samber/lo" "github.com/spf13/cobra" @@ -117,6 +118,7 @@ func Serve(cmd *cobra.Command, args []string) error { ImportPath: splits[0], }) + log.Infof("Convert table %s field %s type to : %s", table.Name, column.Name, toType) return defaultTableModelField }) }), diff --git a/backend/main_test.go b/backend/main_test.go index a439a84..929e973 100755 --- a/backend/main_test.go +++ b/backend/main_test.go @@ -4,8 +4,11 @@ import ( "testing" "backend/common/service/model" + "backend/pkg/pg" + "backend/providers/wechat" "git.ipao.vip/rogeecn/atom" + "github.com/jinzhu/copier" ) func Test_GenModel(t *testing.T) { @@ -14,3 +17,18 @@ func Test_GenModel(t *testing.T) { t.Fatal(err) } } + +func Test_GenModel2(t *testing.T) { + token := &wechat.AuthorizeAccessToken{ + AccessToken: "123", + RefreshToken: "123", + ExpiresIn: 123, + Openid: "123", + Scope: "123", + } + + var oauthInfo pg.UserOAuth + copier.Copy(&oauthInfo, token) + + t.Logf("%+v", oauthInfo) +} diff --git a/backend/modules/middlewares/middlewares.go b/backend/modules/middlewares/middlewares.go index 88b2177..17c3487 100644 --- a/backend/modules/middlewares/middlewares.go +++ b/backend/modules/middlewares/middlewares.go @@ -69,8 +69,8 @@ func (f *Middlewares) SilentAuth(c fiber.Ctx) error { return errors.Wrap(err, "failed to get user") } - c.SetUserContext(context.WithValue(c.UserContext(), consts.JwtToken, tokenCookie)) - c.SetUserContext(context.WithValue(c.UserContext(), consts.SessionUser, user)) + c.SetUserContext(context.WithValue(c.UserContext(), consts.CtxKeyJwt, tokenCookie)) + c.SetUserContext(context.WithValue(c.UserContext(), consts.CtxKeySession, user)) return c.Next() } @@ -113,7 +113,7 @@ func (f *Middlewares) AuthUserInfo(c fiber.Ctx) error { var oauthInfo pg.UserOAuth copier.Copy(&oauthInfo, token) - user, err := f.userSvc.GetOrNew(c.Context(), token.Openid, oauthInfo) + user, err := f.userSvc.GetOrNew(c.Context(), 1, token.Openid, oauthInfo) if err != nil { return errors.Wrap(err, "failed to get user") } @@ -127,7 +127,7 @@ func (f *Middlewares) AuthUserInfo(c fiber.Ctx) error { // set the openid to the cookie c.Cookie(&fiber.Cookie{ - Name: "sid", + Name: "token", Value: jwtToken, HTTPOnly: true, }) diff --git a/backend/modules/users/service.go b/backend/modules/users/service.go index a593be3..ca3fa44 100644 --- a/backend/modules/users/service.go +++ b/backend/modules/users/service.go @@ -3,9 +3,12 @@ package users import ( "context" "database/sql" + "time" + "backend/common/consts" "backend/database/models/qvyun/public/model" "backend/database/models/qvyun/public/table" + "backend/pkg/db" "backend/pkg/pg" . "github.com/go-jet/jet/v2/postgres" @@ -42,25 +45,118 @@ func (svc *Service) GetByOpenID(ctx context.Context, openid string) (*model.User } // GetOrNew -func (svc *Service) GetOrNew(ctx context.Context, openid string, authInfo pg.UserOAuth) (*model.Users, error) { +func (svc *Service) GetOrNew(ctx context.Context, tenantID int64, openid string, authInfo pg.UserOAuth) (*model.Users, error) { + log := svc.log.WithField("method", "GetOrNew") + user, err := svc.GetByOpenID(ctx, openid) if err == nil { + // check: if tenant has user return user, nil } - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - // user = &model.Users{ - // OpenID: openid, - // OAuth:,authInfo - // } - // if err := user.Insert(ctx, svc.db, table.Users); err != nil { - // return nil, errors.Wrap(err, "failed to insert user") - // } + if errors.Is(err, sql.ErrNoRows) { + user = &model.Users{ + OpenID: openid, + OAuth: authInfo, + ExpireIn: time.Now().Add(time.Minute * time.Duration(authInfo.ExpiresIn)), } - return nil, errors.Wrap(err, "failed to get user by openid") + tx, err := svc.db.BeginTx(ctx, nil) + if err != nil { + return nil, errors.Wrap(err, "failed to begin transaction") + } + defer tx.Rollback() + ctx = context.WithValue(ctx, consts.CtxKeyTx, tx) + + user, err := svc.CreateFromModel(ctx, user) + if err != nil { + return nil, err + } + + // create user-tenant relation + if err := svc.CreateTenantUser(ctx, user.ID, tenantID); err != nil { + return nil, errors.Wrap(err, "failed to create user-tenant relation") + } + + if err := tx.Commit(); err != nil { + return nil, errors.Wrap(err, "failed to commit transaction") + } + + log.Infof("create new user for tenant: %d success, openID: %s", tenantID, openid) + return user, nil } - return nil, errors.New("unknown error") + return nil, errors.Wrap(err, "failed to get user by openid") +} + +// CreateFromModel create user from model +func (svc *Service) CreateFromModel(ctx context.Context, user *model.Users) (*model.Users, error) { + log := svc.log.WithField("method", "CreateFromModel") + + stmt := table.Users.INSERT().MODEL(user).RETURNING(table.Users.AllColumns) + log.Debug(stmt.DebugSql()) + + // get tx from context + + var item model.Users + if err := stmt.QueryContext(ctx, db.FromContext(ctx, svc.db), &item); err != nil { + return nil, errors.Wrap(err, "failed to create user") + } + + return &item, nil +} + +// GetTenantByID +func (svc *Service) GetTenantByID(ctx context.Context, id int64) (*model.Tenants, error) { + log := svc.log.WithField("method", "GetTenantByID") + + stmt := table.Tenants.SELECT(table.Tenants.AllColumns).WHERE(table.Tenants.ID.EQ(Int64(id))) + log.Debug(stmt.DebugSql()) + + var item model.Tenants + if err := stmt.QueryContext(ctx, db.FromContext(ctx, svc.db), &item); err != nil { + return nil, errors.Wrapf(err, "failed to query tenant by id %d", id) + } + return &item, nil +} + +// TenantHasUser +func (svc *Service) TenantHasUser(ctx context.Context, userID, tenantID int64) (bool, error) { + log := svc.log.WithField("method", "TenantHasUser") + + tbl := table.UsersTenants + stmt := tbl. + SELECT(COUNT(tbl.ID)). + WHERE( + tbl.UserID.EQ(Int64(userID)).AND( + tbl.TenantID.EQ(Int64(tenantID)), + ), + ) + log.Debug(stmt.DebugSql()) + + var cnt int + if err := stmt.QueryContext(ctx, db.FromContext(ctx, svc.db), &cnt); err != nil { + return false, errors.Wrap(err, "failed to query user-tenant relation") + } + + return cnt > 0, nil +} + +// CreateTenantUser +func (svc *Service) CreateTenantUser(ctx context.Context, userID, tenantID int64) error { + log := svc.log.WithField("method", "CreateTenantUser") + + stmt := table.UsersTenants.INSERT( + table.UsersTenants.UserID, + table.UsersTenants.TenantID, + ).VALUES( + Int64(userID), + Int64(tenantID), + ) + log.Debug(stmt.DebugSql()) + + if _, err := stmt.ExecContext(ctx, db.FromContext(ctx, svc.db)); err != nil { + return errors.Wrap(err, "failed to create user-tenant relation") + } + return nil } diff --git a/backend/pkg/db/db.go b/backend/pkg/db/db.go new file mode 100644 index 0000000..1565aa2 --- /dev/null +++ b/backend/pkg/db/db.go @@ -0,0 +1,17 @@ +package db + +import ( + "context" + "database/sql" + + "backend/common/consts" + + "github.com/go-jet/jet/v2/qrm" +) + +func FromContext(ctx context.Context, db *sql.DB) qrm.DB { + if tx, ok := ctx.Value(consts.CtxKeyTx).(*sql.Tx); ok { + return tx + } + return db +}