package services import ( "context" "crypto/rand" "fmt" "math/big" "strconv" "strings" "sync" "time" "quyun/v2/app/http/dto" "quyun/v2/app/requests" "quyun/v2/database" "quyun/v2/database/models" "quyun/v2/providers/ali" "github.com/pkg/errors" "github.com/samber/lo" log "github.com/sirupsen/logrus" "go.ipao.vip/gen" "gorm.io/gorm" ) // @provider type users struct { mu sync.Mutex `inject:"false"` lastSentAtByPhone map[string]time.Time `inject:"false"` codeByPhone map[string]phoneCodeEntry `inject:"false"` smsNotifyClient *ali.SMSNotifyClient } // prepare func (m *users) Prepare() error { m.lastSentAtByPhone = make(map[string]time.Time) m.codeByPhone = make(map[string]phoneCodeEntry) return nil } // List returns a paginated list of users func (m *users) List( ctx context.Context, filter *dto.UserListQuery, ) (*requests.Pager, error) { filter.Pagination.Format() tbl, query := models.UserQuery.QueryContext(ctx) query = query.Order(tbl.ID.Desc()) keyword := "" if filter.Keyword != nil && *filter.Keyword != "" { keyword = strings.TrimSpace(*filter.Keyword) query = query. Where(tbl.Phone.Like(database.WrapLike(keyword))). Or(tbl.Username.Like(database.WrapLike(keyword))). Or(tbl.OpenID.Eq(keyword)) if id, err := strconv.ParseInt(keyword, 10, 64); err == nil && id > 0 { query = query.Or(tbl.ID.Eq(id)) } } if filter.OnlyBought != nil && *filter.OnlyBought { // 仅返回“购买数量>0”的用户:通过 JOIN user_posts 做存在性过滤。 // 注意:FindByPage 内部用 Count(),在 GROUP BY 场景下可能不准确,这里改为手动 Count(DISTINCT users.id)。 tblUserPost, _ := models.UserPostQuery.QueryContext(ctx) query = query.Join(tblUserPost, tbl.ID.EqCol(tblUserPost.UserID)).Group(tbl.ID) offset := int(filter.Pagination.Offset()) limit := int(filter.Pagination.Limit) items, err := query.Offset(offset).Limit(limit).Find() if err != nil { return nil, errors.Wrap(err, "query users error") } db := _db.WithContext(ctx).Model(&models.User{}). Joins("JOIN user_posts ON user_posts.user_id = users.id") if keyword != "" { like := database.WrapLike(keyword) args := []any{like, like} where := "(users.phone LIKE ? OR users.username LIKE ?)" if id, err := strconv.ParseInt(keyword, 10, 64); err == nil && id > 0 { where = "(users.phone LIKE ? OR users.username LIKE ? OR users.id = ?)" args = append(args, id) } db = db.Where(where, args...) } var cnt int64 if err := db.Distinct("users.id").Count(&cnt).Error; err != nil { return nil, errors.Wrap(err, "count users error") } return &requests.Pager{ Items: items, Total: cnt, Pagination: *filter.Pagination, }, nil } items, cnt, err := query.FindByPage(int(filter.Pagination.Offset()), int(filter.Pagination.Limit)) if err != nil { return nil, errors.Wrap(err, "query users error") } return &requests.Pager{ Items: items, Total: cnt, Pagination: *filter.Pagination, }, nil } // BoughtStatistics 获取指定用户 ID 的购买作品数量(仅统计 user_posts 记录数)。 func (m *users) BoughtStatistics(ctx context.Context, userIDs []int64) (map[int64]int64, error) { if len(userIDs) == 0 { return map[int64]int64{}, nil } // 管理端用户列表需要展示购买数量;这里用 group by 聚合,避免 N+1。 tbl, query := models.UserPostQuery.QueryContext(ctx) var items []struct { Count int64 UserID int64 } if err := query. Select( tbl.UserID.Count().As("count"), tbl.UserID, ). Where(tbl.UserID.In(userIDs...)). Group(tbl.UserID). Scan(&items); err != nil { return nil, err } result := make(map[int64]int64, len(items)) for _, item := range items { result[item.UserID] = item.Count } return result, nil } // PostList returns a paginated list of posts for a user func (m *users) PostList(ctx context.Context, userId int64, filter *dto.PostListQuery) (*requests.Pager, error) { filter.Format() tbl, query := models.UserPostQuery.QueryContext(ctx) query = query.Order(tbl.CreatedAt.Desc()) pagePosts, cnt, err := query.Select(tbl.PostID). Where(tbl.UserID.Eq(userId)). FindByPage(int(filter.Offset()), int(filter.Limit)) if err != nil { return nil, err } postIds := lo.Map(pagePosts, func(item *models.UserPost, _ int) int64 { return item.PostID }) postTbl, postQuery := models.PostQuery.QueryContext(ctx) items, err := postQuery.Where(postTbl.ID.In(postIds...)).Find() if err != nil { return nil, err } itemMap := lo.KeyBy(items, func(item *models.Post) int64 { return item.ID }) tmpItems := []*models.Post{} for _, id := range postIds { if i, ok := itemMap[id]; ok { tmpItems = append(tmpItems, i) } } return &requests.Pager{ Items: tmpItems, Total: cnt, Pagination: *filter.Pagination, }, nil } // GetUsersMapByIDs func (m *users) GetUsersMapByIDs(ctx context.Context, ids []int64) (map[int64]*models.User, error) { if len(ids) == 0 { return nil, nil } tbl, query := models.UserQuery.QueryContext(ctx) items, err := query.Where(tbl.ID.In(ids...)).Find() if err != nil { return nil, errors.Wrapf(err, "failed to get users by ids:%v", ids) } return lo.KeyBy(items, func(item *models.User) int64 { return item.ID }), nil } // BatchCheckHasBought checks if the user has bought the given post IDs func (m *users) BatchCheckHasBought(ctx context.Context, userId int64, postIDs []int64) (map[int64]bool, error) { tbl, query := models.UserPostQuery.QueryContext(ctx) userPosts, err := query. Where( tbl.UserID.Eq(userId), tbl.PostID.In(postIDs...), ). Find() if err != nil { return nil, errors.Wrapf(err, "check user has bought failed, user_id: %d, post_ids: %+v", userId, postIDs) } result := make(map[int64]bool) for _, postID := range postIDs { result[postID] = false } for _, post := range userPosts { result[post.PostID] = true } return result, nil } // HasBought func (m *users) HasBought(ctx context.Context, userID, postID int64) (bool, error) { tbl, query := models.UserPostQuery.QueryContext(ctx) cnt, err := query. Where( tbl.UserID.Eq(userID), tbl.PostID.Eq(postID), ). Count() if err != nil { return false, errors.Wrap(err, "failed to check user bought") } return cnt > 0, nil } // SetUsername func (m *users) SetUsername(ctx context.Context, userID int64, username string) error { tbl, query := models.UserQuery.QueryContext(ctx) _, err := query. Where( tbl.ID.Eq(userID), ). Update(tbl.Username, username) if err != nil { return err } return nil } // BuyPosts func (m *users) BuyPosts(ctx context.Context, userID, postID, price int64) error { model := &models.UserPost{UserID: userID, PostID: postID, Price: price} return model.Create(ctx) } // RevokePosts func (m *users) RevokeUserPosts(ctx context.Context, userID, postID int64) error { tbl, query := models.UserPostQuery.QueryContext(ctx) _, err := query.Where( tbl.UserID.Eq(userID), tbl.PostID.Eq(postID), ).Delete() return err } // FindByID func (m *users) FindByID(ctx context.Context, userID int64) (*models.User, error) { tbl, query := models.UserQuery.QueryContext(ctx) user, err := query.Where(tbl.ID.Eq(userID)).First() if err != nil { return nil, errors.Wrapf(err, "find by id failed, id: %d", userID) } return user, nil } // SetBalance func (m *users) SetBalance(ctx context.Context, userID, balance int64) error { tbl, query := models.UserQuery.QueryContext(ctx) _, err := query.Where(tbl.ID.Eq(userID)).Update(tbl.Balance, balance) return err } // AddBalance adds the given amount to the user's balance func (m *users) AddBalance(ctx context.Context, userID, amount int64) error { tbl, query := models.UserQuery.QueryContext(ctx) _, err := query.Where(tbl.ID.Eq(userID)).Inc(tbl.Balance, amount) return err } // Desc desc the given amount to the user's balance func (m *users) DescBalance(ctx context.Context, userID, amount int64) error { user, err := m.FindByID(ctx, userID) if err != nil { return err } if user.Balance < amount { return errors.New("balance not enough") } tbl, query := models.UserQuery.QueryContext(ctx) _, err = query.Where(tbl.ID.Eq(userID)).Inc(tbl.Balance, -amount) return err } // Count func (m *users) Count(ctx context.Context, conds ...gen.Condition) (int64, error) { _, query := models.UserQuery.QueryContext(ctx) if len(conds) > 0 { query = query.Where(conds...) } return query.Count() } // FindByPhone func (m *users) FindByPhone(ctx context.Context, phone string) (*models.User, error) { tbl, query := models.UserQuery.QueryContext(ctx) return query.Where(tbl.Phone.Eq(phone)).First() } type phoneCodeEntry struct { code string expiresAt time.Time } func (m *users) ensurePhoneAuthMaps() { if m.lastSentAtByPhone == nil { m.lastSentAtByPhone = make(map[string]time.Time) } if m.codeByPhone == nil { m.codeByPhone = make(map[string]phoneCodeEntry) } } func (m *users) normalizePhone(phone string) string { return strings.TrimSpace(phone) } func (m *users) isSendTooFrequent(now time.Time, phone string) bool { last, ok := m.lastSentAtByPhone[phone] if !ok { return false } // 前端倒计时 60s;后端用 58s 做保护,避免客户端/服务端时间误差导致“刚到 60s 仍被拒绝”。 return now.Sub(last) < 58*time.Second } func (m *users) gen4Digits() (string, error) { // 0000-9999 n, err := rand.Int(rand.Reader, big.NewInt(10000)) if err != nil { return "", errors.Wrap(err, "failed to generate sms code") } return fmt.Sprintf("%04d", n.Int64()), nil } // SendPhoneCode 发送短信验证码(内存限流:同一手机号 58s 内仅允许发送一次;验证码 5 分钟过期)。 func (m *users) SendPhoneCode(ctx context.Context, phone string) error { phone = m.normalizePhone(phone) if phone == "" { return errors.New("手机号不能为空") } // 前置校验:手机号必须已注册 _, err := m.FindByPhone(ctx, phone) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return errors.New("手机号未注册,请联系管理员开通") } return errors.Wrap(err, "failed to find user by phone") } now := time.Now() m.mu.Lock() defer m.mu.Unlock() m.ensurePhoneAuthMaps() if m.isSendTooFrequent(now, phone) { return errors.New("验证码发送过于频繁,请稍后再试") } code, err := m.smsNotifyClient.SendTo(phone) if err != nil { log.WithError(err).Errorf("SendPhoneCode to %s", phone) return err } // 生成/覆盖验证码:同一手机号再次发送时以最新验证码为准 expiresAt := now.Add(5 * time.Minute) m.codeByPhone[phone] = phoneCodeEntry{ code: code, expiresAt: expiresAt, } m.lastSentAtByPhone[phone] = now // log phone and code log.Infof("SendPhoneCode to %s: code=%s", phone, code) if _db != nil { // 记录短信验证码发送日志(用于后台审计与排查)。 _ = _db.WithContext(ctx).Create(&models.SmsCodeSend{ Phone: phone, Code: code, SentAt: now, ExpiresAt: expiresAt, }).Error } return nil } // ValidatePhoneCode 校验短信验证码,成功后删除验证码并返回用户信息(用于生成 token)。 func (m *users) ValidatePhoneCode(ctx context.Context, phone, code string) (*models.User, error) { phone = m.normalizePhone(phone) code = strings.TrimSpace(code) if phone == "" { return nil, errors.New("手机号不能为空") } if code == "" { return nil, errors.New("验证码不能为空") } // 先确认手机号存在,避免对不存在手机号暴露验证码状态 user, err := m.FindByPhone(ctx, phone) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("手机号未注册,请联系管理员开通") } return nil, errors.Wrap(err, "failed to find user by phone") } now := time.Now() m.mu.Lock() defer m.mu.Unlock() m.ensurePhoneAuthMaps() entry, ok := m.codeByPhone[phone] if !ok { return nil, errors.New("验证码已过期或不存在") } if now.After(entry.expiresAt) { delete(m.codeByPhone, phone) return nil, errors.New("验证码已过期或不存在") } if entry.code != code { return nil, errors.New("验证码错误") } // 验证通过后删除验证码,防止重复验证(防重放)。 delete(m.codeByPhone, phone) return user, nil } // SetPhone 管理端设置用户手机号。 func (m *users) SetPhone(ctx context.Context, userID int64, phone string) error { phone = strings.TrimSpace(phone) if phone == "" { return errors.New("手机号不能为空") } if len(phone) != 11 { return errors.New("手机号必须为 11 位数字") } for _, r := range phone { if r < '0' || r > '9' { return errors.New("手机号必须为 11 位数字") } } // 业务约束:手机号建议全局唯一(至少在本系统内),避免登录/验证身份混淆。 tbl, query := models.UserQuery.QueryContext(ctx) _, err := query.Where(tbl.Phone.Eq(phone), tbl.ID.Neq(userID)).First() if err == nil { return errors.New("手机号已被其他用户占用") } if !errors.Is(err, gorm.ErrRecordNotFound) { return errors.Wrap(err, "failed to check phone uniqueness") } // 仅更新 phone 字段,避免覆盖其它字段。 if _, err := query.Where(tbl.ID.Eq(userID)).Update(tbl.Phone, phone); err != nil { return errors.Wrap(err, "failed to update user phone") } return nil } // CreateByPhone 管理端通过手机号创建新用户(手机号必填,昵称可选)。 func (m *users) CreateByPhone(ctx context.Context, phone, username string) (*models.User, error) { phone = strings.TrimSpace(phone) if phone == "" { return nil, errors.New("手机号不能为空") } if len(phone) != 11 { return nil, errors.New("手机号必须为 11 位数字") } for _, r := range phone { if r < '0' || r > '9' { return nil, errors.New("手机号必须为 11 位数字") } } _, err := m.FindByPhone(ctx, phone) if err == nil { return nil, errors.New("手机号已被其他用户占用") } if !errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.Wrap(err, "failed to check phone uniqueness") } openID := "phone:" + phone tbl, query := models.UserQuery.QueryContext(ctx) if _, err := query.Where(tbl.OpenID.Eq(openID)).First(); err == nil { return nil, errors.New("用户已存在") } else if !errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.Wrap(err, "failed to check open_id uniqueness") } username = strings.TrimSpace(username) if username == "" { username = "用户" + phone[len(phone)-4:] } user := &models.User{ OpenID: openID, Username: username, Phone: phone, Balance: 0, Avatar: "", } if err := _db.WithContext(ctx).Omit("metas", "auth_token").Create(user).Error; err != nil { return nil, errors.Wrap(err, "failed to create user") } return user, nil }