fix: scope notifications by tenant

This commit is contained in:
2026-01-09 15:49:49 +08:00
parent c446dec5c6
commit 4d90f547e6
9 changed files with 62 additions and 44 deletions

View File

@@ -58,11 +58,12 @@ type notificationTestWorker struct {
func (w *notificationTestWorker) Work(ctx context.Context, job *river.Job[jobs_args.NotificationArgs]) error { func (w *notificationTestWorker) Work(ctx context.Context, job *river.Job[jobs_args.NotificationArgs]) error {
arg := job.Args arg := job.Args
n := &models.Notification{ n := &models.Notification{
UserID: arg.UserID, TenantID: arg.TenantID,
Type: arg.Type, UserID: arg.UserID,
Title: arg.Title, Type: arg.Type,
Content: arg.Content, Title: arg.Title,
IsRead: false, Content: arg.Content,
IsRead: false,
} }
return models.NotificationQuery.WithContext(ctx).Create(n) return models.NotificationQuery.WithContext(ctx).Create(n)
} }

View File

@@ -269,7 +269,8 @@ func (u *User) Following(ctx fiber.Ctx, user *models.User) ([]dto.TenantProfile,
// @Bind typeArg query key(type) // @Bind typeArg query key(type)
// @Bind page query // @Bind page query
func (u *User) Notifications(ctx fiber.Ctx, user *models.User, typeArg string, page int) (*requests.Pager, error) { func (u *User) Notifications(ctx fiber.Ctx, user *models.User, typeArg string, page int) (*requests.Pager, error) {
return services.Notification.List(ctx, user.ID, page, typeArg) tenantID := getTenantID(ctx)
return services.Notification.List(ctx, tenantID, user.ID, page, typeArg)
} }
// Mark notification as read // Mark notification as read
@@ -284,7 +285,8 @@ func (u *User) Notifications(ctx fiber.Ctx, user *models.User, typeArg string, p
// @Bind user local key(__ctx_user) // @Bind user local key(__ctx_user)
// @Bind id path // @Bind id path
func (u *User) MarkNotificationRead(ctx fiber.Ctx, user *models.User, id int64) error { func (u *User) MarkNotificationRead(ctx fiber.Ctx, user *models.User, id int64) error {
return services.Notification.MarkRead(ctx, user.ID, id) tenantID := getTenantID(ctx)
return services.Notification.MarkRead(ctx, tenantID, user.ID, id)
} }
// Mark all notifications as read // Mark all notifications as read
@@ -297,7 +299,8 @@ func (u *User) MarkNotificationRead(ctx fiber.Ctx, user *models.User, id int64)
// @Success 200 {string} string "OK" // @Success 200 {string} string "OK"
// @Bind user local key(__ctx_user) // @Bind user local key(__ctx_user)
func (u *User) MarkAllNotificationsRead(ctx fiber.Ctx, user *models.User) error { func (u *User) MarkAllNotificationsRead(ctx fiber.Ctx, user *models.User) error {
return services.Notification.MarkAllRead(ctx, user.ID) tenantID := getTenantID(ctx)
return services.Notification.MarkAllRead(ctx, tenantID, user.ID)
} }
// List my coupons // List my coupons

View File

@@ -3,10 +3,11 @@ package args
import "github.com/riverqueue/river" import "github.com/riverqueue/river"
type NotificationArgs struct { type NotificationArgs struct {
UserID int64 `json:"user_id"` TenantID int64 `json:"tenant_id"`
Type string `json:"type"` UserID int64 `json:"user_id"`
Title string `json:"title"` Type string `json:"type"`
Content string `json:"content"` Title string `json:"title"`
Content string `json:"content"`
} }
func (NotificationArgs) Kind() string { func (NotificationArgs) Kind() string {

View File

@@ -17,11 +17,12 @@ type NotificationWorker struct {
func (j *NotificationWorker) Work(ctx context.Context, job *river.Job[args.NotificationArgs]) error { func (j *NotificationWorker) Work(ctx context.Context, job *river.Job[args.NotificationArgs]) error {
arg := job.Args arg := job.Args
n := &models.Notification{ n := &models.Notification{
UserID: arg.UserID, TenantID: arg.TenantID,
Type: arg.Type, UserID: arg.UserID,
Title: arg.Title, Type: arg.Type,
Content: arg.Content, Title: arg.Title,
IsRead: false, Content: arg.Content,
IsRead: false,
} }
return models.NotificationQuery.WithContext(ctx).Create(n) return models.NotificationQuery.WithContext(ctx).Create(n)
} }

View File

@@ -388,7 +388,7 @@ func (s *content) LikeComment(ctx context.Context, tenantID, userID, id int64) e
} }
if Notification != nil { if Notification != nil {
_ = Notification.Send(ctx, cm.UserID, "interaction", "评论点赞", "有人点赞了您的评论") _ = Notification.Send(ctx, tenantID, cm.UserID, "interaction", "评论点赞", "有人点赞了您的评论")
} }
return nil return nil
} }
@@ -660,7 +660,7 @@ func (s *content) addInteract(ctx context.Context, tenantID, userID, contentId i
case "favorite": case "favorite":
actionName = "收藏" actionName = "收藏"
} }
_ = Notification.Send(ctx, c.UserID, "interaction", "新的"+actionName, "有人"+actionName+"了您的作品: "+c.Title) _ = Notification.Send(ctx, tenantID, c.UserID, "interaction", "新的"+actionName, "有人"+actionName+"了您的作品: "+c.Title)
} }
return nil return nil
} }

View File

@@ -18,9 +18,12 @@ type notification struct {
job *job.Job job *job.Job
} }
func (s *notification) List(ctx context.Context, userID int64, page int, typeArg string) (*requests.Pager, error) { func (s *notification) List(ctx context.Context, tenantID, userID int64, page int, typeArg string) (*requests.Pager, error) {
tbl, q := models.NotificationQuery.QueryContext(ctx) tbl, q := models.NotificationQuery.QueryContext(ctx)
q = q.Where(tbl.UserID.Eq(userID)) q = q.Where(tbl.UserID.Eq(userID))
if tenantID > 0 {
q = q.Where(tbl.TenantID.Eq(tenantID))
}
if typeArg != "" && typeArg != "all" { if typeArg != "" && typeArg != "all" {
q = q.Where(tbl.Type.Eq(typeArg)) q = q.Where(tbl.Type.Eq(typeArg))
@@ -58,41 +61,49 @@ func (s *notification) List(ctx context.Context, userID int64, page int, typeArg
}, nil }, nil
} }
func (s *notification) MarkRead(ctx context.Context, userID, id int64) error { func (s *notification) MarkRead(ctx context.Context, tenantID, userID, id int64) error {
_, err := models.NotificationQuery.WithContext(ctx). tbl, q := models.NotificationQuery.QueryContext(ctx)
Where(models.NotificationQuery.ID.Eq(id), models.NotificationQuery.UserID.Eq(userID)). q = q.Where(tbl.ID.Eq(id), tbl.UserID.Eq(userID))
UpdateSimple(models.NotificationQuery.IsRead.Value(true)) if tenantID > 0 {
q = q.Where(tbl.TenantID.Eq(tenantID))
}
_, err := q.UpdateSimple(tbl.IsRead.Value(true))
if err != nil { if err != nil {
return errorx.ErrDatabaseError.WithCause(err) return errorx.ErrDatabaseError.WithCause(err)
} }
return nil return nil
} }
func (s *notification) MarkAllRead(ctx context.Context, userID int64) error { func (s *notification) MarkAllRead(ctx context.Context, tenantID, userID int64) error {
_, err := models.NotificationQuery.WithContext(ctx). tbl, q := models.NotificationQuery.QueryContext(ctx)
Where(models.NotificationQuery.UserID.Eq(userID), models.NotificationQuery.IsRead.Is(false)). q = q.Where(tbl.UserID.Eq(userID), tbl.IsRead.Is(false))
UpdateSimple(models.NotificationQuery.IsRead.Value(true)) if tenantID > 0 {
q = q.Where(tbl.TenantID.Eq(tenantID))
}
_, err := q.UpdateSimple(tbl.IsRead.Value(true))
if err != nil { if err != nil {
return errorx.ErrDatabaseError.WithCause(err) return errorx.ErrDatabaseError.WithCause(err)
} }
return nil return nil
} }
func (s *notification) Send(ctx context.Context, userID int64, typ, title, content string) error { func (s *notification) Send(ctx context.Context, tenantID, userID int64, typ, title, content string) error {
arg := args.NotificationArgs{ arg := args.NotificationArgs{
UserID: userID, TenantID: tenantID,
Type: typ, UserID: userID,
Title: title, Type: typ,
Content: content, Title: title,
Content: content,
} }
// 测试环境下同步写入,避免异步任务未启动导致结果不确定。 // 测试环境下同步写入,避免异步任务未启动导致结果不确定。
if os.Getenv("JOB_INLINE") == "1" { if os.Getenv("JOB_INLINE") == "1" {
n := &models.Notification{ n := &models.Notification{
UserID: userID, TenantID: tenantID,
Type: typ, UserID: userID,
Title: title, Type: typ,
Content: content, Title: title,
IsRead: false, Content: content,
IsRead: false,
} }
if err := models.NotificationQuery.WithContext(ctx).Create(n); err != nil { if err := models.NotificationQuery.WithContext(ctx).Create(n); err != nil {
return errorx.ErrDatabaseError.WithCause(err) return errorx.ErrDatabaseError.WithCause(err)

View File

@@ -44,16 +44,17 @@ func (s *NotificationTestSuite) Test_CRUD() {
ctx := s.T().Context() ctx := s.T().Context()
database.Truncate(ctx, s.DB, models.TableNameNotification) database.Truncate(ctx, s.DB, models.TableNameNotification)
tenantID := int64(1)
uID := int64(100) uID := int64(100)
ctx = context.WithValue(ctx, consts.CtxKeyUser, uID) ctx = context.WithValue(ctx, consts.CtxKeyUser, uID)
Convey("should send notification", func() { Convey("should send notification", func() {
err := Notification.Send(ctx, uID, "system", "Welcome", "Hello World") err := Notification.Send(ctx, tenantID, uID, "system", "Welcome", "Hello World")
So(err, ShouldBeNil) So(err, ShouldBeNil)
var list *requests.Pager var list *requests.Pager
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
list, err = Notification.List(ctx, uID, 1, "") list, err = Notification.List(ctx, tenantID, uID, 1, "")
So(err, ShouldBeNil) So(err, ShouldBeNil)
if list.Total > 0 { if list.Total > 0 {
break break
@@ -69,7 +70,7 @@ func (s *NotificationTestSuite) Test_CRUD() {
// Mark Read // Mark Read
// Need ID // Need ID
n, _ := models.NotificationQuery.WithContext(ctx).Where(models.NotificationQuery.UserID.Eq(uID)).First() n, _ := models.NotificationQuery.WithContext(ctx).Where(models.NotificationQuery.UserID.Eq(uID)).First()
err = Notification.MarkRead(ctx, uID, n.ID) err = Notification.MarkRead(ctx, tenantID, uID, n.ID)
So(err, ShouldBeNil) So(err, ShouldBeNil)
nReload, _ := models.NotificationQuery.WithContext(ctx).Where(models.NotificationQuery.ID.Eq(n.ID)).First() nReload, _ := models.NotificationQuery.WithContext(ctx).Where(models.NotificationQuery.ID.Eq(n.ID)).First()

View File

@@ -385,9 +385,9 @@ func (s *order) settleOrder(ctx context.Context, o *models.Order, method, extern
} }
if Notification != nil { if Notification != nil {
_ = Notification.Send(ctx, o.UserID, "order", "支付成功", "订单已支付,您可以查看已购内容。") _ = Notification.Send(ctx, o.TenantID, o.UserID, "order", "支付成功", "订单已支付,您可以查看已购内容。")
if tenantOwnerID > 0 { if tenantOwnerID > 0 {
_ = Notification.Send(ctx, tenantOwnerID, "order", "新的订单", "您的店铺有新的订单,收入已入账。") _ = Notification.Send(ctx, o.TenantID, tenantOwnerID, "order", "新的订单", "您的店铺有新的订单,收入已入账。")
} }
} }
return nil return nil

View File

@@ -140,7 +140,7 @@ func (s *tenant) Follow(ctx context.Context, tenantID, userID int64) error {
} }
if Notification != nil { if Notification != nil {
_ = Notification.Send(ctx, t.UserID, "interaction", "新增粉丝", "有人关注了您的店铺: "+t.Name) _ = Notification.Send(ctx, tenantID, t.UserID, "interaction", "新增粉丝", "有人关注了您的店铺: "+t.Name)
} }
return nil return nil
} }