package services import ( "context" "time" "quyun/v2/app/http/super/dto" "quyun/v2/app/requests" "quyun/v2/database" "quyun/v2/database/models" "quyun/v2/pkg/consts" "github.com/pkg/errors" "github.com/samber/lo" "github.com/sirupsen/logrus" "go.ipao.vip/gen" ) // @provider type tenant struct{} func (t *tenant) ContainsUserID(ctx context.Context, tenantID, userID int64) (*models.User, error) { tbl, query := models.TenantUserQuery.QueryContext(ctx) _, err := query.Where(tbl.TenantID.Eq(tenantID), tbl.UserID.Eq(userID)).First() if err != nil { return nil, errors.Wrapf(err, "ContainsUserID failed, tenantID: %d, userID: %d", tenantID, userID) } return User.FindByID(ctx, userID) } // AddUser func (t *tenant) AddUser(ctx context.Context, tenantID, userID int64) error { tenantUser := &models.TenantUser{ TenantID: tenantID, UserID: userID, } if err := tenantUser.Create(ctx); err != nil { return errors.Wrapf(err, "AddUser failed, tenantID: %d, userID: %d", tenantID, userID) } return nil } // RemoveUser func (t *tenant) RemoveUser(ctx context.Context, tenantID, userID int64) error { tbl, query := models.TenantUserQuery.QueryContext(ctx) tenantUser, err := query.Where(tbl.TenantID.Eq(tenantID), tbl.UserID.Eq(userID)).First() if err != nil { return errors.Wrapf(err, "RemoveUser failed to find, tenantID: %d, userID: %d", tenantID, userID) } _, err = tenantUser.Delete(ctx) if err != nil { return errors.Wrapf(err, "RemoveUser failed to delete, tenantID: %d, userID: %d", tenantID, userID) } return nil } // SetUserRole func (t *tenant) SetUserRole(ctx context.Context, tenantID, userID int64, role ...consts.TenantUserRole) error { tbl, query := models.TenantUserQuery.QueryContext(ctx) tenantUser, err := query.Where(tbl.TenantID.Eq(tenantID), tbl.UserID.Eq(userID)).First() if err != nil { return errors.Wrapf(err, "SetUserRole failed to find, tenantID: %d, userID: %d", tenantID, userID) } tenantUser.Role = role if _, err := tenantUser.Update(ctx); err != nil { return errors.Wrapf(err, "SetUserRole failed to update, tenantID: %d, userID: %d", tenantID, userID) } return nil } // Pager func (t *tenant) Pager(ctx context.Context, filter *dto.TenantFilter) (*requests.Pager, error) { tbl, query := models.TenantQuery.QueryContext(ctx) conds := []gen.Condition{} if filter.Name != nil { conds = append(conds, tbl.Name.Like(database.WrapLike(*filter.Name))) } if filter.Status != nil { conds = append(conds, tbl.Status.Eq(*filter.Status)) } filter.Pagination.Format() mm, total, err := query.Where(conds...).Order(tbl.ID.Desc()).FindByPage(int(filter.Offset()), int(filter.Limit)) if err != nil { return nil, err } tenantIds := lo.Map(mm, func(item *models.Tenant, _ int) int64 { return item.ID }) userCountMapping, err := t.TenantUserCountMapping(ctx, tenantIds) if err != nil { return nil, err } userBalanceMapping, err := t.TenantUserBalanceMapping(ctx, tenantIds) if err != nil { return nil, err } items := lo.Map(mm, func(model *models.Tenant, _ int) *dto.TenantItem { return &dto.TenantItem{ Tenant: model, UserCount: lo.ValueOr(userCountMapping, model.ID, 0), UserBalance: lo.ValueOr(userBalanceMapping, model.ID, 0), StatusDescription: model.Status.Description(), } }) return &requests.Pager{ Pagination: filter.Pagination, Total: total, Items: items, }, nil } func (t *tenant) TenantUserCountMapping(ctx context.Context, tenantIds []int64) (map[int64]int64, error) { tbl, query := models.TenantUserQuery.QueryContext(ctx) var items []struct { TenantID int64 Count int64 } err := query. Select( tbl.TenantID, tbl.UserID.Count().As("count"), ). Where(tbl.TenantID.In(tenantIds...)). Group(tbl.TenantID). Scan(&items) if err != nil { return nil, err } result := make(map[int64]int64) for _, item := range items { result[item.TenantID] = item.Count } return result, nil } // TenantUserBalanceMapping func (t *tenant) TenantUserBalanceMapping(ctx context.Context, tenantIds []int64) (map[int64]int64, error) { tbl, query := models.TenantUserQuery.QueryContext(ctx) var items []struct { TenantID int64 Balance int64 } err := query. Select( tbl.TenantID, tbl.Balance.Sum().As("balance"), ). Where(tbl.TenantID.In(tenantIds...)). Group(tbl.TenantID). Scan(&items) if err != nil { return nil, err } result := make(map[int64]int64) for _, item := range items { result[item.TenantID] = item.Balance } return result, nil } // FindByID func (t *tenant) FindByID(ctx context.Context, id int64) (*models.Tenant, error) { tbl, query := models.TenantQuery.QueryContext(ctx) m, err := query.Where(tbl.ID.Eq(id)).First() if err != nil { return nil, errors.Wrapf(err, "find by id failed, id: %d", id) } return m, nil } // AddExpireDuration func (t *tenant) AddExpireDuration(ctx context.Context, tenantID int64, duration time.Duration) error { logrus.WithField("tenant_id", tenantID).WithField("duration", duration).Info("add expire duration") m, err := t.FindByID(ctx, tenantID) if err != nil { return err } if m.ExpiredAt.Before(time.Now()) { m.ExpiredAt = time.Now().Add(duration) } else { m.ExpiredAt = m.ExpiredAt.Add(duration) } return m.Save(ctx) } // UpdateStatus func (t *tenant) UpdateStatus(ctx context.Context, tenantID int64, status consts.TenantStatus) error { logrus.WithField("tenant_id", tenantID).WithField("status", status).Info("update tenant status") m, err := t.FindByID(ctx, tenantID) if err != nil { return err } m.Status = status _, err = m.Update(ctx) if err != nil { return err } return nil }