diff --git a/backend/app/http/v1/user.go b/backend/app/http/v1/user.go index ecde09c..f38b9cf 100644 --- a/backend/app/http/v1/user.go +++ b/backend/app/http/v1/user.go @@ -316,5 +316,6 @@ func (u *User) MarkAllNotificationsRead(ctx fiber.Ctx, user *models.User) error // @Bind user local key(__ctx_user) // @Bind status query func (u *User) MyCoupons(ctx fiber.Ctx, user *models.User, status string) ([]dto.UserCouponItem, error) { - return services.Coupon.ListUserCoupons(ctx, user.ID, status) + tenantID := getTenantID(ctx) + return services.Coupon.ListUserCoupons(ctx, tenantID, user.ID, status) } diff --git a/backend/app/services/coupon.go b/backend/app/services/coupon.go index aacc1e2..bb23fb8 100644 --- a/backend/app/services/coupon.go +++ b/backend/app/services/coupon.go @@ -2,11 +2,14 @@ package services import ( "context" + "errors" "time" "quyun/v2/app/errorx" coupon_dto "quyun/v2/app/http/v1/dto" "quyun/v2/database/models" + + "gorm.io/gorm" ) // @provider @@ -14,6 +17,7 @@ type coupon struct{} func (s *coupon) ListUserCoupons( ctx context.Context, + tenantID int64, userID int64, status string, ) ([]coupon_dto.UserCouponItem, error) { @@ -32,11 +36,41 @@ func (s *coupon) ListUserCoupons( if err != nil { return nil, errorx.ErrDatabaseError.WithCause(err) } + if len(list) == 0 { + return []coupon_dto.UserCouponItem{}, nil + } - var res []coupon_dto.UserCouponItem + couponIDSet := make(map[int64]struct{}, len(list)) + couponIDs := make([]int64, 0, len(list)) for _, v := range list { - c, _ := models.CouponQuery.WithContext(ctx).Where(models.CouponQuery.ID.Eq(v.CouponID)).First() + if _, ok := couponIDSet[v.CouponID]; ok { + continue + } + couponIDs = append(couponIDs, v.CouponID) + couponIDSet[v.CouponID] = struct{}{} + } + cTbl, cQ := models.CouponQuery.QueryContext(ctx) + cQ = cQ.Where(cTbl.ID.In(couponIDs...)) + if tenantID > 0 { + cQ = cQ.Where(cTbl.TenantID.Eq(tenantID)) + } + coupons, err := cQ.Find() + if err != nil { + return nil, errorx.ErrDatabaseError.WithCause(err) + } + + couponMap := make(map[int64]*models.Coupon, len(coupons)) + for _, c := range coupons { + couponMap[c.ID] = c + } + + res := make([]coupon_dto.UserCouponItem, 0, len(list)) + for _, v := range list { + c, ok := couponMap[v.CouponID] + if !ok { + continue + } item := coupon_dto.UserCouponItem{ ID: v.ID, CouponID: v.CouponID, @@ -61,7 +95,7 @@ func (s *coupon) ListUserCoupons( } // Validate checks if a coupon can be used for an order and returns the discount amount -func (s *coupon) Validate(ctx context.Context, userID, userCouponID, amount int64) (int64, error) { +func (s *coupon) Validate(ctx context.Context, tenantID, userID, userCouponID, amount int64) (int64, error) { uc, err := models.UserCouponQuery.WithContext(ctx).Where(models.UserCouponQuery.ID.Eq(userCouponID)).First() if err != nil { return 0, errorx.ErrRecordNotFound.WithMsg("优惠券不存在") @@ -77,6 +111,9 @@ func (s *coupon) Validate(ctx context.Context, userID, userCouponID, amount int6 if err != nil { return 0, errorx.ErrRecordNotFound.WithMsg("优惠券信息缺失") } + if tenantID > 0 && c.TenantID != tenantID { + return 0, errorx.ErrForbidden.WithMsg("优惠券租户不匹配") + } now := time.Now() if !c.StartAt.IsZero() && now.Before(c.StartAt) { @@ -109,7 +146,29 @@ func (s *coupon) Validate(ctx context.Context, userID, userCouponID, amount int6 } // MarkUsed marks a user coupon as used (intended to be called inside a transaction) -func (s *coupon) MarkUsed(ctx context.Context, tx *models.Query, userCouponID, orderID int64) error { +func (s *coupon) MarkUsed(ctx context.Context, tx *models.Query, tenantID, userCouponID, orderID int64) error { + uc, err := tx.UserCoupon.WithContext(ctx).Where(tx.UserCoupon.ID.Eq(userCouponID)).First() + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return errorx.ErrRecordNotFound.WithMsg("优惠券不存在") + } + return errorx.ErrDatabaseError.WithCause(err) + } + if uc.Status != "unused" { + return errorx.ErrBusinessLogic.WithMsg("优惠券核销失败") + } + + c, err := tx.Coupon.WithContext(ctx).Where(tx.Coupon.ID.Eq(uc.CouponID)).First() + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return errorx.ErrRecordNotFound.WithMsg("优惠券信息缺失") + } + return errorx.ErrDatabaseError.WithCause(err) + } + if tenantID > 0 && c.TenantID != tenantID { + return errorx.ErrForbidden.WithMsg("优惠券租户不匹配") + } + now := time.Now() // Update User Coupon info, err := tx.UserCoupon.WithContext(ctx). @@ -127,8 +186,6 @@ func (s *coupon) MarkUsed(ctx context.Context, tx *models.Query, userCouponID, o } // Update Coupon used quantity (Optional, but good for stats) - // We need CouponID from uc - uc, _ := tx.UserCoupon.WithContext(ctx).Where(tx.UserCoupon.ID.Eq(userCouponID)).First() _, _ = tx.Coupon.WithContext(ctx).Where(tx.Coupon.ID.Eq(uc.CouponID)).UpdateSimple(tx.Coupon.UsedQuantity.Add(1)) return nil diff --git a/backend/app/services/coupon_test.go b/backend/app/services/coupon_test.go index 84121b1..8f5f71b 100644 --- a/backend/app/services/coupon_test.go +++ b/backend/app/services/coupon_test.go @@ -56,6 +56,7 @@ func (s *CouponTestSuite) Test_CouponFlow() { // 1. Create Coupon (Fixed 5.00 CNY, Min 10.00 CNY) cp := &models.Coupon{ + TenantID: tenantID, Title: "Save 5", Type: "fix_amount", Value: 500, @@ -72,13 +73,13 @@ func (s *CouponTestSuite) Test_CouponFlow() { models.UserCouponQuery.WithContext(ctx).Create(uc) Convey("should validate coupon successfully", func() { - discount, err := Coupon.Validate(ctx, user.ID, uc.ID, 1500) + discount, err := Coupon.Validate(ctx, tenantID, user.ID, uc.ID, 1500) So(err, ShouldBeNil) So(discount, ShouldEqual, 500) }) Convey("should fail if below min amount", func() { - _, err := Coupon.Validate(ctx, user.ID, uc.ID, 800) + _, err := Coupon.Validate(ctx, tenantID, user.ID, uc.ID, 800) So(err, ShouldNotBeNil) }) diff --git a/backend/app/services/order.go b/backend/app/services/order.go index 45f3739..1322bb1 100644 --- a/backend/app/services/order.go +++ b/backend/app/services/order.go @@ -136,7 +136,7 @@ func (s *order) Create( // Validate Coupon if form.UserCouponID > 0 { - discount, err := Coupon.Validate(ctx, uid, form.UserCouponID, amountOriginal) + discount, err := Coupon.Validate(ctx, tenantID, uid, form.UserCouponID, amountOriginal) if err != nil { return nil, err } @@ -188,7 +188,7 @@ func (s *order) Create( // Mark Coupon Used if form.UserCouponID > 0 { - if err := Coupon.MarkUsed(ctx, tx, form.UserCouponID, order.ID); err != nil { + if err := Coupon.MarkUsed(ctx, tx, tenantID, form.UserCouponID, order.ID); err != nil { return err } } diff --git a/backend/app/services/user.go b/backend/app/services/user.go index a763f97..74a4934 100644 --- a/backend/app/services/user.go +++ b/backend/app/services/user.go @@ -160,9 +160,12 @@ func (s *user) RealName(ctx context.Context, userID int64, form *user_dto.RealNa } // GetNotifications 获取通知 -func (s *user) GetNotifications(ctx context.Context, userID int64, typeArg string) ([]user_dto.Notification, error) { +func (s *user) GetNotifications(ctx context.Context, tenantID, userID int64, typeArg string) ([]user_dto.Notification, error) { tbl, query := models.NotificationQuery.QueryContext(ctx) query = query.Where(tbl.UserID.Eq(userID)) + if tenantID > 0 { + query = query.Where(tbl.TenantID.Eq(tenantID)) + } if typeArg != "" && typeArg != "all" { query = query.Where(tbl.Type.Eq(typeArg)) } diff --git a/backend/app/services/user_test.go b/backend/app/services/user_test.go index 60d1511..ec4e3be 100644 --- a/backend/app/services/user_test.go +++ b/backend/app/services/user_test.go @@ -171,15 +171,16 @@ func (s *UserTestSuite) Test_GetNotifications() { // Mock notifications _ = models.Q.Notification.WithContext(ctx).Create(&models.Notification{ - UserID: userID, - Type: "system", - Title: "Welcome", - Content: "Hello World", - IsRead: false, + TenantID: tenantID, + UserID: userID, + Type: "system", + Title: "Welcome", + Content: "Hello World", + IsRead: false, }) Convey("should return notifications", func() { - list, err := User.GetNotifications(ctx, userID, "all") + list, err := User.GetNotifications(ctx, tenantID, userID, "all") So(err, ShouldBeNil) So(len(list), ShouldEqual, 1) So(list[0].Title, ShouldEqual, "Welcome")