From 4d90f547e6ede1a4656dec3917fae457a31407b3 Mon Sep 17 00:00:00 2001 From: Rogee Date: Fri, 9 Jan 2026 15:49:49 +0800 Subject: [PATCH] fix: scope notifications by tenant --- backend/app/commands/testx/testing.go | 11 ++--- backend/app/http/v1/user.go | 9 +++-- backend/app/jobs/args/notification.go | 9 +++-- backend/app/jobs/notification_job.go | 11 ++--- backend/app/services/content.go | 4 +- backend/app/services/notification.go | 49 ++++++++++++++--------- backend/app/services/notification_test.go | 7 ++-- backend/app/services/order.go | 4 +- backend/app/services/tenant.go | 2 +- 9 files changed, 62 insertions(+), 44 deletions(-) diff --git a/backend/app/commands/testx/testing.go b/backend/app/commands/testx/testing.go index 056b614..eddb521 100644 --- a/backend/app/commands/testx/testing.go +++ b/backend/app/commands/testx/testing.go @@ -58,11 +58,12 @@ type notificationTestWorker struct { func (w *notificationTestWorker) Work(ctx context.Context, job *river.Job[jobs_args.NotificationArgs]) error { arg := job.Args n := &models.Notification{ - UserID: arg.UserID, - Type: arg.Type, - Title: arg.Title, - Content: arg.Content, - IsRead: false, + TenantID: arg.TenantID, + UserID: arg.UserID, + Type: arg.Type, + Title: arg.Title, + Content: arg.Content, + IsRead: false, } return models.NotificationQuery.WithContext(ctx).Create(n) } diff --git a/backend/app/http/v1/user.go b/backend/app/http/v1/user.go index 4e1370e..ecde09c 100644 --- a/backend/app/http/v1/user.go +++ b/backend/app/http/v1/user.go @@ -269,7 +269,8 @@ func (u *User) Following(ctx fiber.Ctx, user *models.User) ([]dto.TenantProfile, // @Bind typeArg query key(type) // @Bind page query func (u *User) Notifications(ctx fiber.Ctx, user *models.User, typeArg string, page int) (*requests.Pager, error) { - return services.Notification.List(ctx, user.ID, page, typeArg) + tenantID := getTenantID(ctx) + return services.Notification.List(ctx, tenantID, user.ID, page, typeArg) } // Mark notification as read @@ -284,7 +285,8 @@ func (u *User) Notifications(ctx fiber.Ctx, user *models.User, typeArg string, p // @Bind user local key(__ctx_user) // @Bind id path func (u *User) MarkNotificationRead(ctx fiber.Ctx, user *models.User, id int64) error { - return services.Notification.MarkRead(ctx, user.ID, id) + tenantID := getTenantID(ctx) + return services.Notification.MarkRead(ctx, tenantID, user.ID, id) } // Mark all notifications as read @@ -297,7 +299,8 @@ func (u *User) MarkNotificationRead(ctx fiber.Ctx, user *models.User, id int64) // @Success 200 {string} string "OK" // @Bind user local key(__ctx_user) func (u *User) MarkAllNotificationsRead(ctx fiber.Ctx, user *models.User) error { - return services.Notification.MarkAllRead(ctx, user.ID) + tenantID := getTenantID(ctx) + return services.Notification.MarkAllRead(ctx, tenantID, user.ID) } // List my coupons diff --git a/backend/app/jobs/args/notification.go b/backend/app/jobs/args/notification.go index efcc377..f7ddf8a 100644 --- a/backend/app/jobs/args/notification.go +++ b/backend/app/jobs/args/notification.go @@ -3,10 +3,11 @@ package args import "github.com/riverqueue/river" type NotificationArgs struct { - UserID int64 `json:"user_id"` - Type string `json:"type"` - Title string `json:"title"` - Content string `json:"content"` + TenantID int64 `json:"tenant_id"` + UserID int64 `json:"user_id"` + Type string `json:"type"` + Title string `json:"title"` + Content string `json:"content"` } func (NotificationArgs) Kind() string { diff --git a/backend/app/jobs/notification_job.go b/backend/app/jobs/notification_job.go index 2b20e0d..36e44c9 100644 --- a/backend/app/jobs/notification_job.go +++ b/backend/app/jobs/notification_job.go @@ -17,11 +17,12 @@ type NotificationWorker struct { func (j *NotificationWorker) Work(ctx context.Context, job *river.Job[args.NotificationArgs]) error { arg := job.Args n := &models.Notification{ - UserID: arg.UserID, - Type: arg.Type, - Title: arg.Title, - Content: arg.Content, - IsRead: false, + TenantID: arg.TenantID, + UserID: arg.UserID, + Type: arg.Type, + Title: arg.Title, + Content: arg.Content, + IsRead: false, } return models.NotificationQuery.WithContext(ctx).Create(n) } diff --git a/backend/app/services/content.go b/backend/app/services/content.go index faa14e3..f64998b 100644 --- a/backend/app/services/content.go +++ b/backend/app/services/content.go @@ -388,7 +388,7 @@ func (s *content) LikeComment(ctx context.Context, tenantID, userID, id int64) e } if Notification != nil { - _ = Notification.Send(ctx, cm.UserID, "interaction", "评论点赞", "有人点赞了您的评论") + _ = Notification.Send(ctx, tenantID, cm.UserID, "interaction", "评论点赞", "有人点赞了您的评论") } return nil } @@ -660,7 +660,7 @@ func (s *content) addInteract(ctx context.Context, tenantID, userID, contentId i case "favorite": actionName = "收藏" } - _ = Notification.Send(ctx, c.UserID, "interaction", "新的"+actionName, "有人"+actionName+"了您的作品: "+c.Title) + _ = Notification.Send(ctx, tenantID, c.UserID, "interaction", "新的"+actionName, "有人"+actionName+"了您的作品: "+c.Title) } return nil } diff --git a/backend/app/services/notification.go b/backend/app/services/notification.go index 01541f9..8a2652e 100644 --- a/backend/app/services/notification.go +++ b/backend/app/services/notification.go @@ -18,9 +18,12 @@ type notification struct { job *job.Job } -func (s *notification) List(ctx context.Context, userID int64, page int, typeArg string) (*requests.Pager, error) { +func (s *notification) List(ctx context.Context, tenantID, userID int64, page int, typeArg string) (*requests.Pager, error) { tbl, q := models.NotificationQuery.QueryContext(ctx) q = q.Where(tbl.UserID.Eq(userID)) + if tenantID > 0 { + q = q.Where(tbl.TenantID.Eq(tenantID)) + } if typeArg != "" && typeArg != "all" { q = q.Where(tbl.Type.Eq(typeArg)) @@ -58,41 +61,49 @@ func (s *notification) List(ctx context.Context, userID int64, page int, typeArg }, nil } -func (s *notification) MarkRead(ctx context.Context, userID, id int64) error { - _, err := models.NotificationQuery.WithContext(ctx). - Where(models.NotificationQuery.ID.Eq(id), models.NotificationQuery.UserID.Eq(userID)). - UpdateSimple(models.NotificationQuery.IsRead.Value(true)) +func (s *notification) MarkRead(ctx context.Context, tenantID, userID, id int64) error { + tbl, q := models.NotificationQuery.QueryContext(ctx) + q = q.Where(tbl.ID.Eq(id), tbl.UserID.Eq(userID)) + if tenantID > 0 { + q = q.Where(tbl.TenantID.Eq(tenantID)) + } + _, err := q.UpdateSimple(tbl.IsRead.Value(true)) if err != nil { return errorx.ErrDatabaseError.WithCause(err) } return nil } -func (s *notification) MarkAllRead(ctx context.Context, userID int64) error { - _, err := models.NotificationQuery.WithContext(ctx). - Where(models.NotificationQuery.UserID.Eq(userID), models.NotificationQuery.IsRead.Is(false)). - UpdateSimple(models.NotificationQuery.IsRead.Value(true)) +func (s *notification) MarkAllRead(ctx context.Context, tenantID, userID int64) error { + tbl, q := models.NotificationQuery.QueryContext(ctx) + q = q.Where(tbl.UserID.Eq(userID), tbl.IsRead.Is(false)) + if tenantID > 0 { + q = q.Where(tbl.TenantID.Eq(tenantID)) + } + _, err := q.UpdateSimple(tbl.IsRead.Value(true)) if err != nil { return errorx.ErrDatabaseError.WithCause(err) } return nil } -func (s *notification) Send(ctx context.Context, userID int64, typ, title, content string) error { +func (s *notification) Send(ctx context.Context, tenantID, userID int64, typ, title, content string) error { arg := args.NotificationArgs{ - UserID: userID, - Type: typ, - Title: title, - Content: content, + TenantID: tenantID, + UserID: userID, + Type: typ, + Title: title, + Content: content, } // 测试环境下同步写入,避免异步任务未启动导致结果不确定。 if os.Getenv("JOB_INLINE") == "1" { n := &models.Notification{ - UserID: userID, - Type: typ, - Title: title, - Content: content, - IsRead: false, + TenantID: tenantID, + UserID: userID, + Type: typ, + Title: title, + Content: content, + IsRead: false, } if err := models.NotificationQuery.WithContext(ctx).Create(n); err != nil { return errorx.ErrDatabaseError.WithCause(err) diff --git a/backend/app/services/notification_test.go b/backend/app/services/notification_test.go index 95b7d37..a028794 100644 --- a/backend/app/services/notification_test.go +++ b/backend/app/services/notification_test.go @@ -44,16 +44,17 @@ func (s *NotificationTestSuite) Test_CRUD() { ctx := s.T().Context() database.Truncate(ctx, s.DB, models.TableNameNotification) + tenantID := int64(1) uID := int64(100) ctx = context.WithValue(ctx, consts.CtxKeyUser, uID) Convey("should send notification", func() { - err := Notification.Send(ctx, uID, "system", "Welcome", "Hello World") + err := Notification.Send(ctx, tenantID, uID, "system", "Welcome", "Hello World") So(err, ShouldBeNil) var list *requests.Pager for i := 0; i < 5; i++ { - list, err = Notification.List(ctx, uID, 1, "") + list, err = Notification.List(ctx, tenantID, uID, 1, "") So(err, ShouldBeNil) if list.Total > 0 { break @@ -69,7 +70,7 @@ func (s *NotificationTestSuite) Test_CRUD() { // Mark Read // Need ID n, _ := models.NotificationQuery.WithContext(ctx).Where(models.NotificationQuery.UserID.Eq(uID)).First() - err = Notification.MarkRead(ctx, uID, n.ID) + err = Notification.MarkRead(ctx, tenantID, uID, n.ID) So(err, ShouldBeNil) nReload, _ := models.NotificationQuery.WithContext(ctx).Where(models.NotificationQuery.ID.Eq(n.ID)).First() diff --git a/backend/app/services/order.go b/backend/app/services/order.go index 02099e7..45f3739 100644 --- a/backend/app/services/order.go +++ b/backend/app/services/order.go @@ -385,9 +385,9 @@ func (s *order) settleOrder(ctx context.Context, o *models.Order, method, extern } if Notification != nil { - _ = Notification.Send(ctx, o.UserID, "order", "支付成功", "订单已支付,您可以查看已购内容。") + _ = Notification.Send(ctx, o.TenantID, o.UserID, "order", "支付成功", "订单已支付,您可以查看已购内容。") if tenantOwnerID > 0 { - _ = Notification.Send(ctx, tenantOwnerID, "order", "新的订单", "您的店铺有新的订单,收入已入账。") + _ = Notification.Send(ctx, o.TenantID, tenantOwnerID, "order", "新的订单", "您的店铺有新的订单,收入已入账。") } } return nil diff --git a/backend/app/services/tenant.go b/backend/app/services/tenant.go index 5c58e60..3c72a49 100644 --- a/backend/app/services/tenant.go +++ b/backend/app/services/tenant.go @@ -140,7 +140,7 @@ func (s *tenant) Follow(ctx context.Context, tenantID, userID int64) error { } if Notification != nil { - _ = Notification.Send(ctx, t.UserID, "interaction", "新增粉丝", "有人关注了您的店铺: "+t.Name) + _ = Notification.Send(ctx, tenantID, t.UserID, "interaction", "新增粉丝", "有人关注了您的店铺: "+t.Name) } return nil }