364 lines
9.4 KiB
Go
364 lines
9.4 KiB
Go
package services
|
||
|
||
import (
|
||
"context"
|
||
"crypto/rand"
|
||
"fmt"
|
||
"math/big"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"quyun/v2/app/requests"
|
||
"quyun/v2/database/models"
|
||
|
||
"github.com/pkg/errors"
|
||
"github.com/samber/lo"
|
||
log "github.com/sirupsen/logrus"
|
||
"go.ipao.vip/gen"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// @provider
|
||
type users struct {
|
||
mu sync.Mutex `inject:"false"`
|
||
|
||
lastSentAtByPhone map[string]time.Time `inject:"false"`
|
||
codeByPhone map[string]phoneCodeEntry `inject:"false"`
|
||
}
|
||
|
||
// prepare
|
||
func (m *users) Prepare() error {
|
||
m.lastSentAtByPhone = make(map[string]time.Time)
|
||
m.codeByPhone = make(map[string]phoneCodeEntry)
|
||
|
||
return nil
|
||
}
|
||
|
||
// List returns a paginated list of users
|
||
func (m *users) List(
|
||
ctx context.Context,
|
||
pagination *requests.Pagination,
|
||
conds ...gen.Condition,
|
||
) (*requests.Pager, error) {
|
||
pagination.Format()
|
||
_, query := models.UserQuery.QueryContext(ctx)
|
||
|
||
items, cnt, err := query.Where(conds...).FindByPage(int(pagination.Offset()), int(pagination.Limit))
|
||
if err != nil {
|
||
return nil, errors.Wrap(err, "query users error")
|
||
}
|
||
|
||
return &requests.Pager{
|
||
Items: items,
|
||
Total: cnt,
|
||
Pagination: *pagination,
|
||
}, nil
|
||
}
|
||
|
||
// PostList returns a paginated list of posts for a user
|
||
func (m *users) PostList(
|
||
ctx context.Context,
|
||
userId int64,
|
||
pagination *requests.Pagination,
|
||
conds ...gen.Condition,
|
||
) (*requests.Pager, error) {
|
||
pagination.Format()
|
||
// stmt := SELECT(tbl.AllColumns).
|
||
// FROM(tbl.
|
||
// RIGHT_JOIN(
|
||
// tblUserPosts,
|
||
// tblUserPosts.PostID.EQ(tbl.ID),
|
||
// ),
|
||
// ).
|
||
// WHERE(CondTrue(cond...)).
|
||
// ORDER_BY(tblUserPosts.ID.DESC()).
|
||
// LIMIT(pagination.Limit).
|
||
// OFFSET(pagination.Offset)
|
||
// m.log().Infof("sql: %s", stmt.DebugSql())
|
||
|
||
tbl, query := models.UserPostQuery.QueryContext(ctx)
|
||
pagePosts, cnt, err := query.Select(tbl.PostID).
|
||
Where(tbl.UserID.Eq(userId)).
|
||
FindByPage(int(pagination.Offset()), int(pagination.Limit))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
postIds := lo.Map(pagePosts, func(item *models.UserPost, _ int) int64 { return item.PostID })
|
||
|
||
postTbl, postQuery := models.PostQuery.QueryContext(ctx)
|
||
items, err := postQuery.Where(postTbl.ID.In(postIds...)).Find()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &requests.Pager{
|
||
Items: items,
|
||
Total: cnt,
|
||
Pagination: *pagination,
|
||
}, nil
|
||
}
|
||
|
||
// GetUsersMapByIDs
|
||
func (m *users) GetUsersMapByIDs(ctx context.Context, ids []int64) (map[int64]*models.User, error) {
|
||
if len(ids) == 0 {
|
||
return nil, nil
|
||
}
|
||
|
||
tbl, query := models.UserQuery.QueryContext(ctx)
|
||
items, err := query.Where(tbl.ID.In(ids...)).Find()
|
||
if err != nil {
|
||
return nil, errors.Wrapf(err, "failed to get users by ids:%v", ids)
|
||
}
|
||
|
||
return lo.KeyBy(items, func(item *models.User) int64 {
|
||
return item.ID
|
||
}), nil
|
||
}
|
||
|
||
// BatchCheckHasBought checks if the user has bought the given post IDs
|
||
func (m *users) BatchCheckHasBought(ctx context.Context, userId int64, postIDs []int64) (map[int64]bool, error) {
|
||
tbl, query := models.UserPostQuery.QueryContext(ctx)
|
||
userPosts, err := query.
|
||
Where(
|
||
tbl.UserID.Eq(userId),
|
||
tbl.PostID.In(postIDs...),
|
||
).
|
||
Find()
|
||
if err != nil {
|
||
return nil, errors.Wrapf(err, "check user has bought failed, user_id: %d, post_ids: %+v", userId, postIDs)
|
||
}
|
||
|
||
result := make(map[int64]bool)
|
||
for _, postID := range postIDs {
|
||
result[postID] = false
|
||
}
|
||
|
||
for _, post := range userPosts {
|
||
result[post.PostID] = true
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// HasBought
|
||
func (m *users) HasBought(ctx context.Context, userID, postID int64) (bool, error) {
|
||
tbl, query := models.UserPostQuery.QueryContext(ctx)
|
||
cnt, err := query.
|
||
Where(
|
||
tbl.UserID.Eq(userID),
|
||
tbl.PostID.Eq(postID),
|
||
).
|
||
Count()
|
||
if err != nil {
|
||
return false, errors.Wrap(err, "failed to check user bought")
|
||
}
|
||
return cnt > 0, nil
|
||
}
|
||
|
||
// SetUsername
|
||
func (m *users) SetUsername(ctx context.Context, userID int64, username string) error {
|
||
tbl, query := models.UserQuery.QueryContext(ctx)
|
||
_, err := query.
|
||
Where(
|
||
tbl.ID.Eq(userID),
|
||
).
|
||
Update(tbl.Username, username)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// BuyPosts
|
||
func (m *users) BuyPosts(ctx context.Context, userID, postID, price int64) error {
|
||
model := &models.UserPost{UserID: userID, PostID: postID, Price: price}
|
||
return model.Create(ctx)
|
||
}
|
||
|
||
// RevokePosts
|
||
func (m *users) RevokeUserPosts(ctx context.Context, userID, postID int64) error {
|
||
tbl, query := models.UserPostQuery.QueryContext(ctx)
|
||
_, err := query.Where(
|
||
tbl.UserID.Eq(userID),
|
||
tbl.PostID.Eq(postID),
|
||
).Delete()
|
||
return err
|
||
}
|
||
|
||
// FindByID
|
||
func (m *users) FindByID(ctx context.Context, userID int64) (*models.User, error) {
|
||
tbl, query := models.UserQuery.QueryContext(ctx)
|
||
user, err := query.Where(tbl.ID.Eq(userID)).First()
|
||
if err != nil {
|
||
return nil, errors.Wrapf(err, "find by id failed, id: %d", userID)
|
||
}
|
||
return user, nil
|
||
}
|
||
|
||
// SetBalance
|
||
func (m *users) SetBalance(ctx context.Context, userID, balance int64) error {
|
||
tbl, query := models.UserQuery.QueryContext(ctx)
|
||
_, err := query.Where(tbl.ID.Eq(userID)).Update(tbl.Balance, balance)
|
||
|
||
return err
|
||
}
|
||
|
||
// AddBalance adds the given amount to the user's balance
|
||
func (m *users) AddBalance(ctx context.Context, userID, amount int64) error {
|
||
tbl, query := models.UserQuery.QueryContext(ctx)
|
||
_, err := query.Where(tbl.ID.Eq(userID)).Inc(tbl.Balance, amount)
|
||
return err
|
||
}
|
||
|
||
// Desc desc the given amount to the user's balance
|
||
func (m *users) DescBalance(ctx context.Context, userID, amount int64) error {
|
||
user, err := m.FindByID(ctx, userID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if user.Balance < amount {
|
||
return errors.New("balance not enough")
|
||
}
|
||
tbl, query := models.UserQuery.QueryContext(ctx)
|
||
_, err = query.Where(tbl.ID.Eq(userID)).Inc(tbl.Balance, -amount)
|
||
return err
|
||
}
|
||
|
||
// Count
|
||
func (m *users) Count(ctx context.Context, conds ...gen.Condition) (int64, error) {
|
||
_, query := models.UserQuery.QueryContext(ctx)
|
||
if len(conds) > 0 {
|
||
query = query.Where(conds...)
|
||
}
|
||
return query.Count()
|
||
}
|
||
|
||
// FindByPhone
|
||
func (m *users) FindByPhone(ctx context.Context, phone string) (*models.User, error) {
|
||
tbl, query := models.UserQuery.QueryContext(ctx)
|
||
return query.Where(tbl.Phone.Eq(phone)).First()
|
||
}
|
||
|
||
type phoneCodeEntry struct {
|
||
code string
|
||
expiresAt time.Time
|
||
}
|
||
|
||
func (m *users) ensurePhoneAuthMaps() {
|
||
if m.lastSentAtByPhone == nil {
|
||
m.lastSentAtByPhone = make(map[string]time.Time)
|
||
}
|
||
if m.codeByPhone == nil {
|
||
m.codeByPhone = make(map[string]phoneCodeEntry)
|
||
}
|
||
}
|
||
|
||
func (m *users) normalizePhone(phone string) string {
|
||
return strings.TrimSpace(phone)
|
||
}
|
||
|
||
func (m *users) isSendTooFrequent(now time.Time, phone string) bool {
|
||
last, ok := m.lastSentAtByPhone[phone]
|
||
if !ok {
|
||
return false
|
||
}
|
||
// 前端倒计时 60s;后端用 58s 做保护,避免客户端/服务端时间误差导致“刚到 60s 仍被拒绝”。
|
||
return now.Sub(last) < 58*time.Second
|
||
}
|
||
|
||
func (m *users) gen4Digits() (string, error) {
|
||
// 0000-9999
|
||
n, err := rand.Int(rand.Reader, big.NewInt(10000))
|
||
if err != nil {
|
||
return "", errors.Wrap(err, "failed to generate sms code")
|
||
}
|
||
return fmt.Sprintf("%04d", n.Int64()), nil
|
||
}
|
||
|
||
// SendPhoneCode 发送短信验证码(内存限流:同一手机号 58s 内仅允许发送一次;验证码 5 分钟过期)。
|
||
func (m *users) SendPhoneCode(ctx context.Context, phone string) error {
|
||
phone = m.normalizePhone(phone)
|
||
if phone == "" {
|
||
return errors.New("手机号不能为空")
|
||
}
|
||
|
||
// 前置校验:手机号必须已注册
|
||
_, err := m.FindByPhone(ctx, phone)
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return errors.New("手机号未注册,请联系管理员开通")
|
||
}
|
||
return errors.Wrap(err, "failed to find user by phone")
|
||
}
|
||
|
||
now := time.Now()
|
||
|
||
m.mu.Lock()
|
||
defer m.mu.Unlock()
|
||
m.ensurePhoneAuthMaps()
|
||
|
||
if m.isSendTooFrequent(now, phone) {
|
||
return errors.New("验证码发送过于频繁,请稍后再试")
|
||
}
|
||
|
||
code, err := m.gen4Digits()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 生成/覆盖验证码:同一手机号再次发送时以最新验证码为准
|
||
m.codeByPhone[phone] = phoneCodeEntry{
|
||
code: code,
|
||
expiresAt: now.Add(5 * time.Minute),
|
||
}
|
||
m.lastSentAtByPhone[phone] = now
|
||
// log phone and code
|
||
log.Infof("SendPhoneCode to %s: code=%s", phone, code)
|
||
|
||
// TODO: 这里应调用实际短信服务商发送 code;当前仅做内存发码与校验支撑。
|
||
return nil
|
||
}
|
||
|
||
// ValidatePhoneCode 校验短信验证码,成功后删除验证码并返回用户信息(用于生成 token)。
|
||
func (m *users) ValidatePhoneCode(ctx context.Context, phone, code string) (*models.User, error) {
|
||
phone = m.normalizePhone(phone)
|
||
code = strings.TrimSpace(code)
|
||
if phone == "" {
|
||
return nil, errors.New("手机号不能为空")
|
||
}
|
||
if code == "" {
|
||
return nil, errors.New("验证码不能为空")
|
||
}
|
||
|
||
// 先确认手机号存在,避免对不存在手机号暴露验证码状态
|
||
user, err := m.FindByPhone(ctx, phone)
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, errors.New("手机号未注册,请联系管理员开通")
|
||
}
|
||
return nil, errors.Wrap(err, "failed to find user by phone")
|
||
}
|
||
|
||
now := time.Now()
|
||
|
||
m.mu.Lock()
|
||
defer m.mu.Unlock()
|
||
m.ensurePhoneAuthMaps()
|
||
|
||
entry, ok := m.codeByPhone[phone]
|
||
if !ok {
|
||
return nil, errors.New("验证码已过期或不存在")
|
||
}
|
||
if now.After(entry.expiresAt) {
|
||
delete(m.codeByPhone, phone)
|
||
return nil, errors.New("验证码已过期或不存在")
|
||
}
|
||
if entry.code != code {
|
||
return nil, errors.New("验证码错误")
|
||
}
|
||
|
||
// 验证通过后删除验证码,防止重复验证(防重放)。
|
||
delete(m.codeByPhone, phone)
|
||
|
||
return user, nil
|
||
}
|