779 lines
21 KiB
Go
779 lines
21 KiB
Go
package services
|
||
|
||
import (
|
||
"context"
|
||
"strings"
|
||
"time"
|
||
|
||
superdto "quyun/v2/app/http/super/dto"
|
||
tenantdto "quyun/v2/app/http/tenant/dto"
|
||
"quyun/v2/app/requests"
|
||
"quyun/v2/database"
|
||
"quyun/v2/database/models"
|
||
"quyun/v2/pkg/consts"
|
||
|
||
"github.com/pkg/errors"
|
||
"github.com/samber/lo"
|
||
"github.com/sirupsen/logrus"
|
||
"go.ipao.vip/gen"
|
||
"go.ipao.vip/gen/field"
|
||
"go.ipao.vip/gen/types"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// tenant implements tenant-related domain operations.
|
||
//
|
||
// @provider
|
||
type tenant struct{}
|
||
|
||
// SuperCreateTenant 超级管理员创建租户,并将指定用户设为租户管理员。
|
||
func (t *tenant) SuperCreateTenant(ctx context.Context, form *superdto.TenantCreateForm) (*models.Tenant, error) {
|
||
if form == nil {
|
||
return nil, errors.New("form is nil")
|
||
}
|
||
|
||
code := strings.ToLower(strings.TrimSpace(form.Code))
|
||
if code == "" {
|
||
return nil, errors.New("code is empty")
|
||
}
|
||
name := strings.TrimSpace(form.Name)
|
||
if name == "" {
|
||
return nil, errors.New("name is empty")
|
||
}
|
||
if form.AdminUserID <= 0 {
|
||
return nil, errors.New("admin_user_id must be > 0")
|
||
}
|
||
duration, err := (&superdto.TenantExpireUpdateForm{Duration: form.Duration}).ParseDuration()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 确保管理员用户存在(同时可提前暴露“用户不存在”的错误,而不是等到外键/逻辑报错)。
|
||
if _, err := User.FindByID(ctx, form.AdminUserID); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
now := time.Now().UTC()
|
||
tenant := &models.Tenant{
|
||
UserID: form.AdminUserID,
|
||
Code: code,
|
||
UUID: types.NewUUIDv4(),
|
||
Name: name,
|
||
Status: consts.TenantStatusVerified,
|
||
Config: types.JSON([]byte(`{}`)),
|
||
ExpiredAt: now.Add(duration),
|
||
}
|
||
|
||
db := _db.WithContext(ctx)
|
||
err = db.Transaction(func(tx *gorm.DB) error {
|
||
if err := tx.Create(tenant).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
tenantUser := &models.TenantUser{
|
||
TenantID: tenant.ID,
|
||
UserID: form.AdminUserID,
|
||
Role: types.NewArray([]consts.TenantUserRole{consts.TenantUserRoleTenantAdmin}),
|
||
Status: consts.UserStatusVerified,
|
||
}
|
||
if err := tx.Create(tenantUser).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
return tx.First(tenant, tenant.ID).Error
|
||
})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return tenant, nil
|
||
}
|
||
|
||
// AdminTenantUsersPage 租户管理员分页查询成员列表(包含用户基础信息)。
|
||
func (t *tenant) AdminTenantUsersPage(ctx context.Context, tenantID int64, filter *tenantdto.AdminTenantUserListFilter) (*requests.Pager, error) {
|
||
if tenantID <= 0 {
|
||
return nil, errors.New("tenant_id must be > 0")
|
||
}
|
||
if filter == nil {
|
||
filter = &tenantdto.AdminTenantUserListFilter{}
|
||
}
|
||
|
||
filter.Pagination.Format()
|
||
|
||
tbl, query := models.TenantUserQuery.QueryContext(ctx)
|
||
conds := []gen.Condition{tbl.TenantID.Eq(tenantID)}
|
||
if filter.UserID != nil && *filter.UserID > 0 {
|
||
conds = append(conds, tbl.UserID.Eq(*filter.UserID))
|
||
}
|
||
if filter.Role != nil && *filter.Role != "" {
|
||
// role 字段为 PostgreSQL text[]:使用数组参数才能正确生成 `@> '{"tenant_admin"}'` 语义。
|
||
conds = append(conds, tbl.Role.Contains(types.NewArray([]consts.TenantUserRole{*filter.Role})))
|
||
}
|
||
if filter.Status != nil && *filter.Status != "" {
|
||
conds = append(conds, tbl.Status.Eq(*filter.Status))
|
||
}
|
||
if username := filter.UsernameTrimmed(); username != "" {
|
||
uTbl, _ := models.UserQuery.QueryContext(ctx)
|
||
query = query.LeftJoin(uTbl, uTbl.ID.EqCol(tbl.UserID))
|
||
conds = append(conds, uTbl.Username.Like(database.WrapLike(username)))
|
||
}
|
||
|
||
items, total, err := query.Where(conds...).Order(tbl.ID.Desc()).FindByPage(int(filter.Offset()), int(filter.Limit))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
userIDs := make([]int64, 0, len(items))
|
||
for _, tu := range items {
|
||
if tu == nil {
|
||
continue
|
||
}
|
||
userIDs = append(userIDs, tu.UserID)
|
||
}
|
||
|
||
var users []*models.User
|
||
if len(userIDs) > 0 {
|
||
uTbl, uQuery := models.UserQuery.QueryContext(ctx)
|
||
users, err = uQuery.Where(uTbl.ID.In(userIDs...)).Find()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
userMap := make(map[int64]*models.User, len(users))
|
||
for _, u := range users {
|
||
if u == nil {
|
||
continue
|
||
}
|
||
userMap[u.ID] = u
|
||
}
|
||
|
||
out := make([]*tenantdto.AdminTenantUserItem, 0, len(items))
|
||
for _, tu := range items {
|
||
if tu == nil {
|
||
continue
|
||
}
|
||
out = append(out, &tenantdto.AdminTenantUserItem{
|
||
TenantUser: tu,
|
||
User: userMap[tu.UserID],
|
||
})
|
||
}
|
||
|
||
return &requests.Pager{
|
||
Pagination: filter.Pagination,
|
||
Total: total,
|
||
Items: out,
|
||
}, nil
|
||
}
|
||
|
||
// SuperTenantUsersPage 超级管理员分页查询租户成员(脱敏 user 字段,避免泄露 password)。
|
||
func (t *tenant) SuperTenantUsersPage(ctx context.Context, tenantID int64, filter *tenantdto.AdminTenantUserListFilter) (*requests.Pager, error) {
|
||
if tenantID <= 0 {
|
||
return nil, errors.New("tenant_id must be > 0")
|
||
}
|
||
if filter == nil {
|
||
filter = &tenantdto.AdminTenantUserListFilter{}
|
||
}
|
||
|
||
filter.Pagination.Format()
|
||
|
||
tbl, query := models.TenantUserQuery.QueryContext(ctx)
|
||
conds := []gen.Condition{tbl.TenantID.Eq(tenantID)}
|
||
if filter.UserID != nil && *filter.UserID > 0 {
|
||
conds = append(conds, tbl.UserID.Eq(*filter.UserID))
|
||
}
|
||
if filter.Role != nil && *filter.Role != "" {
|
||
conds = append(conds, tbl.Role.Contains(types.NewArray([]consts.TenantUserRole{*filter.Role})))
|
||
}
|
||
if filter.Status != nil && *filter.Status != "" {
|
||
conds = append(conds, tbl.Status.Eq(*filter.Status))
|
||
}
|
||
if username := filter.UsernameTrimmed(); username != "" {
|
||
uTbl, _ := models.UserQuery.QueryContext(ctx)
|
||
query = query.LeftJoin(uTbl, uTbl.ID.EqCol(tbl.UserID))
|
||
conds = append(conds, uTbl.Username.Like(database.WrapLike(username)))
|
||
}
|
||
|
||
items, total, err := query.Where(conds...).Order(tbl.ID.Desc()).FindByPage(int(filter.Offset()), int(filter.Limit))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
userIDs := make([]int64, 0, len(items))
|
||
for _, tu := range items {
|
||
if tu == nil {
|
||
continue
|
||
}
|
||
userIDs = append(userIDs, tu.UserID)
|
||
}
|
||
|
||
var users []*models.User
|
||
if len(userIDs) > 0 {
|
||
uTbl, uQuery := models.UserQuery.QueryContext(ctx)
|
||
users, err = uQuery.Where(uTbl.ID.In(userIDs...)).Find()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
userMap := make(map[int64]*models.User, len(users))
|
||
for _, u := range users {
|
||
if u == nil {
|
||
continue
|
||
}
|
||
userMap[u.ID] = u
|
||
}
|
||
|
||
out := make([]*superdto.SuperTenantUserItem, 0, len(items))
|
||
for _, tu := range items {
|
||
if tu == nil {
|
||
continue
|
||
}
|
||
u := userMap[tu.UserID]
|
||
var lite *superdto.SuperUserLite
|
||
if u != nil {
|
||
lite = &superdto.SuperUserLite{
|
||
ID: u.ID,
|
||
Username: u.Username,
|
||
Status: u.Status,
|
||
StatusDescription: u.Status.Description(),
|
||
Roles: u.Roles,
|
||
VerifiedAt: u.VerifiedAt,
|
||
CreatedAt: u.CreatedAt,
|
||
UpdatedAt: u.UpdatedAt,
|
||
}
|
||
}
|
||
out = append(out, &superdto.SuperTenantUserItem{
|
||
TenantUser: tu,
|
||
User: lite,
|
||
})
|
||
}
|
||
|
||
return &requests.Pager{
|
||
Pagination: filter.Pagination,
|
||
Total: total,
|
||
Items: out,
|
||
}, nil
|
||
}
|
||
|
||
func (t *tenant) ContainsUserID(ctx context.Context, tenantID, userID int64) (*models.User, error) {
|
||
tbl, query := models.TenantUserQuery.QueryContext(ctx)
|
||
|
||
_, err := query.Where(tbl.TenantID.Eq(tenantID), tbl.UserID.Eq(userID)).First()
|
||
if err != nil {
|
||
return nil, errors.Wrapf(err, "ContainsUserID failed, tenantID: %d, userID: %d", tenantID, userID)
|
||
}
|
||
|
||
return User.FindByID(ctx, userID)
|
||
}
|
||
|
||
// AddUser
|
||
func (t *tenant) AddUser(ctx context.Context, tenantID, userID int64) error {
|
||
logrus.WithFields(logrus.Fields{
|
||
"tenant_id": tenantID,
|
||
"user_id": userID,
|
||
}).Info("services.tenant.add_user")
|
||
|
||
// 幂等:若成员关系已存在,则直接返回成功,避免重复插入触发唯一约束错误。
|
||
tbl, query := models.TenantUserQuery.QueryContext(ctx)
|
||
_, err := query.Where(tbl.TenantID.Eq(tenantID), tbl.UserID.Eq(userID)).First()
|
||
if err == nil {
|
||
return nil
|
||
}
|
||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return errors.Wrapf(err, "AddUser failed to query existing, tenantID: %d, userID: %d", tenantID, userID)
|
||
}
|
||
|
||
// 关键默认值:加入租户默认成为 member,并设置为 verified(避免 DB 默认值与枚举不一致导致脏数据)。
|
||
tenantUser := &models.TenantUser{
|
||
TenantID: tenantID,
|
||
UserID: userID,
|
||
Role: types.NewArray([]consts.TenantUserRole{consts.TenantUserRoleMember}),
|
||
Status: consts.UserStatusVerified,
|
||
}
|
||
|
||
if err := tenantUser.Create(ctx); err != nil {
|
||
return errors.Wrapf(err, "AddUser failed, tenantID: %d, userID: %d", tenantID, userID)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// RemoveUser
|
||
func (t *tenant) RemoveUser(ctx context.Context, tenantID, userID int64) error {
|
||
tbl, query := models.TenantUserQuery.QueryContext(ctx)
|
||
tenantUser, err := query.Where(tbl.TenantID.Eq(tenantID), tbl.UserID.Eq(userID)).First()
|
||
if err != nil {
|
||
// 幂等:成员不存在时也返回成功,便于后台重试/批量移除。
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil
|
||
}
|
||
return errors.Wrapf(err, "RemoveUser failed to find, tenantID: %d, userID: %d", tenantID, userID)
|
||
}
|
||
|
||
_, err = tenantUser.Delete(ctx)
|
||
if err != nil {
|
||
return errors.Wrapf(err, "RemoveUser failed to delete, tenantID: %d, userID: %d", tenantID, userID)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// SetUserRole
|
||
func (t *tenant) SetUserRole(ctx context.Context, tenantID, userID int64, role ...consts.TenantUserRole) error {
|
||
tbl, query := models.TenantUserQuery.QueryContext(ctx)
|
||
tenantUser, err := query.Where(tbl.TenantID.Eq(tenantID), tbl.UserID.Eq(userID)).First()
|
||
if err != nil {
|
||
return errors.Wrapf(err, "SetUserRole failed to find, tenantID: %d, userID: %d", tenantID, userID)
|
||
}
|
||
|
||
// 角色更新:当前约定 role 数组通常只存一个主角色(member/tenant_admin)。
|
||
tenantUser.Role = types.NewArray(role)
|
||
if _, err := tenantUser.Update(ctx); err != nil {
|
||
return errors.Wrapf(err, "SetUserRole failed to update, tenantID: %d, userID: %d", tenantID, userID)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// Pager
|
||
func (t *tenant) Pager(ctx context.Context, filter *superdto.TenantFilter) (*requests.Pager, error) {
|
||
tbl, query := models.TenantQuery.QueryContext(ctx)
|
||
|
||
conds := []gen.Condition{}
|
||
if filter == nil {
|
||
filter = &superdto.TenantFilter{}
|
||
}
|
||
|
||
if filter.ID != nil && *filter.ID > 0 {
|
||
conds = append(conds, tbl.ID.Eq(*filter.ID))
|
||
}
|
||
if filter.UserID != nil && *filter.UserID > 0 {
|
||
conds = append(conds, tbl.UserID.Eq(*filter.UserID))
|
||
}
|
||
|
||
if name := filter.NameTrimmed(); name != "" {
|
||
conds = append(conds, tbl.Name.Like(database.WrapLike(name)))
|
||
}
|
||
|
||
if code := filter.CodeTrimmed(); code != "" {
|
||
// code 在库内按约定存储为 lower-case;这里统一转小写后做 like。
|
||
conds = append(conds, tbl.Code.Like(database.WrapLike(code)))
|
||
}
|
||
|
||
if filter.Status != nil {
|
||
conds = append(conds, tbl.Status.Eq(*filter.Status))
|
||
}
|
||
|
||
filter.Pagination.Format()
|
||
|
||
if filter.ExpiredAtFrom != nil {
|
||
conds = append(conds, tbl.ExpiredAt.Gte(*filter.ExpiredAtFrom))
|
||
}
|
||
if filter.ExpiredAtTo != nil {
|
||
conds = append(conds, tbl.ExpiredAt.Lte(*filter.ExpiredAtTo))
|
||
}
|
||
if filter.CreatedAtFrom != nil {
|
||
conds = append(conds, tbl.CreatedAt.Gte(*filter.CreatedAtFrom))
|
||
}
|
||
if filter.CreatedAtTo != nil {
|
||
conds = append(conds, tbl.CreatedAt.Lte(*filter.CreatedAtTo))
|
||
}
|
||
|
||
// 排序白名单:避免把任意字符串拼进 SQL 导致注入或慢查询。
|
||
orderBys := make([]field.Expr, 0, 6)
|
||
allowedAsc := map[string]field.Expr{
|
||
"id": tbl.ID.Asc(),
|
||
"code": tbl.Code.Asc(),
|
||
"name": tbl.Name.Asc(),
|
||
"status": tbl.Status.Asc(),
|
||
"expired_at": tbl.ExpiredAt.Asc(),
|
||
"created_at": tbl.CreatedAt.Asc(),
|
||
"updated_at": tbl.UpdatedAt.Asc(),
|
||
}
|
||
allowedDesc := map[string]field.Expr{
|
||
"id": tbl.ID.Desc(),
|
||
"code": tbl.Code.Desc(),
|
||
"name": tbl.Name.Desc(),
|
||
"status": tbl.Status.Desc(),
|
||
"expired_at": tbl.ExpiredAt.Desc(),
|
||
"created_at": tbl.CreatedAt.Desc(),
|
||
"updated_at": tbl.UpdatedAt.Desc(),
|
||
}
|
||
for _, f := range filter.AscFields() {
|
||
f = strings.TrimSpace(f)
|
||
if f == "" {
|
||
continue
|
||
}
|
||
if ob, ok := allowedAsc[f]; ok {
|
||
orderBys = append(orderBys, ob)
|
||
}
|
||
}
|
||
for _, f := range filter.DescFields() {
|
||
f = strings.TrimSpace(f)
|
||
if f == "" {
|
||
continue
|
||
}
|
||
if ob, ok := allowedDesc[f]; ok {
|
||
orderBys = append(orderBys, ob)
|
||
}
|
||
}
|
||
if len(orderBys) == 0 {
|
||
orderBys = append(orderBys, tbl.ID.Desc())
|
||
} else {
|
||
orderBys = append(orderBys, tbl.ID.Desc())
|
||
}
|
||
|
||
mm, total, err := query.Where(conds...).Order(orderBys...).FindByPage(int(filter.Offset()), int(filter.Limit))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
tenantIds := lo.Map(mm, func(item *models.Tenant, _ int) int64 { return item.ID })
|
||
|
||
userCountMapping, err := t.TenantUserCountMapping(ctx, tenantIds)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
incomeMapping, err := t.TenantIncomePaidMapping(ctx, tenantIds)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
items := lo.Map(mm, func(model *models.Tenant, _ int) *superdto.TenantItem {
|
||
return &superdto.TenantItem{
|
||
Tenant: model,
|
||
UserCount: lo.ValueOr(userCountMapping, model.ID, 0),
|
||
IncomeAmountPaidSum: lo.ValueOr(incomeMapping, model.ID, 0),
|
||
StatusDescription: model.Status.Description(),
|
||
}
|
||
})
|
||
|
||
ownerMapping, err := t.TenantOwnerUserMapping(ctx, mm)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, it := range items {
|
||
if it == nil || it.Tenant == nil {
|
||
continue
|
||
}
|
||
it.Owner = ownerMapping[it.Tenant.ID]
|
||
}
|
||
|
||
adminUsersMapping, err := t.TenantAdminUsersMapping(ctx, tenantIds)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, it := range items {
|
||
if it == nil || it.Tenant == nil {
|
||
continue
|
||
}
|
||
it.AdminUsers = adminUsersMapping[it.Tenant.ID]
|
||
}
|
||
|
||
return &requests.Pager{
|
||
Pagination: filter.Pagination,
|
||
Total: total,
|
||
Items: items,
|
||
}, nil
|
||
}
|
||
|
||
func (t *tenant) TenantOwnerUserMapping(ctx context.Context, tenants []*models.Tenant) (map[int64]*superdto.TenantOwnerUserLite, error) {
|
||
result := make(map[int64]*superdto.TenantOwnerUserLite, len(tenants))
|
||
|
||
userIDs := make([]int64, 0, len(tenants))
|
||
tenantIDs := make([]int64, 0, len(tenants))
|
||
for _, te := range tenants {
|
||
if te == nil || te.ID <= 0 {
|
||
continue
|
||
}
|
||
tenantIDs = append(tenantIDs, te.ID)
|
||
if te.UserID > 0 {
|
||
userIDs = append(userIDs, te.UserID)
|
||
}
|
||
}
|
||
for _, tenantID := range tenantIDs {
|
||
result[tenantID] = nil
|
||
}
|
||
userIDs = lo.Uniq(userIDs)
|
||
if len(userIDs) == 0 {
|
||
return result, nil
|
||
}
|
||
|
||
uTbl, uQuery := models.UserQuery.QueryContext(ctx)
|
||
users, err := uQuery.Where(uTbl.ID.In(userIDs...)).Find()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
userMap := make(map[int64]*models.User, len(users))
|
||
for _, u := range users {
|
||
if u == nil {
|
||
continue
|
||
}
|
||
userMap[u.ID] = u
|
||
}
|
||
|
||
for _, te := range tenants {
|
||
if te == nil || te.ID <= 0 || te.UserID <= 0 {
|
||
continue
|
||
}
|
||
u := userMap[te.UserID]
|
||
if u == nil {
|
||
continue
|
||
}
|
||
result[te.ID] = &superdto.TenantOwnerUserLite{
|
||
ID: u.ID,
|
||
Username: u.Username,
|
||
}
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// TenantAdminUsersMapping 返回每个租户的管理员用户(用于 superadmin 租户列表展示)。
|
||
func (t *tenant) TenantAdminUsersMapping(ctx context.Context, tenantIDs []int64) (map[int64][]*superdto.TenantAdminUserLite, error) {
|
||
result := make(map[int64][]*superdto.TenantAdminUserLite, len(tenantIDs))
|
||
for _, id := range tenantIDs {
|
||
if id <= 0 {
|
||
continue
|
||
}
|
||
result[id] = nil
|
||
}
|
||
if len(result) == 0 {
|
||
return result, nil
|
||
}
|
||
|
||
tuTbl, tuQuery := models.TenantUserQuery.QueryContext(ctx)
|
||
tus, err := tuQuery.Where(
|
||
tuTbl.TenantID.In(tenantIDs...),
|
||
tuTbl.Role.Contains(types.NewArray([]consts.TenantUserRole{consts.TenantUserRoleTenantAdmin})),
|
||
).Find()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
userIDs := make([]int64, 0, len(tus))
|
||
type pair struct {
|
||
tenantID int64
|
||
userID int64
|
||
}
|
||
pairs := make([]pair, 0, len(tus))
|
||
for _, tu := range tus {
|
||
if tu == nil || tu.TenantID <= 0 || tu.UserID <= 0 {
|
||
continue
|
||
}
|
||
userIDs = append(userIDs, tu.UserID)
|
||
pairs = append(pairs, pair{tenantID: tu.TenantID, userID: tu.UserID})
|
||
}
|
||
userIDs = lo.Uniq(userIDs)
|
||
|
||
userMap := map[int64]*models.User{}
|
||
if len(userIDs) > 0 {
|
||
uTbl, uQuery := models.UserQuery.QueryContext(ctx)
|
||
users, err := uQuery.Where(uTbl.ID.In(userIDs...)).Find()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, u := range users {
|
||
if u == nil {
|
||
continue
|
||
}
|
||
userMap[u.ID] = u
|
||
}
|
||
}
|
||
|
||
for _, p := range pairs {
|
||
u := userMap[p.userID]
|
||
if u == nil {
|
||
continue
|
||
}
|
||
result[p.tenantID] = append(result[p.tenantID], &superdto.TenantAdminUserLite{
|
||
ID: u.ID,
|
||
Username: u.Username,
|
||
})
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
func (t *tenant) TenantUserCountMapping(ctx context.Context, tenantIds []int64) (map[int64]int64, error) {
|
||
// 关键语义:返回值必须包含入参中的所有 tenant_id。
|
||
// 即便该租户当前没有成员,也应返回 count=0,便于调用方直接取值而无需额外补全逻辑。
|
||
result := make(map[int64]int64, len(tenantIds))
|
||
for _, id := range tenantIds {
|
||
if id <= 0 {
|
||
continue
|
||
}
|
||
result[id] = 0
|
||
}
|
||
if len(result) == 0 {
|
||
return result, nil
|
||
}
|
||
|
||
tbl, query := models.TenantUserQuery.QueryContext(ctx)
|
||
|
||
var items []struct {
|
||
TenantID int64
|
||
Count int64
|
||
}
|
||
err := query.
|
||
Select(
|
||
tbl.TenantID,
|
||
tbl.UserID.Count().As("count"),
|
||
).
|
||
Where(tbl.TenantID.In(tenantIds...)).
|
||
Group(tbl.TenantID).
|
||
Scan(&items)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
for _, item := range items {
|
||
result[item.TenantID] = item.Count
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// TenantUserBalanceMapping
|
||
func (t *tenant) TenantUserBalanceMapping(ctx context.Context, tenantIds []int64) (map[int64]int64, error) {
|
||
// 关键语义:返回值必须包含入参中的所有 tenant_id。
|
||
// 即便该租户当前没有成员,也应返回 balance=0,保持调用方逻辑一致。
|
||
result := make(map[int64]int64, len(tenantIds))
|
||
for _, id := range tenantIds {
|
||
if id <= 0 {
|
||
continue
|
||
}
|
||
result[id] = 0
|
||
}
|
||
if len(result) == 0 {
|
||
return result, nil
|
||
}
|
||
|
||
var items []struct {
|
||
TenantID int64
|
||
Balance int64
|
||
}
|
||
|
||
// 全局余额:按租户维度统计“该租户成员的 users.balance 之和”。
|
||
// 注意:用户可能加入多个租户,因此不同租户的统计会出现重复计入(这符合“按租户视角”统计的直觉)。
|
||
err := models.Q.TenantUser.
|
||
WithContext(ctx).
|
||
UnderlyingDB().
|
||
Table(models.TableNameTenantUser+" tu").
|
||
Select("tu.tenant_id, COALESCE(SUM(u.balance), 0) AS balance").
|
||
Joins("JOIN "+models.TableNameUser+" u ON u.id = tu.user_id AND u.deleted_at IS NULL").
|
||
Where("tu.tenant_id IN ?", tenantIds).
|
||
Group("tu.tenant_id").
|
||
Scan(&items).
|
||
Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
for _, item := range items {
|
||
result[item.TenantID] = item.Balance
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// TenantIncomePaidMapping 按租户维度统计“已支付订单”的累计收入(单位:分,CNY)。
|
||
// 说明:
|
||
// - 仅统计 orders.status = paid 的订单金额;
|
||
// - refunding/refunded 不计入收入(避免把已退/退款中的金额当作收入)。
|
||
func (t *tenant) TenantIncomePaidMapping(ctx context.Context, tenantIDs []int64) (map[int64]int64, error) {
|
||
result := make(map[int64]int64, len(tenantIDs))
|
||
for _, id := range tenantIDs {
|
||
if id <= 0 {
|
||
continue
|
||
}
|
||
result[id] = 0
|
||
}
|
||
if len(result) == 0 {
|
||
return result, nil
|
||
}
|
||
|
||
oTbl, oQuery := models.OrderQuery.QueryContext(ctx)
|
||
var rows []struct {
|
||
TenantID int64
|
||
Income int64
|
||
}
|
||
err := oQuery.
|
||
Select(oTbl.TenantID, oTbl.AmountPaid.Sum().As("income")).
|
||
Where(oTbl.TenantID.In(tenantIDs...), oTbl.Status.Eq(consts.OrderStatusPaid)).
|
||
Group(oTbl.TenantID).
|
||
Scan(&rows)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, row := range rows {
|
||
result[row.TenantID] = row.Income
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// FindByID
|
||
func (t *tenant) FindByID(ctx context.Context, id int64) (*models.Tenant, error) {
|
||
tbl, query := models.TenantQuery.QueryContext(ctx)
|
||
m, err := query.Where(tbl.ID.Eq(id)).First()
|
||
if err != nil {
|
||
return nil, errors.Wrapf(err, "find by id failed, id: %d", id)
|
||
}
|
||
return m, nil
|
||
}
|
||
|
||
func (t *tenant) FindByCode(ctx context.Context, code string) (*models.Tenant, error) {
|
||
code = strings.TrimSpace(code)
|
||
if code == "" {
|
||
return nil, errors.New("tenant code is empty")
|
||
}
|
||
code = strings.ToLower(code)
|
||
|
||
var m models.Tenant
|
||
err := models.Q.Tenant.WithContext(ctx).UnderlyingDB().Where("lower(code) = ?", code).First(&m).Error
|
||
if err != nil {
|
||
return nil, errors.Wrapf(err, "find by code failed, code: %s", code)
|
||
}
|
||
return &m, nil
|
||
}
|
||
|
||
func (t *tenant) FindTenantUser(ctx context.Context, tenantID, userID int64) (*models.TenantUser, error) {
|
||
logrus.WithField("tenant_id", tenantID).WithField("user_id", userID).Info("find tenant user")
|
||
tbl, query := models.TenantUserQuery.QueryContext(ctx)
|
||
m, err := query.Where(tbl.TenantID.Eq(tenantID), tbl.UserID.Eq(userID)).First()
|
||
if err != nil {
|
||
return nil, errors.Wrapf(err, "find tenant user failed, tenantID: %d, userID: %d", tenantID, userID)
|
||
}
|
||
return m, nil
|
||
}
|
||
|
||
// AddExpireDuration
|
||
func (t *tenant) AddExpireDuration(ctx context.Context, tenantID int64, duration time.Duration) error {
|
||
logrus.WithField("tenant_id", tenantID).WithField("duration", duration).Info("add expire duration")
|
||
|
||
m, err := t.FindByID(ctx, tenantID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
if m.ExpiredAt.Before(time.Now()) {
|
||
m.ExpiredAt = time.Now().Add(duration)
|
||
} else {
|
||
m.ExpiredAt = m.ExpiredAt.Add(duration)
|
||
}
|
||
return m.Save(ctx)
|
||
}
|
||
|
||
// UpdateStatus
|
||
func (t *tenant) UpdateStatus(ctx context.Context, tenantID int64, status consts.TenantStatus) error {
|
||
logrus.WithField("tenant_id", tenantID).WithField("status", status).Info("update tenant status")
|
||
|
||
m, err := t.FindByID(ctx, tenantID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
m.Status = status
|
||
_, err = m.Update(ctx)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|