diff --git a/backend/app/commands/testx/testing.go b/backend/app/commands/testx/testing.go index 96024db..056b614 100644 --- a/backend/app/commands/testx/testing.go +++ b/backend/app/commands/testx/testing.go @@ -115,7 +115,8 @@ func Serve(providers container.Providers, t *testing.T, invoke any) { dig.In Initials []contracts.Initial `group:"initials"` Job *job.Job - }) error { + }, + ) error { _ = p.Initials ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) diff --git a/backend/app/http/v1/auth/auth.go b/backend/app/http/v1/auth/auth.go index bd6fa30..c142817 100644 --- a/backend/app/http/v1/auth/auth.go +++ b/backend/app/http/v1/auth/auth.go @@ -3,6 +3,8 @@ package auth import ( "quyun/v2/app/http/v1/dto" "quyun/v2/app/services" + "quyun/v2/database/models" + "quyun/v2/pkg/consts" "github.com/gofiber/fiber/v3" ) @@ -12,7 +14,7 @@ type Auth struct{} // SendOTP sends an OTP to the provided phone number. // -// @Router /v1/auth/otp [post] +// @Router /t/:tenantCode/v1/auth/otp [post] // @Summary Send OTP // @Description Send OTP to phone number // @Tags Auth @@ -27,7 +29,7 @@ func (a *Auth) SendOTP(ctx fiber.Ctx, form *dto.SendOTPForm) error { // Login logs in or registers a user with OTP. // -// @Router /v1/auth/login [post] +// @Router /t/:tenantCode/v1/auth/login [post] // @Summary Login or Register with OTP // @Description Login or register user using phone number and OTP // @Tags Auth @@ -37,5 +39,11 @@ func (a *Auth) SendOTP(ctx fiber.Ctx, form *dto.SendOTPForm) error { // @Success 200 {object} dto.LoginResponse // @Bind form body func (a *Auth) Login(ctx fiber.Ctx, form *dto.LoginForm) (*dto.LoginResponse, error) { - return services.User.LoginWithOTP(ctx, form.Phone, form.OTP) + tenantID := int64(0) + if t := ctx.Locals(consts.CtxKeyTenant); t != nil { + if tenant, ok := t.(*models.Tenant); ok { + tenantID = tenant.ID + } + } + return services.User.LoginWithOTP(ctx, tenantID, form.Phone, form.OTP) } diff --git a/backend/app/http/v1/auth/routes.gen.go b/backend/app/http/v1/auth/routes.gen.go index f163611..9e9b8fa 100644 --- a/backend/app/http/v1/auth/routes.gen.go +++ b/backend/app/http/v1/auth/routes.gen.go @@ -42,13 +42,13 @@ func (r *Routes) Name() string { // Each route is registered with its corresponding controller action and parameter bindings. func (r *Routes) Register(router fiber.Router) { // Register routes for controller: Auth - r.log.Debugf("Registering route: Post /v1/auth/login -> auth.Login") - router.Post("/v1/auth/login"[len(r.Path()):], DataFunc1( + r.log.Debugf("Registering route: Post /t/:tenantCode/v1/auth/login -> auth.Login") + router.Post("/t/:tenantCode/v1/auth/login"[len(r.Path()):], DataFunc1( r.auth.Login, Body[dto.LoginForm]("form"), )) - r.log.Debugf("Registering route: Post /v1/auth/otp -> auth.SendOTP") - router.Post("/v1/auth/otp"[len(r.Path()):], Func1( + r.log.Debugf("Registering route: Post /t/:tenantCode/v1/auth/otp -> auth.SendOTP") + router.Post("/t/:tenantCode/v1/auth/otp"[len(r.Path()):], Func1( r.auth.SendOTP, Body[dto.SendOTPForm]("form"), )) diff --git a/backend/app/http/v1/common.go b/backend/app/http/v1/common.go index 475f85d..c915728 100644 --- a/backend/app/http/v1/common.go +++ b/backend/app/http/v1/common.go @@ -13,7 +13,7 @@ import ( // @provider type Common struct{} -// @Router /v1/upload [post] +// @Router /t/:tenantCode/v1/upload [post] // @Summary Upload file // @Description Upload file // @Tags Common @@ -31,16 +31,17 @@ func (c *Common) Upload( file *multipart.FileHeader, form *dto.UploadForm, ) (*dto.UploadResult, error) { + tenantID := getTenantID(ctx) val := "" if form != nil { val = form.Type } - return services.Common.Upload(ctx, user.ID, file, val) + return services.Common.Upload(ctx, tenantID, user.ID, file, val) } // Get options (enums) // -// @Router /v1/common/options [get] +// @Router /t/:tenantCode/v1/common/options [get] // @Summary Get options // @Description Get global options (enums) // @Tags Common @@ -53,7 +54,7 @@ func (c *Common) GetOptions(ctx fiber.Ctx) (*dto.OptionsResponse, error) { // Check file hash for deduplication // -// @Router /v1/upload/check [get] +// @Router /t/:tenantCode/v1/upload/check [get] // @Summary Check hash // @Description Check if file hash exists // @Tags Common @@ -64,10 +65,11 @@ func (c *Common) GetOptions(ctx fiber.Ctx) (*dto.OptionsResponse, error) { // @Bind user local key(__ctx_user) // @Bind hash query func (c *Common) CheckHash(ctx fiber.Ctx, user *models.User, hash string) (*dto.UploadResult, error) { - return services.Common.CheckHash(ctx, user.ID, hash) + tenantID := getTenantID(ctx) + return services.Common.CheckHash(ctx, tenantID, user.ID, hash) } -// @Router /v1/upload/init [post] +// @Router /t/:tenantCode/v1/upload/init [post] // @Summary Init multipart upload // @Description Initialize multipart upload // @Tags Common @@ -78,10 +80,11 @@ func (c *Common) CheckHash(ctx fiber.Ctx, user *models.User, hash string) (*dto. // @Bind user local key(__ctx_user) // @Bind form body func (c *Common) InitUpload(ctx fiber.Ctx, user *models.User, form *dto.UploadInitForm) (*dto.UploadInitResponse, error) { - return services.Common.InitUpload(ctx.Context(), user.ID, form) + tenantID := getTenantID(ctx) + return services.Common.InitUpload(ctx.Context(), tenantID, user.ID, form) } -// @Router /v1/upload/part [post] +// @Router /t/:tenantCode/v1/upload/part [post] // @Summary Upload part // @Description Upload a part // @Tags Common @@ -94,10 +97,11 @@ func (c *Common) InitUpload(ctx fiber.Ctx, user *models.User, form *dto.UploadIn // @Bind file file // @Bind form body func (c *Common) UploadPart(ctx fiber.Ctx, user *models.User, file *multipart.FileHeader, form *dto.UploadPartForm) error { - return services.Common.UploadPart(ctx.Context(), user.ID, file, form) + tenantID := getTenantID(ctx) + return services.Common.UploadPart(ctx.Context(), tenantID, user.ID, file, form) } -// @Router /v1/upload/complete [post] +// @Router /t/:tenantCode/v1/upload/complete [post] // @Summary Complete upload // @Description Complete multipart upload // @Tags Common @@ -108,10 +112,11 @@ func (c *Common) UploadPart(ctx fiber.Ctx, user *models.User, file *multipart.Fi // @Bind user local key(__ctx_user) // @Bind form body func (c *Common) CompleteUpload(ctx fiber.Ctx, user *models.User, form *dto.UploadCompleteForm) (*dto.UploadResult, error) { - return services.Common.CompleteUpload(ctx.Context(), user.ID, form) + tenantID := getTenantID(ctx) + return services.Common.CompleteUpload(ctx.Context(), tenantID, user.ID, form) } -// @Router /v1/upload/:uploadId [delete] +// @Router /t/:tenantCode/v1/upload/:uploadId [delete] // @Summary Abort upload // @Description Abort multipart upload // @Tags Common @@ -122,10 +127,11 @@ func (c *Common) CompleteUpload(ctx fiber.Ctx, user *models.User, form *dto.Uplo // @Bind user local key(__ctx_user) // @Bind uploadId path func (c *Common) AbortUpload(ctx fiber.Ctx, user *models.User, uploadId string) error { - return services.Common.AbortUpload(ctx.Context(), user.ID, uploadId) + tenantID := getTenantID(ctx) + return services.Common.AbortUpload(ctx.Context(), tenantID, user.ID, uploadId) } -// @Router /v1/media-assets/:id [delete] +// @Router /t/:tenantCode/v1/media-assets/:id [delete] // @Summary Delete media asset // @Description Delete media asset // @Tags Common @@ -136,7 +142,8 @@ func (c *Common) AbortUpload(ctx fiber.Ctx, user *models.User, uploadId string) // @Bind user local key(__ctx_user) // @Bind id path func (c *Common) DeleteMediaAsset(ctx fiber.Ctx, user *models.User, id int64) error { - return services.Common.DeleteMediaAsset(ctx.Context(), user.ID, id) + tenantID := getTenantID(ctx) + return services.Common.DeleteMediaAsset(ctx.Context(), tenantID, user.ID, id) } // Upload file diff --git a/backend/app/http/v1/content.go b/backend/app/http/v1/content.go index 1caba52..512b827 100644 --- a/backend/app/http/v1/content.go +++ b/backend/app/http/v1/content.go @@ -1,11 +1,10 @@ package v1 import ( + "quyun/v2/app/errorx" "quyun/v2/app/http/v1/dto" "quyun/v2/app/requests" "quyun/v2/app/services" - "quyun/v2/database/models" - "quyun/v2/pkg/consts" "github.com/gofiber/fiber/v3" ) @@ -15,7 +14,7 @@ type Content struct{} // List contents (Explore / Search) // -// @Router /v1/contents [get] +// @Router /t/:tenantCode/v1/contents [get] // @Summary List contents // @Description List contents with filtering and pagination // @Tags Content @@ -32,12 +31,19 @@ func (c *Content) List( ctx fiber.Ctx, filter *dto.ContentListFilter, ) (*requests.Pager, error) { - return services.Content.List(ctx, filter) + tenantID := getTenantID(ctx) + if tenantID > 0 { + if filter.TenantID != nil && *filter.TenantID != tenantID { + return nil, errorx.ErrForbidden.WithMsg("租户不匹配") + } + filter.TenantID = &tenantID + } + return services.Content.List(ctx, tenantID, filter) } // Get content detail // -// @Router /v1/contents/:id [get] +// @Router /t/:tenantCode/v1/contents/:id [get] // @Summary Get content detail // @Description Get content detail by ID // @Tags Content @@ -47,13 +53,14 @@ func (c *Content) List( // @Success 200 {object} dto.ContentDetail // @Bind id path func (c *Content) Get(ctx fiber.Ctx, id int64) (*dto.ContentDetail, error) { + tenantID := getTenantID(ctx) uid := getUserID(ctx) - return services.Content.Get(ctx, uid, id) + return services.Content.Get(ctx, tenantID, uid, id) } // Get comments for a content // -// @Router /v1/contents/:id/comments [get] +// @Router /t/:tenantCode/v1/contents/:id/comments [get] // @Summary Get comments // @Description Get comments for a content // @Tags Content @@ -65,13 +72,14 @@ func (c *Content) Get(ctx fiber.Ctx, id int64) (*dto.ContentDetail, error) { // @Bind id path // @Bind page query func (c *Content) ListComments(ctx fiber.Ctx, id int64, page int) (*requests.Pager, error) { + tenantID := getTenantID(ctx) uid := getUserID(ctx) - return services.Content.ListComments(ctx, uid, id, page) + return services.Content.ListComments(ctx, tenantID, uid, id, page) } // Post a comment // -// @Router /v1/contents/:id/comments [post] +// @Router /t/:tenantCode/v1/contents/:id/comments [post] // @Summary Post comment // @Description Post a comment to a content // @Tags Content @@ -83,13 +91,14 @@ func (c *Content) ListComments(ctx fiber.Ctx, id int64, page int) (*requests.Pag // @Bind id path // @Bind form body func (c *Content) CreateComment(ctx fiber.Ctx, id int64, form *dto.CommentCreateForm) error { + tenantID := getTenantID(ctx) uid := getUserID(ctx) - return services.Content.CreateComment(ctx, uid, id, form) + return services.Content.CreateComment(ctx, tenantID, uid, id, form) } // Like a comment // -// @Router /v1/comments/:id/like [post] +// @Router /t/:tenantCode/v1/comments/:id/like [post] // @Summary Like comment // @Description Like a comment // @Tags Content @@ -99,65 +108,70 @@ func (c *Content) CreateComment(ctx fiber.Ctx, id int64, form *dto.CommentCreate // @Success 200 {string} string "Liked" // @Bind id path func (c *Content) LikeComment(ctx fiber.Ctx, id int64) error { + tenantID := getTenantID(ctx) uid := getUserID(ctx) - return services.Content.LikeComment(ctx, uid, id) + return services.Content.LikeComment(ctx, tenantID, uid, id) } // Add like // -// @Router /v1/contents/:id/like [post] +// @Router /t/:tenantCode/v1/contents/:id/like [post] // @Summary Add like // @Tags Content // @Param id path int64 true "Content ID" // @Success 200 {string} string "Liked" // @Bind id path func (c *Content) AddLike(ctx fiber.Ctx, id int64) error { + tenantID := getTenantID(ctx) uid := getUserID(ctx) - return services.Content.AddLike(ctx, uid, id) + return services.Content.AddLike(ctx, tenantID, uid, id) } // Remove like // -// @Router /v1/contents/:id/like [delete] +// @Router /t/:tenantCode/v1/contents/:id/like [delete] // @Summary Remove like // @Tags Content // @Param id path int64 true "Content ID" // @Success 200 {string} string "Unliked" // @Bind id path func (c *Content) RemoveLike(ctx fiber.Ctx, id int64) error { + tenantID := getTenantID(ctx) uid := getUserID(ctx) - return services.Content.RemoveLike(ctx, uid, id) + return services.Content.RemoveLike(ctx, tenantID, uid, id) } // Add favorite // -// @Router /v1/contents/:id/favorite [post] +// @Router /t/:tenantCode/v1/contents/:id/favorite [post] // @Summary Add favorite // @Tags Content // @Param id path int64 true "Content ID" // @Success 200 {string} string "Favorited" // @Bind id path func (c *Content) AddFavorite(ctx fiber.Ctx, id int64) error { + tenantID := getTenantID(ctx) uid := getUserID(ctx) - return services.Content.AddFavorite(ctx, uid, id) + return services.Content.AddFavorite(ctx, tenantID, uid, id) } // Remove favorite // -// @Router /v1/contents/:id/favorite [delete] +// @Router /t/:tenantCode/v1/contents/:id/favorite [delete] // @Summary Remove favorite // @Tags Content // @Param id path int64 true "Content ID" // @Success 200 {string} string "Unfavorited" // @Bind id path func (c *Content) RemoveFavorite(ctx fiber.Ctx, id int64) error { + tenantID := getTenantID(ctx) uid := getUserID(ctx) - return services.Content.RemoveFavorite(ctx, uid, id) + return services.Content.RemoveFavorite(ctx, tenantID, uid, id) } // List curated topics // -// @Router /v1/topics [get] +// @Router /t/:tenantCode/v1/topics [get] // @Summary List topics // @Description List curated topics // @Tags Content @@ -165,14 +179,6 @@ func (c *Content) RemoveFavorite(ctx fiber.Ctx, id int64) error { // @Produce json // @Success 200 {array} dto.Topic func (c *Content) ListTopics(ctx fiber.Ctx) ([]dto.Topic, error) { - return services.Content.ListTopics(ctx) -} - -func getUserID(ctx fiber.Ctx) int64 { - if u := ctx.Locals(consts.CtxKeyUser); u != nil { - if user, ok := u.(*models.User); ok { - return user.ID - } - } - return 0 + tenantID := getTenantID(ctx) + return services.Content.ListTopics(ctx, tenantID) } diff --git a/backend/app/http/v1/creator.go b/backend/app/http/v1/creator.go index 3303f21..f4612e4 100644 --- a/backend/app/http/v1/creator.go +++ b/backend/app/http/v1/creator.go @@ -14,7 +14,7 @@ type Creator struct{} // Apply to become a creator // -// @Router /v1/creator/apply [post] +// @Router /t/:tenantCode/v1/creator/apply [post] // @Summary Apply creator // @Description Apply to become a creator // @Tags CreatorCenter @@ -25,12 +25,13 @@ type Creator struct{} // @Bind user local key(__ctx_user) // @Bind form body func (c *Creator) Apply(ctx fiber.Ctx, user *models.User, form *dto.ApplyForm) error { - return services.Creator.Apply(ctx, user.ID, form) + tenantID := getTenantID(ctx) + return services.Creator.Apply(ctx, tenantID, user.ID, form) } // Get creator dashboard stats // -// @Router /v1/creator/dashboard [get] +// @Router /t/:tenantCode/v1/creator/dashboard [get] // @Summary Dashboard stats // @Description Get creator dashboard stats // @Tags CreatorCenter @@ -39,12 +40,13 @@ func (c *Creator) Apply(ctx fiber.Ctx, user *models.User, form *dto.ApplyForm) e // @Success 200 {object} dto.DashboardStats // @Bind user local key(__ctx_user) func (c *Creator) Dashboard(ctx fiber.Ctx, user *models.User) (*dto.DashboardStats, error) { - return services.Creator.Dashboard(ctx, user.ID) + tenantID := getTenantID(ctx) + return services.Creator.Dashboard(ctx, tenantID, user.ID) } // Get content details for edit // -// @Router /v1/creator/contents/:id [get] +// @Router /t/:tenantCode/v1/creator/contents/:id [get] // @Summary Get content // @Description Get content details for edit // @Tags CreatorCenter @@ -55,12 +57,13 @@ func (c *Creator) Dashboard(ctx fiber.Ctx, user *models.User) (*dto.DashboardSta // @Bind user local key(__ctx_user) // @Bind id path func (c *Creator) GetContent(ctx fiber.Ctx, user *models.User, id int64) (*dto.ContentEditDTO, error) { - return services.Creator.GetContent(ctx, user.ID, id) + tenantID := getTenantID(ctx) + return services.Creator.GetContent(ctx, tenantID, user.ID, id) } // List creator contents // -// @Router /v1/creator/contents [get] +// @Router /t/:tenantCode/v1/creator/contents [get] // @Summary List contents // @Description List creator contents // @Tags CreatorCenter @@ -77,12 +80,13 @@ func (c *Creator) ListContents( user *models.User, filter *dto.CreatorContentListFilter, ) (*requests.Pager, error) { - return services.Creator.ListContents(ctx, user.ID, filter) + tenantID := getTenantID(ctx) + return services.Creator.ListContents(ctx, tenantID, user.ID, filter) } // Create/Publish content // -// @Router /v1/creator/contents [post] +// @Router /t/:tenantCode/v1/creator/contents [post] // @Summary Create content // @Description Create/Publish content // @Tags CreatorCenter @@ -93,12 +97,13 @@ func (c *Creator) ListContents( // @Bind user local key(__ctx_user) // @Bind form body func (c *Creator) CreateContent(ctx fiber.Ctx, user *models.User, form *dto.ContentCreateForm) error { - return services.Creator.CreateContent(ctx, user.ID, form) + tenantID := getTenantID(ctx) + return services.Creator.CreateContent(ctx, tenantID, user.ID, form) } // Update content // -// @Router /v1/creator/contents/:id [put] +// @Router /t/:tenantCode/v1/creator/contents/:id [put] // @Summary Update content // @Description Update content // @Tags CreatorCenter @@ -111,12 +116,13 @@ func (c *Creator) CreateContent(ctx fiber.Ctx, user *models.User, form *dto.Cont // @Bind id path // @Bind form body func (c *Creator) UpdateContent(ctx fiber.Ctx, user *models.User, id int64, form *dto.ContentUpdateForm) error { - return services.Creator.UpdateContent(ctx, user.ID, id, form) + tenantID := getTenantID(ctx) + return services.Creator.UpdateContent(ctx, tenantID, user.ID, id, form) } // Delete content // -// @Router /v1/creator/contents/:id [delete] +// @Router /t/:tenantCode/v1/creator/contents/:id [delete] // @Summary Delete content // @Description Delete content // @Tags CreatorCenter @@ -127,12 +133,13 @@ func (c *Creator) UpdateContent(ctx fiber.Ctx, user *models.User, id int64, form // @Bind user local key(__ctx_user) // @Bind id path func (c *Creator) DeleteContent(ctx fiber.Ctx, user *models.User, id int64) error { - return services.Creator.DeleteContent(ctx, user.ID, id) + tenantID := getTenantID(ctx) + return services.Creator.DeleteContent(ctx, tenantID, user.ID, id) } // List sales orders // -// @Router /v1/creator/orders [get] +// @Router /t/:tenantCode/v1/creator/orders [get] // @Summary List sales orders // @Description List sales orders // @Tags CreatorCenter @@ -148,12 +155,13 @@ func (c *Creator) ListOrders( user *models.User, filter *dto.CreatorOrderListFilter, ) ([]dto.Order, error) { - return services.Creator.ListOrders(ctx, user.ID, filter) + tenantID := getTenantID(ctx) + return services.Creator.ListOrders(ctx, tenantID, user.ID, filter) } // Process refund // -// @Router /v1/creator/orders/:id/refund [post] +// @Router /t/:tenantCode/v1/creator/orders/:id/refund [post] // @Summary Process refund // @Description Process refund // @Tags CreatorCenter @@ -166,12 +174,13 @@ func (c *Creator) ListOrders( // @Bind id path // @Bind form body func (c *Creator) Refund(ctx fiber.Ctx, user *models.User, id int64, form *dto.RefundForm) error { - return services.Creator.ProcessRefund(ctx, user.ID, id, form) + tenantID := getTenantID(ctx) + return services.Creator.ProcessRefund(ctx, tenantID, user.ID, id, form) } // Get channel settings // -// @Router /v1/creator/settings [get] +// @Router /t/:tenantCode/v1/creator/settings [get] // @Summary Get settings // @Description Get channel settings // @Tags CreatorCenter @@ -180,12 +189,13 @@ func (c *Creator) Refund(ctx fiber.Ctx, user *models.User, id int64, form *dto.R // @Success 200 {object} dto.Settings // @Bind user local key(__ctx_user) func (c *Creator) GetSettings(ctx fiber.Ctx, user *models.User) (*dto.Settings, error) { - return services.Creator.GetSettings(ctx, user.ID) + tenantID := getTenantID(ctx) + return services.Creator.GetSettings(ctx, tenantID, user.ID) } // Update channel settings // -// @Router /v1/creator/settings [put] +// @Router /t/:tenantCode/v1/creator/settings [put] // @Summary Update settings // @Description Update channel settings // @Tags CreatorCenter @@ -196,12 +206,13 @@ func (c *Creator) GetSettings(ctx fiber.Ctx, user *models.User) (*dto.Settings, // @Bind user local key(__ctx_user) // @Bind form body func (c *Creator) UpdateSettings(ctx fiber.Ctx, user *models.User, form *dto.Settings) error { - return services.Creator.UpdateSettings(ctx, user.ID, form) + tenantID := getTenantID(ctx) + return services.Creator.UpdateSettings(ctx, tenantID, user.ID, form) } // List payout accounts // -// @Router /v1/creator/payout-accounts [get] +// @Router /t/:tenantCode/v1/creator/payout-accounts [get] // @Summary List payout accounts // @Description List payout accounts // @Tags CreatorCenter @@ -210,12 +221,13 @@ func (c *Creator) UpdateSettings(ctx fiber.Ctx, user *models.User, form *dto.Set // @Success 200 {array} dto.PayoutAccount // @Bind user local key(__ctx_user) func (c *Creator) ListPayoutAccounts(ctx fiber.Ctx, user *models.User) ([]dto.PayoutAccount, error) { - return services.Creator.ListPayoutAccounts(ctx, user.ID) + tenantID := getTenantID(ctx) + return services.Creator.ListPayoutAccounts(ctx, tenantID, user.ID) } // Add payout account // -// @Router /v1/creator/payout-accounts [post] +// @Router /t/:tenantCode/v1/creator/payout-accounts [post] // @Summary Add payout account // @Description Add payout account // @Tags CreatorCenter @@ -226,12 +238,13 @@ func (c *Creator) ListPayoutAccounts(ctx fiber.Ctx, user *models.User) ([]dto.Pa // @Bind user local key(__ctx_user) // @Bind form body func (c *Creator) AddPayoutAccount(ctx fiber.Ctx, user *models.User, form *dto.PayoutAccount) error { - return services.Creator.AddPayoutAccount(ctx, user.ID, form) + tenantID := getTenantID(ctx) + return services.Creator.AddPayoutAccount(ctx, tenantID, user.ID, form) } // Remove payout account // -// @Router /v1/creator/payout-accounts [delete] +// @Router /t/:tenantCode/v1/creator/payout-accounts [delete] // @Summary Remove payout account // @Description Remove payout account // @Tags CreatorCenter @@ -242,12 +255,13 @@ func (c *Creator) AddPayoutAccount(ctx fiber.Ctx, user *models.User, form *dto.P // @Bind user local key(__ctx_user) // @Bind id query func (c *Creator) RemovePayoutAccount(ctx fiber.Ctx, user *models.User, id int64) error { - return services.Creator.RemovePayoutAccount(ctx, user.ID, id) + tenantID := getTenantID(ctx) + return services.Creator.RemovePayoutAccount(ctx, tenantID, user.ID, id) } // Request withdrawal // -// @Router /v1/creator/withdraw [post] +// @Router /t/:tenantCode/v1/creator/withdraw [post] // @Summary Request withdrawal // @Description Request withdrawal // @Tags CreatorCenter @@ -258,5 +272,6 @@ func (c *Creator) RemovePayoutAccount(ctx fiber.Ctx, user *models.User, id int64 // @Bind user local key(__ctx_user) // @Bind form body func (c *Creator) Withdraw(ctx fiber.Ctx, user *models.User, form *dto.WithdrawForm) error { - return services.Creator.Withdraw(ctx, user.ID, form) + tenantID := getTenantID(ctx) + return services.Creator.Withdraw(ctx, tenantID, user.ID, form) } diff --git a/backend/app/http/v1/helpers.go b/backend/app/http/v1/helpers.go new file mode 100644 index 0000000..0930cf1 --- /dev/null +++ b/backend/app/http/v1/helpers.go @@ -0,0 +1,26 @@ +package v1 + +import ( + "quyun/v2/database/models" + "quyun/v2/pkg/consts" + + "github.com/gofiber/fiber/v3" +) + +func getUserID(ctx fiber.Ctx) int64 { + if u := ctx.Locals(consts.CtxKeyUser); u != nil { + if user, ok := u.(*models.User); ok { + return user.ID + } + } + return 0 +} + +func getTenantID(ctx fiber.Ctx) int64 { + if t := ctx.Locals(consts.CtxKeyTenant); t != nil { + if tenant, ok := t.(*models.Tenant); ok { + return tenant.ID + } + } + return 0 +} diff --git a/backend/app/http/v1/routes.gen.go b/backend/app/http/v1/routes.gen.go index 0a5aaac..11922ba 100644 --- a/backend/app/http/v1/routes.gen.go +++ b/backend/app/http/v1/routes.gen.go @@ -50,368 +50,368 @@ func (r *Routes) Name() string { // Each route is registered with its corresponding controller action and parameter bindings. func (r *Routes) Register(router fiber.Router) { // Register routes for controller: Common - r.log.Debugf("Registering route: Delete /v1/media-assets/:id -> common.DeleteMediaAsset") - router.Delete("/v1/media-assets/:id"[len(r.Path()):], Func2( + r.log.Debugf("Registering route: Delete /t/:tenantCode/v1/media-assets/:id -> common.DeleteMediaAsset") + router.Delete("/t/:tenantCode/v1/media-assets/:id"[len(r.Path()):], Func2( r.common.DeleteMediaAsset, Local[*models.User]("__ctx_user"), PathParam[int64]("id"), )) - r.log.Debugf("Registering route: Delete /v1/upload/:uploadId -> common.AbortUpload") - router.Delete("/v1/upload/:uploadId"[len(r.Path()):], Func2( + r.log.Debugf("Registering route: Delete /t/:tenantCode/v1/upload/:uploadId -> common.AbortUpload") + router.Delete("/t/:tenantCode/v1/upload/:uploadId"[len(r.Path()):], Func2( r.common.AbortUpload, Local[*models.User]("__ctx_user"), PathParam[string]("uploadId"), )) - r.log.Debugf("Registering route: Get /v1/common/options -> common.GetOptions") - router.Get("/v1/common/options"[len(r.Path()):], DataFunc0( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/common/options -> common.GetOptions") + router.Get("/t/:tenantCode/v1/common/options"[len(r.Path()):], DataFunc0( r.common.GetOptions, )) - r.log.Debugf("Registering route: Get /v1/upload/check -> common.CheckHash") - router.Get("/v1/upload/check"[len(r.Path()):], DataFunc2( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/upload/check -> common.CheckHash") + router.Get("/t/:tenantCode/v1/upload/check"[len(r.Path()):], DataFunc2( r.common.CheckHash, Local[*models.User]("__ctx_user"), QueryParam[string]("hash"), )) - r.log.Debugf("Registering route: Post /v1/upload -> common.Upload") - router.Post("/v1/upload"[len(r.Path()):], DataFunc3( + r.log.Debugf("Registering route: Post /t/:tenantCode/v1/upload -> common.Upload") + router.Post("/t/:tenantCode/v1/upload"[len(r.Path()):], DataFunc3( r.common.Upload, Local[*models.User]("__ctx_user"), File[multipart.FileHeader]("file"), Body[dto.UploadForm]("form"), )) - r.log.Debugf("Registering route: Post /v1/upload/complete -> common.CompleteUpload") - router.Post("/v1/upload/complete"[len(r.Path()):], DataFunc2( + r.log.Debugf("Registering route: Post /t/:tenantCode/v1/upload/complete -> common.CompleteUpload") + router.Post("/t/:tenantCode/v1/upload/complete"[len(r.Path()):], DataFunc2( r.common.CompleteUpload, Local[*models.User]("__ctx_user"), Body[dto.UploadCompleteForm]("form"), )) - r.log.Debugf("Registering route: Post /v1/upload/init -> common.InitUpload") - router.Post("/v1/upload/init"[len(r.Path()):], DataFunc2( + r.log.Debugf("Registering route: Post /t/:tenantCode/v1/upload/init -> common.InitUpload") + router.Post("/t/:tenantCode/v1/upload/init"[len(r.Path()):], DataFunc2( r.common.InitUpload, Local[*models.User]("__ctx_user"), Body[dto.UploadInitForm]("form"), )) - r.log.Debugf("Registering route: Post /v1/upload/part -> common.UploadPart") - router.Post("/v1/upload/part"[len(r.Path()):], Func3( + r.log.Debugf("Registering route: Post /t/:tenantCode/v1/upload/part -> common.UploadPart") + router.Post("/t/:tenantCode/v1/upload/part"[len(r.Path()):], Func3( r.common.UploadPart, Local[*models.User]("__ctx_user"), File[multipart.FileHeader]("file"), Body[dto.UploadPartForm]("form"), )) // Register routes for controller: Content - r.log.Debugf("Registering route: Delete /v1/contents/:id/favorite -> content.RemoveFavorite") - router.Delete("/v1/contents/:id/favorite"[len(r.Path()):], Func1( + r.log.Debugf("Registering route: Delete /t/:tenantCode/v1/contents/:id/favorite -> content.RemoveFavorite") + router.Delete("/t/:tenantCode/v1/contents/:id/favorite"[len(r.Path()):], Func1( r.content.RemoveFavorite, PathParam[int64]("id"), )) - r.log.Debugf("Registering route: Delete /v1/contents/:id/like -> content.RemoveLike") - router.Delete("/v1/contents/:id/like"[len(r.Path()):], Func1( + r.log.Debugf("Registering route: Delete /t/:tenantCode/v1/contents/:id/like -> content.RemoveLike") + router.Delete("/t/:tenantCode/v1/contents/:id/like"[len(r.Path()):], Func1( r.content.RemoveLike, PathParam[int64]("id"), )) - r.log.Debugf("Registering route: Get /v1/contents -> content.List") - router.Get("/v1/contents"[len(r.Path()):], DataFunc1( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/contents -> content.List") + router.Get("/t/:tenantCode/v1/contents"[len(r.Path()):], DataFunc1( r.content.List, Query[dto.ContentListFilter]("filter"), )) - r.log.Debugf("Registering route: Get /v1/contents/:id -> content.Get") - router.Get("/v1/contents/:id"[len(r.Path()):], DataFunc1( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/contents/:id -> content.Get") + router.Get("/t/:tenantCode/v1/contents/:id"[len(r.Path()):], DataFunc1( r.content.Get, PathParam[int64]("id"), )) - r.log.Debugf("Registering route: Get /v1/contents/:id/comments -> content.ListComments") - router.Get("/v1/contents/:id/comments"[len(r.Path()):], DataFunc2( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/contents/:id/comments -> content.ListComments") + router.Get("/t/:tenantCode/v1/contents/:id/comments"[len(r.Path()):], DataFunc2( r.content.ListComments, PathParam[int64]("id"), QueryParam[int]("page"), )) - r.log.Debugf("Registering route: Get /v1/topics -> content.ListTopics") - router.Get("/v1/topics"[len(r.Path()):], DataFunc0( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/topics -> content.ListTopics") + router.Get("/t/:tenantCode/v1/topics"[len(r.Path()):], DataFunc0( r.content.ListTopics, )) - r.log.Debugf("Registering route: Post /v1/comments/:id/like -> content.LikeComment") - router.Post("/v1/comments/:id/like"[len(r.Path()):], Func1( + r.log.Debugf("Registering route: Post /t/:tenantCode/v1/comments/:id/like -> content.LikeComment") + router.Post("/t/:tenantCode/v1/comments/:id/like"[len(r.Path()):], Func1( r.content.LikeComment, PathParam[int64]("id"), )) - r.log.Debugf("Registering route: Post /v1/contents/:id/comments -> content.CreateComment") - router.Post("/v1/contents/:id/comments"[len(r.Path()):], Func2( + r.log.Debugf("Registering route: Post /t/:tenantCode/v1/contents/:id/comments -> content.CreateComment") + router.Post("/t/:tenantCode/v1/contents/:id/comments"[len(r.Path()):], Func2( r.content.CreateComment, PathParam[int64]("id"), Body[dto.CommentCreateForm]("form"), )) - r.log.Debugf("Registering route: Post /v1/contents/:id/favorite -> content.AddFavorite") - router.Post("/v1/contents/:id/favorite"[len(r.Path()):], Func1( + r.log.Debugf("Registering route: Post /t/:tenantCode/v1/contents/:id/favorite -> content.AddFavorite") + router.Post("/t/:tenantCode/v1/contents/:id/favorite"[len(r.Path()):], Func1( r.content.AddFavorite, PathParam[int64]("id"), )) - r.log.Debugf("Registering route: Post /v1/contents/:id/like -> content.AddLike") - router.Post("/v1/contents/:id/like"[len(r.Path()):], Func1( + r.log.Debugf("Registering route: Post /t/:tenantCode/v1/contents/:id/like -> content.AddLike") + router.Post("/t/:tenantCode/v1/contents/:id/like"[len(r.Path()):], Func1( r.content.AddLike, PathParam[int64]("id"), )) // Register routes for controller: Creator - r.log.Debugf("Registering route: Delete /v1/creator/contents/:id -> creator.DeleteContent") - router.Delete("/v1/creator/contents/:id"[len(r.Path()):], Func2( + r.log.Debugf("Registering route: Delete /t/:tenantCode/v1/creator/contents/:id -> creator.DeleteContent") + router.Delete("/t/:tenantCode/v1/creator/contents/:id"[len(r.Path()):], Func2( r.creator.DeleteContent, Local[*models.User]("__ctx_user"), PathParam[int64]("id"), )) - r.log.Debugf("Registering route: Delete /v1/creator/payout-accounts -> creator.RemovePayoutAccount") - router.Delete("/v1/creator/payout-accounts"[len(r.Path()):], Func2( + r.log.Debugf("Registering route: Delete /t/:tenantCode/v1/creator/payout-accounts -> creator.RemovePayoutAccount") + router.Delete("/t/:tenantCode/v1/creator/payout-accounts"[len(r.Path()):], Func2( r.creator.RemovePayoutAccount, Local[*models.User]("__ctx_user"), QueryParam[int64]("id"), )) - r.log.Debugf("Registering route: Get /v1/creator/contents -> creator.ListContents") - router.Get("/v1/creator/contents"[len(r.Path()):], DataFunc2( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/creator/contents -> creator.ListContents") + router.Get("/t/:tenantCode/v1/creator/contents"[len(r.Path()):], DataFunc2( r.creator.ListContents, Local[*models.User]("__ctx_user"), Query[dto.CreatorContentListFilter]("filter"), )) - r.log.Debugf("Registering route: Get /v1/creator/contents/:id -> creator.GetContent") - router.Get("/v1/creator/contents/:id"[len(r.Path()):], DataFunc2( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/creator/contents/:id -> creator.GetContent") + router.Get("/t/:tenantCode/v1/creator/contents/:id"[len(r.Path()):], DataFunc2( r.creator.GetContent, Local[*models.User]("__ctx_user"), PathParam[int64]("id"), )) - r.log.Debugf("Registering route: Get /v1/creator/dashboard -> creator.Dashboard") - router.Get("/v1/creator/dashboard"[len(r.Path()):], DataFunc1( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/creator/dashboard -> creator.Dashboard") + router.Get("/t/:tenantCode/v1/creator/dashboard"[len(r.Path()):], DataFunc1( r.creator.Dashboard, Local[*models.User]("__ctx_user"), )) - r.log.Debugf("Registering route: Get /v1/creator/orders -> creator.ListOrders") - router.Get("/v1/creator/orders"[len(r.Path()):], DataFunc2( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/creator/orders -> creator.ListOrders") + router.Get("/t/:tenantCode/v1/creator/orders"[len(r.Path()):], DataFunc2( r.creator.ListOrders, Local[*models.User]("__ctx_user"), Query[dto.CreatorOrderListFilter]("filter"), )) - r.log.Debugf("Registering route: Get /v1/creator/payout-accounts -> creator.ListPayoutAccounts") - router.Get("/v1/creator/payout-accounts"[len(r.Path()):], DataFunc1( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/creator/payout-accounts -> creator.ListPayoutAccounts") + router.Get("/t/:tenantCode/v1/creator/payout-accounts"[len(r.Path()):], DataFunc1( r.creator.ListPayoutAccounts, Local[*models.User]("__ctx_user"), )) - r.log.Debugf("Registering route: Get /v1/creator/settings -> creator.GetSettings") - router.Get("/v1/creator/settings"[len(r.Path()):], DataFunc1( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/creator/settings -> creator.GetSettings") + router.Get("/t/:tenantCode/v1/creator/settings"[len(r.Path()):], DataFunc1( r.creator.GetSettings, Local[*models.User]("__ctx_user"), )) - r.log.Debugf("Registering route: Post /v1/creator/apply -> creator.Apply") - router.Post("/v1/creator/apply"[len(r.Path()):], Func2( + r.log.Debugf("Registering route: Post /t/:tenantCode/v1/creator/apply -> creator.Apply") + router.Post("/t/:tenantCode/v1/creator/apply"[len(r.Path()):], Func2( r.creator.Apply, Local[*models.User]("__ctx_user"), Body[dto.ApplyForm]("form"), )) - r.log.Debugf("Registering route: Post /v1/creator/contents -> creator.CreateContent") - router.Post("/v1/creator/contents"[len(r.Path()):], Func2( + r.log.Debugf("Registering route: Post /t/:tenantCode/v1/creator/contents -> creator.CreateContent") + router.Post("/t/:tenantCode/v1/creator/contents"[len(r.Path()):], Func2( r.creator.CreateContent, Local[*models.User]("__ctx_user"), Body[dto.ContentCreateForm]("form"), )) - r.log.Debugf("Registering route: Post /v1/creator/orders/:id/refund -> creator.Refund") - router.Post("/v1/creator/orders/:id/refund"[len(r.Path()):], Func3( + r.log.Debugf("Registering route: Post /t/:tenantCode/v1/creator/orders/:id/refund -> creator.Refund") + router.Post("/t/:tenantCode/v1/creator/orders/:id/refund"[len(r.Path()):], Func3( r.creator.Refund, Local[*models.User]("__ctx_user"), PathParam[int64]("id"), Body[dto.RefundForm]("form"), )) - r.log.Debugf("Registering route: Post /v1/creator/payout-accounts -> creator.AddPayoutAccount") - router.Post("/v1/creator/payout-accounts"[len(r.Path()):], Func2( + r.log.Debugf("Registering route: Post /t/:tenantCode/v1/creator/payout-accounts -> creator.AddPayoutAccount") + router.Post("/t/:tenantCode/v1/creator/payout-accounts"[len(r.Path()):], Func2( r.creator.AddPayoutAccount, Local[*models.User]("__ctx_user"), Body[dto.PayoutAccount]("form"), )) - r.log.Debugf("Registering route: Post /v1/creator/withdraw -> creator.Withdraw") - router.Post("/v1/creator/withdraw"[len(r.Path()):], Func2( + r.log.Debugf("Registering route: Post /t/:tenantCode/v1/creator/withdraw -> creator.Withdraw") + router.Post("/t/:tenantCode/v1/creator/withdraw"[len(r.Path()):], Func2( r.creator.Withdraw, Local[*models.User]("__ctx_user"), Body[dto.WithdrawForm]("form"), )) - r.log.Debugf("Registering route: Put /v1/creator/contents/:id -> creator.UpdateContent") - router.Put("/v1/creator/contents/:id"[len(r.Path()):], Func3( + r.log.Debugf("Registering route: Put /t/:tenantCode/v1/creator/contents/:id -> creator.UpdateContent") + router.Put("/t/:tenantCode/v1/creator/contents/:id"[len(r.Path()):], Func3( r.creator.UpdateContent, Local[*models.User]("__ctx_user"), PathParam[int64]("id"), Body[dto.ContentUpdateForm]("form"), )) - r.log.Debugf("Registering route: Put /v1/creator/settings -> creator.UpdateSettings") - router.Put("/v1/creator/settings"[len(r.Path()):], Func2( + r.log.Debugf("Registering route: Put /t/:tenantCode/v1/creator/settings -> creator.UpdateSettings") + router.Put("/t/:tenantCode/v1/creator/settings"[len(r.Path()):], Func2( r.creator.UpdateSettings, Local[*models.User]("__ctx_user"), Body[dto.Settings]("form"), )) // Register routes for controller: Storage - r.log.Debugf("Registering route: Get /v1/storage/* -> storage.Download") - router.Get("/v1/storage/*"[len(r.Path()):], Func2( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/storage/* -> storage.Download") + router.Get("/t/:tenantCode/v1/storage/*"[len(r.Path()):], Func2( r.storage.Download, QueryParam[string]("expires"), QueryParam[string]("sign"), )) - r.log.Debugf("Registering route: Put /v1/storage/* -> storage.Upload") - router.Put("/v1/storage/*"[len(r.Path()):], DataFunc2( + r.log.Debugf("Registering route: Put /t/:tenantCode/v1/storage/* -> storage.Upload") + router.Put("/t/:tenantCode/v1/storage/*"[len(r.Path()):], DataFunc2( r.storage.Upload, QueryParam[string]("expires"), QueryParam[string]("sign"), )) // Register routes for controller: Tenant - r.log.Debugf("Registering route: Delete /v1/tenants/:id/follow -> tenant.Unfollow") - router.Delete("/v1/tenants/:id/follow"[len(r.Path()):], Func2( + r.log.Debugf("Registering route: Delete /t/:tenantCode/v1/tenants/:id/follow -> tenant.Unfollow") + router.Delete("/t/:tenantCode/v1/tenants/:id/follow"[len(r.Path()):], Func2( r.tenant.Unfollow, Local[*models.User]("__ctx_user"), PathParam[int64]("id"), )) - r.log.Debugf("Registering route: Get /v1/creators/:id/contents -> tenant.ListContents") - router.Get("/v1/creators/:id/contents"[len(r.Path()):], DataFunc2( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/creators/:id/contents -> tenant.ListContents") + router.Get("/t/:tenantCode/v1/creators/:id/contents"[len(r.Path()):], DataFunc2( r.tenant.ListContents, PathParam[int64]("id"), Query[dto.ContentListFilter]("filter"), )) - r.log.Debugf("Registering route: Get /v1/tenants -> tenant.List") - router.Get("/v1/tenants"[len(r.Path()):], DataFunc1( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/tenants -> tenant.List") + router.Get("/t/:tenantCode/v1/tenants"[len(r.Path()):], DataFunc1( r.tenant.List, Query[dto.TenantListFilter]("filter"), )) - r.log.Debugf("Registering route: Get /v1/tenants/:id -> tenant.Get") - router.Get("/v1/tenants/:id"[len(r.Path()):], DataFunc2( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/tenants/:id -> tenant.Get") + router.Get("/t/:tenantCode/v1/tenants/:id"[len(r.Path()):], DataFunc2( r.tenant.Get, Local[*models.User]("__ctx_user"), PathParam[int64]("id"), )) - r.log.Debugf("Registering route: Post /v1/tenants/:id/follow -> tenant.Follow") - router.Post("/v1/tenants/:id/follow"[len(r.Path()):], Func2( + r.log.Debugf("Registering route: Post /t/:tenantCode/v1/tenants/:id/follow -> tenant.Follow") + router.Post("/t/:tenantCode/v1/tenants/:id/follow"[len(r.Path()):], Func2( r.tenant.Follow, Local[*models.User]("__ctx_user"), PathParam[int64]("id"), )) // Register routes for controller: Transaction - r.log.Debugf("Registering route: Get /v1/orders/:id/status -> transaction.Status") - router.Get("/v1/orders/:id/status"[len(r.Path()):], DataFunc1( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/orders/:id/status -> transaction.Status") + router.Get("/t/:tenantCode/v1/orders/:id/status"[len(r.Path()):], DataFunc1( r.transaction.Status, PathParam[int64]("id"), )) - r.log.Debugf("Registering route: Post /v1/orders -> transaction.Create") - router.Post("/v1/orders"[len(r.Path()):], DataFunc2( + r.log.Debugf("Registering route: Post /t/:tenantCode/v1/orders -> transaction.Create") + router.Post("/t/:tenantCode/v1/orders"[len(r.Path()):], DataFunc2( r.transaction.Create, Local[*models.User]("__ctx_user"), Body[dto.OrderCreateForm]("form"), )) - r.log.Debugf("Registering route: Post /v1/orders/:id/pay -> transaction.Pay") - router.Post("/v1/orders/:id/pay"[len(r.Path()):], DataFunc3( + r.log.Debugf("Registering route: Post /t/:tenantCode/v1/orders/:id/pay -> transaction.Pay") + router.Post("/t/:tenantCode/v1/orders/:id/pay"[len(r.Path()):], DataFunc3( r.transaction.Pay, Local[*models.User]("__ctx_user"), PathParam[int64]("id"), Body[dto.OrderPayForm]("form"), )) - r.log.Debugf("Registering route: Post /v1/webhook/payment/notify -> transaction.Webhook") - router.Post("/v1/webhook/payment/notify"[len(r.Path()):], DataFunc1( + r.log.Debugf("Registering route: Post /t/:tenantCode/v1/webhook/payment/notify -> transaction.Webhook") + router.Post("/t/:tenantCode/v1/webhook/payment/notify"[len(r.Path()):], DataFunc1( r.transaction.Webhook, Body[WebhookForm]("form"), )) // Register routes for controller: User - r.log.Debugf("Registering route: Delete /v1/me/favorites/:contentId -> user.RemoveFavorite") - router.Delete("/v1/me/favorites/:contentId"[len(r.Path()):], Func2( + r.log.Debugf("Registering route: Delete /t/:tenantCode/v1/me/favorites/:contentId -> user.RemoveFavorite") + router.Delete("/t/:tenantCode/v1/me/favorites/:contentId"[len(r.Path()):], Func2( r.user.RemoveFavorite, Local[*models.User]("__ctx_user"), PathParam[int64]("contentId"), )) - r.log.Debugf("Registering route: Delete /v1/me/likes/:contentId -> user.RemoveLike") - router.Delete("/v1/me/likes/:contentId"[len(r.Path()):], Func2( + r.log.Debugf("Registering route: Delete /t/:tenantCode/v1/me/likes/:contentId -> user.RemoveLike") + router.Delete("/t/:tenantCode/v1/me/likes/:contentId"[len(r.Path()):], Func2( r.user.RemoveLike, Local[*models.User]("__ctx_user"), PathParam[int64]("contentId"), )) - r.log.Debugf("Registering route: Get /v1/me -> user.Me") - router.Get("/v1/me"[len(r.Path()):], DataFunc1( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/me -> user.Me") + router.Get("/t/:tenantCode/v1/me"[len(r.Path()):], DataFunc1( r.user.Me, Local[*models.User]("__ctx_user"), )) - r.log.Debugf("Registering route: Get /v1/me/coupons -> user.MyCoupons") - router.Get("/v1/me/coupons"[len(r.Path()):], DataFunc2( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/me/coupons -> user.MyCoupons") + router.Get("/t/:tenantCode/v1/me/coupons"[len(r.Path()):], DataFunc2( r.user.MyCoupons, Local[*models.User]("__ctx_user"), QueryParam[string]("status"), )) - r.log.Debugf("Registering route: Get /v1/me/favorites -> user.Favorites") - router.Get("/v1/me/favorites"[len(r.Path()):], DataFunc1( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/me/favorites -> user.Favorites") + router.Get("/t/:tenantCode/v1/me/favorites"[len(r.Path()):], DataFunc1( r.user.Favorites, Local[*models.User]("__ctx_user"), )) - r.log.Debugf("Registering route: Get /v1/me/following -> user.Following") - router.Get("/v1/me/following"[len(r.Path()):], DataFunc1( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/me/following -> user.Following") + router.Get("/t/:tenantCode/v1/me/following"[len(r.Path()):], DataFunc1( r.user.Following, Local[*models.User]("__ctx_user"), )) - r.log.Debugf("Registering route: Get /v1/me/library -> user.Library") - router.Get("/v1/me/library"[len(r.Path()):], DataFunc1( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/me/library -> user.Library") + router.Get("/t/:tenantCode/v1/me/library"[len(r.Path()):], DataFunc1( r.user.Library, Local[*models.User]("__ctx_user"), )) - r.log.Debugf("Registering route: Get /v1/me/likes -> user.Likes") - router.Get("/v1/me/likes"[len(r.Path()):], DataFunc1( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/me/likes -> user.Likes") + router.Get("/t/:tenantCode/v1/me/likes"[len(r.Path()):], DataFunc1( r.user.Likes, Local[*models.User]("__ctx_user"), )) - r.log.Debugf("Registering route: Get /v1/me/notifications -> user.Notifications") - router.Get("/v1/me/notifications"[len(r.Path()):], DataFunc3( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/me/notifications -> user.Notifications") + router.Get("/t/:tenantCode/v1/me/notifications"[len(r.Path()):], DataFunc3( r.user.Notifications, Local[*models.User]("__ctx_user"), QueryParam[string]("type"), QueryParam[int]("page"), )) - r.log.Debugf("Registering route: Get /v1/me/orders -> user.ListOrders") - router.Get("/v1/me/orders"[len(r.Path()):], DataFunc2( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/me/orders -> user.ListOrders") + router.Get("/t/:tenantCode/v1/me/orders"[len(r.Path()):], DataFunc2( r.user.ListOrders, Local[*models.User]("__ctx_user"), QueryParam[string]("status"), )) - r.log.Debugf("Registering route: Get /v1/me/orders/:id -> user.GetOrder") - router.Get("/v1/me/orders/:id"[len(r.Path()):], DataFunc2( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/me/orders/:id -> user.GetOrder") + router.Get("/t/:tenantCode/v1/me/orders/:id"[len(r.Path()):], DataFunc2( r.user.GetOrder, Local[*models.User]("__ctx_user"), PathParam[int64]("id"), )) - r.log.Debugf("Registering route: Get /v1/me/wallet -> user.Wallet") - router.Get("/v1/me/wallet"[len(r.Path()):], DataFunc1( + r.log.Debugf("Registering route: Get /t/:tenantCode/v1/me/wallet -> user.Wallet") + router.Get("/t/:tenantCode/v1/me/wallet"[len(r.Path()):], DataFunc1( r.user.Wallet, Local[*models.User]("__ctx_user"), )) - r.log.Debugf("Registering route: Post /v1/me/favorites -> user.AddFavorite") - router.Post("/v1/me/favorites"[len(r.Path()):], Func2( + r.log.Debugf("Registering route: Post /t/:tenantCode/v1/me/favorites -> user.AddFavorite") + router.Post("/t/:tenantCode/v1/me/favorites"[len(r.Path()):], Func2( r.user.AddFavorite, Local[*models.User]("__ctx_user"), QueryParam[int64]("contentId"), )) - r.log.Debugf("Registering route: Post /v1/me/likes -> user.AddLike") - router.Post("/v1/me/likes"[len(r.Path()):], Func2( + r.log.Debugf("Registering route: Post /t/:tenantCode/v1/me/likes -> user.AddLike") + router.Post("/t/:tenantCode/v1/me/likes"[len(r.Path()):], Func2( r.user.AddLike, Local[*models.User]("__ctx_user"), QueryParam[int64]("contentId"), )) - r.log.Debugf("Registering route: Post /v1/me/notifications/:id/read -> user.MarkNotificationRead") - router.Post("/v1/me/notifications/:id/read"[len(r.Path()):], Func2( + r.log.Debugf("Registering route: Post /t/:tenantCode/v1/me/notifications/:id/read -> user.MarkNotificationRead") + router.Post("/t/:tenantCode/v1/me/notifications/:id/read"[len(r.Path()):], Func2( r.user.MarkNotificationRead, Local[*models.User]("__ctx_user"), PathParam[int64]("id"), )) - r.log.Debugf("Registering route: Post /v1/me/notifications/read-all -> user.MarkAllNotificationsRead") - router.Post("/v1/me/notifications/read-all"[len(r.Path()):], Func1( + r.log.Debugf("Registering route: Post /t/:tenantCode/v1/me/notifications/read-all -> user.MarkAllNotificationsRead") + router.Post("/t/:tenantCode/v1/me/notifications/read-all"[len(r.Path()):], Func1( r.user.MarkAllNotificationsRead, Local[*models.User]("__ctx_user"), )) - r.log.Debugf("Registering route: Post /v1/me/realname -> user.RealName") - router.Post("/v1/me/realname"[len(r.Path()):], Func2( + r.log.Debugf("Registering route: Post /t/:tenantCode/v1/me/realname -> user.RealName") + router.Post("/t/:tenantCode/v1/me/realname"[len(r.Path()):], Func2( r.user.RealName, Local[*models.User]("__ctx_user"), Body[dto.RealNameForm]("form"), )) - r.log.Debugf("Registering route: Post /v1/me/wallet/recharge -> user.Recharge") - router.Post("/v1/me/wallet/recharge"[len(r.Path()):], DataFunc2( + r.log.Debugf("Registering route: Post /t/:tenantCode/v1/me/wallet/recharge -> user.Recharge") + router.Post("/t/:tenantCode/v1/me/wallet/recharge"[len(r.Path()):], DataFunc2( r.user.Recharge, Local[*models.User]("__ctx_user"), Body[dto.RechargeForm]("form"), )) - r.log.Debugf("Registering route: Put /v1/me -> user.Update") - router.Put("/v1/me"[len(r.Path()):], Func2( + r.log.Debugf("Registering route: Put /t/:tenantCode/v1/me -> user.Update") + router.Put("/t/:tenantCode/v1/me"[len(r.Path()):], Func2( r.user.Update, Local[*models.User]("__ctx_user"), Body[dto.UserUpdate]("form"), diff --git a/backend/app/http/v1/routes.manual.go b/backend/app/http/v1/routes.manual.go index 195c778..7f3e7a3 100644 --- a/backend/app/http/v1/routes.manual.go +++ b/backend/app/http/v1/routes.manual.go @@ -1,11 +1,12 @@ package v1 func (r *Routes) Path() string { - return "/v1" + return "/t/:tenantCode/v1" } func (r *Routes) Middlewares() []any { return []any{ + r.middlewares.TenantResolver, r.middlewares.Auth, } } diff --git a/backend/app/http/v1/storage.go b/backend/app/http/v1/storage.go index ba7f070..646d340 100644 --- a/backend/app/http/v1/storage.go +++ b/backend/app/http/v1/storage.go @@ -17,7 +17,7 @@ type Storage struct { // Upload file // -// @Router /v1/storage/* [put] +// @Router /t/:tenantCode/v1/storage/* [put] // @Summary Upload file // @Tags Storage // @Accept octet-stream @@ -58,7 +58,7 @@ func (s *Storage) Upload(ctx fiber.Ctx, expires, sign string) (string, error) { // Download file // -// @Router /v1/storage/* [get] +// @Router /t/:tenantCode/v1/storage/* [get] // @Summary Download file // @Tags Storage // @Accept json diff --git a/backend/app/http/v1/tenant.go b/backend/app/http/v1/tenant.go index 095bf11..37deb71 100644 --- a/backend/app/http/v1/tenant.go +++ b/backend/app/http/v1/tenant.go @@ -1,6 +1,7 @@ package v1 import ( + "quyun/v2/app/errorx" "quyun/v2/app/http/v1/dto" "quyun/v2/app/requests" "quyun/v2/app/services" @@ -14,7 +15,7 @@ type Tenant struct{} // List creator contents // -// @Router /v1/creators/:id/contents [get] +// @Router /t/:tenantCode/v1/creators/:id/contents [get] // @Summary List creator contents // @Description List contents of a specific creator // @Tags TenantPublic @@ -27,16 +28,20 @@ type Tenant struct{} // @Bind id path // @Bind filter query func (t *Tenant) ListContents(ctx fiber.Ctx, id int64, filter *dto.ContentListFilter) (*requests.Pager, error) { + tenantID := getTenantID(ctx) + if tenantID > 0 && id != tenantID { + return nil, errorx.ErrForbidden.WithMsg("租户不匹配") + } if filter == nil { filter = &dto.ContentListFilter{} } - filter.TenantID = &id - return services.Content.List(ctx, filter) + filter.TenantID = &tenantID + return services.Content.List(ctx, tenantID, filter) } // List tenants (search) // -// @Router /v1/tenants [get] +// @Router /t/:tenantCode/v1/tenants [get] // @Summary List tenants // @Description Search tenants // @Tags TenantPublic @@ -48,12 +53,13 @@ func (t *Tenant) ListContents(ctx fiber.Ctx, id int64, filter *dto.ContentListFi // @Success 200 {object} requests.Pager // @Bind filter query func (t *Tenant) List(ctx fiber.Ctx, filter *dto.TenantListFilter) (*requests.Pager, error) { - return services.Tenant.List(ctx, filter) + tenantID := getTenantID(ctx) + return services.Tenant.List(ctx, tenantID, filter) } // Get tenant public profile // -// @Router /v1/tenants/:id [get] +// @Router /t/:tenantCode/v1/tenants/:id [get] // @Summary Get tenant profile // @Description Get tenant public profile // @Tags TenantPublic @@ -68,12 +74,16 @@ func (t *Tenant) Get(ctx fiber.Ctx, user *models.User, id int64) (*dto.TenantPro if user != nil { uid = user.ID } - return services.Tenant.GetPublicProfile(ctx, uid, id) + tenantID := getTenantID(ctx) + if tenantID > 0 && id != tenantID { + return nil, errorx.ErrForbidden.WithMsg("租户不匹配") + } + return services.Tenant.GetPublicProfile(ctx, tenantID, uid) } // Follow a tenant // -// @Router /v1/tenants/:id/follow [post] +// @Router /t/:tenantCode/v1/tenants/:id/follow [post] // @Summary Follow tenant // @Description Follow a tenant // @Tags TenantPublic @@ -84,12 +94,16 @@ func (t *Tenant) Get(ctx fiber.Ctx, user *models.User, id int64) (*dto.TenantPro // @Bind user local key(__ctx_user) // @Bind id path func (t *Tenant) Follow(ctx fiber.Ctx, user *models.User, id int64) error { - return services.Tenant.Follow(ctx, user.ID, id) + tenantID := getTenantID(ctx) + if tenantID > 0 && id != tenantID { + return errorx.ErrForbidden.WithMsg("租户不匹配") + } + return services.Tenant.Follow(ctx, tenantID, user.ID) } // Unfollow a tenant // -// @Router /v1/tenants/:id/follow [delete] +// @Router /t/:tenantCode/v1/tenants/:id/follow [delete] // @Summary Unfollow tenant // @Description Unfollow a tenant // @Tags TenantPublic @@ -100,5 +114,9 @@ func (t *Tenant) Follow(ctx fiber.Ctx, user *models.User, id int64) error { // @Bind user local key(__ctx_user) // @Bind id path func (t *Tenant) Unfollow(ctx fiber.Ctx, user *models.User, id int64) error { - return services.Tenant.Unfollow(ctx, user.ID, id) + tenantID := getTenantID(ctx) + if tenantID > 0 && id != tenantID { + return errorx.ErrForbidden.WithMsg("租户不匹配") + } + return services.Tenant.Unfollow(ctx, tenantID, user.ID) } diff --git a/backend/app/http/v1/transaction.go b/backend/app/http/v1/transaction.go index 1e7051b..457177a 100644 --- a/backend/app/http/v1/transaction.go +++ b/backend/app/http/v1/transaction.go @@ -13,7 +13,7 @@ type Transaction struct{} // Create Order // -// @Router /v1/orders [post] +// @Router /t/:tenantCode/v1/orders [post] // @Summary Create Order // @Description Create Order // @Tags Transaction @@ -28,12 +28,13 @@ func (t *Transaction) Create( user *models.User, form *dto.OrderCreateForm, ) (*dto.OrderCreateResponse, error) { - return services.Order.Create(ctx, user.ID, form) + tenantID := getTenantID(ctx) + return services.Order.Create(ctx, tenantID, user.ID, form) } // Pay for order // -// @Router /v1/orders/:id/pay [post] +// @Router /t/:tenantCode/v1/orders/:id/pay [post] // @Summary Pay for order // @Description Pay for order // @Tags Transaction @@ -51,12 +52,13 @@ func (t *Transaction) Pay( id int64, form *dto.OrderPayForm, ) (*dto.OrderPayResponse, error) { - return services.Order.Pay(ctx, user.ID, id, form) + tenantID := getTenantID(ctx) + return services.Order.Pay(ctx, tenantID, user.ID, id, form) } // Check order payment status // -// @Router /v1/orders/:id/status [get] +// @Router /t/:tenantCode/v1/orders/:id/status [get] // @Summary Check order status // @Description Check order payment status // @Tags Transaction @@ -66,7 +68,8 @@ func (t *Transaction) Pay( // @Success 200 {object} dto.OrderStatusResponse // @Bind id path func (t *Transaction) Status(ctx fiber.Ctx, id int64) (*dto.OrderStatusResponse, error) { - return services.Order.Status(ctx, id) + tenantID := getTenantID(ctx) + return services.Order.Status(ctx, tenantID, id) } type WebhookForm struct { @@ -76,7 +79,7 @@ type WebhookForm struct { // Payment Webhook // -// @Router /v1/webhook/payment/notify [post] +// @Router /t/:tenantCode/v1/webhook/payment/notify [post] // @Summary Payment Webhook // @Description Payment Webhook // @Tags Transaction @@ -86,7 +89,8 @@ type WebhookForm struct { // @Success 200 {string} string "success" // @Bind form body func (t *Transaction) Webhook(ctx fiber.Ctx, form *WebhookForm) (string, error) { - err := services.Order.ProcessExternalPayment(ctx, form.OrderID, form.ExternalID) + tenantID := getTenantID(ctx) + err := services.Order.ProcessExternalPayment(ctx, tenantID, form.OrderID, form.ExternalID) if err != nil { return "fail", err } diff --git a/backend/app/http/v1/user.go b/backend/app/http/v1/user.go index 66359ed..4e1370e 100644 --- a/backend/app/http/v1/user.go +++ b/backend/app/http/v1/user.go @@ -15,7 +15,7 @@ type User struct{} // Get current user profile // -// @Router /v1/me [get] +// @Router /t/:tenantCode/v1/me [get] // @Summary Get user profile // @Description Get current user profile // @Tags UserCenter @@ -30,7 +30,7 @@ func (u *User) Me(ctx fiber.Ctx, user *models.User) (*auth_dto.User, error) { // Update user profile // -// @Router /v1/me [put] +// @Router /t/:tenantCode/v1/me [put] // @Summary Update user profile // @Description Update user profile // @Tags UserCenter @@ -46,7 +46,7 @@ func (u *User) Update(ctx fiber.Ctx, user *models.User, form *dto.UserUpdate) er // Submit real-name authentication // -// @Router /v1/me/realname [post] +// @Router /t/:tenantCode/v1/me/realname [post] // @Summary Realname auth // @Description Submit real-name authentication // @Tags UserCenter @@ -62,7 +62,7 @@ func (u *User) RealName(ctx fiber.Ctx, user *models.User, form *dto.RealNameForm // Get wallet balance and transactions // -// @Router /v1/me/wallet [get] +// @Router /t/:tenantCode/v1/me/wallet [get] // @Summary Get wallet // @Description Get wallet balance and transactions // @Tags UserCenter @@ -71,12 +71,13 @@ func (u *User) RealName(ctx fiber.Ctx, user *models.User, form *dto.RealNameForm // @Success 200 {object} dto.WalletResponse // @Bind user local key(__ctx_user) func (u *User) Wallet(ctx fiber.Ctx, user *models.User) (*dto.WalletResponse, error) { - return services.Wallet.GetWallet(ctx, user.ID) + tenantID := getTenantID(ctx) + return services.Wallet.GetWallet(ctx, tenantID, user.ID) } // Recharge wallet // -// @Router /v1/me/wallet/recharge [post] +// @Router /t/:tenantCode/v1/me/wallet/recharge [post] // @Summary Recharge wallet // @Description Recharge wallet // @Tags UserCenter @@ -87,12 +88,13 @@ func (u *User) Wallet(ctx fiber.Ctx, user *models.User) (*dto.WalletResponse, er // @Bind user local key(__ctx_user) // @Bind form body func (u *User) Recharge(ctx fiber.Ctx, user *models.User, form *dto.RechargeForm) (*dto.RechargeResponse, error) { - return services.Wallet.Recharge(ctx, user.ID, form) + tenantID := getTenantID(ctx) + return services.Wallet.Recharge(ctx, tenantID, user.ID, form) } // List user orders // -// @Router /v1/me/orders [get] +// @Router /t/:tenantCode/v1/me/orders [get] // @Summary List orders // @Description List user orders // @Tags UserCenter @@ -103,12 +105,13 @@ func (u *User) Recharge(ctx fiber.Ctx, user *models.User, form *dto.RechargeForm // @Bind user local key(__ctx_user) // @Bind status query func (u *User) ListOrders(ctx fiber.Ctx, user *models.User, status string) ([]dto.Order, error) { - return services.Order.ListUserOrders(ctx, user.ID, status) + tenantID := getTenantID(ctx) + return services.Order.ListUserOrders(ctx, tenantID, user.ID, status) } // Get user order detail // -// @Router /v1/me/orders/:id [get] +// @Router /t/:tenantCode/v1/me/orders/:id [get] // @Summary Get order detail // @Description Get user order detail // @Tags UserCenter @@ -119,12 +122,13 @@ func (u *User) ListOrders(ctx fiber.Ctx, user *models.User, status string) ([]dt // @Bind user local key(__ctx_user) // @Bind id path func (u *User) GetOrder(ctx fiber.Ctx, user *models.User, id int64) (*dto.Order, error) { - return services.Order.GetUserOrder(ctx, user.ID, id) + tenantID := getTenantID(ctx) + return services.Order.GetUserOrder(ctx, tenantID, user.ID, id) } // Get purchased content // -// @Router /v1/me/library [get] +// @Router /t/:tenantCode/v1/me/library [get] // @Summary Get library // @Description Get purchased content // @Tags UserCenter @@ -133,12 +137,13 @@ func (u *User) GetOrder(ctx fiber.Ctx, user *models.User, id int64) (*dto.Order, // @Success 200 {array} dto.ContentItem // @Bind user local key(__ctx_user) func (u *User) Library(ctx fiber.Ctx, user *models.User) ([]dto.ContentItem, error) { - return services.Content.GetLibrary(ctx, user.ID) + tenantID := getTenantID(ctx) + return services.Content.GetLibrary(ctx, tenantID, user.ID) } // Get favorites // -// @Router /v1/me/favorites [get] +// @Router /t/:tenantCode/v1/me/favorites [get] // @Summary Get favorites // @Description Get favorites // @Tags UserCenter @@ -147,12 +152,13 @@ func (u *User) Library(ctx fiber.Ctx, user *models.User) ([]dto.ContentItem, err // @Success 200 {array} dto.ContentItem // @Bind user local key(__ctx_user) func (u *User) Favorites(ctx fiber.Ctx, user *models.User) ([]dto.ContentItem, error) { - return services.Content.GetFavorites(ctx, user.ID) + tenantID := getTenantID(ctx) + return services.Content.GetFavorites(ctx, tenantID, user.ID) } // Add to favorites // -// @Router /v1/me/favorites [post] +// @Router /t/:tenantCode/v1/me/favorites [post] // @Summary Add favorite // @Description Add to favorites // @Tags UserCenter @@ -163,12 +169,13 @@ func (u *User) Favorites(ctx fiber.Ctx, user *models.User) ([]dto.ContentItem, e // @Bind user local key(__ctx_user) // @Bind contentId query func (u *User) AddFavorite(ctx fiber.Ctx, user *models.User, contentId int64) error { - return services.Content.AddFavorite(ctx, user.ID, contentId) + tenantID := getTenantID(ctx) + return services.Content.AddFavorite(ctx, tenantID, user.ID, contentId) } // Remove from favorites // -// @Router /v1/me/favorites/:contentId [delete] +// @Router /t/:tenantCode/v1/me/favorites/:contentId [delete] // @Summary Remove favorite // @Description Remove from favorites // @Tags UserCenter @@ -179,12 +186,13 @@ func (u *User) AddFavorite(ctx fiber.Ctx, user *models.User, contentId int64) er // @Bind user local key(__ctx_user) // @Bind contentId path func (u *User) RemoveFavorite(ctx fiber.Ctx, user *models.User, contentId int64) error { - return services.Content.RemoveFavorite(ctx, user.ID, contentId) + tenantID := getTenantID(ctx) + return services.Content.RemoveFavorite(ctx, tenantID, user.ID, contentId) } // Get liked contents // -// @Router /v1/me/likes [get] +// @Router /t/:tenantCode/v1/me/likes [get] // @Summary Get likes // @Description Get liked contents // @Tags UserCenter @@ -193,12 +201,13 @@ func (u *User) RemoveFavorite(ctx fiber.Ctx, user *models.User, contentId int64) // @Success 200 {array} dto.ContentItem // @Bind user local key(__ctx_user) func (u *User) Likes(ctx fiber.Ctx, user *models.User) ([]dto.ContentItem, error) { - return services.Content.GetLikes(ctx, user.ID) + tenantID := getTenantID(ctx) + return services.Content.GetLikes(ctx, tenantID, user.ID) } // Like content // -// @Router /v1/me/likes [post] +// @Router /t/:tenantCode/v1/me/likes [post] // @Summary Like content // @Description Like content // @Tags UserCenter @@ -209,12 +218,13 @@ func (u *User) Likes(ctx fiber.Ctx, user *models.User) ([]dto.ContentItem, error // @Bind user local key(__ctx_user) // @Bind contentId query func (u *User) AddLike(ctx fiber.Ctx, user *models.User, contentId int64) error { - return services.Content.AddLike(ctx, user.ID, contentId) + tenantID := getTenantID(ctx) + return services.Content.AddLike(ctx, tenantID, user.ID, contentId) } // Unlike content // -// @Router /v1/me/likes/:contentId [delete] +// @Router /t/:tenantCode/v1/me/likes/:contentId [delete] // @Summary Unlike content // @Description Unlike content // @Tags UserCenter @@ -225,12 +235,13 @@ func (u *User) AddLike(ctx fiber.Ctx, user *models.User, contentId int64) error // @Bind user local key(__ctx_user) // @Bind contentId path func (u *User) RemoveLike(ctx fiber.Ctx, user *models.User, contentId int64) error { - return services.Content.RemoveLike(ctx, user.ID, contentId) + tenantID := getTenantID(ctx) + return services.Content.RemoveLike(ctx, tenantID, user.ID, contentId) } // Get following tenants // -// @Router /v1/me/following [get] +// @Router /t/:tenantCode/v1/me/following [get] // @Summary Get following // @Description Get following tenants // @Tags UserCenter @@ -239,12 +250,13 @@ func (u *User) RemoveLike(ctx fiber.Ctx, user *models.User, contentId int64) err // @Success 200 {array} dto.TenantProfile // @Bind user local key(__ctx_user) func (u *User) Following(ctx fiber.Ctx, user *models.User) ([]dto.TenantProfile, error) { - return services.Tenant.ListFollowed(ctx, user.ID) + tenantID := getTenantID(ctx) + return services.Tenant.ListFollowed(ctx, tenantID, user.ID) } // Get notifications // -// @Router /v1/me/notifications [get] +// @Router /t/:tenantCode/v1/me/notifications [get] // @Summary Get notifications // @Description Get notifications // @Tags UserCenter @@ -262,7 +274,7 @@ func (u *User) Notifications(ctx fiber.Ctx, user *models.User, typeArg string, p // Mark notification as read // -// @Router /v1/me/notifications/:id/read [post] +// @Router /t/:tenantCode/v1/me/notifications/:id/read [post] // @Summary Mark as read // @Tags UserCenter // @Accept json @@ -277,7 +289,7 @@ func (u *User) MarkNotificationRead(ctx fiber.Ctx, user *models.User, id int64) // Mark all notifications as read // -// @Router /v1/me/notifications/read-all [post] +// @Router /t/:tenantCode/v1/me/notifications/read-all [post] // @Summary Mark all as read // @Tags UserCenter // @Accept json @@ -290,7 +302,7 @@ func (u *User) MarkAllNotificationsRead(ctx fiber.Ctx, user *models.User) error // List my coupons // -// @Router /v1/me/coupons [get] +// @Router /t/:tenantCode/v1/me/coupons [get] // @Summary List coupons // @Description List my coupons // @Tags UserCenter diff --git a/backend/app/middlewares/middlewares.go b/backend/app/middlewares/middlewares.go index b206c27..854e88b 100644 --- a/backend/app/middlewares/middlewares.go +++ b/backend/app/middlewares/middlewares.go @@ -1,16 +1,19 @@ package middlewares import ( + "errors" "strings" "quyun/v2/app/errorx" "quyun/v2/app/services" + "quyun/v2/database/models" "quyun/v2/pkg/consts" "quyun/v2/providers/jwt" "github.com/gofiber/fiber/v3" log "github.com/sirupsen/logrus" "go.ipao.vip/gen/types" + "gorm.io/gorm" ) // Middlewares provides reusable Fiber middlewares shared across modules. @@ -49,13 +52,18 @@ func (m *Middlewares) Auth(ctx fiber.Ctx) error { } // Set Context - ctx.Locals("__ctx_user", user) - if claims.TenantID > 0 { - tenant, err := services.Tenant.GetModelByID(ctx, claims.TenantID) + ctx.Locals(consts.CtxKeyUser, user) + + if tenant := ctx.Locals(consts.CtxKeyTenant); tenant != nil { + if model, ok := tenant.(*models.Tenant); ok && claims.TenantID > 0 && model.ID != claims.TenantID { + return errorx.ErrForbidden.WithMsg("租户不匹配") + } + } else if claims.TenantID > 0 { + tenantModel, err := services.Tenant.GetModelByID(ctx, claims.TenantID) if err != nil { return errorx.ErrUnauthorized.WithCause(err).WithMsg("TenantNotFound") } - ctx.Locals("__ctx_tenant", tenant) + ctx.Locals(consts.CtxKeyTenant, tenantModel) } return ctx.Next() @@ -84,7 +92,26 @@ func (m *Middlewares) SuperAuth(ctx fiber.Ctx) error { return errorx.ErrForbidden.WithMsg("无权限访问") } - ctx.Locals("__ctx_user", user) + ctx.Locals(consts.CtxKeyUser, user) + return ctx.Next() +} + +func (m *Middlewares) TenantResolver(ctx fiber.Ctx) error { + tenantCode := strings.TrimSpace(ctx.Params("tenantCode")) + if tenantCode == "" { + return errorx.ErrMissingParameter.WithMsg("缺少租户编码") + } + + tbl, q := models.TenantQuery.QueryContext(ctx) + tenant, err := q.Where(tbl.Code.Eq(tenantCode)).First() + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return errorx.ErrRecordNotFound.WithMsg("租户不存在") + } + return errorx.ErrDatabaseError.WithCause(err) + } + + ctx.Locals(consts.CtxKeyTenant, tenant) return ctx.Next() } @@ -98,7 +125,7 @@ func hasRole(roles types.Array[consts.Role], role consts.Role) bool { } func isPublicRoute(ctx fiber.Ctx) bool { - path := ctx.Path() + path := normalizeTenantPath(ctx.Path()) method := ctx.Method() if method == fiber.MethodGet { @@ -127,6 +154,13 @@ func isPublicRoute(ctx fiber.Ctx) bool { return true } + if method == fiber.MethodPost { + switch path { + case "/v1/auth/otp", "/v1/auth/login": + return true + } + } + if method == fiber.MethodPut && strings.HasPrefix(path, "/v1/storage/") { return true } @@ -144,3 +178,19 @@ func isSuperPublicRoute(ctx fiber.Ctx) bool { return false } + +func normalizeTenantPath(path string) string { + if !strings.HasPrefix(path, "/t/") { + return path + } + rest := strings.TrimPrefix(path, "/t/") + slash := strings.Index(rest, "/") + if slash == -1 { + return path + } + rest = rest[slash:] + if strings.HasPrefix(rest, "/v1") { + return rest + } + return path +} diff --git a/backend/app/services/common.go b/backend/app/services/common.go index 87a31ac..7f0c567 100644 --- a/backend/app/services/common.go +++ b/backend/app/services/common.go @@ -27,6 +27,7 @@ import ( "github.com/google/uuid" "github.com/jackc/pgconn" "go.ipao.vip/gen/types" + "gorm.io/gorm" ) // @provider @@ -49,7 +50,7 @@ func (s *common) Options(ctx context.Context) (*common_dto.OptionsResponse, erro }, nil } -func (s *common) CheckHash(ctx context.Context, userID int64, hash string) (*common_dto.UploadResult, error) { +func (s *common) CheckHash(ctx context.Context, tenantID, userID int64, hash string) (*common_dto.UploadResult, error) { existing, err := models.MediaAssetQuery.WithContext(ctx).Where(models.MediaAssetQuery.Hash.Eq(hash)).First() if err != nil { return nil, nil // Not found, proceed to upload @@ -58,18 +59,25 @@ func (s *common) CheckHash(ctx context.Context, userID int64, hash string) (*com // Found existing file (Global deduplication hit) // Check if user already has it (Logic deduplication hit) - myExisting, err := models.MediaAssetQuery.WithContext(ctx). - Where(models.MediaAssetQuery.Hash.Eq(hash), models.MediaAssetQuery.UserID.Eq(userID)). - First() + myQuery := models.MediaAssetQuery.WithContext(ctx). + Where(models.MediaAssetQuery.Hash.Eq(hash), models.MediaAssetQuery.UserID.Eq(userID)) + if tenantID > 0 { + myQuery = myQuery.Where(models.MediaAssetQuery.TenantID.Eq(tenantID)) + } + myExisting, err := myQuery.First() if err == nil { return s.composeUploadResult(myExisting), nil } // Create new record for this user reusing existing ObjectKey - t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.UserID.Eq(userID)).First() - var tid int64 = 0 - if err == nil { - tid = t.ID + // 优先使用路径租户,避免跨租户写入。 + tenant, err := s.resolveTenant(ctx, tenantID, userID) + if err != nil { + return nil, err + } + var tid int64 + if tenant != nil { + tid = tenant.ID } asset := &models.MediaAsset{ @@ -107,13 +115,47 @@ func (s *common) buildObjectKey(tenant *models.Tenant, hash, filename string) st return path.Join("quyun", tenantUUID, hash+ext) } -func (s *common) InitUpload(ctx context.Context, userID int64, form *common_dto.UploadInitForm) (*common_dto.UploadInitResponse, error) { +func (s *common) resolveTenant(ctx context.Context, tenantID, userID int64) (*models.Tenant, error) { + if tenantID > 0 { + tbl, q := models.TenantQuery.QueryContext(ctx) + tenant, err := q.Where(tbl.ID.Eq(tenantID)).First() + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errorx.ErrRecordNotFound.WithMsg("租户不存在") + } + return nil, errorx.ErrDatabaseError.WithCause(err) + } + return tenant, nil + } + if userID == 0 { + return nil, nil + } + tbl, q := models.TenantQuery.QueryContext(ctx) + tenant, err := q.Where(tbl.UserID.Eq(userID)).First() + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, errorx.ErrDatabaseError.WithCause(err) + } + return tenant, nil +} + +func (s *common) uploadTempDir(localPath string, tenantID int64, uploadID string) string { + tenantKey := "public" + if tenantID > 0 { + tenantKey = strconv.FormatInt(tenantID, 10) + } + return filepath.Join(localPath, "temp", tenantKey, uploadID) +} + +func (s *common) InitUpload(ctx context.Context, tenantID, userID int64, form *common_dto.UploadInitForm) (*common_dto.UploadInitResponse, error) { uploadID := uuid.NewString() localPath := s.storage.Config.LocalPath if localPath == "" { localPath = "./storage" } - tempDir := filepath.Join(localPath, "temp", uploadID) + tempDir := s.uploadTempDir(localPath, tenantID, uploadID) if err := os.MkdirAll(tempDir, 0o755); err != nil { return nil, errorx.ErrInternalError.WithCause(err) } @@ -134,12 +176,12 @@ func (s *common) InitUpload(ctx context.Context, userID int64, form *common_dto. }, nil } -func (s *common) UploadPart(ctx context.Context, userID int64, file *multipart.FileHeader, form *common_dto.UploadPartForm) error { +func (s *common) UploadPart(ctx context.Context, tenantID, userID int64, file *multipart.FileHeader, form *common_dto.UploadPartForm) error { localPath := s.storage.Config.LocalPath if localPath == "" { localPath = "./storage" } - partPath := filepath.Join(localPath, "temp", form.UploadID, strconv.Itoa(form.PartNumber)) + partPath := filepath.Join(s.uploadTempDir(localPath, tenantID, form.UploadID), strconv.Itoa(form.PartNumber)) src, err := file.Open() if err != nil { @@ -159,12 +201,12 @@ func (s *common) UploadPart(ctx context.Context, userID int64, file *multipart.F return nil } -func (s *common) CompleteUpload(ctx context.Context, userID int64, form *common_dto.UploadCompleteForm) (*common_dto.UploadResult, error) { +func (s *common) CompleteUpload(ctx context.Context, tenantID, userID int64, form *common_dto.UploadCompleteForm) (*common_dto.UploadResult, error) { localPath := s.storage.Config.LocalPath if localPath == "" { localPath = "./storage" } - tempDir := filepath.Join(localPath, "temp", form.UploadID) + tempDir := s.uploadTempDir(localPath, tenantID, form.UploadID) // Read Meta var meta UploadMeta @@ -220,21 +262,27 @@ func (s *common) CompleteUpload(ctx context.Context, userID int64, form *common_ dst.Close() // Ensure flush before potential removal // Deduplication Logic (Similar to Upload) - t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.UserID.Eq(userID)).First() - var tid int64 = 0 - if err == nil { - tid = t.ID + tenant, err := s.resolveTenant(ctx, tenantID, userID) + if err != nil { + return nil, err + } + var tid int64 + if tenant != nil { + tid = tenant.ID } - objectKey := s.buildObjectKey(t, hash, meta.Filename) + objectKey := s.buildObjectKey(tenant, hash, meta.Filename) existing, err := models.MediaAssetQuery.WithContext(ctx).Where(models.MediaAssetQuery.Hash.Eq(hash)).First() var asset *models.MediaAsset if err == nil { os.Remove(mergedPath) // Delete duplicate - myExisting, err := models.MediaAssetQuery.WithContext(ctx). - Where(models.MediaAssetQuery.Hash.Eq(hash), models.MediaAssetQuery.UserID.Eq(userID)). - First() + myQuery := models.MediaAssetQuery.WithContext(ctx). + Where(models.MediaAssetQuery.Hash.Eq(hash), models.MediaAssetQuery.UserID.Eq(userID)) + if tenantID > 0 { + myQuery = myQuery.Where(models.MediaAssetQuery.TenantID.Eq(tenantID)) + } + myExisting, err := myQuery.First() if err == nil { os.RemoveAll(tempDir) return s.composeUploadResult(myExisting), nil @@ -282,10 +330,13 @@ func (s *common) CompleteUpload(ctx context.Context, userID int64, form *common_ return s.composeUploadResult(asset), nil } -func (s *common) DeleteMediaAsset(ctx context.Context, userID, id int64) error { - asset, err := models.MediaAssetQuery.WithContext(ctx). - Where(models.MediaAssetQuery.ID.Eq(id), models.MediaAssetQuery.UserID.Eq(userID)). - First() +func (s *common) DeleteMediaAsset(ctx context.Context, tenantID, userID, id int64) error { + query := models.MediaAssetQuery.WithContext(ctx). + Where(models.MediaAssetQuery.ID.Eq(id), models.MediaAssetQuery.UserID.Eq(userID)) + if tenantID > 0 { + query = query.Where(models.MediaAssetQuery.TenantID.Eq(tenantID)) + } + asset, err := query.First() if err != nil { return errorx.ErrRecordNotFound } @@ -308,17 +359,18 @@ func (s *common) DeleteMediaAsset(ctx context.Context, userID, id int64) error { return nil } -func (s *common) AbortUpload(ctx context.Context, userID int64, uploadId string) error { +func (s *common) AbortUpload(ctx context.Context, tenantID, userID int64, uploadId string) error { localPath := s.storage.Config.LocalPath if localPath == "" { localPath = "./storage" } - tempDir := filepath.Join(localPath, "temp", uploadId) + tempDir := s.uploadTempDir(localPath, tenantID, uploadId) return os.RemoveAll(tempDir) } func (s *common) Upload( ctx context.Context, + tenantID int64, userID int64, file *multipart.FileHeader, typeArg string, @@ -357,13 +409,16 @@ func (s *common) Upload( hash := hex.EncodeToString(hasher.Sum(nil)) - t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.UserID.Eq(userID)).First() - var tid int64 = 0 - if err == nil { - tid = t.ID + tenant, err := s.resolveTenant(ctx, tenantID, userID) + if err != nil { + return nil, err + } + var tid int64 + if tenant != nil { + tid = tenant.ID } - objectKey := s.buildObjectKey(t, hash, file.Filename) + objectKey := s.buildObjectKey(tenant, hash, file.Filename) var asset *models.MediaAsset // Deduplication Check @@ -374,9 +429,12 @@ func (s *common) Upload( os.RemoveAll(tmpDir) // Check if user already has it (Logic Deduplication) - myExisting, err := models.MediaAssetQuery.WithContext(ctx). - Where(models.MediaAssetQuery.Hash.Eq(hash), models.MediaAssetQuery.UserID.Eq(userID)). - First() + myQuery := models.MediaAssetQuery.WithContext(ctx). + Where(models.MediaAssetQuery.Hash.Eq(hash), models.MediaAssetQuery.UserID.Eq(userID)) + if tenantID > 0 { + myQuery = myQuery.Where(models.MediaAssetQuery.TenantID.Eq(tenantID)) + } + myExisting, err := myQuery.First() if err == nil { return s.composeUploadResult(myExisting), nil } diff --git a/backend/app/services/content.go b/backend/app/services/content.go index b7c5314..faa14e3 100644 --- a/backend/app/services/content.go +++ b/backend/app/services/content.go @@ -18,11 +18,14 @@ import ( // @provider type content struct{} -func (s *content) List(ctx context.Context, filter *content_dto.ContentListFilter) (*requests.Pager, error) { +func (s *content) List(ctx context.Context, tenantID int64, filter *content_dto.ContentListFilter) (*requests.Pager, error) { tbl, q := models.ContentQuery.QueryContext(ctx) // Filters q = q.Where(tbl.Status.Eq(consts.ContentStatusPublished)) + if tenantID > 0 { + q = q.Where(tbl.TenantID.Eq(tenantID)) + } if filter.Keyword != nil && *filter.Keyword != "" { keyword := "%" + *filter.Keyword + "%" q = q.Where(tbl.Title.Like(keyword)).Or(tbl.Description.Like(keyword)) @@ -31,6 +34,9 @@ func (s *content) List(ctx context.Context, filter *content_dto.ContentListFilte q = q.Where(tbl.Genre.Eq(*filter.Genre)) } if filter.TenantID != nil && *filter.TenantID > 0 { + if tenantID > 0 && *filter.TenantID != tenantID { + return nil, errorx.ErrForbidden.WithMsg("租户不匹配") + } q = q.Where(tbl.TenantID.Eq(*filter.TenantID)) } if filter.IsPinned != nil { @@ -128,16 +134,22 @@ func (s *content) List(ctx context.Context, filter *content_dto.ContentListFilte }, nil } -func (s *content) Get(ctx context.Context, userID, id int64) (*content_dto.ContentDetail, error) { +func (s *content) Get(ctx context.Context, tenantID, userID, id int64) (*content_dto.ContentDetail, error) { // Increment Views - _, _ = models.ContentQuery.WithContext(ctx). - Where(models.ContentQuery.ID.Eq(id)). - UpdateSimple(models.ContentQuery.Views.Add(1)) + update := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.ID.Eq(id)) + if tenantID > 0 { + update = update.Where(models.ContentQuery.TenantID.Eq(tenantID)) + } + _, _ = update.UpdateSimple(models.ContentQuery.Views.Add(1)) _, q := models.ContentQuery.QueryContext(ctx) var item models.Content - err := q.UnderlyingDB(). + db := q.UnderlyingDB() + if tenantID > 0 { + db = db.Where("tenant_id = ?", tenantID) + } + err := db. Preload("Author"). Preload("ContentAssets", func(db *gorm.DB) *gorm.DB { return db.Order("sort ASC") @@ -232,10 +244,25 @@ func (s *content) Get(ctx context.Context, userID, id int64) (*content_dto.Conte return detail, nil } -func (s *content) ListComments(ctx context.Context, userID, id int64, page int) (*requests.Pager, error) { +func (s *content) ListComments(ctx context.Context, tenantID, userID, id int64, page int) (*requests.Pager, error) { + if tenantID > 0 { + _, err := models.ContentQuery.WithContext(ctx). + Where(models.ContentQuery.ID.Eq(id), models.ContentQuery.TenantID.Eq(tenantID)). + First() + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errorx.ErrRecordNotFound + } + return nil, errorx.ErrDatabaseError.WithCause(err) + } + } + tbl, q := models.CommentQuery.QueryContext(ctx) q = q.Where(tbl.ContentID.Eq(id)).Preload(tbl.User) + if tenantID > 0 { + q = q.Where(tbl.TenantID.Eq(tenantID)) + } q = q.Order(tbl.CreatedAt.Desc()) p := requests.Pagination{Page: int64(page), Limit: 10} @@ -291,6 +318,7 @@ func (s *content) ListComments(ctx context.Context, userID, id int64, page int) func (s *content) CreateComment( ctx context.Context, + tenantID int64, userID int64, id int64, form *content_dto.CommentCreateForm, @@ -300,7 +328,11 @@ func (s *content) CreateComment( } uid := userID - c, err := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.ID.Eq(id)).First() + query := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.ID.Eq(id)) + if tenantID > 0 { + query = query.Where(models.ContentQuery.TenantID.Eq(tenantID)) + } + c, err := query.First() if err != nil { return errorx.ErrRecordNotFound } @@ -319,14 +351,18 @@ func (s *content) CreateComment( return nil } -func (s *content) LikeComment(ctx context.Context, userID, id int64) error { +func (s *content) LikeComment(ctx context.Context, tenantID, userID, id int64) error { if userID == 0 { return errorx.ErrUnauthorized } uid := userID // Fetch comment for author - cm, err := models.CommentQuery.WithContext(ctx).Where(models.CommentQuery.ID.Eq(id)).First() + query := models.CommentQuery.WithContext(ctx).Where(models.CommentQuery.ID.Eq(id)) + if tenantID > 0 { + query = query.Where(models.CommentQuery.TenantID.Eq(tenantID)) + } + cm, err := query.First() if err != nil { return errorx.ErrRecordNotFound } @@ -357,14 +393,18 @@ func (s *content) LikeComment(ctx context.Context, userID, id int64) error { return nil } -func (s *content) GetLibrary(ctx context.Context, userID int64) ([]user_dto.ContentItem, error) { +func (s *content) GetLibrary(ctx context.Context, tenantID, userID int64) ([]user_dto.ContentItem, error) { if userID == 0 { return nil, errorx.ErrUnauthorized } uid := userID tbl, q := models.ContentAccessQuery.QueryContext(ctx) - accessList, err := q.Where(tbl.UserID.Eq(uid), tbl.Status.Eq(consts.ContentAccessStatusActive)).Find() + q = q.Where(tbl.UserID.Eq(uid), tbl.Status.Eq(consts.ContentAccessStatusActive)) + if tenantID > 0 { + q = q.Where(tbl.TenantID.Eq(tenantID)) + } + accessList, err := q.Find() if err != nil { return nil, errorx.ErrDatabaseError.WithCause(err) } @@ -380,7 +420,11 @@ func (s *content) GetLibrary(ctx context.Context, userID int64) ([]user_dto.Cont ctbl, cq := models.ContentQuery.QueryContext(ctx) var list []*models.Content - err = cq.Where(ctbl.ID.In(contentIDs...)). + cq = cq.Where(ctbl.ID.In(contentIDs...)) + if tenantID > 0 { + cq = cq.Where(ctbl.TenantID.Eq(tenantID)) + } + err = cq. UnderlyingDB(). Preload("Author"). Preload("ContentAssets.Asset"). @@ -398,36 +442,40 @@ func (s *content) GetLibrary(ctx context.Context, userID int64) ([]user_dto.Cont return data, nil } -func (s *content) GetFavorites(ctx context.Context, userID int64) ([]user_dto.ContentItem, error) { - return s.getInteractList(ctx, userID, "favorite") +func (s *content) GetFavorites(ctx context.Context, tenantID, userID int64) ([]user_dto.ContentItem, error) { + return s.getInteractList(ctx, tenantID, userID, "favorite") } -func (s *content) AddFavorite(ctx context.Context, userID, contentId int64) error { - return s.addInteract(ctx, userID, contentId, "favorite") +func (s *content) AddFavorite(ctx context.Context, tenantID, userID, contentId int64) error { + return s.addInteract(ctx, tenantID, userID, contentId, "favorite") } -func (s *content) RemoveFavorite(ctx context.Context, userID, contentId int64) error { - return s.removeInteract(ctx, userID, contentId, "favorite") +func (s *content) RemoveFavorite(ctx context.Context, tenantID, userID, contentId int64) error { + return s.removeInteract(ctx, tenantID, userID, contentId, "favorite") } -func (s *content) GetLikes(ctx context.Context, userID int64) ([]user_dto.ContentItem, error) { - return s.getInteractList(ctx, userID, "like") +func (s *content) GetLikes(ctx context.Context, tenantID, userID int64) ([]user_dto.ContentItem, error) { + return s.getInteractList(ctx, tenantID, userID, "like") } -func (s *content) AddLike(ctx context.Context, userID, contentId int64) error { - return s.addInteract(ctx, userID, contentId, "like") +func (s *content) AddLike(ctx context.Context, tenantID, userID, contentId int64) error { + return s.addInteract(ctx, tenantID, userID, contentId, "like") } -func (s *content) RemoveLike(ctx context.Context, userID, contentId int64) error { - return s.removeInteract(ctx, userID, contentId, "like") +func (s *content) RemoveLike(ctx context.Context, tenantID, userID, contentId int64) error { + return s.removeInteract(ctx, tenantID, userID, contentId, "like") } -func (s *content) ListTopics(ctx context.Context) ([]content_dto.Topic, error) { +func (s *content) ListTopics(ctx context.Context, tenantID int64) ([]content_dto.Topic, error) { var results []struct { Genre string Count int } - err := models.ContentQuery.WithContext(ctx).UnderlyingDB(). + db := models.ContentQuery.WithContext(ctx).UnderlyingDB() + if tenantID > 0 { + db = db.Where("tenant_id = ?", tenantID) + } + err := db. Model(&models.Content{}). Where("status = ?", consts.ContentStatusPublished). Select("genre, count(*) as count"). @@ -445,8 +493,12 @@ func (s *content) ListTopics(ctx context.Context) ([]content_dto.Topic, error) { // Fetch latest content in this genre to get a cover var c models.Content - models.ContentQuery.WithContext(ctx). - Where(models.ContentQuery.Genre.Eq(r.Genre), models.ContentQuery.Status.Eq(consts.ContentStatusPublished)). + query := models.ContentQuery.WithContext(ctx). + Where(models.ContentQuery.Genre.Eq(r.Genre), models.ContentQuery.Status.Eq(consts.ContentStatusPublished)) + if tenantID > 0 { + query = query.Where(models.ContentQuery.TenantID.Eq(tenantID)) + } + query. Order(models.ContentQuery.PublishedAt.Desc()). UnderlyingDB(). Preload("ContentAssets"). @@ -554,15 +606,19 @@ func (s *content) toMediaURLs(assets []*models.ContentAsset) []content_dto.Media return urls } -func (s *content) addInteract(ctx context.Context, userID, contentId int64, typ string) error { +func (s *content) addInteract(ctx context.Context, tenantID, userID, contentId int64, typ string) error { if userID == 0 { return errorx.ErrUnauthorized } uid := userID // Fetch content for author - c, err := models.ContentQuery.WithContext(ctx). - Where(models.ContentQuery.ID.Eq(contentId)). + query := models.ContentQuery.WithContext(ctx). + Where(models.ContentQuery.ID.Eq(contentId)) + if tenantID > 0 { + query = query.Where(models.ContentQuery.TenantID.Eq(tenantID)) + } + c, err := query. Select(models.ContentQuery.UserID, models.ContentQuery.Title). First() if err != nil { @@ -583,7 +639,11 @@ func (s *content) addInteract(ctx context.Context, userID, contentId int64, typ } if typ == "like" { - _, err := tx.Content.WithContext(ctx).Where(tx.Content.ID.Eq(contentId)).UpdateSimple(tx.Content.Likes.Add(1)) + contentQuery := tx.Content.WithContext(ctx).Where(tx.Content.ID.Eq(contentId)) + if tenantID > 0 { + contentQuery = contentQuery.Where(tx.Content.TenantID.Eq(tenantID)) + } + _, err := contentQuery.UpdateSimple(tx.Content.Likes.Add(1)) return err } return nil @@ -605,7 +665,7 @@ func (s *content) addInteract(ctx context.Context, userID, contentId int64, typ return nil } -func (s *content) removeInteract(ctx context.Context, userID, contentId int64, typ string) error { +func (s *content) removeInteract(ctx context.Context, tenantID, userID, contentId int64, typ string) error { if userID == 0 { return errorx.ErrUnauthorized } @@ -623,14 +683,18 @@ func (s *content) removeInteract(ctx context.Context, userID, contentId int64, t } if typ == "like" { - _, err := tx.Content.WithContext(ctx).Where(tx.Content.ID.Eq(contentId)).UpdateSimple(tx.Content.Likes.Sub(1)) + contentQuery := tx.Content.WithContext(ctx).Where(tx.Content.ID.Eq(contentId)) + if tenantID > 0 { + contentQuery = contentQuery.Where(tx.Content.TenantID.Eq(tenantID)) + } + _, err := contentQuery.UpdateSimple(tx.Content.Likes.Sub(1)) return err } return nil }) } -func (s *content) getInteractList(ctx context.Context, userID int64, typ string) ([]user_dto.ContentItem, error) { +func (s *content) getInteractList(ctx context.Context, tenantID, userID int64, typ string) ([]user_dto.ContentItem, error) { if userID == 0 { return nil, errorx.ErrUnauthorized } @@ -653,7 +717,11 @@ func (s *content) getInteractList(ctx context.Context, userID int64, typ string) ctbl, cq := models.ContentQuery.QueryContext(ctx) var list []*models.Content - err = cq.Where(ctbl.ID.In(contentIDs...)). + cq = cq.Where(ctbl.ID.In(contentIDs...)) + if tenantID > 0 { + cq = cq.Where(ctbl.TenantID.Eq(tenantID)) + } + err = cq. UnderlyingDB(). Preload("Author"). Preload("ContentAssets.Asset"). diff --git a/backend/app/services/content_test.go b/backend/app/services/content_test.go index db0f7bc..eccd99d 100644 --- a/backend/app/services/content_test.go +++ b/backend/app/services/content_test.go @@ -41,6 +41,7 @@ func Test_Content(t *testing.T) { func (s *ContentTestSuite) Test_List() { Convey("List", s.T(), func() { ctx := s.T().Context() + tenantID := int64(1) database.Truncate(ctx, s.DB, models.TableNameContent, models.TableNameUser) // Create Author @@ -73,7 +74,7 @@ func (s *ContentTestSuite) Test_List() { Limit: 10, }, } - res, err := Content.List(ctx, filter) + res, err := Content.List(ctx, tenantID, filter) So(err, ShouldBeNil) So(res.Total, ShouldEqual, 1) items := res.Items.([]content_dto.ContentItem) @@ -86,6 +87,7 @@ func (s *ContentTestSuite) Test_List() { func (s *ContentTestSuite) Test_Get() { Convey("Get", s.T(), func() { ctx := s.T().Context() + tenantID := int64(1) database.Truncate(ctx, s.DB, models.TableNameContent, models.TableNameMediaAsset, models.TableNameContentAsset, models.TableNameUser) // Author @@ -125,7 +127,7 @@ func (s *ContentTestSuite) Test_Get() { ctx = context.WithValue(ctx, consts.CtxKeyUser, author.ID) Convey("should get detail with assets", func() { - detail, err := Content.Get(ctx, author.ID, content.ID) + detail, err := Content.Get(ctx, tenantID, author.ID, content.ID) So(err, ShouldBeNil) So(detail.Title, ShouldEqual, "Detail Content") So(detail.AuthorName, ShouldEqual, "Author1") @@ -138,6 +140,7 @@ func (s *ContentTestSuite) Test_Get() { func (s *ContentTestSuite) Test_CreateComment() { Convey("CreateComment", s.T(), func() { ctx := s.T().Context() + tenantID := int64(1) database.Truncate(ctx, s.DB, models.TableNameContent, models.TableNameComment, models.TableNameUser) // User & Content @@ -153,7 +156,7 @@ func (s *ContentTestSuite) Test_CreateComment() { form := &content_dto.CommentCreateForm{ Content: "Nice!", } - err := Content.CreateComment(ctx, u.ID, c.ID, form) + err := Content.CreateComment(ctx, tenantID, u.ID, c.ID, form) So(err, ShouldBeNil) count, _ := models.CommentQuery.WithContext(ctx).Where(models.CommentQuery.ContentID.Eq(c.ID)).Count() @@ -165,6 +168,7 @@ func (s *ContentTestSuite) Test_CreateComment() { func (s *ContentTestSuite) Test_Library() { Convey("Library", s.T(), func() { ctx := s.T().Context() + tenantID := int64(1) database.Truncate(ctx, s.DB, models.TableNameContent, models.TableNameContentAccess, models.TableNameUser, models.TableNameContentAsset, models.TableNameMediaAsset) // User @@ -192,7 +196,7 @@ func (s *ContentTestSuite) Test_Library() { }) Convey("should get library content with details", func() { - list, err := Content.GetLibrary(ctx, u.ID) + list, err := Content.GetLibrary(ctx, tenantID, u.ID) So(err, ShouldBeNil) So(len(list), ShouldEqual, 1) So(list[0].Title, ShouldEqual, "Paid Content") @@ -206,6 +210,7 @@ func (s *ContentTestSuite) Test_Library() { func (s *ContentTestSuite) Test_Interact() { Convey("Interact", s.T(), func() { ctx := s.T().Context() + tenantID := int64(1) database.Truncate(ctx, s.DB, models.TableNameContent, models.TableNameUserContentAction, models.TableNameUser) // User & Content @@ -218,7 +223,7 @@ func (s *ContentTestSuite) Test_Interact() { Convey("Like flow", func() { // Add Like - err := Content.AddLike(ctx, u.ID, c.ID) + err := Content.AddLike(ctx, tenantID, u.ID, c.ID) So(err, ShouldBeNil) // Verify count @@ -226,13 +231,13 @@ func (s *ContentTestSuite) Test_Interact() { So(cReload.Likes, ShouldEqual, 1) // Get Likes - likes, err := Content.GetLikes(ctx, u.ID) + likes, err := Content.GetLikes(ctx, tenantID, u.ID) So(err, ShouldBeNil) So(len(likes), ShouldEqual, 1) So(likes[0].ID, ShouldEqual, c.ID) // Remove Like - err = Content.RemoveLike(ctx, u.ID, c.ID) + err = Content.RemoveLike(ctx, tenantID, u.ID, c.ID) So(err, ShouldBeNil) // Verify count @@ -242,21 +247,21 @@ func (s *ContentTestSuite) Test_Interact() { Convey("Favorite flow", func() { // Add Favorite - err := Content.AddFavorite(ctx, u.ID, c.ID) + err := Content.AddFavorite(ctx, tenantID, u.ID, c.ID) So(err, ShouldBeNil) // Get Favorites - favs, err := Content.GetFavorites(ctx, u.ID) + favs, err := Content.GetFavorites(ctx, tenantID, u.ID) So(err, ShouldBeNil) So(len(favs), ShouldEqual, 1) So(favs[0].ID, ShouldEqual, c.ID) // Remove Favorite - err = Content.RemoveFavorite(ctx, u.ID, c.ID) + err = Content.RemoveFavorite(ctx, tenantID, u.ID, c.ID) So(err, ShouldBeNil) // Get Favorites - favs, err = Content.GetFavorites(ctx, u.ID) + favs, err = Content.GetFavorites(ctx, tenantID, u.ID) So(err, ShouldBeNil) So(len(favs), ShouldEqual, 0) }) @@ -266,6 +271,7 @@ func (s *ContentTestSuite) Test_Interact() { func (s *ContentTestSuite) Test_ListTopics() { Convey("ListTopics", s.T(), func() { ctx := s.T().Context() + tenantID := int64(1) database.Truncate(ctx, s.DB, models.TableNameContent, models.TableNameUser) u := &models.User{Username: "user_t", Phone: "13900000005"} @@ -280,7 +286,7 @@ func (s *ContentTestSuite) Test_ListTopics() { ) Convey("should aggregate topics", func() { - topics, err := Content.ListTopics(ctx) + topics, err := Content.ListTopics(ctx, tenantID) So(err, ShouldBeNil) So(len(topics), ShouldBeGreaterThanOrEqualTo, 2) @@ -302,6 +308,7 @@ func (s *ContentTestSuite) Test_ListTopics() { func (s *ContentTestSuite) Test_PreviewLogic() { Convey("Preview Logic", s.T(), func() { ctx := s.T().Context() + tenantID := int64(1) database.Truncate(ctx, s.DB, models.TableNameContent, models.TableNameContentAsset, models.TableNameContentAccess, models.TableNameUser, models.TableNameMediaAsset) author := &models.User{Username: "author_p", Phone: "13900000006"} @@ -324,7 +331,7 @@ func (s *ContentTestSuite) Test_PreviewLogic() { models.UserQuery.WithContext(ctx).Create(guest) guestCtx := context.WithValue(ctx, consts.CtxKeyUser, guest.ID) - detail, err := Content.Get(guestCtx, 0, c.ID) + detail, err := Content.Get(guestCtx, tenantID, 0, c.ID) So(err, ShouldBeNil) So(len(detail.MediaUrls), ShouldEqual, 1) So(detail.MediaUrls[0].URL, ShouldContainSubstring, "preview.mp4") @@ -333,7 +340,7 @@ func (s *ContentTestSuite) Test_PreviewLogic() { Convey("owner should see all", func() { ownerCtx := context.WithValue(ctx, consts.CtxKeyUser, author.ID) - detail, err := Content.Get(ownerCtx, author.ID, c.ID) + detail, err := Content.Get(ownerCtx, tenantID, author.ID, c.ID) So(err, ShouldBeNil) So(len(detail.MediaUrls), ShouldEqual, 2) So(detail.IsPurchased, ShouldBeTrue) @@ -348,7 +355,7 @@ func (s *ContentTestSuite) Test_PreviewLogic() { UserID: buyer.ID, ContentID: c.ID, Status: consts.ContentAccessStatusActive, }) - detail, err := Content.Get(buyerCtx, buyer.ID, c.ID) + detail, err := Content.Get(buyerCtx, tenantID, buyer.ID, c.ID) So(err, ShouldBeNil) So(len(detail.MediaUrls), ShouldEqual, 2) So(detail.IsPurchased, ShouldBeTrue) @@ -359,6 +366,7 @@ func (s *ContentTestSuite) Test_PreviewLogic() { func (s *ContentTestSuite) Test_ViewCounting() { Convey("ViewCounting", s.T(), func() { ctx := s.T().Context() + tenantID := int64(1) database.Truncate(ctx, s.DB, models.TableNameContent, models.TableNameUser) author := &models.User{Username: "author_v", Phone: "13900000009"} @@ -368,7 +376,7 @@ func (s *ContentTestSuite) Test_ViewCounting() { models.ContentQuery.WithContext(ctx).Create(c) Convey("should increment views", func() { - _, err := Content.Get(ctx, 0, c.ID) + _, err := Content.Get(ctx, tenantID, 0, c.ID) So(err, ShouldBeNil) cReload, _ := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.ID.Eq(c.ID)).First() diff --git a/backend/app/services/coupon_test.go b/backend/app/services/coupon_test.go index 955d1a0..84121b1 100644 --- a/backend/app/services/coupon_test.go +++ b/backend/app/services/coupon_test.go @@ -39,6 +39,7 @@ func Test_Coupon(t *testing.T) { func (s *CouponTestSuite) Test_CouponFlow() { Convey("Coupon Flow", s.T(), func() { ctx := s.T().Context() + tenantID := int64(1) database.Truncate( ctx, s.DB, @@ -83,9 +84,10 @@ func (s *CouponTestSuite) Test_CouponFlow() { Convey("should apply in Order.Create", func() { // Setup Content - c := &models.Content{UserID: 99, Title: "Test", Status: consts.ContentStatusPublished} + c := &models.Content{TenantID: tenantID, UserID: 99, Title: "Test", Status: consts.ContentStatusPublished} models.ContentQuery.WithContext(ctx).Create(c) models.ContentPriceQuery.WithContext(ctx).Create(&models.ContentPrice{ + TenantID: tenantID, ContentID: c.ID, PriceAmount: 2000, // 20.00 CNY Currency: "CNY", @@ -96,7 +98,7 @@ func (s *CouponTestSuite) Test_CouponFlow() { UserCouponID: uc.ID, } // Simulate Auth context for Order service - res, err := Order.Create(ctx, user.ID, form) + res, err := Order.Create(ctx, tenantID, user.ID, form) So(err, ShouldBeNil) // Verify Order diff --git a/backend/app/services/creator.go b/backend/app/services/creator.go index ecca1f0..65d6fde 100644 --- a/backend/app/services/creator.go +++ b/backend/app/services/creator.go @@ -31,7 +31,7 @@ var genreMap = map[string]string{ "Qinqiang": "秦腔", } -func (s *creator) Apply(ctx context.Context, userID int64, form *creator_dto.ApplyForm) error { +func (s *creator) Apply(ctx context.Context, tenantID, userID int64, form *creator_dto.ApplyForm) error { if userID == 0 { return errorx.ErrUnauthorized } @@ -72,8 +72,8 @@ func (s *creator) Apply(ctx context.Context, userID int64, form *creator_dto.App return nil } -func (s *creator) Dashboard(ctx context.Context, userID int64) (*creator_dto.DashboardStats, error) { - tid, err := s.getTenantID(ctx, userID) +func (s *creator) Dashboard(ctx context.Context, tenantID, userID int64) (*creator_dto.DashboardStats, error) { + tid, err := s.getTenantID(ctx, tenantID, userID) if err != nil { return nil, err } @@ -107,10 +107,11 @@ func (s *creator) Dashboard(ctx context.Context, userID int64) (*creator_dto.Das func (s *creator) ListContents( ctx context.Context, + tenantID int64, userID int64, filter *creator_dto.CreatorContentListFilter, ) (*requests.Pager, error) { - tid, err := s.getTenantID(ctx, userID) + tid, err := s.getTenantID(ctx, tenantID, userID) if err != nil { return nil, err } @@ -248,8 +249,8 @@ func (s *creator) ListContents( }, nil } -func (s *creator) CreateContent(ctx context.Context, userID int64, form *creator_dto.ContentCreateForm) error { - tid, err := s.getTenantID(ctx, userID) +func (s *creator) CreateContent(ctx context.Context, tenantID, userID int64, form *creator_dto.ContentCreateForm) error { + tid, err := s.getTenantID(ctx, tenantID, userID) if err != nil { return err } @@ -321,11 +322,12 @@ func (s *creator) CreateContent(ctx context.Context, userID int64, form *creator func (s *creator) UpdateContent( ctx context.Context, + tenantID int64, userID int64, id int64, form *creator_dto.ContentUpdateForm, ) error { - tid, err := s.getTenantID(ctx, userID) + tid, err := s.getTenantID(ctx, tenantID, userID) if err != nil { return err } @@ -451,8 +453,8 @@ func (s *creator) UpdateContent( }) } -func (s *creator) DeleteContent(ctx context.Context, userID, id int64) error { - tid, err := s.getTenantID(ctx, userID) +func (s *creator) DeleteContent(ctx context.Context, tenantID, userID, id int64) error { + tid, err := s.getTenantID(ctx, tenantID, userID) if err != nil { return err } @@ -472,8 +474,8 @@ func (s *creator) DeleteContent(ctx context.Context, userID, id int64) error { return nil } -func (s *creator) GetContent(ctx context.Context, userID, id int64) (*creator_dto.ContentEditDTO, error) { - tid, err := s.getTenantID(ctx, userID) +func (s *creator) GetContent(ctx context.Context, tenantID, userID, id int64) (*creator_dto.ContentEditDTO, error) { + tid, err := s.getTenantID(ctx, tenantID, userID) if err != nil { return nil, err } @@ -548,10 +550,11 @@ func (s *creator) GetContent(ctx context.Context, userID, id int64) (*creator_dt func (s *creator) ListOrders( ctx context.Context, + tenantID int64, userID int64, filter *creator_dto.CreatorOrderListFilter, ) ([]creator_dto.Order, error) { - tid, err := s.getTenantID(ctx, userID) + tid, err := s.getTenantID(ctx, tenantID, userID) if err != nil { return nil, err } @@ -634,8 +637,8 @@ func (s *creator) ListOrders( return data, nil } -func (s *creator) ProcessRefund(ctx context.Context, userID, id int64, form *creator_dto.RefundForm) error { - tid, err := s.getTenantID(ctx, userID) +func (s *creator) ProcessRefund(ctx context.Context, tenantID, userID, id int64, form *creator_dto.RefundForm) error { + tid, err := s.getTenantID(ctx, tenantID, userID) if err != nil { return err } @@ -738,8 +741,8 @@ func (s *creator) ProcessRefund(ctx context.Context, userID, id int64, form *cre return errorx.ErrBadRequest.WithMsg("无效的操作") } -func (s *creator) GetSettings(ctx context.Context, userID int64) (*creator_dto.Settings, error) { - tid, err := s.getTenantID(ctx, userID) +func (s *creator) GetSettings(ctx context.Context, tenantID, userID int64) (*creator_dto.Settings, error) { + tid, err := s.getTenantID(ctx, tenantID, userID) if err != nil { return nil, err } @@ -758,8 +761,8 @@ func (s *creator) GetSettings(ctx context.Context, userID int64) (*creator_dto.S }, nil } -func (s *creator) UpdateSettings(ctx context.Context, userID int64, form *creator_dto.Settings) error { - tid, err := s.getTenantID(ctx, userID) +func (s *creator) UpdateSettings(ctx context.Context, tenantID, userID int64, form *creator_dto.Settings) error { + tid, err := s.getTenantID(ctx, tenantID, userID) if err != nil { return err } @@ -782,8 +785,8 @@ func (s *creator) UpdateSettings(ctx context.Context, userID int64, form *creato return err } -func (s *creator) ListPayoutAccounts(ctx context.Context, userID int64) ([]creator_dto.PayoutAccount, error) { - tid, err := s.getTenantID(ctx, userID) +func (s *creator) ListPayoutAccounts(ctx context.Context, tenantID, userID int64) ([]creator_dto.PayoutAccount, error) { + tid, err := s.getTenantID(ctx, tenantID, userID) if err != nil { return nil, err } @@ -806,8 +809,8 @@ func (s *creator) ListPayoutAccounts(ctx context.Context, userID int64) ([]creat return data, nil } -func (s *creator) AddPayoutAccount(ctx context.Context, userID int64, form *creator_dto.PayoutAccount) error { - tid, err := s.getTenantID(ctx, userID) +func (s *creator) AddPayoutAccount(ctx context.Context, tenantID, userID int64, form *creator_dto.PayoutAccount) error { + tid, err := s.getTenantID(ctx, tenantID, userID) if err != nil { return err } @@ -827,8 +830,8 @@ func (s *creator) AddPayoutAccount(ctx context.Context, userID int64, form *crea return nil } -func (s *creator) RemovePayoutAccount(ctx context.Context, userID, id int64) error { - tid, err := s.getTenantID(ctx, userID) +func (s *creator) RemovePayoutAccount(ctx context.Context, tenantID, userID, id int64) error { + tid, err := s.getTenantID(ctx, tenantID, userID) if err != nil { return err } @@ -842,8 +845,8 @@ func (s *creator) RemovePayoutAccount(ctx context.Context, userID, id int64) err return nil } -func (s *creator) Withdraw(ctx context.Context, userID int64, form *creator_dto.WithdrawForm) error { - tid, err := s.getTenantID(ctx, userID) +func (s *creator) Withdraw(ctx context.Context, tenantID, userID int64, form *creator_dto.WithdrawForm) error { + tid, err := s.getTenantID(ctx, tenantID, userID) if err != nil { return err } @@ -920,7 +923,7 @@ func (s *creator) Withdraw(ctx context.Context, userID int64, form *creator_dto. // Helpers -func (s *creator) getTenantID(ctx context.Context, userID int64) (int64, error) { +func (s *creator) getTenantID(ctx context.Context, tenantID, userID int64) (int64, error) { if userID == 0 { return 0, errorx.ErrUnauthorized } @@ -934,5 +937,8 @@ func (s *creator) getTenantID(ctx context.Context, userID int64) (int64, error) } return 0, errorx.ErrDatabaseError.WithCause(err) } + if tenantID > 0 && t.ID != tenantID { + return 0, errorx.ErrPermissionDenied.WithMsg("无权限访问该租户") + } return t.ID, nil } diff --git a/backend/app/services/creator_test.go b/backend/app/services/creator_test.go index b82d938..11f1237 100644 --- a/backend/app/services/creator_test.go +++ b/backend/app/services/creator_test.go @@ -40,6 +40,7 @@ func Test_Creator(t *testing.T) { func (s *CreatorTestSuite) Test_Apply() { Convey("Apply", s.T(), func() { ctx := s.T().Context() + tenantID := int64(0) database.Truncate(ctx, s.DB, models.TableNameTenant, models.TableNameTenantUser, models.TableNameUser) u := &models.User{Username: "creator1", Phone: "13700000001"} @@ -50,7 +51,7 @@ func (s *CreatorTestSuite) Test_Apply() { form := &creator_dto.ApplyForm{ Name: "My Channel", } - err := Creator.Apply(ctx, u.ID, form) + err := Creator.Apply(ctx, tenantID, u.ID, form) So(err, ShouldBeNil) t, _ := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.UserID.Eq(u.ID)).First() @@ -72,6 +73,7 @@ func (s *CreatorTestSuite) Test_Apply() { func (s *CreatorTestSuite) Test_CreateContent() { Convey("CreateContent", s.T(), func() { ctx := s.T().Context() + tenantID := int64(0) database.Truncate( ctx, s.DB, @@ -89,6 +91,7 @@ func (s *CreatorTestSuite) Test_CreateContent() { // Create Tenant manually t := &models.Tenant{UserID: u.ID, Name: "Channel 2", Code: "123", Status: consts.TenantStatusVerified} models.TenantQuery.WithContext(ctx).Create(t) + tenantID = t.ID Convey("should create content and assets", func() { form := &creator_dto.ContentCreateForm{ @@ -97,7 +100,7 @@ func (s *CreatorTestSuite) Test_CreateContent() { Price: 9.99, // MediaIDs: ... need media asset } - err := Creator.CreateContent(ctx, u.ID, form) + err := Creator.CreateContent(ctx, tenantID, u.ID, form) So(err, ShouldBeNil) c, _ := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.Title.Eq("New Song")).First() @@ -116,6 +119,7 @@ func (s *CreatorTestSuite) Test_CreateContent() { func (s *CreatorTestSuite) Test_UpdateContent() { Convey("UpdateContent", s.T(), func() { ctx := s.T().Context() + tenantID := int64(0) database.Truncate( ctx, s.DB, @@ -132,6 +136,7 @@ func (s *CreatorTestSuite) Test_UpdateContent() { t := &models.Tenant{UserID: u.ID, Name: "Channel 3", Code: "124", Status: consts.TenantStatusVerified} models.TenantQuery.WithContext(ctx).Create(t) + tenantID = t.ID c := &models.Content{TenantID: t.ID, UserID: u.ID, Title: "Old Title", Genre: "audio"} models.ContentQuery.WithContext(ctx).Create(c) @@ -145,7 +150,7 @@ func (s *CreatorTestSuite) Test_UpdateContent() { Genre: "video", Price: &price, } - err := Creator.UpdateContent(ctx, u.ID, c.ID, form) + err := Creator.UpdateContent(ctx, tenantID, u.ID, c.ID, form) So(err, ShouldBeNil) // Verify @@ -162,6 +167,7 @@ func (s *CreatorTestSuite) Test_UpdateContent() { func (s *CreatorTestSuite) Test_Dashboard() { Convey("Dashboard", s.T(), func() { ctx := s.T().Context() + tenantID := int64(0) database.Truncate( ctx, s.DB, @@ -178,6 +184,7 @@ func (s *CreatorTestSuite) Test_Dashboard() { t := &models.Tenant{UserID: u.ID, Name: "Channel 4", Code: "125", Status: consts.TenantStatusVerified} models.TenantQuery.WithContext(ctx).Create(t) + tenantID = t.ID // Mock Data // 1. Followers @@ -198,7 +205,7 @@ func (s *CreatorTestSuite) Test_Dashboard() { ) Convey("should get stats", func() { - stats, err := Creator.Dashboard(ctx, u.ID) + stats, err := Creator.Dashboard(ctx, tenantID, u.ID) So(err, ShouldBeNil) So(stats.TotalFollowers.Value, ShouldEqual, 2) // Implementation sums 'debit_purchase' only based on my code @@ -210,6 +217,7 @@ func (s *CreatorTestSuite) Test_Dashboard() { func (s *CreatorTestSuite) Test_PayoutAccount() { Convey("PayoutAccount", s.T(), func() { ctx := s.T().Context() + tenantID := int64(0) database.Truncate(ctx, s.DB, models.TableNameTenant, models.TableNamePayoutAccount, models.TableNameUser) u := &models.User{Username: "creator5", Phone: "13700000005"} @@ -218,6 +226,7 @@ func (s *CreatorTestSuite) Test_PayoutAccount() { t := &models.Tenant{UserID: u.ID, Name: "Channel 5", Code: "126", Status: consts.TenantStatusVerified} models.TenantQuery.WithContext(ctx).Create(t) + tenantID = t.ID Convey("should CRUD payout account", func() { // Add @@ -227,21 +236,21 @@ func (s *CreatorTestSuite) Test_PayoutAccount() { Account: "user@example.com", Realname: "John Doe", } - err := Creator.AddPayoutAccount(ctx, u.ID, form) + err := Creator.AddPayoutAccount(ctx, tenantID, u.ID, form) So(err, ShouldBeNil) // List - list, err := Creator.ListPayoutAccounts(ctx, u.ID) + list, err := Creator.ListPayoutAccounts(ctx, tenantID, u.ID) So(err, ShouldBeNil) So(len(list), ShouldEqual, 1) So(list[0].Account, ShouldEqual, "user@example.com") // Remove - err = Creator.RemovePayoutAccount(ctx, u.ID, list[0].ID) + err = Creator.RemovePayoutAccount(ctx, tenantID, u.ID, list[0].ID) So(err, ShouldBeNil) // Verify Empty - list, err = Creator.ListPayoutAccounts(ctx, u.ID) + list, err = Creator.ListPayoutAccounts(ctx, tenantID, u.ID) So(err, ShouldBeNil) So(len(list), ShouldEqual, 0) }) @@ -251,6 +260,7 @@ func (s *CreatorTestSuite) Test_PayoutAccount() { func (s *CreatorTestSuite) Test_Withdraw() { Convey("Withdraw", s.T(), func() { ctx := s.T().Context() + tenantID := int64(0) database.Truncate( ctx, s.DB, @@ -267,6 +277,7 @@ func (s *CreatorTestSuite) Test_Withdraw() { t := &models.Tenant{UserID: u.ID, Name: "Channel 6", Code: "127", Status: consts.TenantStatusVerified} models.TenantQuery.WithContext(ctx).Create(t) + tenantID = t.ID pa := &models.PayoutAccount{ TenantID: t.ID, @@ -283,7 +294,7 @@ func (s *CreatorTestSuite) Test_Withdraw() { Amount: 20.00, AccountID: pa.ID, } - err := Creator.Withdraw(ctx, u.ID, form) + err := Creator.Withdraw(ctx, tenantID, u.ID, form) So(err, ShouldBeNil) // Verify Balance Deducted @@ -308,7 +319,7 @@ func (s *CreatorTestSuite) Test_Withdraw() { Amount: 100.00, AccountID: pa.ID, } - err := Creator.Withdraw(ctx, u.ID, form) + err := Creator.Withdraw(ctx, tenantID, u.ID, form) So(err, ShouldNotBeNil) }) }) @@ -317,6 +328,7 @@ func (s *CreatorTestSuite) Test_Withdraw() { func (s *CreatorTestSuite) Test_Refund() { Convey("Refund", s.T(), func() { ctx := s.T().Context() + tenantID := int64(0) database.Truncate(ctx, s.DB, models.TableNameTenant, models.TableNameUser, models.TableNameOrder, models.TableNameOrderItem, models.TableNameContentAccess, models.TableNameTenantLedger, @@ -330,6 +342,7 @@ func (s *CreatorTestSuite) Test_Refund() { // Tenant t := &models.Tenant{UserID: creator.ID, Name: "Channel 7", Code: "128", Status: consts.TenantStatusVerified} models.TenantQuery.WithContext(ctx).Create(t) + tenantID = t.ID // Buyer buyer := &models.User{Username: "buyer7", Phone: "13900000007", Balance: 0} @@ -349,7 +362,7 @@ func (s *CreatorTestSuite) Test_Refund() { Convey("should accept refund", func() { form := &creator_dto.RefundForm{Action: "accept", Reason: "Defective"} - err := Creator.ProcessRefund(ctx, creator.ID, o.ID, form) + err := Creator.ProcessRefund(ctx, tenantID, creator.ID, o.ID, form) So(err, ShouldBeNil) // Verify Order diff --git a/backend/app/services/order.go b/backend/app/services/order.go index 722471c..02099e7 100644 --- a/backend/app/services/order.go +++ b/backend/app/services/order.go @@ -21,14 +21,19 @@ import ( // @provider type order struct{} -func (s *order) ListUserOrders(ctx context.Context, userID int64, status string) ([]user_dto.Order, error) { +func (s *order) ListUserOrders(ctx context.Context, tenantID, userID int64, status string) ([]user_dto.Order, error) { if userID == 0 { return nil, errorx.ErrUnauthorized } uid := userID tbl, q := models.OrderQuery.QueryContext(ctx) - q = q.Where(tbl.UserID.Eq(uid)) + if tenantID > 0 { + q = q.Where(tbl.UserID.Eq(uid), tbl.TenantID.Eq(tenantID)). + Or(tbl.UserID.Eq(uid), tbl.Type.Eq(consts.OrderTypeRecharge)) + } else { + q = q.Where(tbl.UserID.Eq(uid)) + } if status != "" && status != "all" { q = q.Where(tbl.Status.Eq(consts.OrderStatus(status))) @@ -46,14 +51,21 @@ func (s *order) ListUserOrders(ctx context.Context, userID int64, status string) return data, nil } -func (s *order) GetUserOrder(ctx context.Context, userID, id int64) (*user_dto.Order, error) { +func (s *order) GetUserOrder(ctx context.Context, tenantID, userID, id int64) (*user_dto.Order, error) { if userID == 0 { return nil, errorx.ErrUnauthorized } uid := userID tbl, q := models.OrderQuery.QueryContext(ctx) - item, err := q.Where(tbl.ID.Eq(id), tbl.UserID.Eq(uid)).First() + itemQuery := q + if tenantID > 0 { + itemQuery = itemQuery.Where(tbl.ID.Eq(id), tbl.UserID.Eq(uid), tbl.TenantID.Eq(tenantID)). + Or(tbl.ID.Eq(id), tbl.UserID.Eq(uid), tbl.Type.Eq(consts.OrderTypeRecharge)) + } else { + itemQuery = itemQuery.Where(tbl.ID.Eq(id), tbl.UserID.Eq(uid)) + } + item, err := itemQuery.First() if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errorx.ErrRecordNotFound @@ -70,6 +82,7 @@ func (s *order) GetUserOrder(ctx context.Context, userID, id int64) (*user_dto.O func (s *order) Create( ctx context.Context, + tenantID int64, userID int64, form *transaction_dto.OrderCreateForm, ) (*transaction_dto.OrderCreateResponse, error) { @@ -86,7 +99,11 @@ func (s *order) Create( } if idempotencyKey != "" { tbl, q := models.OrderQuery.QueryContext(ctx) - existing, err := q.Where(tbl.UserID.Eq(uid), tbl.IdempotencyKey.Eq(idempotencyKey)).First() + q = q.Where(tbl.UserID.Eq(uid), tbl.IdempotencyKey.Eq(idempotencyKey)) + if tenantID > 0 { + q = q.Where(tbl.TenantID.Eq(tenantID)) + } + existing, err := q.First() if err == nil { return &transaction_dto.OrderCreateResponse{OrderID: existing.ID}, nil } @@ -96,7 +113,11 @@ func (s *order) Create( } // 1. Fetch Content & Price - content, err := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.ID.Eq(cid)).First() + contentQuery := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.ID.Eq(cid)) + if tenantID > 0 { + contentQuery = contentQuery.Where(models.ContentQuery.TenantID.Eq(tenantID)) + } + content, err := contentQuery.First() if err != nil { return nil, errorx.ErrRecordNotFound.WithMsg("内容不存在") } @@ -188,6 +209,7 @@ func (s *order) Create( func (s *order) Pay( ctx context.Context, + tenantID int64, userID int64, id int64, form *transaction_dto.OrderPayForm, @@ -204,6 +226,9 @@ func (s *order) Pay( if err != nil { return nil, errorx.ErrRecordNotFound } + if tenantID > 0 && o.TenantID > 0 && o.TenantID != tenantID { + return nil, errorx.ErrForbidden.WithMsg("租户不匹配") + } if o.Status != consts.OrderStatusCreated { return nil, errorx.ErrStatusConflict.WithMsg("订单状态不可支付") } @@ -219,11 +244,14 @@ func (s *order) Pay( } // ProcessExternalPayment handles callback from payment gateway -func (s *order) ProcessExternalPayment(ctx context.Context, orderID int64, externalID string) error { +func (s *order) ProcessExternalPayment(ctx context.Context, tenantID, orderID int64, externalID string) error { o, err := models.OrderQuery.WithContext(ctx).Where(models.OrderQuery.ID.Eq(orderID)).First() if err != nil { return errorx.ErrRecordNotFound } + if tenantID > 0 && o.TenantID > 0 && o.TenantID != tenantID { + return errorx.ErrForbidden.WithMsg("租户不匹配") + } if o.Status != consts.OrderStatusCreated { return nil // Already processed idempotency } @@ -365,7 +393,7 @@ func (s *order) settleOrder(ctx context.Context, o *models.Order, method, extern return nil } -func (s *order) Status(ctx context.Context, id int64) (*transaction_dto.OrderStatusResponse, error) { +func (s *order) Status(ctx context.Context, tenantID, id int64) (*transaction_dto.OrderStatusResponse, error) { o, err := models.OrderQuery.WithContext(ctx).Where(models.OrderQuery.ID.Eq(id)).First() if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -373,6 +401,9 @@ func (s *order) Status(ctx context.Context, id int64) (*transaction_dto.OrderSta } return nil, errorx.ErrDatabaseError.WithCause(err) } + if tenantID > 0 && o.TenantID > 0 && o.TenantID != tenantID { + return nil, errorx.ErrForbidden.WithMsg("租户不匹配") + } return &transaction_dto.OrderStatusResponse{ Status: string(o.Status), diff --git a/backend/app/services/order_test.go b/backend/app/services/order_test.go index d01adfb..40f66ee 100644 --- a/backend/app/services/order_test.go +++ b/backend/app/services/order_test.go @@ -39,6 +39,7 @@ func Test_Order(t *testing.T) { func (s *OrderTestSuite) Test_PurchaseFlow() { Convey("Purchase Flow", s.T(), func() { ctx := s.T().Context() + tenantID := int64(0) database.Truncate(ctx, s.DB, models.TableNameOrder, models.TableNameOrderItem, models.TableNameUser, models.TableNameContent, models.TableNameContentPrice, models.TableNameTenant, @@ -57,6 +58,7 @@ func (s *OrderTestSuite) Test_PurchaseFlow() { Status: consts.TenantStatusVerified, } models.TenantQuery.WithContext(ctx).Create(tenant) + tenantID = tenant.ID // Content content := &models.Content{ TenantID: tenant.ID, @@ -83,7 +85,7 @@ func (s *OrderTestSuite) Test_PurchaseFlow() { Convey("should create and pay order successfully", func() { // Step 1: Create Order form := &order_dto.OrderCreateForm{ContentID: content.ID} - createRes, err := Order.Create(ctx, buyer.ID, form) + createRes, err := Order.Create(ctx, tenantID, buyer.ID, form) So(err, ShouldBeNil) So(createRes.OrderID, ShouldNotBeEmpty) @@ -95,7 +97,7 @@ func (s *OrderTestSuite) Test_PurchaseFlow() { // Step 2: Pay Order payForm := &order_dto.OrderPayForm{Method: "balance"} - _, err = Order.Pay(ctx, buyer.ID, createRes.OrderID, payForm) + _, err = Order.Pay(ctx, tenantID, buyer.ID, createRes.OrderID, payForm) So(err, ShouldBeNil) // Verify Order Paid @@ -130,11 +132,11 @@ func (s *OrderTestSuite) Test_PurchaseFlow() { Update(models.UserQuery.Balance, 500) form := &order_dto.OrderCreateForm{ContentID: content.ID} - createRes, err := Order.Create(ctx, buyer.ID, form) + createRes, err := Order.Create(ctx, tenantID, buyer.ID, form) So(err, ShouldBeNil) payForm := &order_dto.OrderPayForm{Method: "balance"} - _, err = Order.Pay(ctx, buyer.ID, createRes.OrderID, payForm) + _, err = Order.Pay(ctx, tenantID, buyer.ID, createRes.OrderID, payForm) So(err, ShouldNotBeNil) // Error should be QuotaExceeded or similar }) @@ -144,6 +146,7 @@ func (s *OrderTestSuite) Test_PurchaseFlow() { func (s *OrderTestSuite) Test_OrderDetails() { Convey("Order Details", s.T(), func() { ctx := s.T().Context() + tenantID := int64(0) database.Truncate( ctx, s.DB, @@ -164,6 +167,7 @@ func (s *OrderTestSuite) Test_OrderDetails() { models.UserQuery.WithContext(ctx).Create(creator) tenant := &models.Tenant{UserID: creator.ID, Name: "Best Shop", Status: consts.TenantStatusVerified} models.TenantQuery.WithContext(ctx).Create(tenant) + tenantID = tenant.ID content := &models.Content{ TenantID: tenant.ID, UserID: creator.ID, @@ -199,13 +203,14 @@ func (s *OrderTestSuite) Test_OrderDetails() { // Create & Pay createRes, _ := Order.Create( ctx, + tenantID, buyer.ID, &order_dto.OrderCreateForm{ContentID: content.ID}, ) - Order.Pay(ctx, buyer.ID, createRes.OrderID, &order_dto.OrderPayForm{Method: "balance"}) + Order.Pay(ctx, tenantID, buyer.ID, createRes.OrderID, &order_dto.OrderPayForm{Method: "balance"}) // Get Detail - detail, err := Order.GetUserOrder(ctx, buyer.ID, createRes.OrderID) + detail, err := Order.GetUserOrder(ctx, tenantID, buyer.ID, createRes.OrderID) So(err, ShouldBeNil) So(detail.TenantName, ShouldEqual, "Best Shop") So(len(detail.Items), ShouldEqual, 1) @@ -219,6 +224,7 @@ func (s *OrderTestSuite) Test_OrderDetails() { func (s *OrderTestSuite) Test_PlatformCommission() { Convey("Platform Commission", s.T(), func() { ctx := s.T().Context() + tenantID := int64(0) database.Truncate( ctx, s.DB, @@ -236,6 +242,7 @@ func (s *OrderTestSuite) Test_PlatformCommission() { // Tenant t := &models.Tenant{UserID: creator.ID, Name: "Shop C", Status: consts.TenantStatusVerified} models.TenantQuery.WithContext(ctx).Create(t) + tenantID = t.ID // Buyer buyer := &models.User{Username: "buyer_c", Balance: 2000} models.UserQuery.WithContext(ctx).Create(buyer) @@ -253,7 +260,7 @@ func (s *OrderTestSuite) Test_PlatformCommission() { Convey("should deduct 10% fee", func() { payForm := &order_dto.OrderPayForm{Method: "balance"} - _, err := Order.Pay(ctx, buyer.ID, o.ID, payForm) + _, err := Order.Pay(ctx, tenantID, buyer.ID, o.ID, payForm) So(err, ShouldBeNil) // Verify Creator Balance (1000 - 10% = 900) @@ -270,6 +277,7 @@ func (s *OrderTestSuite) Test_PlatformCommission() { func (s *OrderTestSuite) Test_ExternalPayment() { Convey("External Payment", s.T(), func() { ctx := s.T().Context() + tenantID := int64(0) database.Truncate( ctx, s.DB, @@ -287,6 +295,7 @@ func (s *OrderTestSuite) Test_ExternalPayment() { // Tenant t := &models.Tenant{UserID: creator.ID, Name: "Shop Ext", Status: consts.TenantStatusVerified} models.TenantQuery.WithContext(ctx).Create(t) + tenantID = t.ID // Buyer (Balance 0) buyer := &models.User{Username: "buyer_ext", Balance: 0} models.UserQuery.WithContext(ctx).Create(buyer) @@ -302,7 +311,7 @@ func (s *OrderTestSuite) Test_ExternalPayment() { models.OrderItemQuery.WithContext(ctx).Create(&models.OrderItem{OrderID: o.ID, ContentID: 999}) Convey("should process external payment callback", func() { - err := Order.ProcessExternalPayment(ctx, o.ID, "ext_tx_id_123") + err := Order.ProcessExternalPayment(ctx, tenantID, o.ID, "ext_tx_id_123") So(err, ShouldBeNil) // Verify Status diff --git a/backend/app/services/super.go b/backend/app/services/super.go index 4bf902c..33c91c9 100644 --- a/backend/app/services/super.go +++ b/backend/app/services/super.go @@ -704,7 +704,7 @@ func (s *super) RefundOrder(ctx context.Context, id int64, form *super_dto.Super return errorx.ErrRecordNotFound.WithMsg("租户不存在") } - return Creator.ProcessRefund(ctx, t.UserID, id, &v1_dto.RefundForm{ + return Creator.ProcessRefund(ctx, t.ID, t.UserID, id, &v1_dto.RefundForm{ Action: "accept", Reason: form.Reason, }) diff --git a/backend/app/services/tenant.go b/backend/app/services/tenant.go index 563910c..5c58e60 100644 --- a/backend/app/services/tenant.go +++ b/backend/app/services/tenant.go @@ -17,9 +17,12 @@ import ( // @provider type tenant struct{} -func (s *tenant) List(ctx context.Context, filter *dto.TenantListFilter) (*requests.Pager, error) { +func (s *tenant) List(ctx context.Context, tenantID int64, filter *dto.TenantListFilter) (*requests.Pager, error) { tbl, q := models.TenantQuery.QueryContext(ctx) q = q.Where(tbl.Status.Eq(consts.TenantStatusVerified)) + if tenantID > 0 { + q = q.Where(tbl.ID.Eq(tenantID)) + } if filter.Keyword != nil && *filter.Keyword != "" { q = q.Where(tbl.Name.Like("%" + *filter.Keyword + "%")) @@ -73,8 +76,8 @@ func (s *tenant) List(ctx context.Context, filter *dto.TenantListFilter) (*reque }, nil } -func (s *tenant) GetPublicProfile(ctx context.Context, userID, id int64) (*dto.TenantProfile, error) { - t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.ID.Eq(id)).First() +func (s *tenant) GetPublicProfile(ctx context.Context, tenantID, userID int64) (*dto.TenantProfile, error) { + t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.ID.Eq(tenantID)).First() if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errorx.ErrRecordNotFound @@ -83,9 +86,9 @@ func (s *tenant) GetPublicProfile(ctx context.Context, userID, id int64) (*dto.T } // Stats - followers, _ := models.TenantUserQuery.WithContext(ctx).Where(models.TenantUserQuery.TenantID.Eq(id)).Count() + followers, _ := models.TenantUserQuery.WithContext(ctx).Where(models.TenantUserQuery.TenantID.Eq(tenantID)).Count() contents, _ := models.ContentQuery.WithContext(ctx). - Where(models.ContentQuery.TenantID.Eq(id), models.ContentQuery.Status.Eq(consts.ContentStatusPublished)). + Where(models.ContentQuery.TenantID.Eq(tenantID), models.ContentQuery.Status.Eq(consts.ContentStatusPublished)). Count() // Following status @@ -93,7 +96,7 @@ func (s *tenant) GetPublicProfile(ctx context.Context, userID, id int64) (*dto.T if userID > 0 { uid := userID isFollowing, _ = models.TenantUserQuery.WithContext(ctx). - Where(models.TenantUserQuery.TenantID.Eq(id), models.TenantUserQuery.UserID.Eq(uid)). + Where(models.TenantUserQuery.TenantID.Eq(tenantID), models.TenantUserQuery.UserID.Eq(uid)). Exists() } @@ -113,20 +116,20 @@ func (s *tenant) GetPublicProfile(ctx context.Context, userID, id int64) (*dto.T }, nil } -func (s *tenant) Follow(ctx context.Context, userID, id int64) error { +func (s *tenant) Follow(ctx context.Context, tenantID, userID int64) error { if userID == 0 { return errorx.ErrUnauthorized } uid := userID // Check if tenant exists - t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.ID.Eq(id)).First() + t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.ID.Eq(tenantID)).First() if err != nil { return errorx.ErrRecordNotFound } tu := &models.TenantUser{ - TenantID: id, + TenantID: tenantID, UserID: uid, Role: types.Array[consts.TenantUserRole]{consts.TenantUserRoleMember}, Status: consts.UserStatusVerified, @@ -142,14 +145,14 @@ func (s *tenant) Follow(ctx context.Context, userID, id int64) error { return nil } -func (s *tenant) Unfollow(ctx context.Context, userID, id int64) error { +func (s *tenant) Unfollow(ctx context.Context, tenantID, userID int64) error { if userID == 0 { return errorx.ErrUnauthorized } uid := userID _, err := models.TenantUserQuery.WithContext(ctx). - Where(models.TenantUserQuery.TenantID.Eq(id), models.TenantUserQuery.UserID.Eq(uid)). + Where(models.TenantUserQuery.TenantID.Eq(tenantID), models.TenantUserQuery.UserID.Eq(uid)). Delete() if err != nil { return errorx.ErrDatabaseError.WithCause(err) @@ -157,14 +160,18 @@ func (s *tenant) Unfollow(ctx context.Context, userID, id int64) error { return nil } -func (s *tenant) ListFollowed(ctx context.Context, userID int64) ([]dto.TenantProfile, error) { +func (s *tenant) ListFollowed(ctx context.Context, tenantID, userID int64) ([]dto.TenantProfile, error) { if userID == 0 { return nil, errorx.ErrUnauthorized } uid := userID tbl, q := models.TenantUserQuery.QueryContext(ctx) - list, err := q.Where(tbl.UserID.Eq(uid)).Find() + q = q.Where(tbl.UserID.Eq(uid)) + if tenantID > 0 { + q = q.Where(tbl.TenantID.Eq(tenantID)) + } + list, err := q.Find() if err != nil { return nil, errorx.ErrDatabaseError.WithCause(err) } diff --git a/backend/app/services/tenant_test.go b/backend/app/services/tenant_test.go index 9087712..26423aa 100644 --- a/backend/app/services/tenant_test.go +++ b/backend/app/services/tenant_test.go @@ -39,6 +39,7 @@ func Test_Tenant(t *testing.T) { func (s *TenantTestSuite) Test_Follow() { Convey("Follow Flow", s.T(), func() { ctx := s.T().Context() + tenantID := int64(0) database.Truncate(ctx, s.DB, models.TableNameTenant, models.TableNameTenantUser, models.TableNameUser) // User @@ -49,29 +50,30 @@ func (s *TenantTestSuite) Test_Follow() { // Tenant t := &models.Tenant{Name: "Tenant A", Status: consts.TenantStatusVerified} models.TenantQuery.WithContext(ctx).Create(t) + tenantID = t.ID Convey("should follow tenant", func() { - err := Tenant.Follow(ctx, u.ID, t.ID) + err := Tenant.Follow(ctx, tenantID, u.ID) So(err, ShouldBeNil) // Verify stats - profile, err := Tenant.GetPublicProfile(ctx, u.ID, t.ID) + profile, err := Tenant.GetPublicProfile(ctx, tenantID, u.ID) So(err, ShouldBeNil) So(profile.IsFollowing, ShouldBeTrue) So(profile.Stats.Followers, ShouldEqual, 1) // List Followed - list, err := Tenant.ListFollowed(ctx, u.ID) + list, err := Tenant.ListFollowed(ctx, tenantID, u.ID) So(err, ShouldBeNil) So(len(list), ShouldEqual, 1) So(list[0].Name, ShouldEqual, "Tenant A") // Unfollow - err = Tenant.Unfollow(ctx, u.ID, t.ID) + err = Tenant.Unfollow(ctx, tenantID, u.ID) So(err, ShouldBeNil) // Verify - profile, err = Tenant.GetPublicProfile(ctx, u.ID, t.ID) + profile, err = Tenant.GetPublicProfile(ctx, tenantID, u.ID) So(err, ShouldBeNil) So(profile.IsFollowing, ShouldBeFalse) So(profile.Stats.Followers, ShouldEqual, 0) diff --git a/backend/app/services/user.go b/backend/app/services/user.go index 1b5146e..a763f97 100644 --- a/backend/app/services/user.go +++ b/backend/app/services/user.go @@ -31,7 +31,7 @@ func (s *user) SendOTP(ctx context.Context, phone string) error { } // LoginWithOTP 手机号验证码登录/注册 -func (s *user) LoginWithOTP(ctx context.Context, phone, otp string) (*auth_dto.LoginResponse, error) { +func (s *user) LoginWithOTP(ctx context.Context, tenantID int64, phone, otp string) (*auth_dto.LoginResponse, error) { // 1. 校验验证码 (模拟:固定 123456) if otp != "1234" { return nil, errorx.ErrInvalidCredentials.WithMsg("验证码错误") @@ -67,8 +67,8 @@ func (s *user) LoginWithOTP(ctx context.Context, phone, otp string) (*auth_dto.L // 4. 生成 Token token, err := s.jwt.CreateToken(s.jwt.CreateClaims(jwt.BaseClaims{ - UserID: u.ID, - // TenantID: 0, // 初始登录无租户上下文 + UserID: u.ID, + TenantID: tenantID, })) if err != nil { return nil, errorx.ErrInternalError.WithMsg("生成令牌失败") diff --git a/backend/app/services/user_test.go b/backend/app/services/user_test.go index f45172e..60d1511 100644 --- a/backend/app/services/user_test.go +++ b/backend/app/services/user_test.go @@ -40,11 +40,12 @@ func Test_User(t *testing.T) { func (s *UserTestSuite) Test_LoginWithOTP() { Convey("LoginWithOTP", s.T(), func() { ctx := s.T().Context() + tenantID := int64(1) database.Truncate(ctx, s.DB, models.TableNameUser) Convey("should create user and login success with correct OTP", func() { phone := "13800138000" - resp, err := User.LoginWithOTP(ctx, phone, "1234") + resp, err := User.LoginWithOTP(ctx, tenantID, phone, "1234") So(err, ShouldBeNil) So(resp, ShouldNotBeNil) So(resp.Token, ShouldNotBeEmpty) @@ -55,17 +56,17 @@ func (s *UserTestSuite) Test_LoginWithOTP() { Convey("should login existing user", func() { phone := "13800138001" // Pre-create user - _, err := User.LoginWithOTP(ctx, phone, "1234") + _, err := User.LoginWithOTP(ctx, tenantID, phone, "1234") So(err, ShouldBeNil) // Login again - resp, err := User.LoginWithOTP(ctx, phone, "1234") + resp, err := User.LoginWithOTP(ctx, tenantID, phone, "1234") So(err, ShouldBeNil) So(resp.User.Phone, ShouldEqual, phone) }) Convey("should fail with incorrect OTP", func() { - resp, err := User.LoginWithOTP(ctx, "13800138002", "000000") + resp, err := User.LoginWithOTP(ctx, tenantID, "13800138002", "000000") So(err, ShouldNotBeNil) So(resp, ShouldBeNil) }) @@ -75,11 +76,12 @@ func (s *UserTestSuite) Test_LoginWithOTP() { func (s *UserTestSuite) Test_Me() { Convey("Me", s.T(), func() { ctx := s.T().Context() + tenantID := int64(1) database.Truncate(ctx, s.DB, models.TableNameUser) // Create user phone := "13800138003" - resp, _ := User.LoginWithOTP(ctx, phone, "1234") + resp, _ := User.LoginWithOTP(ctx, tenantID, phone, "1234") userID := resp.User.ID Convey("should return user profile", func() { @@ -104,10 +106,11 @@ func (s *UserTestSuite) Test_Me() { func (s *UserTestSuite) Test_Update() { Convey("Update", s.T(), func() { ctx := s.T().Context() + tenantID := int64(1) database.Truncate(ctx, s.DB, models.TableNameUser) phone := "13800138004" - resp, _ := User.LoginWithOTP(ctx, phone, "1234") + resp, _ := User.LoginWithOTP(ctx, tenantID, phone, "1234") userID := resp.User.ID ctx = context.WithValue(ctx, consts.CtxKeyUser, userID) @@ -132,10 +135,11 @@ func (s *UserTestSuite) Test_Update() { func (s *UserTestSuite) Test_RealName() { Convey("RealName", s.T(), func() { ctx := s.T().Context() + tenantID := int64(1) database.Truncate(ctx, s.DB, models.TableNameUser) phone := "13800138005" - resp, _ := User.LoginWithOTP(ctx, phone, "1234") + resp, _ := User.LoginWithOTP(ctx, tenantID, phone, "1234") userID := resp.User.ID ctx = context.WithValue(ctx, consts.CtxKeyUser, userID) @@ -157,10 +161,11 @@ func (s *UserTestSuite) Test_RealName() { func (s *UserTestSuite) Test_GetNotifications() { Convey("GetNotifications", s.T(), func() { ctx := s.T().Context() + tenantID := int64(1) database.Truncate(ctx, s.DB, models.TableNameUser, models.TableNameNotification) phone := "13800138006" - resp, _ := User.LoginWithOTP(ctx, phone, "1234") + resp, _ := User.LoginWithOTP(ctx, tenantID, phone, "1234") userID := resp.User.ID ctx = context.WithValue(ctx, consts.CtxKeyUser, userID) diff --git a/backend/app/services/wallet.go b/backend/app/services/wallet.go index bdcb3ba..a61da93 100644 --- a/backend/app/services/wallet.go +++ b/backend/app/services/wallet.go @@ -20,7 +20,7 @@ import ( // @provider type wallet struct{} -func (s *wallet) GetWallet(ctx context.Context, userID int64) (*user_dto.WalletResponse, error) { +func (s *wallet) GetWallet(ctx context.Context, tenantID, userID int64) (*user_dto.WalletResponse, error) { // Get Balance u, err := models.UserQuery.WithContext(ctx).Where(models.UserQuery.ID.Eq(userID)).First() if err != nil { @@ -33,7 +33,13 @@ func (s *wallet) GetWallet(ctx context.Context, userID int64) (*user_dto.WalletR // Get Transactions (Orders) // Both purchase (expense) and recharge (income - if paid) tbl, q := models.OrderQuery.QueryContext(ctx) - orders, err := q.Where(tbl.UserID.Eq(userID), tbl.Status.Eq(consts.OrderStatusPaid)). + if tenantID > 0 { + q = q.Where(tbl.UserID.Eq(userID), tbl.Status.Eq(consts.OrderStatusPaid), tbl.TenantID.Eq(tenantID)). + Or(tbl.UserID.Eq(userID), tbl.Status.Eq(consts.OrderStatusPaid), tbl.Type.Eq(consts.OrderTypeRecharge)) + } else { + q = q.Where(tbl.UserID.Eq(userID), tbl.Status.Eq(consts.OrderStatusPaid)) + } + orders, err := q. Order(tbl.CreatedAt.Desc()). Limit(20). // Limit to recent 20 Find() @@ -71,6 +77,7 @@ func (s *wallet) GetWallet(ctx context.Context, userID int64) (*user_dto.WalletR func (s *wallet) Recharge( ctx context.Context, + tenantID int64, userID int64, form *user_dto.RechargeForm, ) (*user_dto.RechargeResponse, error) { @@ -98,7 +105,7 @@ func (s *wallet) Recharge( // MOCK: Automatically pay for recharge order to close the loop // In production, this would be a callback from payment gateway - if err := Order.ProcessExternalPayment(ctx, order.ID, "mock_auto_pay"); err != nil { + if err := Order.ProcessExternalPayment(ctx, tenantID, order.ID, "mock_auto_pay"); err != nil { return nil, err } diff --git a/backend/app/services/wallet_test.go b/backend/app/services/wallet_test.go index 90d450f..b5d3ec5 100644 --- a/backend/app/services/wallet_test.go +++ b/backend/app/services/wallet_test.go @@ -40,6 +40,7 @@ func Test_Wallet(t *testing.T) { func (s *WalletTestSuite) Test_GetWallet() { Convey("GetWallet", s.T(), func() { ctx := s.T().Context() + tenantID := int64(1) database.Truncate(ctx, s.DB, models.TableNameUser, models.TableNameOrder) u := &models.User{Username: "wallet_user", Balance: 5000} // 50.00 @@ -58,7 +59,7 @@ func (s *WalletTestSuite) Test_GetWallet() { models.OrderQuery.WithContext(ctx).Create(o1, o2) Convey("should return balance and transactions", func() { - res, err := Wallet.GetWallet(ctx, u.ID) + res, err := Wallet.GetWallet(ctx, tenantID, u.ID) So(err, ShouldBeNil) So(res.Balance, ShouldEqual, 50.0) So(len(res.Transactions), ShouldEqual, 2) @@ -74,6 +75,7 @@ func (s *WalletTestSuite) Test_GetWallet() { func (s *WalletTestSuite) Test_Recharge() { Convey("Recharge", s.T(), func() { ctx := s.T().Context() + tenantID := int64(1) database.Truncate(ctx, s.DB, models.TableNameUser, models.TableNameOrder) u := &models.User{Username: "recharge_user"} @@ -82,7 +84,7 @@ func (s *WalletTestSuite) Test_Recharge() { Convey("should create recharge order", func() { form := &user_dto.RechargeForm{Amount: 100.0} - res, err := Wallet.Recharge(ctx, u.ID, form) + res, err := Wallet.Recharge(ctx, tenantID, u.ID, form) So(err, ShouldBeNil) So(res.OrderID, ShouldNotBeEmpty) diff --git a/frontend/portal/src/components/TopNavbar.vue b/frontend/portal/src/components/TopNavbar.vue index 3e7d211..518023c 100644 --- a/frontend/portal/src/components/TopNavbar.vue +++ b/frontend/portal/src/components/TopNavbar.vue @@ -2,16 +2,16 @@