Files
quyun-v2/backend/app/services/user.go

496 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package services
import (
"context"
"strings"
"quyun/v2/app/http/super/dto"
"quyun/v2/app/requests"
"quyun/v2/database"
"quyun/v2/database/models"
"quyun/v2/pkg/consts"
"github.com/pkg/errors"
"github.com/samber/lo"
"github.com/sirupsen/logrus"
"go.ipao.vip/gen"
"go.ipao.vip/gen/field"
"go.ipao.vip/gen/types"
"golang.org/x/crypto/bcrypt"
)
// @provider
type user struct{}
func (t *user) FindByID(ctx context.Context, userID int64) (*models.User, error) {
tbl, query := models.UserQuery.QueryContext(ctx)
model, err := query.Preload(tbl.OwnedTenant, tbl.Tenants).Where(tbl.ID.Eq(userID)).First()
if err != nil {
return nil, errors.Wrapf(err, "FindByID failed, %d", userID)
}
return model, nil
}
func (t *user) FindByUsername(ctx context.Context, username string) (*models.User, error) {
tbl, query := models.UserQuery.QueryContext(ctx)
model, err := query.Where(tbl.Username.Eq(username)).First()
if err != nil {
return nil, errors.Wrapf(err, "FindByUsername failed, %s", username)
}
return model, nil
}
// Detail 查询用户详情(超级管理员侧,返回脱敏后的 DTO
func (t *user) Detail(ctx context.Context, userID int64) (*dto.UserItem, error) {
if userID <= 0 {
return nil, errors.New("user_id must be > 0")
}
model, err := t.FindByID(ctx, userID)
if err != nil {
return nil, err
}
ownedTenantCounts, err := t.UserOwnedTenantCountMapping(ctx, []int64{model.ID})
if err != nil {
return nil, err
}
joinedTenantCounts, err := t.UserJoinedTenantCountMapping(ctx, []int64{model.ID})
if err != nil {
return nil, err
}
return &dto.UserItem{
ID: model.ID,
Username: model.Username,
Roles: model.Roles,
Status: model.Status,
StatusDescription: model.Status.Description(),
Balance: model.Balance,
BalanceFrozen: model.BalanceFrozen,
VerifiedAt: model.VerifiedAt,
CreatedAt: model.CreatedAt,
UpdatedAt: model.UpdatedAt,
OwnedTenantCount: ownedTenantCounts[model.ID],
JoinedTenantCount: joinedTenantCounts[model.ID],
}, nil
}
func (t *user) Create(ctx context.Context, user *models.User) (*models.User, error) {
if err := user.Create(ctx); err != nil {
return nil, errors.Wrapf(err, "Create user failed, %s", user.Username)
}
return user, nil
}
// ResetPasswordByUsername 通过用户名(手机号)重置密码。
func (t *user) ResetPasswordByUsername(ctx context.Context, username, newPassword string) error {
username = strings.TrimSpace(username)
if username == "" {
return errors.New("username is required")
}
if newPassword == "" {
return errors.New("new_password is required")
}
m, err := t.FindByUsername(ctx, username)
if err != nil {
return err
}
// bcrypt hash避免直接落明文。
bytes, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return errors.Wrap(err, "generate password hash failed")
}
m.Password = string(bytes)
return m.Save(ctx)
}
// SetStatus 设置用户状态(超级管理员侧)。
func (t *user) SetStatus(ctx context.Context, userID int64, status consts.UserStatus) error {
m, err := t.FindByID(ctx, userID)
if err != nil {
return err
}
m.Status = status
return m.Save(ctx)
}
// Page 用户分页查询(超级管理员侧)。
func (t *user) Page(ctx context.Context, filter *dto.UserPageFilter) (*requests.Pager, error) {
tbl, query := models.UserQuery.QueryContext(ctx)
if filter == nil {
filter = &dto.UserPageFilter{}
}
conds := []gen.Condition{}
if filter.ID != nil && *filter.ID > 0 {
conds = append(conds, tbl.ID.Eq(*filter.ID))
}
if username := filter.UsernameTrimmed(); username != "" {
conds = append(conds, tbl.Username.Like(database.WrapLike(username)))
}
if filter.Role != nil && *filter.Role != "" {
conds = append(conds, tbl.Roles.Contains(types.NewArray([]consts.Role{*filter.Role})))
}
if filter.TenantID != nil {
tuTbl, _ := models.TenantUserQuery.QueryContext(ctx)
query = query.RightJoin(tuTbl, tuTbl.UserID.EqCol(tbl.ID))
conds = append(conds, tuTbl.TenantID.Eq(*filter.TenantID))
}
if filter.Status != nil {
conds = append(conds, tbl.Status.Eq(*filter.Status))
}
filter.Pagination.Format()
if filter.CreatedAtFrom != nil {
conds = append(conds, tbl.CreatedAt.Gte(*filter.CreatedAtFrom))
}
if filter.CreatedAtTo != nil {
conds = append(conds, tbl.CreatedAt.Lte(*filter.CreatedAtTo))
}
if filter.VerifiedAtFrom != nil {
conds = append(conds, tbl.VerifiedAt.Gte(*filter.VerifiedAtFrom))
}
if filter.VerifiedAtTo != nil {
conds = append(conds, tbl.VerifiedAt.Lte(*filter.VerifiedAtTo))
}
// 排序白名单:避免把任意字段拼进 SQL 导致注入或慢查询。
orderBys := make([]field.Expr, 0, 6)
allowedAsc := map[string]field.Expr{
"id": tbl.ID.Asc(),
"username": tbl.Username.Asc(),
"status": tbl.Status.Asc(),
"balance": tbl.Balance.Asc(),
"verified_at": tbl.VerifiedAt.Asc(),
"created_at": tbl.CreatedAt.Asc(),
"updated_at": tbl.UpdatedAt.Asc(),
}
allowedDesc := map[string]field.Expr{
"id": tbl.ID.Desc(),
"username": tbl.Username.Desc(),
"status": tbl.Status.Desc(),
"balance": tbl.Balance.Desc(),
"verified_at": tbl.VerifiedAt.Desc(),
"created_at": tbl.CreatedAt.Desc(),
"updated_at": tbl.UpdatedAt.Desc(),
}
for _, f := range filter.AscFields() {
f = strings.TrimSpace(f)
if f == "" {
continue
}
if ob, ok := allowedAsc[f]; ok {
orderBys = append(orderBys, ob)
}
}
for _, f := range filter.DescFields() {
f = strings.TrimSpace(f)
if f == "" {
continue
}
if ob, ok := allowedDesc[f]; ok {
orderBys = append(orderBys, ob)
}
}
if len(orderBys) == 0 {
orderBys = append(orderBys, tbl.ID.Desc())
} else {
orderBys = append(orderBys, tbl.ID.Desc())
}
users, total, err := query.Where(conds...).Order(orderBys...).FindByPage(int(filter.Offset()), int(filter.Limit))
if err != nil {
return nil, err
}
userIDs := make([]int64, 0, len(users))
for _, u := range users {
if u == nil {
continue
}
userIDs = append(userIDs, u.ID)
}
ownedTenantCounts, err := t.UserOwnedTenantCountMapping(ctx, userIDs)
if err != nil {
return nil, err
}
joinedTenantCounts, err := t.UserJoinedTenantCountMapping(ctx, userIDs)
if err != nil {
return nil, err
}
items := lo.Map(users, func(model *models.User, _ int) *dto.UserItem {
if model == nil {
return &dto.UserItem{}
}
return &dto.UserItem{
ID: model.ID,
Username: model.Username,
Roles: model.Roles,
Status: model.Status,
StatusDescription: model.Status.Description(),
Balance: model.Balance,
BalanceFrozen: model.BalanceFrozen,
VerifiedAt: model.VerifiedAt,
CreatedAt: model.CreatedAt,
UpdatedAt: model.UpdatedAt,
OwnedTenantCount: ownedTenantCounts[model.ID],
JoinedTenantCount: joinedTenantCounts[model.ID],
}
})
return &requests.Pager{
Pagination: filter.Pagination,
Total: total,
Items: items,
}, nil
}
func (t *user) UserOwnedTenantCountMapping(ctx context.Context, userIDs []int64) (map[int64]int64, error) {
result := make(map[int64]int64, len(userIDs))
for _, id := range userIDs {
if id <= 0 {
continue
}
result[id] = 0
}
if len(result) == 0 {
return result, nil
}
ttbl, tquery := models.TenantQuery.QueryContext(ctx)
var rows []struct {
UserID int64
Count int64
}
err := tquery.
Select(ttbl.UserID, ttbl.ID.Count().As("count")).
Where(ttbl.UserID.In(userIDs...)).
Group(ttbl.UserID).
Scan(&rows)
if err != nil {
return nil, err
}
for _, row := range rows {
result[row.UserID] = row.Count
}
return result, nil
}
func (t *user) UserJoinedTenantCountMapping(ctx context.Context, userIDs []int64) (map[int64]int64, error) {
result := make(map[int64]int64, len(userIDs))
for _, id := range userIDs {
if id <= 0 {
continue
}
result[id] = 0
}
if len(result) == 0 {
return result, nil
}
tutbl, tuquery := models.TenantUserQuery.QueryContext(ctx)
var rows []struct {
UserID int64
Count int64
}
err := tuquery.
Select(tutbl.UserID, tutbl.TenantID.Count().As("count")).
Where(tutbl.UserID.In(userIDs...)).
Group(tutbl.UserID).
Scan(&rows)
if err != nil {
return nil, err
}
for _, row := range rows {
result[row.UserID] = row.Count
}
return result, nil
}
// UpdateStatus 更新用户状态(超级管理员侧)。
func (t *user) UpdateStatus(ctx context.Context, userID int64, status consts.UserStatus) error {
logrus.WithField("user_id", userID).WithField("status", status).Info("update user status")
m, err := t.FindByID(ctx, userID)
if err != nil {
return err
}
m.Status = status
_, err = m.Update(ctx)
if err != nil {
return err
}
return nil
}
// UpdateRoles 更新用户角色(超级管理员侧)。
func (t *user) UpdateRoles(ctx context.Context, userID int64, roles []consts.Role) error {
if userID <= 0 {
return errors.New("user_id must be > 0")
}
roles = lo.Uniq(lo.Filter(roles, func(r consts.Role, _ int) bool {
return r != ""
}))
if len(roles) == 0 {
return errors.New("roles is empty")
}
// 约定:系统用户至少包含 user 角色。
if !lo.Contains(roles, consts.RoleUser) {
roles = append(roles, consts.RoleUser)
}
roles = lo.Uniq(roles)
m, err := t.FindByID(ctx, userID)
if err != nil {
return err
}
m.Roles = types.NewArray(roles)
_, err = m.Update(ctx)
return err
}
// Statistics 按状态统计用户数量(超级管理员侧)。
func (t *user) Statistics(ctx context.Context) ([]*dto.UserStatistics, error) {
tbl, query := models.UserQuery.QueryContext(ctx)
var statistics []*dto.UserStatistics
err := query.Select(tbl.Status, tbl.ID.Count().As("count")).Group(tbl.Status).Scan(&statistics)
if err != nil {
return nil, err
}
return lo.Map(statistics, func(item *dto.UserStatistics, _ int) *dto.UserStatistics {
item.StatusDescription = item.Status.Description()
return item
}), nil
}
// TenantsPage 分页查询“用户加入的租户”(通过 tenant_users 关联)。
func (t *user) TenantsPage(ctx context.Context, userID int64, filter *dto.UserTenantPageFilter) (*requests.Pager, error) {
if userID <= 0 {
return nil, errors.New("user_id must be > 0")
}
if filter == nil {
filter = &dto.UserTenantPageFilter{}
}
filter.Pagination.Format()
tuTbl, query := models.TenantUserQuery.QueryContext(ctx)
conds := []gen.Condition{tuTbl.UserID.Eq(userID)}
if filter.TenantID != nil && *filter.TenantID > 0 {
conds = append(conds, tuTbl.TenantID.Eq(*filter.TenantID))
}
if filter.Role != nil && *filter.Role != "" {
conds = append(conds, tuTbl.Role.Contains(types.NewArray([]consts.TenantUserRole{*filter.Role})))
}
if filter.Status != nil && *filter.Status != "" {
conds = append(conds, tuTbl.Status.Eq(*filter.Status))
}
if filter.CreatedAtFrom != nil {
conds = append(conds, tuTbl.CreatedAt.Gte(*filter.CreatedAtFrom))
}
if filter.CreatedAtTo != nil {
conds = append(conds, tuTbl.CreatedAt.Lte(*filter.CreatedAtTo))
}
code := filter.CodeTrimmed()
name := filter.NameTrimmed()
if code != "" || name != "" {
tTbl, _ := models.TenantQuery.QueryContext(ctx)
query = query.LeftJoin(tTbl, tTbl.ID.EqCol(tuTbl.TenantID))
if code != "" {
conds = append(conds, tTbl.Code.Like(database.WrapLike(code)))
}
if name != "" {
conds = append(conds, tTbl.Name.Like(database.WrapLike(name)))
}
}
rows, total, err := query.Where(conds...).Order(tuTbl.ID.Desc()).FindByPage(int(filter.Offset()), int(filter.Limit))
if err != nil {
return nil, err
}
tenantIDs := make([]int64, 0, len(rows))
for _, tu := range rows {
if tu == nil {
continue
}
tenantIDs = append(tenantIDs, tu.TenantID)
}
tenantIDs = lo.Uniq(tenantIDs)
tenants := make(map[int64]*models.Tenant, len(tenantIDs))
tenantList := make([]*models.Tenant, 0, len(tenantIDs))
if len(tenantIDs) > 0 {
tTbl, tQuery := models.TenantQuery.QueryContext(ctx)
ts, err := tQuery.Where(tTbl.ID.In(tenantIDs...)).Find()
if err != nil {
return nil, err
}
for _, te := range ts {
if te == nil {
continue
}
tenants[te.ID] = te
tenantList = append(tenantList, te)
}
}
ownerMap, err := Tenant.TenantOwnerUserMapping(ctx, tenantList)
if err != nil {
return nil, err
}
items := make([]*dto.UserTenantItem, 0, len(rows))
for _, tu := range rows {
if tu == nil {
continue
}
te := tenants[tu.TenantID]
if te == nil {
continue
}
items = append(items, &dto.UserTenantItem{
TenantID: te.ID,
Code: te.Code,
Name: te.Name,
TenantStatus: te.Status,
TenantStatusDescription: te.Status.Description(),
ExpiredAt: te.ExpiredAt,
Owner: ownerMap[te.ID],
Role: tu.Role,
MemberStatus: tu.Status,
MemberStatusDescription: tu.Status.Description(),
JoinedAt: tu.CreatedAt,
})
}
return &requests.Pager{
Pagination: filter.Pagination,
Total: total,
Items: items,
}, nil
}