382 lines
10 KiB
Go
382 lines
10 KiB
Go
package users
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"time"
|
|
|
|
"backend/database/models/qvyun/public/model"
|
|
"backend/database/models/qvyun/public/table"
|
|
"backend/pkg/consts"
|
|
"backend/pkg/db"
|
|
"backend/pkg/errorx"
|
|
"backend/pkg/pg"
|
|
"backend/providers/jwt"
|
|
|
|
. "github.com/go-jet/jet/v2/postgres"
|
|
"github.com/go-jet/jet/v2/qrm"
|
|
"github.com/pkg/errors"
|
|
"github.com/sirupsen/logrus"
|
|
hashids "github.com/speps/go-hashids/v2"
|
|
)
|
|
|
|
// @provider:except
|
|
type Service struct {
|
|
db *sql.DB
|
|
hashIds *hashids.HashID
|
|
log *logrus.Entry `inject:"false"`
|
|
}
|
|
|
|
func (svc *Service) Prepare() error {
|
|
svc.log = logrus.WithField("module", "users.service")
|
|
return nil
|
|
}
|
|
|
|
// GetByID
|
|
func (svc *Service) GetByID(ctx context.Context, id int64) (*model.Users, error) {
|
|
tbl := table.Users
|
|
stmt := tbl.
|
|
SELECT(tbl.AllColumns).
|
|
WHERE(
|
|
tbl.ID.EQ(Int64(id)),
|
|
)
|
|
svc.log.WithField("method", "GetByID").Debug(stmt.DebugSql())
|
|
|
|
var item model.Users
|
|
if err := stmt.QueryContext(ctx, svc.db, &item); err != nil {
|
|
return nil, errors.Wrap(err, "failed to query user by id")
|
|
}
|
|
return &item, nil
|
|
}
|
|
|
|
// GetByOpenID
|
|
func (svc *Service) GetByOpenID(ctx context.Context, openid string) (*model.Users, error) {
|
|
tbl := table.Users
|
|
stmt := tbl.
|
|
SELECT(tbl.AllColumns).
|
|
WHERE(
|
|
tbl.OpenID.EQ(String(openid)),
|
|
)
|
|
svc.log.WithField("method", "GetByOpenID").Debug(stmt.DebugSql())
|
|
|
|
var item model.Users
|
|
if err := stmt.QueryContext(ctx, svc.db, &item); err != nil {
|
|
return nil, errors.Wrap(err, "failed to query user by openid")
|
|
}
|
|
return &item, nil
|
|
}
|
|
|
|
// GetOrNew
|
|
func (svc *Service) GetOrNew(ctx context.Context, tenantID int64, openid string, authInfo pg.UserOAuth) (*model.Users, error) {
|
|
log := svc.log.WithField("method", "GetOrNew")
|
|
|
|
svc.log.Infof("get or new user for tenant: %d, openid: %s", tenantID, openid)
|
|
if openid == "" {
|
|
return nil, errors.New("openid is empty")
|
|
}
|
|
user, err := svc.GetByOpenID(ctx, openid)
|
|
if err == nil {
|
|
// check: if tenant has user
|
|
hasUser, err := svc.TenantHasUser(ctx, user.ID, tenantID)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to check user-tenant relation")
|
|
}
|
|
|
|
if !hasUser {
|
|
// create user-tenant relation
|
|
if err := svc.CreateTenantUser(ctx, user.ID, tenantID); err != nil {
|
|
return nil, errors.Wrap(err, "failed to create user-tenant relation")
|
|
}
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
if errors.Is(err, qrm.ErrNoRows) {
|
|
user = &model.Users{
|
|
OpenID: openid,
|
|
OAuth: authInfo,
|
|
ExpireIn: time.Now().Add(time.Minute * time.Duration(authInfo.ExpiresIn)),
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
}
|
|
|
|
tx, err := svc.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to begin transaction")
|
|
}
|
|
defer tx.Rollback()
|
|
ctx = context.WithValue(ctx, consts.CtxKeyTx, tx)
|
|
|
|
user, err := svc.CreateFromModel(ctx, user)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// create user-tenant relation
|
|
if err := svc.CreateTenantUser(ctx, user.ID, tenantID); err != nil {
|
|
return nil, errors.Wrap(err, "failed to create user-tenant relation")
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return nil, errors.Wrap(err, "failed to commit transaction")
|
|
}
|
|
|
|
log.Infof("create new user for tenant: %d success, openID: %s", tenantID, openid)
|
|
return user, nil
|
|
}
|
|
|
|
return nil, errors.Wrap(err, "failed to get user by openid")
|
|
}
|
|
|
|
// CreateFromModel create user from model
|
|
func (svc *Service) CreateFromModel(ctx context.Context, user *model.Users) (*model.Users, error) {
|
|
log := svc.log.WithField("method", "CreateFromModel")
|
|
|
|
tbl := table.Users
|
|
stmt := tbl.INSERT(tbl.AllColumns.Except(tbl.ID)).MODEL(user).RETURNING(tbl.AllColumns)
|
|
log.Debug(stmt.DebugSql())
|
|
|
|
var userModel model.Users
|
|
err := stmt.QueryContext(ctx, db.FromContext(ctx, svc.db), &userModel)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to create user")
|
|
}
|
|
return &userModel, nil
|
|
}
|
|
|
|
// GetTenantByID
|
|
func (svc *Service) GetTenantByID(ctx context.Context, id int64) (*model.Tenants, error) {
|
|
log := svc.log.WithField("method", "GetTenantByID")
|
|
|
|
stmt := table.Tenants.SELECT(table.Tenants.AllColumns).WHERE(table.Tenants.ID.EQ(Int64(id)))
|
|
log.Debug(stmt.DebugSql())
|
|
|
|
var item model.Tenants
|
|
if err := stmt.QueryContext(ctx, db.FromContext(ctx, svc.db), &item); err != nil {
|
|
return nil, errors.Wrapf(err, "failed to query tenant by id %d", id)
|
|
}
|
|
return &item, nil
|
|
}
|
|
|
|
// TenantHasUser
|
|
func (svc *Service) TenantHasUser(ctx context.Context, userID, tenantID int64) (bool, error) {
|
|
log := svc.log.WithField("method", "TenantHasUser")
|
|
|
|
tbl := table.UsersTenants
|
|
stmt := tbl.
|
|
SELECT(COUNT(tbl.ID).AS("cnt")).
|
|
WHERE(
|
|
tbl.UserID.EQ(Int64(userID)).AND(
|
|
tbl.TenantID.EQ(Int64(tenantID)),
|
|
),
|
|
)
|
|
log.Debug(stmt.DebugSql())
|
|
|
|
var result struct {
|
|
cnt int64
|
|
}
|
|
if err := stmt.QueryContext(ctx, db.FromContext(ctx, svc.db), &result); err != nil {
|
|
return false, errors.Wrap(err, "failed to query user-tenant relation")
|
|
}
|
|
|
|
return result.cnt > 0, nil
|
|
}
|
|
|
|
// CreateTenantUser
|
|
func (svc *Service) CreateTenantUser(ctx context.Context, userID, tenantID int64) error {
|
|
log := svc.log.WithField("method", "CreateTenantUser")
|
|
|
|
stmt := table.UsersTenants.INSERT(
|
|
table.UsersTenants.UserID,
|
|
table.UsersTenants.TenantID,
|
|
).VALUES(
|
|
Int64(userID),
|
|
Int64(tenantID),
|
|
).ON_CONFLICT(
|
|
table.UsersTenants.UserID,
|
|
table.UsersTenants.TenantID,
|
|
).DO_NOTHING()
|
|
log.Debug(stmt.DebugSql())
|
|
|
|
if _, err := stmt.ExecContext(ctx, db.FromContext(ctx, svc.db)); err != nil {
|
|
return errors.Wrap(err, "failed to create user-tenant relation")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetTenantIDBySlug
|
|
func (svc *Service) GetTenantBySlug(ctx context.Context, slug string) (*model.Tenants, error) {
|
|
log := svc.log.WithField("method", "GetTenantIDBySlug")
|
|
|
|
stmt := table.Tenants.SELECT(table.Tenants.AllColumns).WHERE(table.Tenants.Slug.EQ(String(slug)))
|
|
log.Debug(stmt.DebugSql())
|
|
|
|
var item model.Tenants
|
|
if err := stmt.QueryContext(ctx, db.FromContext(ctx, svc.db), &item); err != nil {
|
|
return nil, errors.Wrap(err, "failed to query tenant id by slug")
|
|
}
|
|
return &item, nil
|
|
}
|
|
|
|
// CreateTenant
|
|
func (svc *Service) CreateTenant(ctx context.Context, name, slug string) error {
|
|
log := svc.log.WithField("method", "CreateTenant")
|
|
|
|
expireAt := time.Now().Add(time.Hour * 24 * 366)
|
|
// 仅保留天数
|
|
expireAt = time.Date(expireAt.Year(), expireAt.Month(), expireAt.Day(), 0, 0, 0, 0, expireAt.Location())
|
|
|
|
tbl := table.Tenants
|
|
stmt := tbl.
|
|
INSERT(tbl.Name, tbl.Slug, tbl.ExpireAt).
|
|
VALUES(String(name), String(slug), TimestampT(expireAt)).
|
|
ON_CONFLICT(tbl.Slug).
|
|
DO_NOTHING()
|
|
log.Debug(stmt.DebugSql())
|
|
|
|
if _, err := stmt.ExecContext(ctx, svc.db); err != nil {
|
|
return errors.Wrapf(err, "create tenant: %s(%s)", name, slug)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SetTenantExpireAtBySlug
|
|
func (svc *Service) SetTenantExpireAtBySlug(ctx context.Context, slug string, expire time.Time) error {
|
|
log := svc.log.WithField("method", "SetTenantExpireAtBySlug")
|
|
|
|
tbl := table.Tenants
|
|
stmt := tbl.
|
|
UPDATE(tbl.ExpireAt).
|
|
SET(TimestampT(expire)).
|
|
WHERE(tbl.Slug.EQ(String(slug)))
|
|
log.Debug(stmt.DebugSql())
|
|
|
|
if _, err := stmt.ExecContext(ctx, svc.db); err != nil {
|
|
return errors.Wrapf(err, "renew tenant: %s expire at %s", slug, expire)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (svc *Service) GenerateChargeCode(ctx context.Context, tenantID, chargeAmount int64) (string, error) {
|
|
log := svc.log.WithField("method", "GenerateChargeCode")
|
|
|
|
timestamp := time.Now().Unix()
|
|
code, err := svc.hashIds.EncodeInt64([]int64{tenantID, chargeAmount, timestamp})
|
|
if err != nil {
|
|
return "", errors.Wrap(err, "failed to encode charge code")
|
|
}
|
|
log.Infof("generate charge code: %s", code)
|
|
|
|
return code, nil
|
|
}
|
|
|
|
// Charge
|
|
func (svc *Service) Charge(ctx context.Context, claim *jwt.Claims, code string) error {
|
|
log := svc.log.WithField("method", "Charge")
|
|
raw, err := svc.hashIds.DecodeInt64WithError(code)
|
|
if err != nil {
|
|
return errorx.InvalidChargeCode
|
|
}
|
|
|
|
if len(raw) != 3 {
|
|
return errorx.InvalidChargeCode
|
|
}
|
|
|
|
tenantId, chargeAmount, timestamp := raw[0], raw[1], raw[2]
|
|
if tenantId != claim.TenantID {
|
|
return errorx.InvalidChargeCode
|
|
}
|
|
generatedAt := time.Unix(timestamp, 0)
|
|
log.Infof("charge code %s generated at: %s", code, generatedAt)
|
|
|
|
if chargeAmount <= 0 {
|
|
return errorx.InvalidChargeCode
|
|
}
|
|
|
|
t := table.UserBalanceHistories
|
|
st := t.SELECT(COUNT(t.ID).AS("cnt")).WHERE(t.Code.EQ(String(code)))
|
|
log.Debug(st.DebugSql())
|
|
|
|
var result struct {
|
|
Cnt int64
|
|
}
|
|
if err := st.QueryContext(ctx, db.FromContext(ctx, svc.db), &result); err != nil {
|
|
return errors.Wrap(err, "failed to query charge code")
|
|
}
|
|
|
|
if result.Cnt > 0 {
|
|
return errorx.InvalidChargeCode
|
|
}
|
|
|
|
has, err := svc.TenantHasUser(ctx, claim.UserID, tenantId)
|
|
if err != nil {
|
|
return errors.Wrap(err, "failed to check user-tenant relation")
|
|
}
|
|
|
|
if !has {
|
|
return errorx.InvalidChargeCode
|
|
}
|
|
|
|
log.Infof("charge tenant: %d, user: %d, amount: %d", claim.TenantID, claim.UserID, chargeAmount)
|
|
|
|
tx, err := svc.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return errors.Wrap(err, "failed to begin transaction")
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
// update user balance in users_tenants
|
|
tbl := table.UsersTenants
|
|
stmt := tbl.
|
|
UPDATE().
|
|
SET(
|
|
tbl.Balance.SET(
|
|
tbl.Balance.ADD(Int64(chargeAmount)),
|
|
),
|
|
).
|
|
WHERE(
|
|
tbl.UserID.EQ(Int64(claim.UserID)).AND(
|
|
tbl.TenantID.EQ(Int64(claim.TenantID)),
|
|
),
|
|
)
|
|
log.Debug(stmt.DebugSql())
|
|
|
|
if _, err := stmt.ExecContext(ctx, db.FromContext(ctx, svc.db)); err != nil {
|
|
return errors.Wrap(err, "failed to charge user balance")
|
|
}
|
|
|
|
// insert charge record
|
|
chargeTbl := table.UserBalanceHistories
|
|
chargeStmt := chargeTbl.
|
|
INSERT(
|
|
chargeTbl.UserID,
|
|
chargeTbl.TenantID,
|
|
chargeTbl.Balance,
|
|
chargeTbl.Target,
|
|
chargeTbl.Type,
|
|
chargeTbl.Code,
|
|
).
|
|
VALUES(
|
|
Int64(claim.UserID),
|
|
Int64(claim.TenantID),
|
|
Int64(chargeAmount),
|
|
Json(pg.BalanceTarget{}.MustValue()),
|
|
String(pg.BalanceTypeCharge.String()),
|
|
String(code),
|
|
)
|
|
log.Debug(chargeStmt.DebugSql())
|
|
|
|
if _, err := chargeStmt.ExecContext(ctx, db.FromContext(ctx, svc.db)); err != nil {
|
|
return errors.Wrap(err, "failed to insert charge record")
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return errors.Wrap(err, "failed to commit transaction")
|
|
}
|
|
|
|
return nil
|
|
}
|