feat: add features
This commit is contained in:
179
backend/common/consts/ctx.gen.go
Normal file
179
backend/common/consts/ctx.gen.go
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
// Code generated by go-enum DO NOT EDIT.
|
||||||
|
// Version: -
|
||||||
|
// Revision: -
|
||||||
|
// Build Date: -
|
||||||
|
// Built By: -
|
||||||
|
|
||||||
|
package consts
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// CtxKeyTx is a CtxKey of type Tx.
|
||||||
|
CtxKeyTx CtxKey = "__ctx_db:"
|
||||||
|
// CtxKeyJwt is a CtxKey of type Jwt.
|
||||||
|
CtxKeyJwt CtxKey = "__jwt_token:"
|
||||||
|
// CtxKeySession is a CtxKey of type Session.
|
||||||
|
CtxKeySession CtxKey = "__session_user:"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrInvalidCtxKey = fmt.Errorf("not a valid CtxKey, try [%s]", strings.Join(_CtxKeyNames, ", "))
|
||||||
|
|
||||||
|
var _CtxKeyNames = []string{
|
||||||
|
string(CtxKeyTx),
|
||||||
|
string(CtxKeyJwt),
|
||||||
|
string(CtxKeySession),
|
||||||
|
}
|
||||||
|
|
||||||
|
// CtxKeyNames returns a list of possible string values of CtxKey.
|
||||||
|
func CtxKeyNames() []string {
|
||||||
|
tmp := make([]string, len(_CtxKeyNames))
|
||||||
|
copy(tmp, _CtxKeyNames)
|
||||||
|
return tmp
|
||||||
|
}
|
||||||
|
|
||||||
|
// CtxKeyValues returns a list of the values for CtxKey
|
||||||
|
func CtxKeyValues() []CtxKey {
|
||||||
|
return []CtxKey{
|
||||||
|
CtxKeyTx,
|
||||||
|
CtxKeyJwt,
|
||||||
|
CtxKeySession,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// String implements the Stringer interface.
|
||||||
|
func (x CtxKey) String() string {
|
||||||
|
return string(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValid provides a quick way to determine if the typed value is
|
||||||
|
// part of the allowed enumerated values
|
||||||
|
func (x CtxKey) IsValid() bool {
|
||||||
|
_, err := ParseCtxKey(string(x))
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var _CtxKeyValue = map[string]CtxKey{
|
||||||
|
"__ctx_db:": CtxKeyTx,
|
||||||
|
"__jwt_token:": CtxKeyJwt,
|
||||||
|
"__session_user:": CtxKeySession,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseCtxKey attempts to convert a string to a CtxKey.
|
||||||
|
func ParseCtxKey(name string) (CtxKey, error) {
|
||||||
|
if x, ok := _CtxKeyValue[name]; ok {
|
||||||
|
return x, nil
|
||||||
|
}
|
||||||
|
return CtxKey(""), fmt.Errorf("%s is %w", name, ErrInvalidCtxKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
var errCtxKeyNilPtr = errors.New("value pointer is nil") // one per type for package clashes
|
||||||
|
|
||||||
|
// Scan implements the Scanner interface.
|
||||||
|
func (x *CtxKey) Scan(value interface{}) (err error) {
|
||||||
|
if value == nil {
|
||||||
|
*x = CtxKey("")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// A wider range of scannable types.
|
||||||
|
// driver.Value values at the top of the list for expediency
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
*x, err = ParseCtxKey(v)
|
||||||
|
case []byte:
|
||||||
|
*x, err = ParseCtxKey(string(v))
|
||||||
|
case CtxKey:
|
||||||
|
*x = v
|
||||||
|
case *CtxKey:
|
||||||
|
if v == nil {
|
||||||
|
return errCtxKeyNilPtr
|
||||||
|
}
|
||||||
|
*x = *v
|
||||||
|
case *string:
|
||||||
|
if v == nil {
|
||||||
|
return errCtxKeyNilPtr
|
||||||
|
}
|
||||||
|
*x, err = ParseCtxKey(*v)
|
||||||
|
default:
|
||||||
|
return errors.New("invalid type for CtxKey")
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements the driver Valuer interface.
|
||||||
|
func (x CtxKey) Value() (driver.Value, error) {
|
||||||
|
return x.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set implements the Golang flag.Value interface func.
|
||||||
|
func (x *CtxKey) Set(val string) error {
|
||||||
|
v, err := ParseCtxKey(val)
|
||||||
|
*x = v
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get implements the Golang flag.Getter interface func.
|
||||||
|
func (x *CtxKey) Get() interface{} {
|
||||||
|
return *x
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type implements the github.com/spf13/pFlag Value interface.
|
||||||
|
func (x *CtxKey) Type() string {
|
||||||
|
return "CtxKey"
|
||||||
|
}
|
||||||
|
|
||||||
|
type NullCtxKey struct {
|
||||||
|
CtxKey CtxKey
|
||||||
|
Valid bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNullCtxKey(val interface{}) (x NullCtxKey) {
|
||||||
|
err := x.Scan(val) // yes, we ignore this error, it will just be an invalid value.
|
||||||
|
_ = err // make any errcheck linters happy
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements the Scanner interface.
|
||||||
|
func (x *NullCtxKey) Scan(value interface{}) (err error) {
|
||||||
|
if value == nil {
|
||||||
|
x.CtxKey, x.Valid = CtxKey(""), false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = x.CtxKey.Scan(value)
|
||||||
|
x.Valid = (err == nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements the driver Valuer interface.
|
||||||
|
func (x NullCtxKey) Value() (driver.Value, error) {
|
||||||
|
if !x.Valid {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
// driver.Value accepts int64 for int values.
|
||||||
|
return string(x.CtxKey), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type NullCtxKeyStr struct {
|
||||||
|
NullCtxKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNullCtxKeyStr(val interface{}) (x NullCtxKeyStr) {
|
||||||
|
x.Scan(val) // yes, we ignore this error, it will just be an invalid value.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements the driver Valuer interface.
|
||||||
|
func (x NullCtxKeyStr) Value() (driver.Value, error) {
|
||||||
|
if !x.Valid {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return x.CtxKey.String(), nil
|
||||||
|
}
|
||||||
9
backend/common/consts/ctx.go
Normal file
9
backend/common/consts/ctx.go
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
package consts
|
||||||
|
|
||||||
|
// swagger:enum CacheKey
|
||||||
|
// ENUM(
|
||||||
|
// Tx = "__ctx_db:",
|
||||||
|
// Jwt = "__jwt_token:",
|
||||||
|
// Session = "__session_user:",
|
||||||
|
// )
|
||||||
|
type CtxKey string
|
||||||
@@ -1,6 +0,0 @@
|
|||||||
package consts
|
|
||||||
|
|
||||||
const (
|
|
||||||
JwtToken = "__jwt_token:"
|
|
||||||
SessionUser = "__session_user:"
|
|
||||||
)
|
|
||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/go-jet/jet/v2/generator/postgres"
|
"github.com/go-jet/jet/v2/generator/postgres"
|
||||||
"github.com/go-jet/jet/v2/generator/template"
|
"github.com/go-jet/jet/v2/generator/template"
|
||||||
pg "github.com/go-jet/jet/v2/postgres"
|
pg "github.com/go-jet/jet/v2/postgres"
|
||||||
|
"github.com/gofiber/fiber/v3/log"
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
@@ -117,6 +118,7 @@ func Serve(cmd *cobra.Command, args []string) error {
|
|||||||
ImportPath: splits[0],
|
ImportPath: splits[0],
|
||||||
})
|
})
|
||||||
|
|
||||||
|
log.Infof("Convert table %s field %s type to : %s", table.Name, column.Name, toType)
|
||||||
return defaultTableModelField
|
return defaultTableModelField
|
||||||
})
|
})
|
||||||
}),
|
}),
|
||||||
|
|||||||
@@ -4,8 +4,11 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"backend/common/service/model"
|
"backend/common/service/model"
|
||||||
|
"backend/pkg/pg"
|
||||||
|
"backend/providers/wechat"
|
||||||
|
|
||||||
"git.ipao.vip/rogeecn/atom"
|
"git.ipao.vip/rogeecn/atom"
|
||||||
|
"github.com/jinzhu/copier"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_GenModel(t *testing.T) {
|
func Test_GenModel(t *testing.T) {
|
||||||
@@ -14,3 +17,18 @@ func Test_GenModel(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_GenModel2(t *testing.T) {
|
||||||
|
token := &wechat.AuthorizeAccessToken{
|
||||||
|
AccessToken: "123",
|
||||||
|
RefreshToken: "123",
|
||||||
|
ExpiresIn: 123,
|
||||||
|
Openid: "123",
|
||||||
|
Scope: "123",
|
||||||
|
}
|
||||||
|
|
||||||
|
var oauthInfo pg.UserOAuth
|
||||||
|
copier.Copy(&oauthInfo, token)
|
||||||
|
|
||||||
|
t.Logf("%+v", oauthInfo)
|
||||||
|
}
|
||||||
|
|||||||
@@ -69,8 +69,8 @@ func (f *Middlewares) SilentAuth(c fiber.Ctx) error {
|
|||||||
return errors.Wrap(err, "failed to get user")
|
return errors.Wrap(err, "failed to get user")
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetUserContext(context.WithValue(c.UserContext(), consts.JwtToken, tokenCookie))
|
c.SetUserContext(context.WithValue(c.UserContext(), consts.CtxKeyJwt, tokenCookie))
|
||||||
c.SetUserContext(context.WithValue(c.UserContext(), consts.SessionUser, user))
|
c.SetUserContext(context.WithValue(c.UserContext(), consts.CtxKeySession, user))
|
||||||
|
|
||||||
return c.Next()
|
return c.Next()
|
||||||
}
|
}
|
||||||
@@ -113,7 +113,7 @@ func (f *Middlewares) AuthUserInfo(c fiber.Ctx) error {
|
|||||||
|
|
||||||
var oauthInfo pg.UserOAuth
|
var oauthInfo pg.UserOAuth
|
||||||
copier.Copy(&oauthInfo, token)
|
copier.Copy(&oauthInfo, token)
|
||||||
user, err := f.userSvc.GetOrNew(c.Context(), token.Openid, oauthInfo)
|
user, err := f.userSvc.GetOrNew(c.Context(), 1, token.Openid, oauthInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "failed to get user")
|
return errors.Wrap(err, "failed to get user")
|
||||||
}
|
}
|
||||||
@@ -127,7 +127,7 @@ func (f *Middlewares) AuthUserInfo(c fiber.Ctx) error {
|
|||||||
|
|
||||||
// set the openid to the cookie
|
// set the openid to the cookie
|
||||||
c.Cookie(&fiber.Cookie{
|
c.Cookie(&fiber.Cookie{
|
||||||
Name: "sid",
|
Name: "token",
|
||||||
Value: jwtToken,
|
Value: jwtToken,
|
||||||
HTTPOnly: true,
|
HTTPOnly: true,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -3,9 +3,12 @@ package users
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"backend/common/consts"
|
||||||
"backend/database/models/qvyun/public/model"
|
"backend/database/models/qvyun/public/model"
|
||||||
"backend/database/models/qvyun/public/table"
|
"backend/database/models/qvyun/public/table"
|
||||||
|
"backend/pkg/db"
|
||||||
"backend/pkg/pg"
|
"backend/pkg/pg"
|
||||||
|
|
||||||
. "github.com/go-jet/jet/v2/postgres"
|
. "github.com/go-jet/jet/v2/postgres"
|
||||||
@@ -42,25 +45,118 @@ func (svc *Service) GetByOpenID(ctx context.Context, openid string) (*model.User
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetOrNew
|
// GetOrNew
|
||||||
func (svc *Service) GetOrNew(ctx context.Context, openid string, authInfo pg.UserOAuth) (*model.Users, error) {
|
func (svc *Service) GetOrNew(ctx context.Context, tenantID int64, openid string, authInfo pg.UserOAuth) (*model.Users, error) {
|
||||||
|
log := svc.log.WithField("method", "GetOrNew")
|
||||||
|
|
||||||
user, err := svc.GetByOpenID(ctx, openid)
|
user, err := svc.GetByOpenID(ctx, openid)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
// check: if tenant has user
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
user = &model.Users{
|
||||||
// user = &model.Users{
|
OpenID: openid,
|
||||||
// OpenID: openid,
|
OAuth: authInfo,
|
||||||
// OAuth:,authInfo
|
ExpireIn: time.Now().Add(time.Minute * time.Duration(authInfo.ExpiresIn)),
|
||||||
// }
|
|
||||||
// if err := user.Insert(ctx, svc.db, table.Users); err != nil {
|
|
||||||
// return nil, errors.Wrap(err, "failed to insert user")
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
return nil, errors.Wrap(err, "failed to get user by openid")
|
|
||||||
|
|
||||||
|
tx, err := svc.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to begin transaction")
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
ctx = context.WithValue(ctx, consts.CtxKeyTx, tx)
|
||||||
|
|
||||||
|
user, err := svc.CreateFromModel(ctx, user)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// create user-tenant relation
|
||||||
|
if err := svc.CreateTenantUser(ctx, user.ID, tenantID); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to create user-tenant relation")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to commit transaction")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("create new user for tenant: %d success, openID: %s", tenantID, openid)
|
||||||
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, errors.New("unknown error")
|
return nil, errors.Wrap(err, "failed to get user by openid")
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateFromModel create user from model
|
||||||
|
func (svc *Service) CreateFromModel(ctx context.Context, user *model.Users) (*model.Users, error) {
|
||||||
|
log := svc.log.WithField("method", "CreateFromModel")
|
||||||
|
|
||||||
|
stmt := table.Users.INSERT().MODEL(user).RETURNING(table.Users.AllColumns)
|
||||||
|
log.Debug(stmt.DebugSql())
|
||||||
|
|
||||||
|
// get tx from context
|
||||||
|
|
||||||
|
var item model.Users
|
||||||
|
if err := stmt.QueryContext(ctx, db.FromContext(ctx, svc.db), &item); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to create user")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &item, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTenantByID
|
||||||
|
func (svc *Service) GetTenantByID(ctx context.Context, id int64) (*model.Tenants, error) {
|
||||||
|
log := svc.log.WithField("method", "GetTenantByID")
|
||||||
|
|
||||||
|
stmt := table.Tenants.SELECT(table.Tenants.AllColumns).WHERE(table.Tenants.ID.EQ(Int64(id)))
|
||||||
|
log.Debug(stmt.DebugSql())
|
||||||
|
|
||||||
|
var item model.Tenants
|
||||||
|
if err := stmt.QueryContext(ctx, db.FromContext(ctx, svc.db), &item); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "failed to query tenant by id %d", id)
|
||||||
|
}
|
||||||
|
return &item, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TenantHasUser
|
||||||
|
func (svc *Service) TenantHasUser(ctx context.Context, userID, tenantID int64) (bool, error) {
|
||||||
|
log := svc.log.WithField("method", "TenantHasUser")
|
||||||
|
|
||||||
|
tbl := table.UsersTenants
|
||||||
|
stmt := tbl.
|
||||||
|
SELECT(COUNT(tbl.ID)).
|
||||||
|
WHERE(
|
||||||
|
tbl.UserID.EQ(Int64(userID)).AND(
|
||||||
|
tbl.TenantID.EQ(Int64(tenantID)),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
log.Debug(stmt.DebugSql())
|
||||||
|
|
||||||
|
var cnt int
|
||||||
|
if err := stmt.QueryContext(ctx, db.FromContext(ctx, svc.db), &cnt); err != nil {
|
||||||
|
return false, errors.Wrap(err, "failed to query user-tenant relation")
|
||||||
|
}
|
||||||
|
|
||||||
|
return cnt > 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTenantUser
|
||||||
|
func (svc *Service) CreateTenantUser(ctx context.Context, userID, tenantID int64) error {
|
||||||
|
log := svc.log.WithField("method", "CreateTenantUser")
|
||||||
|
|
||||||
|
stmt := table.UsersTenants.INSERT(
|
||||||
|
table.UsersTenants.UserID,
|
||||||
|
table.UsersTenants.TenantID,
|
||||||
|
).VALUES(
|
||||||
|
Int64(userID),
|
||||||
|
Int64(tenantID),
|
||||||
|
)
|
||||||
|
log.Debug(stmt.DebugSql())
|
||||||
|
|
||||||
|
if _, err := stmt.ExecContext(ctx, db.FromContext(ctx, svc.db)); err != nil {
|
||||||
|
return errors.Wrap(err, "failed to create user-tenant relation")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
17
backend/pkg/db/db.go
Normal file
17
backend/pkg/db/db.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"backend/common/consts"
|
||||||
|
|
||||||
|
"github.com/go-jet/jet/v2/qrm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func FromContext(ctx context.Context, db *sql.DB) qrm.DB {
|
||||||
|
if tx, ok := ctx.Value(consts.CtxKeyTx).(*sql.Tx); ok {
|
||||||
|
return tx
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user