package services import ( "context" "crypto/rand" "fmt" "math/big" "strings" "sync" "time" "quyun/v2/app/requests" "quyun/v2/database/models" "github.com/pkg/errors" "github.com/samber/lo" log "github.com/sirupsen/logrus" "go.ipao.vip/gen" "gorm.io/gorm" ) // @provider type users struct { mu sync.Mutex `inject:"false"` lastSentAtByPhone map[string]time.Time `inject:"false"` codeByPhone map[string]phoneCodeEntry `inject:"false"` } // prepare func (m *users) Prepare() error { m.lastSentAtByPhone = make(map[string]time.Time) m.codeByPhone = make(map[string]phoneCodeEntry) return nil } // List returns a paginated list of users func (m *users) List( ctx context.Context, pagination *requests.Pagination, conds ...gen.Condition, ) (*requests.Pager, error) { pagination.Format() _, query := models.UserQuery.QueryContext(ctx) items, cnt, err := query.Where(conds...).FindByPage(int(pagination.Offset()), int(pagination.Limit)) if err != nil { return nil, errors.Wrap(err, "query users error") } return &requests.Pager{ Items: items, Total: cnt, Pagination: *pagination, }, nil } // PostList returns a paginated list of posts for a user func (m *users) PostList( ctx context.Context, userId int64, pagination *requests.Pagination, conds ...gen.Condition, ) (*requests.Pager, error) { pagination.Format() // stmt := SELECT(tbl.AllColumns). // FROM(tbl. // RIGHT_JOIN( // tblUserPosts, // tblUserPosts.PostID.EQ(tbl.ID), // ), // ). // WHERE(CondTrue(cond...)). // ORDER_BY(tblUserPosts.ID.DESC()). // LIMIT(pagination.Limit). // OFFSET(pagination.Offset) // m.log().Infof("sql: %s", stmt.DebugSql()) tbl, query := models.UserPostQuery.QueryContext(ctx) pagePosts, cnt, err := query.Select(tbl.PostID). Where(tbl.UserID.Eq(userId)). FindByPage(int(pagination.Offset()), int(pagination.Limit)) if err != nil { return nil, err } postIds := lo.Map(pagePosts, func(item *models.UserPost, _ int) int64 { return item.PostID }) postTbl, postQuery := models.PostQuery.QueryContext(ctx) items, err := postQuery.Where(postTbl.ID.In(postIds...)).Find() if err != nil { return nil, err } return &requests.Pager{ Items: items, Total: cnt, Pagination: *pagination, }, nil } // GetUsersMapByIDs func (m *users) GetUsersMapByIDs(ctx context.Context, ids []int64) (map[int64]*models.User, error) { if len(ids) == 0 { return nil, nil } tbl, query := models.UserQuery.QueryContext(ctx) items, err := query.Where(tbl.ID.In(ids...)).Find() if err != nil { return nil, errors.Wrapf(err, "failed to get users by ids:%v", ids) } return lo.KeyBy(items, func(item *models.User) int64 { return item.ID }), nil } // BatchCheckHasBought checks if the user has bought the given post IDs func (m *users) BatchCheckHasBought(ctx context.Context, userId int64, postIDs []int64) (map[int64]bool, error) { tbl, query := models.UserPostQuery.QueryContext(ctx) userPosts, err := query. Where( tbl.UserID.Eq(userId), tbl.PostID.In(postIDs...), ). Find() if err != nil { return nil, errors.Wrapf(err, "check user has bought failed, user_id: %d, post_ids: %+v", userId, postIDs) } result := make(map[int64]bool) for _, postID := range postIDs { result[postID] = false } for _, post := range userPosts { result[post.PostID] = true } return result, nil } // HasBought func (m *users) HasBought(ctx context.Context, userID, postID int64) (bool, error) { tbl, query := models.UserPostQuery.QueryContext(ctx) cnt, err := query. Where( tbl.UserID.Eq(userID), tbl.PostID.Eq(postID), ). Count() if err != nil { return false, errors.Wrap(err, "failed to check user bought") } return cnt > 0, nil } // SetUsername func (m *users) SetUsername(ctx context.Context, userID int64, username string) error { tbl, query := models.UserQuery.QueryContext(ctx) _, err := query. Where( tbl.ID.Eq(userID), ). Update(tbl.Username, username) if err != nil { return err } return nil } // BuyPosts func (m *users) BuyPosts(ctx context.Context, userID, postID, price int64) error { model := &models.UserPost{UserID: userID, PostID: postID, Price: price} return model.Create(ctx) } // RevokePosts func (m *users) RevokeUserPosts(ctx context.Context, userID, postID int64) error { tbl, query := models.UserPostQuery.QueryContext(ctx) _, err := query.Where( tbl.UserID.Eq(userID), tbl.PostID.Eq(postID), ).Delete() return err } // FindByID func (m *users) FindByID(ctx context.Context, userID int64) (*models.User, error) { tbl, query := models.UserQuery.QueryContext(ctx) user, err := query.Where(tbl.ID.Eq(userID)).First() if err != nil { return nil, errors.Wrapf(err, "find by id failed, id: %d", userID) } return user, nil } // SetBalance func (m *users) SetBalance(ctx context.Context, userID, balance int64) error { tbl, query := models.UserQuery.QueryContext(ctx) _, err := query.Where(tbl.ID.Eq(userID)).Update(tbl.Balance, balance) return err } // AddBalance adds the given amount to the user's balance func (m *users) AddBalance(ctx context.Context, userID, amount int64) error { tbl, query := models.UserQuery.QueryContext(ctx) _, err := query.Where(tbl.ID.Eq(userID)).Inc(tbl.Balance, amount) return err } // Desc desc the given amount to the user's balance func (m *users) DescBalance(ctx context.Context, userID, amount int64) error { user, err := m.FindByID(ctx, userID) if err != nil { return err } if user.Balance < amount { return errors.New("balance not enough") } tbl, query := models.UserQuery.QueryContext(ctx) _, err = query.Where(tbl.ID.Eq(userID)).Inc(tbl.Balance, -amount) return err } // Count func (m *users) Count(ctx context.Context, conds ...gen.Condition) (int64, error) { _, query := models.UserQuery.QueryContext(ctx) if len(conds) > 0 { query = query.Where(conds...) } return query.Count() } // FindByPhone func (m *users) FindByPhone(ctx context.Context, phone string) (*models.User, error) { tbl, query := models.UserQuery.QueryContext(ctx) return query.Where(tbl.Phone.Eq(phone)).First() } type phoneCodeEntry struct { code string expiresAt time.Time } func (m *users) ensurePhoneAuthMaps() { if m.lastSentAtByPhone == nil { m.lastSentAtByPhone = make(map[string]time.Time) } if m.codeByPhone == nil { m.codeByPhone = make(map[string]phoneCodeEntry) } } func (m *users) normalizePhone(phone string) string { return strings.TrimSpace(phone) } func (m *users) isSendTooFrequent(now time.Time, phone string) bool { last, ok := m.lastSentAtByPhone[phone] if !ok { return false } // 前端倒计时 60s;后端用 58s 做保护,避免客户端/服务端时间误差导致“刚到 60s 仍被拒绝”。 return now.Sub(last) < 58*time.Second } func (m *users) gen4Digits() (string, error) { // 0000-9999 n, err := rand.Int(rand.Reader, big.NewInt(10000)) if err != nil { return "", errors.Wrap(err, "failed to generate sms code") } return fmt.Sprintf("%04d", n.Int64()), nil } // SendPhoneCode 发送短信验证码(内存限流:同一手机号 58s 内仅允许发送一次;验证码 5 分钟过期)。 func (m *users) SendPhoneCode(ctx context.Context, phone string) error { phone = m.normalizePhone(phone) if phone == "" { return errors.New("手机号不能为空") } // 前置校验:手机号必须已注册 _, err := m.FindByPhone(ctx, phone) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return errors.New("手机号未注册,请联系管理员开通") } return errors.Wrap(err, "failed to find user by phone") } now := time.Now() m.mu.Lock() defer m.mu.Unlock() m.ensurePhoneAuthMaps() if m.isSendTooFrequent(now, phone) { return errors.New("验证码发送过于频繁,请稍后再试") } code, err := m.gen4Digits() if err != nil { return err } // 生成/覆盖验证码:同一手机号再次发送时以最新验证码为准 m.codeByPhone[phone] = phoneCodeEntry{ code: code, expiresAt: now.Add(5 * time.Minute), } m.lastSentAtByPhone[phone] = now // log phone and code log.Infof("SendPhoneCode to %s: code=%s", phone, code) // TODO: 这里应调用实际短信服务商发送 code;当前仅做内存发码与校验支撑。 return nil } // ValidatePhoneCode 校验短信验证码,成功后删除验证码并返回用户信息(用于生成 token)。 func (m *users) ValidatePhoneCode(ctx context.Context, phone, code string) (*models.User, error) { phone = m.normalizePhone(phone) code = strings.TrimSpace(code) if phone == "" { return nil, errors.New("手机号不能为空") } if code == "" { return nil, errors.New("验证码不能为空") } // 先确认手机号存在,避免对不存在手机号暴露验证码状态 user, err := m.FindByPhone(ctx, phone) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("手机号未注册,请联系管理员开通") } return nil, errors.Wrap(err, "failed to find user by phone") } now := time.Now() m.mu.Lock() defer m.mu.Unlock() m.ensurePhoneAuthMaps() entry, ok := m.codeByPhone[phone] if !ok { return nil, errors.New("验证码已过期或不存在") } if now.After(entry.expiresAt) { delete(m.codeByPhone, phone) return nil, errors.New("验证码已过期或不存在") } if entry.code != code { return nil, errors.New("验证码错误") } // 验证通过后删除验证码,防止重复验证(防重放)。 delete(m.codeByPhone, phone) return user, nil }