diff --git a/backend/app/commands/seed/seed.go b/backend/app/commands/seed/seed.go index 83f419a..190e92b 100644 --- a/backend/app/commands/seed/seed.go +++ b/backend/app/commands/seed/seed.go @@ -120,13 +120,13 @@ func Serve(cmd *cobra.Command, args []string) error { UserID: creator.ID, Title: title, Description: fmt.Sprintf("这是关于 %s 的详细介绍...", title), - Genre: "京剧", - Status: consts.ContentStatusPublished, - Visibility: consts.ContentVisibilityPublic, - Views: int32(rand.Intn(10000)), - Likes: int32(rand.Intn(1000)), - } - models.ContentQuery.WithContext(ctx).Create(c) + Genre: "京剧", + Status: consts.ContentStatusPublished, + Visibility: consts.ContentVisibilityPublic, + Views: int32(rand.Intn(10000)), + Likes: int32(rand.Intn(1000)), + } + models.ContentQuery.WithContext(ctx).Create(c) // Price models.ContentPriceQuery.WithContext(ctx).Create(&models.ContentPrice{ TenantID: tenant.ID, diff --git a/backend/app/http/v1/common.go b/backend/app/http/v1/common.go index 7fe5855..a0e74ba 100644 --- a/backend/app/http/v1/common.go +++ b/backend/app/http/v1/common.go @@ -5,8 +5,10 @@ import ( "quyun/v2/app/http/v1/dto" "quyun/v2/app/services" + "quyun/v2/pkg/consts" "github.com/gofiber/fiber/v3" + "github.com/spf13/cast" ) // @provider @@ -30,5 +32,6 @@ func (c *Common) Upload(ctx fiber.Ctx, file *multipart.FileHeader, typeArg *stri if typeArg != nil { val = *typeArg } - return services.Common.Upload(ctx, file, val) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Common.Upload(ctx.Context(), uid, file, val) } diff --git a/backend/app/http/v1/content.go b/backend/app/http/v1/content.go index d1510a4..acb844d 100644 --- a/backend/app/http/v1/content.go +++ b/backend/app/http/v1/content.go @@ -4,8 +4,10 @@ import ( "quyun/v2/app/http/v1/dto" "quyun/v2/app/requests" "quyun/v2/app/services" + "quyun/v2/pkg/consts" "github.com/gofiber/fiber/v3" + "github.com/spf13/cast" ) // @provider @@ -45,7 +47,8 @@ func (c *Content) List( // @Success 200 {object} dto.ContentDetail // @Bind id path func (c *Content) Get(ctx fiber.Ctx, id string) (*dto.ContentDetail, error) { - return services.Content.Get(ctx, id) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Content.Get(ctx.Context(), uid, id) } // Get comments for a content @@ -62,7 +65,8 @@ func (c *Content) Get(ctx fiber.Ctx, id string) (*dto.ContentDetail, error) { // @Bind id path // @Bind page query func (c *Content) ListComments(ctx fiber.Ctx, id string, page int) (*requests.Pager, error) { - return services.Content.ListComments(ctx, id, page) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Content.ListComments(ctx.Context(), uid, id, page) } // Post a comment @@ -79,7 +83,8 @@ func (c *Content) ListComments(ctx fiber.Ctx, id string, page int) (*requests.Pa // @Bind id path // @Bind form body func (c *Content) CreateComment(ctx fiber.Ctx, id string, form *dto.CommentCreateForm) error { - return services.Content.CreateComment(ctx, id, form) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Content.CreateComment(ctx.Context(), uid, id, form) } // Like a comment @@ -94,7 +99,8 @@ func (c *Content) CreateComment(ctx fiber.Ctx, id string, form *dto.CommentCreat // @Success 200 {string} string "Liked" // @Bind id path func (c *Content) LikeComment(ctx fiber.Ctx, id string) error { - return services.Content.LikeComment(ctx, id) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Content.LikeComment(ctx.Context(), uid, id) } // List curated topics diff --git a/backend/app/http/v1/creator.go b/backend/app/http/v1/creator.go index 1ce5369..c0d9b1b 100644 --- a/backend/app/http/v1/creator.go +++ b/backend/app/http/v1/creator.go @@ -3,8 +3,10 @@ package v1 import ( "quyun/v2/app/http/v1/dto" "quyun/v2/app/services" + "quyun/v2/pkg/consts" "github.com/gofiber/fiber/v3" + "github.com/spf13/cast" ) // @provider @@ -22,7 +24,8 @@ type Creator struct{} // @Success 200 {string} string "Application submitted" // @Bind form body func (c *Creator) Apply(ctx fiber.Ctx, form *dto.ApplyForm) error { - return services.Creator.Apply(ctx, form) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Creator.Apply(ctx.Context(), uid, form) } // Get creator dashboard stats @@ -35,7 +38,8 @@ func (c *Creator) Apply(ctx fiber.Ctx, form *dto.ApplyForm) error { // @Produce json // @Success 200 {object} dto.DashboardStats func (c *Creator) Dashboard(ctx fiber.Ctx) (*dto.DashboardStats, error) { - return services.Creator.Dashboard(ctx) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Creator.Dashboard(ctx.Context(), uid) } // List creator contents @@ -52,7 +56,8 @@ func (c *Creator) Dashboard(ctx fiber.Ctx) (*dto.DashboardStats, error) { // @Success 200 {array} dto.ContentItem // @Bind filter query func (c *Creator) ListContents(ctx fiber.Ctx, filter *dto.CreatorContentListFilter) ([]dto.ContentItem, error) { - return services.Creator.ListContents(ctx, filter) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Creator.ListContents(ctx.Context(), uid, filter) } // Create/Publish content @@ -67,7 +72,8 @@ func (c *Creator) ListContents(ctx fiber.Ctx, filter *dto.CreatorContentListFilt // @Success 200 {string} string "Created" // @Bind form body func (c *Creator) CreateContent(ctx fiber.Ctx, form *dto.ContentCreateForm) error { - return services.Creator.CreateContent(ctx, form) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Creator.CreateContent(ctx.Context(), uid, form) } // Update content @@ -84,7 +90,8 @@ func (c *Creator) CreateContent(ctx fiber.Ctx, form *dto.ContentCreateForm) erro // @Bind id path // @Bind form body func (c *Creator) UpdateContent(ctx fiber.Ctx, id string, form *dto.ContentUpdateForm) error { - return services.Creator.UpdateContent(ctx, id, form) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Creator.UpdateContent(ctx.Context(), uid, id, form) } // Delete content @@ -99,7 +106,8 @@ func (c *Creator) UpdateContent(ctx fiber.Ctx, id string, form *dto.ContentUpdat // @Success 200 {string} string "Deleted" // @Bind id path func (c *Creator) DeleteContent(ctx fiber.Ctx, id string) error { - return services.Creator.DeleteContent(ctx, id) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Creator.DeleteContent(ctx.Context(), uid, id) } // List sales orders @@ -115,7 +123,8 @@ func (c *Creator) DeleteContent(ctx fiber.Ctx, id string) error { // @Success 200 {array} dto.Order // @Bind filter query func (c *Creator) ListOrders(ctx fiber.Ctx, filter *dto.CreatorOrderListFilter) ([]dto.Order, error) { - return services.Creator.ListOrders(ctx, filter) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Creator.ListOrders(ctx.Context(), uid, filter) } // Process refund @@ -132,7 +141,8 @@ func (c *Creator) ListOrders(ctx fiber.Ctx, filter *dto.CreatorOrderListFilter) // @Bind id path // @Bind form body func (c *Creator) Refund(ctx fiber.Ctx, id string, form *dto.RefundForm) error { - return services.Creator.ProcessRefund(ctx, id, form) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Creator.ProcessRefund(ctx.Context(), uid, id, form) } // Get channel settings @@ -145,7 +155,8 @@ func (c *Creator) Refund(ctx fiber.Ctx, id string, form *dto.RefundForm) error { // @Produce json // @Success 200 {object} dto.Settings func (c *Creator) GetSettings(ctx fiber.Ctx) (*dto.Settings, error) { - return services.Creator.GetSettings(ctx) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Creator.GetSettings(ctx.Context(), uid) } // Update channel settings @@ -160,7 +171,8 @@ func (c *Creator) GetSettings(ctx fiber.Ctx) (*dto.Settings, error) { // @Success 200 {string} string "Updated" // @Bind form body func (c *Creator) UpdateSettings(ctx fiber.Ctx, form *dto.Settings) error { - return services.Creator.UpdateSettings(ctx, form) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Creator.UpdateSettings(ctx.Context(), uid, form) } // List payout accounts @@ -173,7 +185,8 @@ func (c *Creator) UpdateSettings(ctx fiber.Ctx, form *dto.Settings) error { // @Produce json // @Success 200 {array} dto.PayoutAccount func (c *Creator) ListPayoutAccounts(ctx fiber.Ctx) ([]dto.PayoutAccount, error) { - return services.Creator.ListPayoutAccounts(ctx) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Creator.ListPayoutAccounts(ctx.Context(), uid) } // Add payout account @@ -188,7 +201,8 @@ func (c *Creator) ListPayoutAccounts(ctx fiber.Ctx) ([]dto.PayoutAccount, error) // @Success 200 {string} string "Added" // @Bind form body func (c *Creator) AddPayoutAccount(ctx fiber.Ctx, form *dto.PayoutAccount) error { - return services.Creator.AddPayoutAccount(ctx, form) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Creator.AddPayoutAccount(ctx.Context(), uid, form) } // Remove payout account @@ -203,7 +217,8 @@ func (c *Creator) AddPayoutAccount(ctx fiber.Ctx, form *dto.PayoutAccount) error // @Success 200 {string} string "Removed" // @Bind id query func (c *Creator) RemovePayoutAccount(ctx fiber.Ctx, id string) error { - return services.Creator.RemovePayoutAccount(ctx, id) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Creator.RemovePayoutAccount(ctx.Context(), uid, id) } // Request withdrawal @@ -218,5 +233,6 @@ func (c *Creator) RemovePayoutAccount(ctx fiber.Ctx, id string) error { // @Success 200 {string} string "Withdrawal requested" // @Bind form body func (c *Creator) Withdraw(ctx fiber.Ctx, form *dto.WithdrawForm) error { - return services.Creator.Withdraw(ctx, form) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Creator.Withdraw(ctx.Context(), uid, form) } diff --git a/backend/app/http/v1/routes.gen.go b/backend/app/http/v1/routes.gen.go index 66a1fa9..1a91d8c 100644 --- a/backend/app/http/v1/routes.gen.go +++ b/backend/app/http/v1/routes.gen.go @@ -8,6 +8,7 @@ import ( "mime/multipart" "quyun/v2/app/http/v1/dto" "quyun/v2/app/middlewares" + "quyun/v2/database/models" "github.com/gofiber/fiber/v3" log "github.com/sirupsen/logrus" @@ -222,8 +223,9 @@ func (r *Routes) Register(router fiber.Router) { PathParam[string]("contentId"), )) r.log.Debugf("Registering route: Get /v1/me -> user.Me") - router.Get("/v1/me"[len(r.Path()):], DataFunc0( + router.Get("/v1/me"[len(r.Path()):], DataFunc1( r.user.Me, + Local[*models.User]("__ctx_user"), )) r.log.Debugf("Registering route: Get /v1/me/coupons -> user.MyCoupons") router.Get("/v1/me/coupons"[len(r.Path()):], DataFunc1( diff --git a/backend/app/http/v1/tenant.go b/backend/app/http/v1/tenant.go index 1b25f88..701a23a 100644 --- a/backend/app/http/v1/tenant.go +++ b/backend/app/http/v1/tenant.go @@ -3,8 +3,10 @@ package v1 import ( "quyun/v2/app/http/v1/dto" "quyun/v2/app/services" + "quyun/v2/pkg/consts" "github.com/gofiber/fiber/v3" + "github.com/spf13/cast" ) // @provider @@ -22,7 +24,8 @@ type Tenant struct{} // @Success 200 {object} dto.TenantProfile // @Bind id path func (t *Tenant) Get(ctx fiber.Ctx, id string) (*dto.TenantProfile, error) { - return services.Tenant.GetPublicProfile(ctx, id) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Tenant.GetPublicProfile(ctx.Context(), uid, id) } // Follow a tenant @@ -37,7 +40,8 @@ func (t *Tenant) Get(ctx fiber.Ctx, id string) (*dto.TenantProfile, error) { // @Success 200 {string} string "Followed" // @Bind id path func (t *Tenant) Follow(ctx fiber.Ctx, id string) error { - return services.Tenant.Follow(ctx, id) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Tenant.Follow(ctx.Context(), uid, id) } // Unfollow a tenant @@ -52,5 +56,6 @@ func (t *Tenant) Follow(ctx fiber.Ctx, id string) error { // @Success 200 {string} string "Unfollowed" // @Bind id path func (t *Tenant) Unfollow(ctx fiber.Ctx, id string) error { - return services.Tenant.Unfollow(ctx, id) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Tenant.Unfollow(ctx.Context(), uid, id) } diff --git a/backend/app/http/v1/transaction.go b/backend/app/http/v1/transaction.go index accd9be..bcc1eab 100644 --- a/backend/app/http/v1/transaction.go +++ b/backend/app/http/v1/transaction.go @@ -3,8 +3,10 @@ package v1 import ( "quyun/v2/app/http/v1/dto" "quyun/v2/app/services" + "quyun/v2/pkg/consts" "github.com/gofiber/fiber/v3" + "github.com/spf13/cast" ) // @provider @@ -22,7 +24,8 @@ type Transaction struct{} // @Success 200 {object} dto.OrderCreateResponse // @Bind form body func (t *Transaction) Create(ctx fiber.Ctx, form *dto.OrderCreateForm) (*dto.OrderCreateResponse, error) { - return services.Order.Create(ctx, form) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Order.Create(ctx.Context(), uid, form) } // Pay for order @@ -39,7 +42,8 @@ func (t *Transaction) Create(ctx fiber.Ctx, form *dto.OrderCreateForm) (*dto.Ord // @Bind id path // @Bind form body func (t *Transaction) Pay(ctx fiber.Ctx, id string, form *dto.OrderPayForm) (*dto.OrderPayResponse, error) { - return services.Order.Pay(ctx, id, form) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Order.Pay(ctx.Context(), uid, id, form) } // Check order payment status diff --git a/backend/app/http/v1/user.go b/backend/app/http/v1/user.go index d84867d..473b2e5 100644 --- a/backend/app/http/v1/user.go +++ b/backend/app/http/v1/user.go @@ -5,8 +5,11 @@ import ( auth_dto "quyun/v2/app/http/v1/dto" "quyun/v2/app/requests" "quyun/v2/app/services" + "quyun/v2/database/models" + "quyun/v2/pkg/consts" "github.com/gofiber/fiber/v3" + "github.com/spf13/cast" ) // @provider @@ -21,8 +24,10 @@ type User struct{} // @Accept json // @Produce json // @Success 200 {object} auth_dto.User -func (u *User) Me(ctx fiber.Ctx) (*auth_dto.User, error) { - return services.User.Me(ctx) +// @Bind user local key(__ctx_user) +func (u *User) Me(ctx fiber.Ctx, user *models.User) (*auth_dto.User, error) { + // uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.User.ToAuthUserDTO(user), nil } // Update user profile @@ -37,7 +42,8 @@ func (u *User) Me(ctx fiber.Ctx) (*auth_dto.User, error) { // @Success 200 {string} string "Updated" // @Bind form body func (u *User) Update(ctx fiber.Ctx, form *dto.UserUpdate) error { - return services.User.Update(ctx, form) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.User.Update(ctx.Context(), uid, form) } // Submit real-name authentication @@ -52,7 +58,8 @@ func (u *User) Update(ctx fiber.Ctx, form *dto.UserUpdate) error { // @Success 200 {string} string "Submitted" // @Bind form body func (u *User) RealName(ctx fiber.Ctx, form *dto.RealNameForm) error { - return services.User.RealName(ctx, form) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.User.RealName(ctx.Context(), uid, form) } // Get wallet balance and transactions @@ -65,7 +72,8 @@ func (u *User) RealName(ctx fiber.Ctx, form *dto.RealNameForm) error { // @Produce json // @Success 200 {object} dto.WalletResponse func (u *User) Wallet(ctx fiber.Ctx) (*dto.WalletResponse, error) { - return services.Wallet.GetWallet(ctx) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Wallet.GetWallet(ctx.Context(), uid) } // Recharge wallet @@ -80,7 +88,8 @@ func (u *User) Wallet(ctx fiber.Ctx) (*dto.WalletResponse, error) { // @Success 200 {object} dto.RechargeResponse // @Bind form body func (u *User) Recharge(ctx fiber.Ctx, form *dto.RechargeForm) (*dto.RechargeResponse, error) { - return services.Wallet.Recharge(ctx, form) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Wallet.Recharge(ctx.Context(), uid, form) } // List user orders @@ -95,7 +104,8 @@ func (u *User) Recharge(ctx fiber.Ctx, form *dto.RechargeForm) (*dto.RechargeRes // @Success 200 {array} dto.Order // @Bind status query func (u *User) ListOrders(ctx fiber.Ctx, status string) ([]dto.Order, error) { - return services.Order.ListUserOrders(ctx, status) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Order.ListUserOrders(ctx.Context(), uid, status) } // Get user order detail @@ -110,7 +120,8 @@ func (u *User) ListOrders(ctx fiber.Ctx, status string) ([]dto.Order, error) { // @Success 200 {object} dto.Order // @Bind id path func (u *User) GetOrder(ctx fiber.Ctx, id string) (*dto.Order, error) { - return services.Order.GetUserOrder(ctx, id) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Order.GetUserOrder(ctx.Context(), uid, id) } // Get purchased content @@ -123,7 +134,8 @@ func (u *User) GetOrder(ctx fiber.Ctx, id string) (*dto.Order, error) { // @Produce json // @Success 200 {array} dto.ContentItem func (u *User) Library(ctx fiber.Ctx) ([]dto.ContentItem, error) { - return services.Content.GetLibrary(ctx) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Content.GetLibrary(ctx.Context(), uid) } // Get favorites @@ -136,7 +148,8 @@ func (u *User) Library(ctx fiber.Ctx) ([]dto.ContentItem, error) { // @Produce json // @Success 200 {array} dto.ContentItem func (u *User) Favorites(ctx fiber.Ctx) ([]dto.ContentItem, error) { - return services.Content.GetFavorites(ctx) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Content.GetFavorites(ctx.Context(), uid) } // Add to favorites @@ -151,7 +164,8 @@ func (u *User) Favorites(ctx fiber.Ctx) ([]dto.ContentItem, error) { // @Success 200 {string} string "Added" // @Bind contentId query func (u *User) AddFavorite(ctx fiber.Ctx, contentId string) error { - return services.Content.AddFavorite(ctx, contentId) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Content.AddFavorite(ctx.Context(), uid, contentId) } // Remove from favorites @@ -166,7 +180,8 @@ func (u *User) AddFavorite(ctx fiber.Ctx, contentId string) error { // @Success 200 {string} string "Removed" // @Bind contentId path func (u *User) RemoveFavorite(ctx fiber.Ctx, contentId string) error { - return services.Content.RemoveFavorite(ctx, contentId) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Content.RemoveFavorite(ctx.Context(), uid, contentId) } // Get liked contents @@ -179,7 +194,8 @@ func (u *User) RemoveFavorite(ctx fiber.Ctx, contentId string) error { // @Produce json // @Success 200 {array} dto.ContentItem func (u *User) Likes(ctx fiber.Ctx) ([]dto.ContentItem, error) { - return services.Content.GetLikes(ctx) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Content.GetLikes(ctx.Context(), uid) } // Like content @@ -194,7 +210,8 @@ func (u *User) Likes(ctx fiber.Ctx) ([]dto.ContentItem, error) { // @Success 200 {string} string "Liked" // @Bind contentId query func (u *User) AddLike(ctx fiber.Ctx, contentId string) error { - return services.Content.AddLike(ctx, contentId) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Content.AddLike(ctx.Context(), uid, contentId) } // Unlike content @@ -209,7 +226,8 @@ func (u *User) AddLike(ctx fiber.Ctx, contentId string) error { // @Success 200 {string} string "Unliked" // @Bind contentId path func (u *User) RemoveLike(ctx fiber.Ctx, contentId string) error { - return services.Content.RemoveLike(ctx, contentId) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Content.RemoveLike(ctx.Context(), uid, contentId) } // Get following tenants @@ -222,7 +240,8 @@ func (u *User) RemoveLike(ctx fiber.Ctx, contentId string) error { // @Produce json // @Success 200 {array} dto.TenantProfile func (u *User) Following(ctx fiber.Ctx) ([]dto.TenantProfile, error) { - return services.Tenant.ListFollowed(ctx) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Tenant.ListFollowed(ctx.Context(), uid) } // Get notifications @@ -239,7 +258,8 @@ func (u *User) Following(ctx fiber.Ctx) ([]dto.TenantProfile, error) { // @Bind typeArg query key(type) // @Bind page query func (u *User) Notifications(ctx fiber.Ctx, typeArg string, page int) (*requests.Pager, error) { - return services.Notification.List(ctx, page, typeArg) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Notification.List(ctx.Context(), uid, page, typeArg) } // List my coupons @@ -254,5 +274,6 @@ func (u *User) Notifications(ctx fiber.Ctx, typeArg string, page int) (*requests // @Success 200 {array} dto.UserCouponItem // @Bind status query func (u *User) MyCoupons(ctx fiber.Ctx, status string) ([]dto.UserCouponItem, error) { - return services.Coupon.ListUserCoupons(ctx, status) + uid := cast.ToInt64(ctx.Locals(consts.CtxKeyUser)) + return services.Coupon.ListUserCoupons(ctx.Context(), uid, status) } diff --git a/backend/app/middlewares/middlewares.go b/backend/app/middlewares/middlewares.go index 96bc75d..e5b2a14 100644 --- a/backend/app/middlewares/middlewares.go +++ b/backend/app/middlewares/middlewares.go @@ -2,6 +2,7 @@ package middlewares import ( "quyun/v2/app/errorx" + "quyun/v2/app/services" "quyun/v2/pkg/consts" "quyun/v2/providers/jwt" @@ -35,10 +36,20 @@ func (m *Middlewares) Auth(ctx fiber.Ctx) error { return errorx.ErrUnauthorized.WithCause(err).WithMsg("Invalid token") } + // get user model + user, err := services.User.GetModelByID(ctx.Context(), claims.UserID) + if err != nil { + return errorx.ErrUnauthorized.WithCause(err).WithMsg("UserNotFound") + } + // Set Context - ctx.Locals(consts.CtxKeyUser, claims.UserID) + ctx.Locals(consts.CtxKeyUser, user) if claims.TenantID > 0 { - ctx.Locals(consts.CtxKeyTenant, claims.TenantID) + tenant, err := services.Tenant.GetModelByID(ctx, claims.TenantID) + if err != nil { + return errorx.ErrUnauthorized.WithCause(err).WithMsg("TenantNotFound") + } + ctx.Locals(consts.CtxKeyTenant, tenant) } return ctx.Next() diff --git a/backend/app/services/common.go b/backend/app/services/common.go index 2f47b0e..e726633 100644 --- a/backend/app/services/common.go +++ b/backend/app/services/common.go @@ -22,13 +22,7 @@ type common struct { storage *storage.Storage } -func (s *common) Upload(ctx context.Context, file *multipart.FileHeader, typeArg string) (*common_dto.UploadResult, error) { - userID := ctx.Value(consts.CtxKeyUser) - if userID == nil { - return nil, errorx.ErrUnauthorized - } - uid := cast.ToInt64(userID) - +func (s *common) Upload(ctx context.Context, userID int64, file *multipart.FileHeader, typeArg string) (*common_dto.UploadResult, error) { // Mock Upload to S3/MinIO (Here we just generate key, actual upload handling via direct upload or stream is better) // But this Upload endpoint accepts file. So we save it. // We need to use storage provider to save it? @@ -61,7 +55,7 @@ func (s *common) Upload(ctx context.Context, file *multipart.FileHeader, typeArg url := s.GetAssetURL(objectKey) // ... rest ... - t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.UserID.Eq(uid)).First() + t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.UserID.Eq(userID)).First() var tid int64 = 0 if err == nil { tid = t.ID @@ -69,7 +63,7 @@ func (s *common) Upload(ctx context.Context, file *multipart.FileHeader, typeArg asset := &models.MediaAsset{ TenantID: tid, - UserID: uid, + UserID: userID, Type: consts.MediaAssetType(typeArg), Status: consts.MediaAssetStatusUploaded, Provider: "local", diff --git a/backend/app/services/content.go b/backend/app/services/content.go index 5bd5653..680c939 100644 --- a/backend/app/services/content.go +++ b/backend/app/services/content.go @@ -81,7 +81,7 @@ func (s *content) List(ctx context.Context, filter *content_dto.ContentListFilte }, nil } -func (s *content) Get(ctx context.Context, id string) (*content_dto.ContentDetail, error) { +func (s *content) Get(ctx context.Context, userID int64, id string) (*content_dto.ContentDetail, error) { cid := cast.ToInt64(id) // Increment Views @@ -110,8 +110,8 @@ func (s *content) Get(ctx context.Context, id string) (*content_dto.ContentDetai isFavorited := false hasAccess := false - if userID := ctx.Value(consts.CtxKeyUser); userID != nil { - uid := cast.ToInt64(userID) + if userID > 0 { + uid := userID // Interaction isLiked, _ = models.UserContentActionQuery.WithContext(ctx). Where(models.UserContentActionQuery.UserID.Eq(uid), @@ -167,7 +167,7 @@ func (s *content) Get(ctx context.Context, id string) (*content_dto.ContentDetai return detail, nil } -func (s *content) ListComments(ctx context.Context, id string, page int) (*requests.Pager, error) { +func (s *content) ListComments(ctx context.Context, userID int64, id string, page int) (*requests.Pager, error) { cid := cast.ToInt64(id) tbl, q := models.CommentQuery.QueryContext(ctx) @@ -187,8 +187,8 @@ func (s *content) ListComments(ctx context.Context, id string, page int) (*reque // User likes likedMap := make(map[int64]bool) - if userID := ctx.Value(consts.CtxKeyUser); userID != nil { - uid := cast.ToInt64(userID) + if userID > 0 { + uid := userID ids := make([]int64, len(list)) for i, v := range list { ids[i] = v.ID @@ -225,12 +225,11 @@ func (s *content) ListComments(ctx context.Context, id string, page int) (*reque }, nil } -func (s *content) CreateComment(ctx context.Context, id string, form *content_dto.CommentCreateForm) error { - userID := ctx.Value(consts.CtxKeyUser) - if userID == nil { +func (s *content) CreateComment(ctx context.Context, userID int64, id string, form *content_dto.CommentCreateForm) error { + if userID == 0 { return errorx.ErrUnauthorized } - uid := cast.ToInt64(userID) + uid := userID cid := cast.ToInt64(id) c, err := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.ID.Eq(cid)).First() @@ -252,12 +251,11 @@ func (s *content) CreateComment(ctx context.Context, id string, form *content_dt return nil } -func (s *content) LikeComment(ctx context.Context, id string) error { - userID := ctx.Value(consts.CtxKeyUser) - if userID == nil { +func (s *content) LikeComment(ctx context.Context, userID int64, id string) error { + if userID == 0 { return errorx.ErrUnauthorized } - uid := cast.ToInt64(userID) + uid := userID cmid := cast.ToInt64(id) // Fetch comment for author @@ -292,12 +290,11 @@ func (s *content) LikeComment(ctx context.Context, id string) error { return nil } -func (s *content) GetLibrary(ctx context.Context) ([]user_dto.ContentItem, error) { - userID := ctx.Value(consts.CtxKeyUser) - if userID == nil { +func (s *content) GetLibrary(ctx context.Context, userID int64) ([]user_dto.ContentItem, error) { + if userID == 0 { return nil, errorx.ErrUnauthorized } - uid := cast.ToInt64(userID) + uid := userID tbl, q := models.ContentAccessQuery.QueryContext(ctx) accessList, err := q.Where(tbl.UserID.Eq(uid), tbl.Status.Eq(consts.ContentAccessStatusActive)).Find() @@ -334,28 +331,28 @@ func (s *content) GetLibrary(ctx context.Context) ([]user_dto.ContentItem, error return data, nil } -func (s *content) GetFavorites(ctx context.Context) ([]user_dto.ContentItem, error) { - return s.getInteractList(ctx, "favorite") +func (s *content) GetFavorites(ctx context.Context, userID int64) ([]user_dto.ContentItem, error) { + return s.getInteractList(ctx, userID, "favorite") } -func (s *content) AddFavorite(ctx context.Context, contentId string) error { - return s.addInteract(ctx, contentId, "favorite") +func (s *content) AddFavorite(ctx context.Context, userID int64, contentId string) error { + return s.addInteract(ctx, userID, contentId, "favorite") } -func (s *content) RemoveFavorite(ctx context.Context, contentId string) error { - return s.removeInteract(ctx, contentId, "favorite") +func (s *content) RemoveFavorite(ctx context.Context, userID int64, contentId string) error { + return s.removeInteract(ctx, userID, contentId, "favorite") } -func (s *content) GetLikes(ctx context.Context) ([]user_dto.ContentItem, error) { - return s.getInteractList(ctx, "like") +func (s *content) GetLikes(ctx context.Context, userID int64) ([]user_dto.ContentItem, error) { + return s.getInteractList(ctx, userID, "like") } -func (s *content) AddLike(ctx context.Context, contentId string) error { - return s.addInteract(ctx, contentId, "like") +func (s *content) AddLike(ctx context.Context, userID int64, contentId string) error { + return s.addInteract(ctx, userID, contentId, "like") } -func (s *content) RemoveLike(ctx context.Context, contentId string) error { - return s.removeInteract(ctx, contentId, "like") +func (s *content) RemoveLike(ctx context.Context, userID int64, contentId string) error { + return s.removeInteract(ctx, userID, contentId, "like") } func (s *content) ListTopics(ctx context.Context) ([]content_dto.Topic, error) { @@ -480,12 +477,11 @@ func (s *content) toMediaURLs(assets []*models.ContentAsset) []content_dto.Media return urls } -func (s *content) addInteract(ctx context.Context, contentId, typ string) error { - userID := ctx.Value(consts.CtxKeyUser) - if userID == nil { +func (s *content) addInteract(ctx context.Context, userID int64, contentId, typ string) error { + if userID == 0 { return errorx.ErrUnauthorized } - uid := cast.ToInt64(userID) + uid := userID cid := cast.ToInt64(contentId) // Fetch content for author @@ -529,12 +525,11 @@ func (s *content) addInteract(ctx context.Context, contentId, typ string) error return nil } -func (s *content) removeInteract(ctx context.Context, contentId, typ string) error { - userID := ctx.Value(consts.CtxKeyUser) - if userID == nil { +func (s *content) removeInteract(ctx context.Context, userID int64, contentId, typ string) error { + if userID == 0 { return errorx.ErrUnauthorized } - uid := cast.ToInt64(userID) + uid := userID cid := cast.ToInt64(contentId) return models.Q.Transaction(func(tx *models.Query) error { @@ -556,12 +551,11 @@ func (s *content) removeInteract(ctx context.Context, contentId, typ string) err }) } -func (s *content) getInteractList(ctx context.Context, typ string) ([]user_dto.ContentItem, error) { - userID := ctx.Value(consts.CtxKeyUser) - if userID == nil { +func (s *content) getInteractList(ctx context.Context, userID int64, typ string) ([]user_dto.ContentItem, error) { + if userID == 0 { return nil, errorx.ErrUnauthorized } - uid := cast.ToInt64(userID) + uid := userID tbl, q := models.UserContentActionQuery.QueryContext(ctx) actions, err := q.Where(tbl.UserID.Eq(uid), tbl.Type.Eq(typ)).Find() diff --git a/backend/app/services/coupon.go b/backend/app/services/coupon.go index 44d5d05..ad997e1 100644 --- a/backend/app/services/coupon.go +++ b/backend/app/services/coupon.go @@ -15,12 +15,11 @@ import ( // @provider type coupon struct{} -func (s *coupon) ListUserCoupons(ctx context.Context, status string) ([]coupon_dto.UserCouponItem, error) { - userID := ctx.Value(consts.CtxKeyUser) - if userID == nil { +func (s *coupon) ListUserCoupons(ctx context.Context, userID int64, status string) ([]coupon_dto.UserCouponItem, error) { + if userID == 0 { return nil, errorx.ErrUnauthorized } - uid := cast.ToInt64(userID) + uid := userID tbl, q := models.UserCouponQuery.QueryContext(ctx) q = q.Where(tbl.UserID.Eq(uid)) diff --git a/backend/app/services/creator.go b/backend/app/services/creator.go index d1b4b56..9c07d85 100644 --- a/backend/app/services/creator.go +++ b/backend/app/services/creator.go @@ -20,12 +20,11 @@ import ( // @provider type creator struct{} -func (s *creator) Apply(ctx context.Context, form *creator_dto.ApplyForm) error { - userID := ctx.Value(consts.CtxKeyUser) - if userID == nil { +func (s *creator) Apply(ctx context.Context, userID int64, form *creator_dto.ApplyForm) error { + if userID == 0 { return errorx.ErrUnauthorized } - uid := cast.ToInt64(userID) + uid := userID tbl, q := models.TenantQuery.QueryContext(ctx) // Check if already has a tenant @@ -62,8 +61,8 @@ func (s *creator) Apply(ctx context.Context, form *creator_dto.ApplyForm) error return nil } -func (s *creator) Dashboard(ctx context.Context) (*creator_dto.DashboardStats, error) { - tid, err := s.getTenantID(ctx) +func (s *creator) Dashboard(ctx context.Context, userID int64) (*creator_dto.DashboardStats, error) { + tid, err := s.getTenantID(ctx, userID) if err != nil { return nil, err } @@ -95,8 +94,8 @@ func (s *creator) Dashboard(ctx context.Context) (*creator_dto.DashboardStats, e return stats, nil } -func (s *creator) ListContents(ctx context.Context, filter *creator_dto.CreatorContentListFilter) ([]creator_dto.ContentItem, error) { - tid, err := s.getTenantID(ctx) +func (s *creator) ListContents(ctx context.Context, userID int64, filter *creator_dto.CreatorContentListFilter) ([]creator_dto.ContentItem, error) { + tid, err := s.getTenantID(ctx, userID) if err != nil { return nil, err } @@ -133,12 +132,12 @@ func (s *creator) ListContents(ctx context.Context, filter *creator_dto.CreatorC return data, nil } -func (s *creator) CreateContent(ctx context.Context, form *creator_dto.ContentCreateForm) error { - tid, err := s.getTenantID(ctx) +func (s *creator) CreateContent(ctx context.Context, userID int64, form *creator_dto.ContentCreateForm) error { + tid, err := s.getTenantID(ctx, userID) if err != nil { return err } - uid := cast.ToInt64(ctx.Value(consts.CtxKeyUser)) + uid := userID return models.Q.Transaction(func(tx *models.Query) error { // 1. Create Content @@ -187,13 +186,13 @@ func (s *creator) CreateContent(ctx context.Context, form *creator_dto.ContentCr }) } -func (s *creator) UpdateContent(ctx context.Context, id string, form *creator_dto.ContentUpdateForm) error { - tid, err := s.getTenantID(ctx) +func (s *creator) UpdateContent(ctx context.Context, userID int64, id string, form *creator_dto.ContentUpdateForm) error { + tid, err := s.getTenantID(ctx, userID) if err != nil { return err } cid := cast.ToInt64(id) - uid := cast.ToInt64(ctx.Value(consts.CtxKeyUser)) + uid := userID return models.Q.Transaction(func(tx *models.Query) error { // 1. Check Ownership @@ -257,9 +256,9 @@ func (s *creator) UpdateContent(ctx context.Context, id string, form *creator_dt }) } -func (s *creator) DeleteContent(ctx context.Context, id string) error { +func (s *creator) DeleteContent(ctx context.Context, userID int64, id string) error { cid := cast.ToInt64(id) - tid, err := s.getTenantID(ctx) + tid, err := s.getTenantID(ctx, userID) if err != nil { return err } @@ -271,8 +270,8 @@ func (s *creator) DeleteContent(ctx context.Context, id string) error { return nil } -func (s *creator) ListOrders(ctx context.Context, filter *creator_dto.CreatorOrderListFilter) ([]creator_dto.Order, error) { - tid, err := s.getTenantID(ctx) +func (s *creator) ListOrders(ctx context.Context, userID int64, filter *creator_dto.CreatorOrderListFilter) ([]creator_dto.Order, error) { + tid, err := s.getTenantID(ctx, userID) if err != nil { return nil, err } @@ -302,13 +301,13 @@ func (s *creator) ListOrders(ctx context.Context, filter *creator_dto.CreatorOrd return data, nil } -func (s *creator) ProcessRefund(ctx context.Context, id string, form *creator_dto.RefundForm) error { - tid, err := s.getTenantID(ctx) +func (s *creator) ProcessRefund(ctx context.Context, userID int64, id string, form *creator_dto.RefundForm) error { + tid, err := s.getTenantID(ctx, userID) if err != nil { return err } oid := cast.ToInt64(id) - uid := cast.ToInt64(ctx.Value(consts.CtxKeyUser)) // Creator ID + uid := userID // Creator ID // Fetch Order o, err := models.OrderQuery.WithContext(ctx).Where(models.OrderQuery.ID.Eq(oid), models.OrderQuery.TenantID.Eq(tid)).First() @@ -402,8 +401,8 @@ func (s *creator) ProcessRefund(ctx context.Context, id string, form *creator_dt return errorx.ErrBadRequest.WithMsg("无效的操作") } -func (s *creator) GetSettings(ctx context.Context) (*creator_dto.Settings, error) { - tid, err := s.getTenantID(ctx) +func (s *creator) GetSettings(ctx context.Context, userID int64) (*creator_dto.Settings, error) { + tid, err := s.getTenantID(ctx, userID) if err != nil { return nil, err } @@ -418,12 +417,12 @@ func (s *creator) GetSettings(ctx context.Context) (*creator_dto.Settings, error }, nil } -func (s *creator) UpdateSettings(ctx context.Context, form *creator_dto.Settings) error { +func (s *creator) UpdateSettings(ctx context.Context, userID int64, form *creator_dto.Settings) error { return nil } -func (s *creator) ListPayoutAccounts(ctx context.Context) ([]creator_dto.PayoutAccount, error) { - tid, err := s.getTenantID(ctx) +func (s *creator) ListPayoutAccounts(ctx context.Context, userID int64) ([]creator_dto.PayoutAccount, error) { + tid, err := s.getTenantID(ctx, userID) if err != nil { return nil, err } @@ -446,12 +445,12 @@ func (s *creator) ListPayoutAccounts(ctx context.Context) ([]creator_dto.PayoutA return data, nil } -func (s *creator) AddPayoutAccount(ctx context.Context, form *creator_dto.PayoutAccount) error { - tid, err := s.getTenantID(ctx) +func (s *creator) AddPayoutAccount(ctx context.Context, userID int64, form *creator_dto.PayoutAccount) error { + tid, err := s.getTenantID(ctx, userID) if err != nil { return err } - uid := cast.ToInt64(ctx.Value(consts.CtxKeyUser)) + uid := userID pa := &models.PayoutAccount{ TenantID: tid, @@ -467,8 +466,8 @@ func (s *creator) AddPayoutAccount(ctx context.Context, form *creator_dto.Payout return nil } -func (s *creator) RemovePayoutAccount(ctx context.Context, id string) error { - tid, err := s.getTenantID(ctx) +func (s *creator) RemovePayoutAccount(ctx context.Context, userID int64, id string) error { + tid, err := s.getTenantID(ctx, userID) if err != nil { return err } @@ -483,12 +482,12 @@ func (s *creator) RemovePayoutAccount(ctx context.Context, id string) error { return nil } -func (s *creator) Withdraw(ctx context.Context, form *creator_dto.WithdrawForm) error { - tid, err := s.getTenantID(ctx) +func (s *creator) Withdraw(ctx context.Context, userID int64, form *creator_dto.WithdrawForm) error { + tid, err := s.getTenantID(ctx, userID) if err != nil { return err } - uid := cast.ToInt64(ctx.Value(consts.CtxKeyUser)) + uid := userID amount := int64(form.Amount * 100) if amount <= 0 { @@ -552,12 +551,11 @@ func (s *creator) Withdraw(ctx context.Context, form *creator_dto.WithdrawForm) // Helpers -func (s *creator) getTenantID(ctx context.Context) (int64, error) { - userID := ctx.Value(consts.CtxKeyUser) - if userID == nil { +func (s *creator) getTenantID(ctx context.Context, userID int64) (int64, error) { + if userID == 0 { return 0, errorx.ErrUnauthorized } - uid := cast.ToInt64(userID) + uid := userID // Simple check: User owns tenant t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.UserID.Eq(uid)).First() diff --git a/backend/app/services/notification.go b/backend/app/services/notification.go index 6ecb30c..38f4227 100644 --- a/backend/app/services/notification.go +++ b/backend/app/services/notification.go @@ -20,15 +20,9 @@ type notification struct { job *job.Job } -func (s *notification) List(ctx context.Context, page int, typeArg string) (*requests.Pager, error) { - userID := ctx.Value(consts.CtxKeyUser) - if userID == nil { - return nil, errorx.ErrUnauthorized - } - uid := cast.ToInt64(userID) - +func (s *notification) List(ctx context.Context, userID int64, page int, typeArg string) (*requests.Pager, error) { tbl, q := models.NotificationQuery.QueryContext(ctx) - q = q.Where(tbl.UserID.Eq(uid)) + q = q.Where(tbl.UserID.Eq(userID)) if typeArg != "" && typeArg != "all" { q = q.Where(tbl.Type.Eq(typeArg)) @@ -66,16 +60,11 @@ func (s *notification) List(ctx context.Context, page int, typeArg string) (*req }, nil } -func (s *notification) MarkRead(ctx context.Context, id string) error { - userID := ctx.Value(consts.CtxKeyUser) - if userID == nil { - return errorx.ErrUnauthorized - } - uid := cast.ToInt64(userID) +func (s *notification) MarkRead(ctx context.Context, userID int64, id string) error { nid := cast.ToInt64(id) _, err := models.NotificationQuery.WithContext(ctx). - Where(models.NotificationQuery.ID.Eq(nid), models.NotificationQuery.UserID.Eq(uid)). + Where(models.NotificationQuery.ID.Eq(nid), models.NotificationQuery.UserID.Eq(userID)). UpdateSimple(models.NotificationQuery.IsRead.Value(true)) if err != nil { return errorx.ErrDatabaseError.WithCause(err) diff --git a/backend/app/services/order.go b/backend/app/services/order.go index 232dd78..e47bb53 100644 --- a/backend/app/services/order.go +++ b/backend/app/services/order.go @@ -21,12 +21,11 @@ import ( // @provider type order struct{} -func (s *order) ListUserOrders(ctx context.Context, status string) ([]user_dto.Order, error) { - userID := ctx.Value(consts.CtxKeyUser) - if userID == nil { +func (s *order) ListUserOrders(ctx context.Context, userID int64, status string) ([]user_dto.Order, error) { + if userID == 0 { return nil, errorx.ErrUnauthorized } - uid := cast.ToInt64(userID) + uid := userID tbl, q := models.OrderQuery.QueryContext(ctx) q = q.Where(tbl.UserID.Eq(uid)) @@ -48,12 +47,11 @@ func (s *order) ListUserOrders(ctx context.Context, status string) ([]user_dto.O return data, nil } -func (s *order) GetUserOrder(ctx context.Context, id string) (*user_dto.Order, error) { - userID := ctx.Value(consts.CtxKeyUser) - if userID == nil { +func (s *order) GetUserOrder(ctx context.Context, userID int64, id string) (*user_dto.Order, error) { + if userID == 0 { return nil, errorx.ErrUnauthorized } - uid := cast.ToInt64(userID) + uid := userID oid := cast.ToInt64(id) tbl, q := models.OrderQuery.QueryContext(ctx) @@ -72,12 +70,11 @@ func (s *order) GetUserOrder(ctx context.Context, id string) (*user_dto.Order, e return &dto, nil } -func (s *order) Create(ctx context.Context, form *transaction_dto.OrderCreateForm) (*transaction_dto.OrderCreateResponse, error) { - userID := ctx.Value(consts.CtxKeyUser) - if userID == nil { +func (s *order) Create(ctx context.Context, userID int64, form *transaction_dto.OrderCreateForm) (*transaction_dto.OrderCreateResponse, error) { + if userID == 0 { return nil, errorx.ErrUnauthorized } - uid := cast.ToInt64(userID) + uid := userID cid := cast.ToInt64(form.ContentID) // 1. Fetch Content & Price @@ -169,12 +166,11 @@ func (s *order) Create(ctx context.Context, form *transaction_dto.OrderCreateFor }, nil } -func (s *order) Pay(ctx context.Context, id string, form *transaction_dto.OrderPayForm) (*transaction_dto.OrderPayResponse, error) { - userID := ctx.Value(consts.CtxKeyUser) - if userID == nil { +func (s *order) Pay(ctx context.Context, userID int64, id string, form *transaction_dto.OrderPayForm) (*transaction_dto.OrderPayResponse, error) { + if userID == 0 { return nil, errorx.ErrUnauthorized } - uid := cast.ToInt64(userID) + uid := userID oid := cast.ToInt64(id) // Fetch Order diff --git a/backend/app/services/tenant.go b/backend/app/services/tenant.go index c7fbdae..b0ee820 100644 --- a/backend/app/services/tenant.go +++ b/backend/app/services/tenant.go @@ -17,7 +17,7 @@ import ( // @provider type tenant struct{} -func (s *tenant) GetPublicProfile(ctx context.Context, id string) (*dto.TenantProfile, error) { +func (s *tenant) GetPublicProfile(ctx context.Context, userID int64, id string) (*dto.TenantProfile, error) { tid := cast.ToInt64(id) t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.ID.Eq(tid)).First() if err != nil { @@ -29,12 +29,14 @@ func (s *tenant) GetPublicProfile(ctx context.Context, id string) (*dto.TenantPr // Stats followers, _ := models.TenantUserQuery.WithContext(ctx).Where(models.TenantUserQuery.TenantID.Eq(tid)).Count() - contents, _ := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.TenantID.Eq(tid), models.ContentQuery.Status.Eq(consts.ContentStatusPublished)).Count() + contents, _ := models.ContentQuery.WithContext(ctx). + Where(models.ContentQuery.TenantID.Eq(tid), models.ContentQuery.Status.Eq(consts.ContentStatusPublished)). + Count() // Following status isFollowing := false - if userID := ctx.Value(consts.CtxKeyUser); userID != nil { - uid := cast.ToInt64(userID) + if userID > 0 { + uid := userID isFollowing, _ = models.TenantUserQuery.WithContext(ctx). Where(models.TenantUserQuery.TenantID.Eq(tid), models.TenantUserQuery.UserID.Eq(uid)). Exists() @@ -52,12 +54,11 @@ func (s *tenant) GetPublicProfile(ctx context.Context, id string) (*dto.TenantPr }, nil } -func (s *tenant) Follow(ctx context.Context, id string) error { - userID := ctx.Value(consts.CtxKeyUser) - if userID == nil { +func (s *tenant) Follow(ctx context.Context, userID int64, id string) error { + if userID == 0 { return errorx.ErrUnauthorized } - uid := cast.ToInt64(userID) + uid := userID tid := cast.ToInt64(id) // Check if tenant exists @@ -83,12 +84,11 @@ func (s *tenant) Follow(ctx context.Context, id string) error { return nil } -func (s *tenant) Unfollow(ctx context.Context, id string) error { - userID := ctx.Value(consts.CtxKeyUser) - if userID == nil { +func (s *tenant) Unfollow(ctx context.Context, userID int64, id string) error { + if userID == 0 { return errorx.ErrUnauthorized } - uid := cast.ToInt64(userID) + uid := userID tid := cast.ToInt64(id) _, err := models.TenantUserQuery.WithContext(ctx). @@ -100,12 +100,11 @@ func (s *tenant) Unfollow(ctx context.Context, id string) error { return nil } -func (s *tenant) ListFollowed(ctx context.Context) ([]dto.TenantProfile, error) { - userID := ctx.Value(consts.CtxKeyUser) - if userID == nil { +func (s *tenant) ListFollowed(ctx context.Context, userID int64) ([]dto.TenantProfile, error) { + if userID == 0 { return nil, errorx.ErrUnauthorized } - uid := cast.ToInt64(userID) + uid := userID tbl, q := models.TenantUserQuery.QueryContext(ctx) list, err := q.Where(tbl.UserID.Eq(uid)).Find() @@ -122,8 +121,12 @@ func (s *tenant) ListFollowed(ctx context.Context) ([]dto.TenantProfile, error) } // Stats - followers, _ := models.TenantUserQuery.WithContext(ctx).Where(models.TenantUserQuery.TenantID.Eq(tu.TenantID)).Count() - contents, _ := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.TenantID.Eq(tu.TenantID), models.ContentQuery.Status.Eq(consts.ContentStatusPublished)).Count() + followers, _ := models.TenantUserQuery.WithContext(ctx). + Where(models.TenantUserQuery.TenantID.Eq(tu.TenantID)). + Count() + contents, _ := models.ContentQuery.WithContext(ctx). + Where(models.ContentQuery.TenantID.Eq(tu.TenantID), models.ContentQuery.Status.Eq(consts.ContentStatusPublished)). + Count() data = append(data, dto.TenantProfile{ ID: cast.ToString(t.ID), @@ -139,3 +142,16 @@ func (s *tenant) ListFollowed(ctx context.Context) ([]dto.TenantProfile, error) return data, nil } + +// GetModelByID 获取指定 ID 的model +func (s *tenant) GetModelByID(ctx context.Context, id int64) (*models.Tenant, error) { + tbl, query := models.TenantQuery.QueryContext(ctx) + u, err := query.Where(tbl.ID.Eq(id)).First() + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errorx.ErrRecordNotFound + } + return nil, errorx.ErrDatabaseError.WithCause(err) + } + return u, nil +} diff --git a/backend/app/services/user.go b/backend/app/services/user.go index f094c36..e288cc0 100644 --- a/backend/app/services/user.go +++ b/backend/app/services/user.go @@ -77,40 +77,36 @@ func (s *user) LoginWithOTP(ctx context.Context, phone, otp string) (*auth_dto.L return &auth_dto.LoginResponse{ Token: token, - User: s.toAuthUserDTO(u), + User: s.ToAuthUserDTO(u), }, nil } -// Me 获取当前用户信息 -func (s *user) Me(ctx context.Context) (*auth_dto.User, error) { - userID := ctx.Value(consts.CtxKeyUser) - if userID == nil { - return nil, errorx.ErrUnauthorized - } - - uid := cast.ToInt64(userID) +// GetModelByID 获取指定 ID 的用户model +func (s *user) GetModelByID(ctx context.Context, userID int64) (*models.User, error) { tbl, query := models.UserQuery.QueryContext(ctx) - u, err := query.Where(tbl.ID.Eq(uid)).First() + u, err := query.Where(tbl.ID.Eq(userID)).First() if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errorx.ErrRecordNotFound } return nil, errorx.ErrDatabaseError.WithCause(err) } + return u, nil +} - return s.toAuthUserDTO(u), nil +// Me 获取当前用户信息 +func (s *user) Me(ctx context.Context, userID int64) (*auth_dto.User, error) { + u, err := s.GetModelByID(ctx, userID) + if err != nil { + return nil, err + } + return s.ToAuthUserDTO(u), nil } // Update 更新用户信息 -func (s *user) Update(ctx context.Context, form *user_dto.UserUpdate) error { - userID := ctx.Value(consts.CtxKeyUser) - if userID == nil { - return errorx.ErrUnauthorized - } - uid := cast.ToInt64(userID) - +func (s *user) Update(ctx context.Context, userID int64, form *user_dto.UserUpdate) error { tbl, query := models.UserQuery.QueryContext(ctx) - _, err := query.Where(tbl.ID.Eq(uid)).Updates(&models.User{ + _, err := query.Where(tbl.ID.Eq(userID)).Updates(&models.User{ Nickname: form.Nickname, Avatar: form.Avatar, Gender: form.Gender, @@ -125,13 +121,7 @@ func (s *user) Update(ctx context.Context, form *user_dto.UserUpdate) error { } // RealName 实名认证 -func (s *user) RealName(ctx context.Context, form *user_dto.RealNameForm) error { - userID := ctx.Value(consts.CtxKeyUser) - if userID == nil { - return errorx.ErrUnauthorized - } - uid := cast.ToInt64(userID) - +func (s *user) RealName(ctx context.Context, userID int64, form *user_dto.RealNameForm) error { // Mock Verification if len(form.IDCard) != 18 { return errorx.ErrBadRequest.WithMsg("身份证号格式错误") @@ -141,7 +131,7 @@ func (s *user) RealName(ctx context.Context, form *user_dto.RealNameForm) error } tbl, query := models.UserQuery.QueryContext(ctx) - u, err := query.Where(tbl.ID.Eq(uid)).First() + u, err := query.Where(tbl.ID.Eq(userID)).First() if err != nil { return errorx.ErrRecordNotFound } @@ -159,7 +149,7 @@ func (s *user) RealName(ctx context.Context, form *user_dto.RealNameForm) error b, _ := json.Marshal(metaMap) - _, err = query.Where(tbl.ID.Eq(uid)).Updates(&models.User{ + _, err = query.Where(tbl.ID.Eq(userID)).Updates(&models.User{ IsRealNameVerified: true, VerifiedAt: time.Now(), Metas: types.JSON(b), @@ -171,15 +161,9 @@ func (s *user) RealName(ctx context.Context, form *user_dto.RealNameForm) error } // GetNotifications 获取通知 -func (s *user) GetNotifications(ctx context.Context, typeArg string) ([]user_dto.Notification, error) { - userID := ctx.Value(consts.CtxKeyUser) - if userID == nil { - return nil, errorx.ErrUnauthorized - } - uid := cast.ToInt64(userID) - +func (s *user) GetNotifications(ctx context.Context, userID int64, typeArg string) ([]user_dto.Notification, error) { tbl, query := models.NotificationQuery.QueryContext(ctx) - query = query.Where(tbl.UserID.Eq(uid)) + query = query.Where(tbl.UserID.Eq(userID)) if typeArg != "" && typeArg != "all" { query = query.Where(tbl.Type.Eq(typeArg)) } @@ -203,7 +187,7 @@ func (s *user) GetNotifications(ctx context.Context, typeArg string) ([]user_dto return result, nil } -func (s *user) toAuthUserDTO(u *models.User) *auth_dto.User { +func (s *user) ToAuthUserDTO(u *models.User) *auth_dto.User { return &auth_dto.User{ ID: cast.ToString(u.ID), Phone: u.Phone, diff --git a/backend/app/services/wallet.go b/backend/app/services/wallet.go index 8ca97c7..36fa978 100644 --- a/backend/app/services/wallet.go +++ b/backend/app/services/wallet.go @@ -20,15 +20,9 @@ import ( // @provider type wallet struct{} -func (s *wallet) GetWallet(ctx context.Context) (*user_dto.WalletResponse, error) { - userID := ctx.Value(consts.CtxKeyUser) - if userID == nil { - return nil, errorx.ErrUnauthorized - } - uid := cast.ToInt64(userID) - +func (s *wallet) GetWallet(ctx context.Context, userID int64) (*user_dto.WalletResponse, error) { // Get Balance - u, err := models.UserQuery.WithContext(ctx).Where(models.UserQuery.ID.Eq(uid)).First() + u, err := models.UserQuery.WithContext(ctx).Where(models.UserQuery.ID.Eq(userID)).First() if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errorx.ErrRecordNotFound @@ -39,7 +33,7 @@ func (s *wallet) GetWallet(ctx context.Context) (*user_dto.WalletResponse, error // Get Transactions (Orders) // Both purchase (expense) and recharge (income - if paid) tbl, q := models.OrderQuery.QueryContext(ctx) - orders, err := q.Where(tbl.UserID.Eq(uid), tbl.Status.Eq(consts.OrderStatusPaid)). + orders, err := q.Where(tbl.UserID.Eq(userID), tbl.Status.Eq(consts.OrderStatusPaid)). Order(tbl.CreatedAt.Desc()). Limit(20). // Limit to recent 20 Find() @@ -74,13 +68,7 @@ func (s *wallet) GetWallet(ctx context.Context) (*user_dto.WalletResponse, error }, nil } -func (s *wallet) Recharge(ctx context.Context, form *user_dto.RechargeForm) (*user_dto.RechargeResponse, error) { - userID := ctx.Value(consts.CtxKeyUser) - if userID == nil { - return nil, errorx.ErrUnauthorized - } - uid := cast.ToInt64(userID) - +func (s *wallet) Recharge(ctx context.Context, userID int64, form *user_dto.RechargeForm) (*user_dto.RechargeResponse, error) { amount := int64(form.Amount * 100) if amount <= 0 { return nil, errorx.ErrBadRequest.WithMsg("金额无效") @@ -89,7 +77,7 @@ func (s *wallet) Recharge(ctx context.Context, form *user_dto.RechargeForm) (*us // Create Recharge Order order := &models.Order{ TenantID: 0, // Platform / System - UserID: uid, + UserID: userID, Type: consts.OrderTypeRecharge, Status: consts.OrderStatusCreated, Currency: consts.CurrencyCNY, diff --git a/backend/pkg/consts/context_keys.go b/backend/pkg/consts/context_keys.go index af1918e..304dfb5 100644 --- a/backend/pkg/consts/context_keys.go +++ b/backend/pkg/consts/context_keys.go @@ -1,8 +1,8 @@ package consts const ( - CtxKeyTenant = "tenant" - CtxKeyClaims = "claims" - CtxKeyUser = "user" - CtxKeyTenantUser = "tenant_user" + CtxKeyTenant = "__ctx_tenant" + CtxKeyClaims = "__ctx_claims" + CtxKeyUser = "__ctx_user" + CtxKeyTenantUser = "__ctx_tenant_user" )