Files
quyun/backend_v1/app/services/users.go
Rogee 6c9063a2c3
Some checks failed
build quyun / Build (push) Failing after 1m26s
feat: 添加手动设置短信验证码功能及相关前端支持
2025-12-23 23:59:01 +08:00

600 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package services
import (
"context"
"crypto/rand"
"fmt"
"math/big"
"strconv"
"strings"
"sync"
"time"
"quyun/v2/app/http/dto"
"quyun/v2/app/requests"
"quyun/v2/database"
"quyun/v2/database/models"
"quyun/v2/providers/ali"
"github.com/pkg/errors"
"github.com/samber/lo"
log "github.com/sirupsen/logrus"
"go.ipao.vip/gen"
"gorm.io/gorm"
)
// @provider
type users struct {
mu sync.Mutex `inject:"false"`
lastSentAtByPhone map[string]time.Time `inject:"false"`
codeByPhone map[string]phoneCodeEntry `inject:"false"`
smsNotifyClient *ali.SMSNotifyClient
}
// prepare
func (m *users) Prepare() error {
m.lastSentAtByPhone = make(map[string]time.Time)
m.codeByPhone = make(map[string]phoneCodeEntry)
return nil
}
// List returns a paginated list of users
func (m *users) List(
ctx context.Context,
filter *dto.UserListQuery,
) (*requests.Pager, error) {
filter.Pagination.Format()
tbl, query := models.UserQuery.QueryContext(ctx)
query = query.Order(tbl.ID.Desc())
keyword := ""
if filter.Keyword != nil && *filter.Keyword != "" {
keyword = strings.TrimSpace(*filter.Keyword)
query = query.
Where(tbl.Phone.Like(database.WrapLike(keyword))).
Or(tbl.Username.Like(database.WrapLike(keyword))).
Or(tbl.OpenID.Eq(keyword))
if id, err := strconv.ParseInt(keyword, 10, 64); err == nil && id > 0 {
query = query.Or(tbl.ID.Eq(id))
}
}
if filter.OnlyBought != nil && *filter.OnlyBought {
// 仅返回“购买数量>0”的用户通过 JOIN user_posts 做存在性过滤。
// 注意FindByPage 内部用 Count(),在 GROUP BY 场景下可能不准确,这里改为手动 Count(DISTINCT users.id)。
tblUserPost, _ := models.UserPostQuery.QueryContext(ctx)
query = query.Join(tblUserPost, tbl.ID.EqCol(tblUserPost.UserID)).Group(tbl.ID)
offset := int(filter.Pagination.Offset())
limit := int(filter.Pagination.Limit)
items, err := query.Offset(offset).Limit(limit).Find()
if err != nil {
return nil, errors.Wrap(err, "query users error")
}
db := _db.WithContext(ctx).Model(&models.User{}).
Joins("JOIN user_posts ON user_posts.user_id = users.id")
if keyword != "" {
like := database.WrapLike(keyword)
args := []any{like, like}
where := "(users.phone LIKE ? OR users.username LIKE ?)"
if id, err := strconv.ParseInt(keyword, 10, 64); err == nil && id > 0 {
where = "(users.phone LIKE ? OR users.username LIKE ? OR users.id = ?)"
args = append(args, id)
}
db = db.Where(where, args...)
}
var cnt int64
if err := db.Distinct("users.id").Count(&cnt).Error; err != nil {
return nil, errors.Wrap(err, "count users error")
}
return &requests.Pager{
Items: items,
Total: cnt,
Pagination: *filter.Pagination,
}, nil
}
items, cnt, err := query.FindByPage(int(filter.Pagination.Offset()), int(filter.Pagination.Limit))
if err != nil {
return nil, errors.Wrap(err, "query users error")
}
return &requests.Pager{
Items: items,
Total: cnt,
Pagination: *filter.Pagination,
}, nil
}
// BoughtStatistics 获取指定用户 ID 的购买作品数量(仅统计 user_posts 记录数)。
func (m *users) BoughtStatistics(ctx context.Context, userIDs []int64) (map[int64]int64, error) {
if len(userIDs) == 0 {
return map[int64]int64{}, nil
}
// 管理端用户列表需要展示购买数量;这里用 group by 聚合,避免 N+1。
tbl, query := models.UserPostQuery.QueryContext(ctx)
var items []struct {
Count int64
UserID int64
}
if err := query.
Select(
tbl.UserID.Count().As("count"),
tbl.UserID,
).
Where(tbl.UserID.In(userIDs...)).
Group(tbl.UserID).
Scan(&items); err != nil {
return nil, err
}
result := make(map[int64]int64, len(items))
for _, item := range items {
result[item.UserID] = item.Count
}
return result, nil
}
// PostList returns a paginated list of posts for a user
func (m *users) PostList(ctx context.Context, userId int64, filter *dto.PostListQuery) (*requests.Pager, error) {
filter.Format()
tbl, query := models.UserPostQuery.QueryContext(ctx)
query = query.Order(tbl.CreatedAt.Desc())
pagePosts, cnt, err := query.Select(tbl.PostID).
Where(tbl.UserID.Eq(userId)).
FindByPage(int(filter.Offset()), int(filter.Limit))
if err != nil {
return nil, err
}
postIds := lo.Map(pagePosts, func(item *models.UserPost, _ int) int64 { return item.PostID })
postTbl, postQuery := models.PostQuery.QueryContext(ctx)
items, err := postQuery.Where(postTbl.ID.In(postIds...)).Find()
if err != nil {
return nil, err
}
itemMap := lo.KeyBy(items, func(item *models.Post) int64 { return item.ID })
tmpItems := []*models.Post{}
for _, id := range postIds {
if i, ok := itemMap[id]; ok {
tmpItems = append(tmpItems, i)
}
}
return &requests.Pager{
Items: tmpItems,
Total: cnt,
Pagination: *filter.Pagination,
}, nil
}
// GetUsersMapByIDs
func (m *users) GetUsersMapByIDs(ctx context.Context, ids []int64) (map[int64]*models.User, error) {
if len(ids) == 0 {
return nil, nil
}
tbl, query := models.UserQuery.QueryContext(ctx)
items, err := query.Where(tbl.ID.In(ids...)).Find()
if err != nil {
return nil, errors.Wrapf(err, "failed to get users by ids:%v", ids)
}
return lo.KeyBy(items, func(item *models.User) int64 {
return item.ID
}), nil
}
// BatchCheckHasBought checks if the user has bought the given post IDs
func (m *users) BatchCheckHasBought(ctx context.Context, userId int64, postIDs []int64) (map[int64]bool, error) {
tbl, query := models.UserPostQuery.QueryContext(ctx)
userPosts, err := query.
Where(
tbl.UserID.Eq(userId),
tbl.PostID.In(postIDs...),
).
Find()
if err != nil {
return nil, errors.Wrapf(err, "check user has bought failed, user_id: %d, post_ids: %+v", userId, postIDs)
}
result := make(map[int64]bool)
for _, postID := range postIDs {
result[postID] = false
}
for _, post := range userPosts {
result[post.PostID] = true
}
return result, nil
}
// HasBought
func (m *users) HasBought(ctx context.Context, userID, postID int64) (bool, error) {
tbl, query := models.UserPostQuery.QueryContext(ctx)
cnt, err := query.
Where(
tbl.UserID.Eq(userID),
tbl.PostID.Eq(postID),
).
Count()
if err != nil {
return false, errors.Wrap(err, "failed to check user bought")
}
return cnt > 0, nil
}
// SetUsername
func (m *users) SetUsername(ctx context.Context, userID int64, username string) error {
tbl, query := models.UserQuery.QueryContext(ctx)
_, err := query.
Where(
tbl.ID.Eq(userID),
).
Update(tbl.Username, username)
if err != nil {
return err
}
return nil
}
// BuyPosts
func (m *users) BuyPosts(ctx context.Context, userID, postID, price int64) error {
model := &models.UserPost{UserID: userID, PostID: postID, Price: price}
return model.Create(ctx)
}
// RevokePosts
func (m *users) RevokeUserPosts(ctx context.Context, userID, postID int64) error {
tbl, query := models.UserPostQuery.QueryContext(ctx)
_, err := query.Where(
tbl.UserID.Eq(userID),
tbl.PostID.Eq(postID),
).Delete()
return err
}
// FindByID
func (m *users) FindByID(ctx context.Context, userID int64) (*models.User, error) {
tbl, query := models.UserQuery.QueryContext(ctx)
user, err := query.Where(tbl.ID.Eq(userID)).First()
if err != nil {
return nil, errors.Wrapf(err, "find by id failed, id: %d", userID)
}
return user, nil
}
// SetBalance
func (m *users) SetBalance(ctx context.Context, userID, balance int64) error {
tbl, query := models.UserQuery.QueryContext(ctx)
_, err := query.Where(tbl.ID.Eq(userID)).Update(tbl.Balance, balance)
return err
}
// AddBalance adds the given amount to the user's balance
func (m *users) AddBalance(ctx context.Context, userID, amount int64) error {
tbl, query := models.UserQuery.QueryContext(ctx)
_, err := query.Where(tbl.ID.Eq(userID)).Inc(tbl.Balance, amount)
return err
}
// Desc desc the given amount to the user's balance
func (m *users) DescBalance(ctx context.Context, userID, amount int64) error {
user, err := m.FindByID(ctx, userID)
if err != nil {
return err
}
if user.Balance < amount {
return errors.New("balance not enough")
}
tbl, query := models.UserQuery.QueryContext(ctx)
_, err = query.Where(tbl.ID.Eq(userID)).Inc(tbl.Balance, -amount)
return err
}
// Count
func (m *users) Count(ctx context.Context, conds ...gen.Condition) (int64, error) {
_, query := models.UserQuery.QueryContext(ctx)
if len(conds) > 0 {
query = query.Where(conds...)
}
return query.Count()
}
// FindByPhone
func (m *users) FindByPhone(ctx context.Context, phone string) (*models.User, error) {
tbl, query := models.UserQuery.QueryContext(ctx)
return query.Where(tbl.Phone.Eq(phone)).First()
}
type phoneCodeEntry struct {
code string
expiresAt time.Time
}
func (m *users) ensurePhoneAuthMaps() {
if m.lastSentAtByPhone == nil {
m.lastSentAtByPhone = make(map[string]time.Time)
}
if m.codeByPhone == nil {
m.codeByPhone = make(map[string]phoneCodeEntry)
}
}
func (m *users) normalizePhone(phone string) string {
return strings.TrimSpace(phone)
}
func (m *users) isSendTooFrequent(now time.Time, phone string) bool {
last, ok := m.lastSentAtByPhone[phone]
if !ok {
return false
}
// 前端倒计时 60s后端用 58s 做保护,避免客户端/服务端时间误差导致“刚到 60s 仍被拒绝”。
return now.Sub(last) < 58*time.Second
}
func (m *users) gen4Digits() (string, error) {
// 0000-9999
n, err := rand.Int(rand.Reader, big.NewInt(10000))
if err != nil {
return "", errors.Wrap(err, "failed to generate sms code")
}
return fmt.Sprintf("%04d", n.Int64()), nil
}
// SetPhoneCode 手动设置短信验证码(后台操作);默认有效期 5 分钟,不受发送频率限制。
func (m *users) SetPhoneCode(ctx context.Context, phone, code string, ttl time.Duration) (*models.SmsCodeSend, error) {
phone = m.normalizePhone(phone)
code = strings.TrimSpace(code)
if phone == "" {
return nil, errors.New("手机号不能为空")
}
if code == "" {
return nil, errors.New("验证码不能为空")
}
if len(code) != 4 {
return nil, errors.New("验证码必须为 4 位数字")
}
for _, r := range code {
if r < '0' || r > '9' {
return nil, errors.New("验证码必须为 4 位数字")
}
}
// 前置校验:手机号必须已注册
_, err := m.FindByPhone(ctx, phone)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("手机号未注册,请联系管理员开通")
}
return nil, errors.Wrap(err, "failed to find user by phone")
}
now := time.Now()
if ttl <= 0 {
ttl = 5 * time.Minute
}
expiresAt := now.Add(ttl)
m.mu.Lock()
m.ensurePhoneAuthMaps()
m.codeByPhone[phone] = phoneCodeEntry{code: code, expiresAt: expiresAt}
m.lastSentAtByPhone[phone] = now
m.mu.Unlock()
if _db == nil {
return nil, errors.New("db not initialized")
}
record := &models.SmsCodeSend{
Phone: phone,
Code: code,
SentAt: now,
ExpiresAt: expiresAt,
}
if err := _db.WithContext(ctx).Create(record).Error; err != nil {
return nil, err
}
log.Infof("SetPhoneCode to %s: code=%s", phone, code)
return record, nil
}
// SendPhoneCode 发送短信验证码(内存限流:同一手机号 58s 内仅允许发送一次;验证码 5 分钟过期)。
func (m *users) SendPhoneCode(ctx context.Context, phone string) error {
phone = m.normalizePhone(phone)
if phone == "" {
return errors.New("手机号不能为空")
}
// 前置校验:手机号必须已注册
_, err := m.FindByPhone(ctx, phone)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("手机号未注册,请联系管理员开通")
}
return errors.Wrap(err, "failed to find user by phone")
}
now := time.Now()
m.mu.Lock()
defer m.mu.Unlock()
m.ensurePhoneAuthMaps()
if m.isSendTooFrequent(now, phone) {
return errors.New("验证码发送过于频繁,请稍后再试")
}
code, err := m.smsNotifyClient.SendTo(phone)
if err != nil {
log.WithError(err).Errorf("SendPhoneCode to %s", phone)
return err
}
// 生成/覆盖验证码:同一手机号再次发送时以最新验证码为准
expiresAt := now.Add(5 * time.Minute)
m.codeByPhone[phone] = phoneCodeEntry{
code: code,
expiresAt: expiresAt,
}
m.lastSentAtByPhone[phone] = now
// log phone and code
log.Infof("SendPhoneCode to %s: code=%s", phone, code)
if _db != nil {
// 记录短信验证码发送日志(用于后台审计与排查)。
_ = _db.WithContext(ctx).Create(&models.SmsCodeSend{
Phone: phone,
Code: code,
SentAt: now,
ExpiresAt: expiresAt,
}).Error
}
return nil
}
// ValidatePhoneCode 校验短信验证码,成功后删除验证码并返回用户信息(用于生成 token
func (m *users) ValidatePhoneCode(ctx context.Context, phone, code string) (*models.User, error) {
phone = m.normalizePhone(phone)
code = strings.TrimSpace(code)
if phone == "" {
return nil, errors.New("手机号不能为空")
}
if code == "" {
return nil, errors.New("验证码不能为空")
}
// 先确认手机号存在,避免对不存在手机号暴露验证码状态
user, err := m.FindByPhone(ctx, phone)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("手机号未注册,请联系管理员开通")
}
return nil, errors.Wrap(err, "failed to find user by phone")
}
now := time.Now()
m.mu.Lock()
defer m.mu.Unlock()
m.ensurePhoneAuthMaps()
entry, ok := m.codeByPhone[phone]
if !ok {
return nil, errors.New("验证码已过期或不存在")
}
if now.After(entry.expiresAt) {
delete(m.codeByPhone, phone)
return nil, errors.New("验证码已过期或不存在")
}
if entry.code != code {
return nil, errors.New("验证码错误")
}
// 验证通过后删除验证码,防止重复验证(防重放)。
delete(m.codeByPhone, phone)
return user, nil
}
// SetPhone 管理端设置用户手机号。
func (m *users) SetPhone(ctx context.Context, userID int64, phone string) error {
phone = strings.TrimSpace(phone)
if phone == "" {
return errors.New("手机号不能为空")
}
if len(phone) != 11 {
return errors.New("手机号必须为 11 位数字")
}
for _, r := range phone {
if r < '0' || r > '9' {
return errors.New("手机号必须为 11 位数字")
}
}
// 业务约束:手机号建议全局唯一(至少在本系统内),避免登录/验证身份混淆。
tbl, query := models.UserQuery.QueryContext(ctx)
_, err := query.Where(tbl.Phone.Eq(phone), tbl.ID.Neq(userID)).First()
if err == nil {
return errors.New("手机号已被其他用户占用")
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return errors.Wrap(err, "failed to check phone uniqueness")
}
// 仅更新 phone 字段,避免覆盖其它字段。
if _, err := query.Where(tbl.ID.Eq(userID)).Update(tbl.Phone, phone); err != nil {
return errors.Wrap(err, "failed to update user phone")
}
return nil
}
// CreateByPhone 管理端通过手机号创建新用户(手机号必填,昵称可选)。
func (m *users) CreateByPhone(ctx context.Context, phone, username string) (*models.User, error) {
phone = strings.TrimSpace(phone)
if phone == "" {
return nil, errors.New("手机号不能为空")
}
if len(phone) != 11 {
return nil, errors.New("手机号必须为 11 位数字")
}
for _, r := range phone {
if r < '0' || r > '9' {
return nil, errors.New("手机号必须为 11 位数字")
}
}
_, err := m.FindByPhone(ctx, phone)
if err == nil {
return nil, errors.New("手机号已被其他用户占用")
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.Wrap(err, "failed to check phone uniqueness")
}
openID := "phone:" + phone
tbl, query := models.UserQuery.QueryContext(ctx)
if _, err := query.Where(tbl.OpenID.Eq(openID)).First(); err == nil {
return nil, errors.New("用户已存在")
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.Wrap(err, "failed to check open_id uniqueness")
}
username = strings.TrimSpace(username)
if username == "" {
username = "用户" + phone[len(phone)-4:]
}
user := &models.User{
OpenID: openID,
Username: username,
Phone: phone,
Balance: 0,
Avatar: "",
}
if err := _db.WithContext(ctx).Omit("metas", "auth_token").Create(user).Error; err != nil {
return nil, errors.Wrap(err, "failed to create user")
}
return user, nil
}