feat: Refactor user context handling and service methods

- Updated middleware to fetch user and tenant models by ID and set them in context.
- Refactored common service methods to accept userID as a parameter instead of extracting from context.
- Modified content service methods to include userID as a parameter for better clarity and performance.
- Adjusted coupon, creator, notification, order, tenant, user, and wallet services to utilize userID directly.
- Enhanced context key constants for improved readability and maintainability.
This commit is contained in:
2025-12-30 22:49:26 +08:00
parent 619f7a69a7
commit 54de243fa1
19 changed files with 278 additions and 252 deletions

View File

@@ -22,13 +22,7 @@ type common struct {
storage *storage.Storage
}
func (s *common) Upload(ctx context.Context, file *multipart.FileHeader, typeArg string) (*common_dto.UploadResult, error) {
userID := ctx.Value(consts.CtxKeyUser)
if userID == nil {
return nil, errorx.ErrUnauthorized
}
uid := cast.ToInt64(userID)
func (s *common) Upload(ctx context.Context, userID int64, file *multipart.FileHeader, typeArg string) (*common_dto.UploadResult, error) {
// Mock Upload to S3/MinIO (Here we just generate key, actual upload handling via direct upload or stream is better)
// But this Upload endpoint accepts file. So we save it.
// We need to use storage provider to save it?
@@ -61,7 +55,7 @@ func (s *common) Upload(ctx context.Context, file *multipart.FileHeader, typeArg
url := s.GetAssetURL(objectKey)
// ... rest ...
t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.UserID.Eq(uid)).First()
t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.UserID.Eq(userID)).First()
var tid int64 = 0
if err == nil {
tid = t.ID
@@ -69,7 +63,7 @@ func (s *common) Upload(ctx context.Context, file *multipart.FileHeader, typeArg
asset := &models.MediaAsset{
TenantID: tid,
UserID: uid,
UserID: userID,
Type: consts.MediaAssetType(typeArg),
Status: consts.MediaAssetStatusUploaded,
Provider: "local",

View File

@@ -81,7 +81,7 @@ func (s *content) List(ctx context.Context, filter *content_dto.ContentListFilte
}, nil
}
func (s *content) Get(ctx context.Context, id string) (*content_dto.ContentDetail, error) {
func (s *content) Get(ctx context.Context, userID int64, id string) (*content_dto.ContentDetail, error) {
cid := cast.ToInt64(id)
// Increment Views
@@ -110,8 +110,8 @@ func (s *content) Get(ctx context.Context, id string) (*content_dto.ContentDetai
isFavorited := false
hasAccess := false
if userID := ctx.Value(consts.CtxKeyUser); userID != nil {
uid := cast.ToInt64(userID)
if userID > 0 {
uid := userID
// Interaction
isLiked, _ = models.UserContentActionQuery.WithContext(ctx).
Where(models.UserContentActionQuery.UserID.Eq(uid),
@@ -167,7 +167,7 @@ func (s *content) Get(ctx context.Context, id string) (*content_dto.ContentDetai
return detail, nil
}
func (s *content) ListComments(ctx context.Context, id string, page int) (*requests.Pager, error) {
func (s *content) ListComments(ctx context.Context, userID int64, id string, page int) (*requests.Pager, error) {
cid := cast.ToInt64(id)
tbl, q := models.CommentQuery.QueryContext(ctx)
@@ -187,8 +187,8 @@ func (s *content) ListComments(ctx context.Context, id string, page int) (*reque
// User likes
likedMap := make(map[int64]bool)
if userID := ctx.Value(consts.CtxKeyUser); userID != nil {
uid := cast.ToInt64(userID)
if userID > 0 {
uid := userID
ids := make([]int64, len(list))
for i, v := range list {
ids[i] = v.ID
@@ -225,12 +225,11 @@ func (s *content) ListComments(ctx context.Context, id string, page int) (*reque
}, nil
}
func (s *content) CreateComment(ctx context.Context, id string, form *content_dto.CommentCreateForm) error {
userID := ctx.Value(consts.CtxKeyUser)
if userID == nil {
func (s *content) CreateComment(ctx context.Context, userID int64, id string, form *content_dto.CommentCreateForm) error {
if userID == 0 {
return errorx.ErrUnauthorized
}
uid := cast.ToInt64(userID)
uid := userID
cid := cast.ToInt64(id)
c, err := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.ID.Eq(cid)).First()
@@ -252,12 +251,11 @@ func (s *content) CreateComment(ctx context.Context, id string, form *content_dt
return nil
}
func (s *content) LikeComment(ctx context.Context, id string) error {
userID := ctx.Value(consts.CtxKeyUser)
if userID == nil {
func (s *content) LikeComment(ctx context.Context, userID int64, id string) error {
if userID == 0 {
return errorx.ErrUnauthorized
}
uid := cast.ToInt64(userID)
uid := userID
cmid := cast.ToInt64(id)
// Fetch comment for author
@@ -292,12 +290,11 @@ func (s *content) LikeComment(ctx context.Context, id string) error {
return nil
}
func (s *content) GetLibrary(ctx context.Context) ([]user_dto.ContentItem, error) {
userID := ctx.Value(consts.CtxKeyUser)
if userID == nil {
func (s *content) GetLibrary(ctx context.Context, userID int64) ([]user_dto.ContentItem, error) {
if userID == 0 {
return nil, errorx.ErrUnauthorized
}
uid := cast.ToInt64(userID)
uid := userID
tbl, q := models.ContentAccessQuery.QueryContext(ctx)
accessList, err := q.Where(tbl.UserID.Eq(uid), tbl.Status.Eq(consts.ContentAccessStatusActive)).Find()
@@ -334,28 +331,28 @@ func (s *content) GetLibrary(ctx context.Context) ([]user_dto.ContentItem, error
return data, nil
}
func (s *content) GetFavorites(ctx context.Context) ([]user_dto.ContentItem, error) {
return s.getInteractList(ctx, "favorite")
func (s *content) GetFavorites(ctx context.Context, userID int64) ([]user_dto.ContentItem, error) {
return s.getInteractList(ctx, userID, "favorite")
}
func (s *content) AddFavorite(ctx context.Context, contentId string) error {
return s.addInteract(ctx, contentId, "favorite")
func (s *content) AddFavorite(ctx context.Context, userID int64, contentId string) error {
return s.addInteract(ctx, userID, contentId, "favorite")
}
func (s *content) RemoveFavorite(ctx context.Context, contentId string) error {
return s.removeInteract(ctx, contentId, "favorite")
func (s *content) RemoveFavorite(ctx context.Context, userID int64, contentId string) error {
return s.removeInteract(ctx, userID, contentId, "favorite")
}
func (s *content) GetLikes(ctx context.Context) ([]user_dto.ContentItem, error) {
return s.getInteractList(ctx, "like")
func (s *content) GetLikes(ctx context.Context, userID int64) ([]user_dto.ContentItem, error) {
return s.getInteractList(ctx, userID, "like")
}
func (s *content) AddLike(ctx context.Context, contentId string) error {
return s.addInteract(ctx, contentId, "like")
func (s *content) AddLike(ctx context.Context, userID int64, contentId string) error {
return s.addInteract(ctx, userID, contentId, "like")
}
func (s *content) RemoveLike(ctx context.Context, contentId string) error {
return s.removeInteract(ctx, contentId, "like")
func (s *content) RemoveLike(ctx context.Context, userID int64, contentId string) error {
return s.removeInteract(ctx, userID, contentId, "like")
}
func (s *content) ListTopics(ctx context.Context) ([]content_dto.Topic, error) {
@@ -480,12 +477,11 @@ func (s *content) toMediaURLs(assets []*models.ContentAsset) []content_dto.Media
return urls
}
func (s *content) addInteract(ctx context.Context, contentId, typ string) error {
userID := ctx.Value(consts.CtxKeyUser)
if userID == nil {
func (s *content) addInteract(ctx context.Context, userID int64, contentId, typ string) error {
if userID == 0 {
return errorx.ErrUnauthorized
}
uid := cast.ToInt64(userID)
uid := userID
cid := cast.ToInt64(contentId)
// Fetch content for author
@@ -529,12 +525,11 @@ func (s *content) addInteract(ctx context.Context, contentId, typ string) error
return nil
}
func (s *content) removeInteract(ctx context.Context, contentId, typ string) error {
userID := ctx.Value(consts.CtxKeyUser)
if userID == nil {
func (s *content) removeInteract(ctx context.Context, userID int64, contentId, typ string) error {
if userID == 0 {
return errorx.ErrUnauthorized
}
uid := cast.ToInt64(userID)
uid := userID
cid := cast.ToInt64(contentId)
return models.Q.Transaction(func(tx *models.Query) error {
@@ -556,12 +551,11 @@ func (s *content) removeInteract(ctx context.Context, contentId, typ string) err
})
}
func (s *content) getInteractList(ctx context.Context, typ string) ([]user_dto.ContentItem, error) {
userID := ctx.Value(consts.CtxKeyUser)
if userID == nil {
func (s *content) getInteractList(ctx context.Context, userID int64, typ string) ([]user_dto.ContentItem, error) {
if userID == 0 {
return nil, errorx.ErrUnauthorized
}
uid := cast.ToInt64(userID)
uid := userID
tbl, q := models.UserContentActionQuery.QueryContext(ctx)
actions, err := q.Where(tbl.UserID.Eq(uid), tbl.Type.Eq(typ)).Find()

View File

@@ -15,12 +15,11 @@ import (
// @provider
type coupon struct{}
func (s *coupon) ListUserCoupons(ctx context.Context, status string) ([]coupon_dto.UserCouponItem, error) {
userID := ctx.Value(consts.CtxKeyUser)
if userID == nil {
func (s *coupon) ListUserCoupons(ctx context.Context, userID int64, status string) ([]coupon_dto.UserCouponItem, error) {
if userID == 0 {
return nil, errorx.ErrUnauthorized
}
uid := cast.ToInt64(userID)
uid := userID
tbl, q := models.UserCouponQuery.QueryContext(ctx)
q = q.Where(tbl.UserID.Eq(uid))

View File

@@ -20,12 +20,11 @@ import (
// @provider
type creator struct{}
func (s *creator) Apply(ctx context.Context, form *creator_dto.ApplyForm) error {
userID := ctx.Value(consts.CtxKeyUser)
if userID == nil {
func (s *creator) Apply(ctx context.Context, userID int64, form *creator_dto.ApplyForm) error {
if userID == 0 {
return errorx.ErrUnauthorized
}
uid := cast.ToInt64(userID)
uid := userID
tbl, q := models.TenantQuery.QueryContext(ctx)
// Check if already has a tenant
@@ -62,8 +61,8 @@ func (s *creator) Apply(ctx context.Context, form *creator_dto.ApplyForm) error
return nil
}
func (s *creator) Dashboard(ctx context.Context) (*creator_dto.DashboardStats, error) {
tid, err := s.getTenantID(ctx)
func (s *creator) Dashboard(ctx context.Context, userID int64) (*creator_dto.DashboardStats, error) {
tid, err := s.getTenantID(ctx, userID)
if err != nil {
return nil, err
}
@@ -95,8 +94,8 @@ func (s *creator) Dashboard(ctx context.Context) (*creator_dto.DashboardStats, e
return stats, nil
}
func (s *creator) ListContents(ctx context.Context, filter *creator_dto.CreatorContentListFilter) ([]creator_dto.ContentItem, error) {
tid, err := s.getTenantID(ctx)
func (s *creator) ListContents(ctx context.Context, userID int64, filter *creator_dto.CreatorContentListFilter) ([]creator_dto.ContentItem, error) {
tid, err := s.getTenantID(ctx, userID)
if err != nil {
return nil, err
}
@@ -133,12 +132,12 @@ func (s *creator) ListContents(ctx context.Context, filter *creator_dto.CreatorC
return data, nil
}
func (s *creator) CreateContent(ctx context.Context, form *creator_dto.ContentCreateForm) error {
tid, err := s.getTenantID(ctx)
func (s *creator) CreateContent(ctx context.Context, userID int64, form *creator_dto.ContentCreateForm) error {
tid, err := s.getTenantID(ctx, userID)
if err != nil {
return err
}
uid := cast.ToInt64(ctx.Value(consts.CtxKeyUser))
uid := userID
return models.Q.Transaction(func(tx *models.Query) error {
// 1. Create Content
@@ -187,13 +186,13 @@ func (s *creator) CreateContent(ctx context.Context, form *creator_dto.ContentCr
})
}
func (s *creator) UpdateContent(ctx context.Context, id string, form *creator_dto.ContentUpdateForm) error {
tid, err := s.getTenantID(ctx)
func (s *creator) UpdateContent(ctx context.Context, userID int64, id string, form *creator_dto.ContentUpdateForm) error {
tid, err := s.getTenantID(ctx, userID)
if err != nil {
return err
}
cid := cast.ToInt64(id)
uid := cast.ToInt64(ctx.Value(consts.CtxKeyUser))
uid := userID
return models.Q.Transaction(func(tx *models.Query) error {
// 1. Check Ownership
@@ -257,9 +256,9 @@ func (s *creator) UpdateContent(ctx context.Context, id string, form *creator_dt
})
}
func (s *creator) DeleteContent(ctx context.Context, id string) error {
func (s *creator) DeleteContent(ctx context.Context, userID int64, id string) error {
cid := cast.ToInt64(id)
tid, err := s.getTenantID(ctx)
tid, err := s.getTenantID(ctx, userID)
if err != nil {
return err
}
@@ -271,8 +270,8 @@ func (s *creator) DeleteContent(ctx context.Context, id string) error {
return nil
}
func (s *creator) ListOrders(ctx context.Context, filter *creator_dto.CreatorOrderListFilter) ([]creator_dto.Order, error) {
tid, err := s.getTenantID(ctx)
func (s *creator) ListOrders(ctx context.Context, userID int64, filter *creator_dto.CreatorOrderListFilter) ([]creator_dto.Order, error) {
tid, err := s.getTenantID(ctx, userID)
if err != nil {
return nil, err
}
@@ -302,13 +301,13 @@ func (s *creator) ListOrders(ctx context.Context, filter *creator_dto.CreatorOrd
return data, nil
}
func (s *creator) ProcessRefund(ctx context.Context, id string, form *creator_dto.RefundForm) error {
tid, err := s.getTenantID(ctx)
func (s *creator) ProcessRefund(ctx context.Context, userID int64, id string, form *creator_dto.RefundForm) error {
tid, err := s.getTenantID(ctx, userID)
if err != nil {
return err
}
oid := cast.ToInt64(id)
uid := cast.ToInt64(ctx.Value(consts.CtxKeyUser)) // Creator ID
uid := userID // Creator ID
// Fetch Order
o, err := models.OrderQuery.WithContext(ctx).Where(models.OrderQuery.ID.Eq(oid), models.OrderQuery.TenantID.Eq(tid)).First()
@@ -402,8 +401,8 @@ func (s *creator) ProcessRefund(ctx context.Context, id string, form *creator_dt
return errorx.ErrBadRequest.WithMsg("无效的操作")
}
func (s *creator) GetSettings(ctx context.Context) (*creator_dto.Settings, error) {
tid, err := s.getTenantID(ctx)
func (s *creator) GetSettings(ctx context.Context, userID int64) (*creator_dto.Settings, error) {
tid, err := s.getTenantID(ctx, userID)
if err != nil {
return nil, err
}
@@ -418,12 +417,12 @@ func (s *creator) GetSettings(ctx context.Context) (*creator_dto.Settings, error
}, nil
}
func (s *creator) UpdateSettings(ctx context.Context, form *creator_dto.Settings) error {
func (s *creator) UpdateSettings(ctx context.Context, userID int64, form *creator_dto.Settings) error {
return nil
}
func (s *creator) ListPayoutAccounts(ctx context.Context) ([]creator_dto.PayoutAccount, error) {
tid, err := s.getTenantID(ctx)
func (s *creator) ListPayoutAccounts(ctx context.Context, userID int64) ([]creator_dto.PayoutAccount, error) {
tid, err := s.getTenantID(ctx, userID)
if err != nil {
return nil, err
}
@@ -446,12 +445,12 @@ func (s *creator) ListPayoutAccounts(ctx context.Context) ([]creator_dto.PayoutA
return data, nil
}
func (s *creator) AddPayoutAccount(ctx context.Context, form *creator_dto.PayoutAccount) error {
tid, err := s.getTenantID(ctx)
func (s *creator) AddPayoutAccount(ctx context.Context, userID int64, form *creator_dto.PayoutAccount) error {
tid, err := s.getTenantID(ctx, userID)
if err != nil {
return err
}
uid := cast.ToInt64(ctx.Value(consts.CtxKeyUser))
uid := userID
pa := &models.PayoutAccount{
TenantID: tid,
@@ -467,8 +466,8 @@ func (s *creator) AddPayoutAccount(ctx context.Context, form *creator_dto.Payout
return nil
}
func (s *creator) RemovePayoutAccount(ctx context.Context, id string) error {
tid, err := s.getTenantID(ctx)
func (s *creator) RemovePayoutAccount(ctx context.Context, userID int64, id string) error {
tid, err := s.getTenantID(ctx, userID)
if err != nil {
return err
}
@@ -483,12 +482,12 @@ func (s *creator) RemovePayoutAccount(ctx context.Context, id string) error {
return nil
}
func (s *creator) Withdraw(ctx context.Context, form *creator_dto.WithdrawForm) error {
tid, err := s.getTenantID(ctx)
func (s *creator) Withdraw(ctx context.Context, userID int64, form *creator_dto.WithdrawForm) error {
tid, err := s.getTenantID(ctx, userID)
if err != nil {
return err
}
uid := cast.ToInt64(ctx.Value(consts.CtxKeyUser))
uid := userID
amount := int64(form.Amount * 100)
if amount <= 0 {
@@ -552,12 +551,11 @@ func (s *creator) Withdraw(ctx context.Context, form *creator_dto.WithdrawForm)
// Helpers
func (s *creator) getTenantID(ctx context.Context) (int64, error) {
userID := ctx.Value(consts.CtxKeyUser)
if userID == nil {
func (s *creator) getTenantID(ctx context.Context, userID int64) (int64, error) {
if userID == 0 {
return 0, errorx.ErrUnauthorized
}
uid := cast.ToInt64(userID)
uid := userID
// Simple check: User owns tenant
t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.UserID.Eq(uid)).First()

View File

@@ -20,15 +20,9 @@ type notification struct {
job *job.Job
}
func (s *notification) List(ctx context.Context, page int, typeArg string) (*requests.Pager, error) {
userID := ctx.Value(consts.CtxKeyUser)
if userID == nil {
return nil, errorx.ErrUnauthorized
}
uid := cast.ToInt64(userID)
func (s *notification) List(ctx context.Context, userID int64, page int, typeArg string) (*requests.Pager, error) {
tbl, q := models.NotificationQuery.QueryContext(ctx)
q = q.Where(tbl.UserID.Eq(uid))
q = q.Where(tbl.UserID.Eq(userID))
if typeArg != "" && typeArg != "all" {
q = q.Where(tbl.Type.Eq(typeArg))
@@ -66,16 +60,11 @@ func (s *notification) List(ctx context.Context, page int, typeArg string) (*req
}, nil
}
func (s *notification) MarkRead(ctx context.Context, id string) error {
userID := ctx.Value(consts.CtxKeyUser)
if userID == nil {
return errorx.ErrUnauthorized
}
uid := cast.ToInt64(userID)
func (s *notification) MarkRead(ctx context.Context, userID int64, id string) error {
nid := cast.ToInt64(id)
_, err := models.NotificationQuery.WithContext(ctx).
Where(models.NotificationQuery.ID.Eq(nid), models.NotificationQuery.UserID.Eq(uid)).
Where(models.NotificationQuery.ID.Eq(nid), models.NotificationQuery.UserID.Eq(userID)).
UpdateSimple(models.NotificationQuery.IsRead.Value(true))
if err != nil {
return errorx.ErrDatabaseError.WithCause(err)

View File

@@ -21,12 +21,11 @@ import (
// @provider
type order struct{}
func (s *order) ListUserOrders(ctx context.Context, status string) ([]user_dto.Order, error) {
userID := ctx.Value(consts.CtxKeyUser)
if userID == nil {
func (s *order) ListUserOrders(ctx context.Context, userID int64, status string) ([]user_dto.Order, error) {
if userID == 0 {
return nil, errorx.ErrUnauthorized
}
uid := cast.ToInt64(userID)
uid := userID
tbl, q := models.OrderQuery.QueryContext(ctx)
q = q.Where(tbl.UserID.Eq(uid))
@@ -48,12 +47,11 @@ func (s *order) ListUserOrders(ctx context.Context, status string) ([]user_dto.O
return data, nil
}
func (s *order) GetUserOrder(ctx context.Context, id string) (*user_dto.Order, error) {
userID := ctx.Value(consts.CtxKeyUser)
if userID == nil {
func (s *order) GetUserOrder(ctx context.Context, userID int64, id string) (*user_dto.Order, error) {
if userID == 0 {
return nil, errorx.ErrUnauthorized
}
uid := cast.ToInt64(userID)
uid := userID
oid := cast.ToInt64(id)
tbl, q := models.OrderQuery.QueryContext(ctx)
@@ -72,12 +70,11 @@ func (s *order) GetUserOrder(ctx context.Context, id string) (*user_dto.Order, e
return &dto, nil
}
func (s *order) Create(ctx context.Context, form *transaction_dto.OrderCreateForm) (*transaction_dto.OrderCreateResponse, error) {
userID := ctx.Value(consts.CtxKeyUser)
if userID == nil {
func (s *order) Create(ctx context.Context, userID int64, form *transaction_dto.OrderCreateForm) (*transaction_dto.OrderCreateResponse, error) {
if userID == 0 {
return nil, errorx.ErrUnauthorized
}
uid := cast.ToInt64(userID)
uid := userID
cid := cast.ToInt64(form.ContentID)
// 1. Fetch Content & Price
@@ -169,12 +166,11 @@ func (s *order) Create(ctx context.Context, form *transaction_dto.OrderCreateFor
}, nil
}
func (s *order) Pay(ctx context.Context, id string, form *transaction_dto.OrderPayForm) (*transaction_dto.OrderPayResponse, error) {
userID := ctx.Value(consts.CtxKeyUser)
if userID == nil {
func (s *order) Pay(ctx context.Context, userID int64, id string, form *transaction_dto.OrderPayForm) (*transaction_dto.OrderPayResponse, error) {
if userID == 0 {
return nil, errorx.ErrUnauthorized
}
uid := cast.ToInt64(userID)
uid := userID
oid := cast.ToInt64(id)
// Fetch Order

View File

@@ -17,7 +17,7 @@ import (
// @provider
type tenant struct{}
func (s *tenant) GetPublicProfile(ctx context.Context, id string) (*dto.TenantProfile, error) {
func (s *tenant) GetPublicProfile(ctx context.Context, userID int64, id string) (*dto.TenantProfile, error) {
tid := cast.ToInt64(id)
t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.ID.Eq(tid)).First()
if err != nil {
@@ -29,12 +29,14 @@ func (s *tenant) GetPublicProfile(ctx context.Context, id string) (*dto.TenantPr
// Stats
followers, _ := models.TenantUserQuery.WithContext(ctx).Where(models.TenantUserQuery.TenantID.Eq(tid)).Count()
contents, _ := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.TenantID.Eq(tid), models.ContentQuery.Status.Eq(consts.ContentStatusPublished)).Count()
contents, _ := models.ContentQuery.WithContext(ctx).
Where(models.ContentQuery.TenantID.Eq(tid), models.ContentQuery.Status.Eq(consts.ContentStatusPublished)).
Count()
// Following status
isFollowing := false
if userID := ctx.Value(consts.CtxKeyUser); userID != nil {
uid := cast.ToInt64(userID)
if userID > 0 {
uid := userID
isFollowing, _ = models.TenantUserQuery.WithContext(ctx).
Where(models.TenantUserQuery.TenantID.Eq(tid), models.TenantUserQuery.UserID.Eq(uid)).
Exists()
@@ -52,12 +54,11 @@ func (s *tenant) GetPublicProfile(ctx context.Context, id string) (*dto.TenantPr
}, nil
}
func (s *tenant) Follow(ctx context.Context, id string) error {
userID := ctx.Value(consts.CtxKeyUser)
if userID == nil {
func (s *tenant) Follow(ctx context.Context, userID int64, id string) error {
if userID == 0 {
return errorx.ErrUnauthorized
}
uid := cast.ToInt64(userID)
uid := userID
tid := cast.ToInt64(id)
// Check if tenant exists
@@ -83,12 +84,11 @@ func (s *tenant) Follow(ctx context.Context, id string) error {
return nil
}
func (s *tenant) Unfollow(ctx context.Context, id string) error {
userID := ctx.Value(consts.CtxKeyUser)
if userID == nil {
func (s *tenant) Unfollow(ctx context.Context, userID int64, id string) error {
if userID == 0 {
return errorx.ErrUnauthorized
}
uid := cast.ToInt64(userID)
uid := userID
tid := cast.ToInt64(id)
_, err := models.TenantUserQuery.WithContext(ctx).
@@ -100,12 +100,11 @@ func (s *tenant) Unfollow(ctx context.Context, id string) error {
return nil
}
func (s *tenant) ListFollowed(ctx context.Context) ([]dto.TenantProfile, error) {
userID := ctx.Value(consts.CtxKeyUser)
if userID == nil {
func (s *tenant) ListFollowed(ctx context.Context, userID int64) ([]dto.TenantProfile, error) {
if userID == 0 {
return nil, errorx.ErrUnauthorized
}
uid := cast.ToInt64(userID)
uid := userID
tbl, q := models.TenantUserQuery.QueryContext(ctx)
list, err := q.Where(tbl.UserID.Eq(uid)).Find()
@@ -122,8 +121,12 @@ func (s *tenant) ListFollowed(ctx context.Context) ([]dto.TenantProfile, error)
}
// Stats
followers, _ := models.TenantUserQuery.WithContext(ctx).Where(models.TenantUserQuery.TenantID.Eq(tu.TenantID)).Count()
contents, _ := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.TenantID.Eq(tu.TenantID), models.ContentQuery.Status.Eq(consts.ContentStatusPublished)).Count()
followers, _ := models.TenantUserQuery.WithContext(ctx).
Where(models.TenantUserQuery.TenantID.Eq(tu.TenantID)).
Count()
contents, _ := models.ContentQuery.WithContext(ctx).
Where(models.ContentQuery.TenantID.Eq(tu.TenantID), models.ContentQuery.Status.Eq(consts.ContentStatusPublished)).
Count()
data = append(data, dto.TenantProfile{
ID: cast.ToString(t.ID),
@@ -139,3 +142,16 @@ func (s *tenant) ListFollowed(ctx context.Context) ([]dto.TenantProfile, error)
return data, nil
}
// GetModelByID 获取指定 ID 的model
func (s *tenant) GetModelByID(ctx context.Context, id int64) (*models.Tenant, error) {
tbl, query := models.TenantQuery.QueryContext(ctx)
u, err := query.Where(tbl.ID.Eq(id)).First()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errorx.ErrRecordNotFound
}
return nil, errorx.ErrDatabaseError.WithCause(err)
}
return u, nil
}

View File

@@ -77,40 +77,36 @@ func (s *user) LoginWithOTP(ctx context.Context, phone, otp string) (*auth_dto.L
return &auth_dto.LoginResponse{
Token: token,
User: s.toAuthUserDTO(u),
User: s.ToAuthUserDTO(u),
}, nil
}
// Me 获取当前用户信息
func (s *user) Me(ctx context.Context) (*auth_dto.User, error) {
userID := ctx.Value(consts.CtxKeyUser)
if userID == nil {
return nil, errorx.ErrUnauthorized
}
uid := cast.ToInt64(userID)
// GetModelByID 获取指定 ID 的用户model
func (s *user) GetModelByID(ctx context.Context, userID int64) (*models.User, error) {
tbl, query := models.UserQuery.QueryContext(ctx)
u, err := query.Where(tbl.ID.Eq(uid)).First()
u, err := query.Where(tbl.ID.Eq(userID)).First()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errorx.ErrRecordNotFound
}
return nil, errorx.ErrDatabaseError.WithCause(err)
}
return u, nil
}
return s.toAuthUserDTO(u), nil
// Me 获取当前用户信息
func (s *user) Me(ctx context.Context, userID int64) (*auth_dto.User, error) {
u, err := s.GetModelByID(ctx, userID)
if err != nil {
return nil, err
}
return s.ToAuthUserDTO(u), nil
}
// Update 更新用户信息
func (s *user) Update(ctx context.Context, form *user_dto.UserUpdate) error {
userID := ctx.Value(consts.CtxKeyUser)
if userID == nil {
return errorx.ErrUnauthorized
}
uid := cast.ToInt64(userID)
func (s *user) Update(ctx context.Context, userID int64, form *user_dto.UserUpdate) error {
tbl, query := models.UserQuery.QueryContext(ctx)
_, err := query.Where(tbl.ID.Eq(uid)).Updates(&models.User{
_, err := query.Where(tbl.ID.Eq(userID)).Updates(&models.User{
Nickname: form.Nickname,
Avatar: form.Avatar,
Gender: form.Gender,
@@ -125,13 +121,7 @@ func (s *user) Update(ctx context.Context, form *user_dto.UserUpdate) error {
}
// RealName 实名认证
func (s *user) RealName(ctx context.Context, form *user_dto.RealNameForm) error {
userID := ctx.Value(consts.CtxKeyUser)
if userID == nil {
return errorx.ErrUnauthorized
}
uid := cast.ToInt64(userID)
func (s *user) RealName(ctx context.Context, userID int64, form *user_dto.RealNameForm) error {
// Mock Verification
if len(form.IDCard) != 18 {
return errorx.ErrBadRequest.WithMsg("身份证号格式错误")
@@ -141,7 +131,7 @@ func (s *user) RealName(ctx context.Context, form *user_dto.RealNameForm) error
}
tbl, query := models.UserQuery.QueryContext(ctx)
u, err := query.Where(tbl.ID.Eq(uid)).First()
u, err := query.Where(tbl.ID.Eq(userID)).First()
if err != nil {
return errorx.ErrRecordNotFound
}
@@ -159,7 +149,7 @@ func (s *user) RealName(ctx context.Context, form *user_dto.RealNameForm) error
b, _ := json.Marshal(metaMap)
_, err = query.Where(tbl.ID.Eq(uid)).Updates(&models.User{
_, err = query.Where(tbl.ID.Eq(userID)).Updates(&models.User{
IsRealNameVerified: true,
VerifiedAt: time.Now(),
Metas: types.JSON(b),
@@ -171,15 +161,9 @@ func (s *user) RealName(ctx context.Context, form *user_dto.RealNameForm) error
}
// GetNotifications 获取通知
func (s *user) GetNotifications(ctx context.Context, typeArg string) ([]user_dto.Notification, error) {
userID := ctx.Value(consts.CtxKeyUser)
if userID == nil {
return nil, errorx.ErrUnauthorized
}
uid := cast.ToInt64(userID)
func (s *user) GetNotifications(ctx context.Context, userID int64, typeArg string) ([]user_dto.Notification, error) {
tbl, query := models.NotificationQuery.QueryContext(ctx)
query = query.Where(tbl.UserID.Eq(uid))
query = query.Where(tbl.UserID.Eq(userID))
if typeArg != "" && typeArg != "all" {
query = query.Where(tbl.Type.Eq(typeArg))
}
@@ -203,7 +187,7 @@ func (s *user) GetNotifications(ctx context.Context, typeArg string) ([]user_dto
return result, nil
}
func (s *user) toAuthUserDTO(u *models.User) *auth_dto.User {
func (s *user) ToAuthUserDTO(u *models.User) *auth_dto.User {
return &auth_dto.User{
ID: cast.ToString(u.ID),
Phone: u.Phone,

View File

@@ -20,15 +20,9 @@ import (
// @provider
type wallet struct{}
func (s *wallet) GetWallet(ctx context.Context) (*user_dto.WalletResponse, error) {
userID := ctx.Value(consts.CtxKeyUser)
if userID == nil {
return nil, errorx.ErrUnauthorized
}
uid := cast.ToInt64(userID)
func (s *wallet) GetWallet(ctx context.Context, userID int64) (*user_dto.WalletResponse, error) {
// Get Balance
u, err := models.UserQuery.WithContext(ctx).Where(models.UserQuery.ID.Eq(uid)).First()
u, err := models.UserQuery.WithContext(ctx).Where(models.UserQuery.ID.Eq(userID)).First()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errorx.ErrRecordNotFound
@@ -39,7 +33,7 @@ func (s *wallet) GetWallet(ctx context.Context) (*user_dto.WalletResponse, error
// Get Transactions (Orders)
// Both purchase (expense) and recharge (income - if paid)
tbl, q := models.OrderQuery.QueryContext(ctx)
orders, err := q.Where(tbl.UserID.Eq(uid), tbl.Status.Eq(consts.OrderStatusPaid)).
orders, err := q.Where(tbl.UserID.Eq(userID), tbl.Status.Eq(consts.OrderStatusPaid)).
Order(tbl.CreatedAt.Desc()).
Limit(20). // Limit to recent 20
Find()
@@ -74,13 +68,7 @@ func (s *wallet) GetWallet(ctx context.Context) (*user_dto.WalletResponse, error
}, nil
}
func (s *wallet) Recharge(ctx context.Context, form *user_dto.RechargeForm) (*user_dto.RechargeResponse, error) {
userID := ctx.Value(consts.CtxKeyUser)
if userID == nil {
return nil, errorx.ErrUnauthorized
}
uid := cast.ToInt64(userID)
func (s *wallet) Recharge(ctx context.Context, userID int64, form *user_dto.RechargeForm) (*user_dto.RechargeResponse, error) {
amount := int64(form.Amount * 100)
if amount <= 0 {
return nil, errorx.ErrBadRequest.WithMsg("金额无效")
@@ -89,7 +77,7 @@ func (s *wallet) Recharge(ctx context.Context, form *user_dto.RechargeForm) (*us
// Create Recharge Order
order := &models.Order{
TenantID: 0, // Platform / System
UserID: uid,
UserID: userID,
Type: consts.OrderTypeRecharge,
Status: consts.OrderStatusCreated,
Currency: consts.CurrencyCNY,