fix: scope notifications by tenant

This commit is contained in:
2026-01-09 15:49:49 +08:00
parent c446dec5c6
commit 4d90f547e6
9 changed files with 62 additions and 44 deletions

View File

@@ -58,11 +58,12 @@ type notificationTestWorker struct {
func (w *notificationTestWorker) Work(ctx context.Context, job *river.Job[jobs_args.NotificationArgs]) error {
arg := job.Args
n := &models.Notification{
UserID: arg.UserID,
Type: arg.Type,
Title: arg.Title,
Content: arg.Content,
IsRead: false,
TenantID: arg.TenantID,
UserID: arg.UserID,
Type: arg.Type,
Title: arg.Title,
Content: arg.Content,
IsRead: false,
}
return models.NotificationQuery.WithContext(ctx).Create(n)
}

View File

@@ -269,7 +269,8 @@ func (u *User) Following(ctx fiber.Ctx, user *models.User) ([]dto.TenantProfile,
// @Bind typeArg query key(type)
// @Bind page query
func (u *User) Notifications(ctx fiber.Ctx, user *models.User, typeArg string, page int) (*requests.Pager, error) {
return services.Notification.List(ctx, user.ID, page, typeArg)
tenantID := getTenantID(ctx)
return services.Notification.List(ctx, tenantID, user.ID, page, typeArg)
}
// Mark notification as read
@@ -284,7 +285,8 @@ func (u *User) Notifications(ctx fiber.Ctx, user *models.User, typeArg string, p
// @Bind user local key(__ctx_user)
// @Bind id path
func (u *User) MarkNotificationRead(ctx fiber.Ctx, user *models.User, id int64) error {
return services.Notification.MarkRead(ctx, user.ID, id)
tenantID := getTenantID(ctx)
return services.Notification.MarkRead(ctx, tenantID, user.ID, id)
}
// Mark all notifications as read
@@ -297,7 +299,8 @@ func (u *User) MarkNotificationRead(ctx fiber.Ctx, user *models.User, id int64)
// @Success 200 {string} string "OK"
// @Bind user local key(__ctx_user)
func (u *User) MarkAllNotificationsRead(ctx fiber.Ctx, user *models.User) error {
return services.Notification.MarkAllRead(ctx, user.ID)
tenantID := getTenantID(ctx)
return services.Notification.MarkAllRead(ctx, tenantID, user.ID)
}
// List my coupons

View File

@@ -3,10 +3,11 @@ package args
import "github.com/riverqueue/river"
type NotificationArgs struct {
UserID int64 `json:"user_id"`
Type string `json:"type"`
Title string `json:"title"`
Content string `json:"content"`
TenantID int64 `json:"tenant_id"`
UserID int64 `json:"user_id"`
Type string `json:"type"`
Title string `json:"title"`
Content string `json:"content"`
}
func (NotificationArgs) Kind() string {

View File

@@ -17,11 +17,12 @@ type NotificationWorker struct {
func (j *NotificationWorker) Work(ctx context.Context, job *river.Job[args.NotificationArgs]) error {
arg := job.Args
n := &models.Notification{
UserID: arg.UserID,
Type: arg.Type,
Title: arg.Title,
Content: arg.Content,
IsRead: false,
TenantID: arg.TenantID,
UserID: arg.UserID,
Type: arg.Type,
Title: arg.Title,
Content: arg.Content,
IsRead: false,
}
return models.NotificationQuery.WithContext(ctx).Create(n)
}

View File

@@ -388,7 +388,7 @@ func (s *content) LikeComment(ctx context.Context, tenantID, userID, id int64) e
}
if Notification != nil {
_ = Notification.Send(ctx, cm.UserID, "interaction", "评论点赞", "有人点赞了您的评论")
_ = Notification.Send(ctx, tenantID, cm.UserID, "interaction", "评论点赞", "有人点赞了您的评论")
}
return nil
}
@@ -660,7 +660,7 @@ func (s *content) addInteract(ctx context.Context, tenantID, userID, contentId i
case "favorite":
actionName = "收藏"
}
_ = Notification.Send(ctx, c.UserID, "interaction", "新的"+actionName, "有人"+actionName+"了您的作品: "+c.Title)
_ = Notification.Send(ctx, tenantID, c.UserID, "interaction", "新的"+actionName, "有人"+actionName+"了您的作品: "+c.Title)
}
return nil
}

View File

@@ -18,9 +18,12 @@ type notification struct {
job *job.Job
}
func (s *notification) List(ctx context.Context, userID int64, page int, typeArg string) (*requests.Pager, error) {
func (s *notification) List(ctx context.Context, tenantID, userID int64, page int, typeArg string) (*requests.Pager, error) {
tbl, q := models.NotificationQuery.QueryContext(ctx)
q = q.Where(tbl.UserID.Eq(userID))
if tenantID > 0 {
q = q.Where(tbl.TenantID.Eq(tenantID))
}
if typeArg != "" && typeArg != "all" {
q = q.Where(tbl.Type.Eq(typeArg))
@@ -58,41 +61,49 @@ func (s *notification) List(ctx context.Context, userID int64, page int, typeArg
}, nil
}
func (s *notification) MarkRead(ctx context.Context, userID, id int64) error {
_, err := models.NotificationQuery.WithContext(ctx).
Where(models.NotificationQuery.ID.Eq(id), models.NotificationQuery.UserID.Eq(userID)).
UpdateSimple(models.NotificationQuery.IsRead.Value(true))
func (s *notification) MarkRead(ctx context.Context, tenantID, userID, id int64) error {
tbl, q := models.NotificationQuery.QueryContext(ctx)
q = q.Where(tbl.ID.Eq(id), tbl.UserID.Eq(userID))
if tenantID > 0 {
q = q.Where(tbl.TenantID.Eq(tenantID))
}
_, err := q.UpdateSimple(tbl.IsRead.Value(true))
if err != nil {
return errorx.ErrDatabaseError.WithCause(err)
}
return nil
}
func (s *notification) MarkAllRead(ctx context.Context, userID int64) error {
_, err := models.NotificationQuery.WithContext(ctx).
Where(models.NotificationQuery.UserID.Eq(userID), models.NotificationQuery.IsRead.Is(false)).
UpdateSimple(models.NotificationQuery.IsRead.Value(true))
func (s *notification) MarkAllRead(ctx context.Context, tenantID, userID int64) error {
tbl, q := models.NotificationQuery.QueryContext(ctx)
q = q.Where(tbl.UserID.Eq(userID), tbl.IsRead.Is(false))
if tenantID > 0 {
q = q.Where(tbl.TenantID.Eq(tenantID))
}
_, err := q.UpdateSimple(tbl.IsRead.Value(true))
if err != nil {
return errorx.ErrDatabaseError.WithCause(err)
}
return nil
}
func (s *notification) Send(ctx context.Context, userID int64, typ, title, content string) error {
func (s *notification) Send(ctx context.Context, tenantID, userID int64, typ, title, content string) error {
arg := args.NotificationArgs{
UserID: userID,
Type: typ,
Title: title,
Content: content,
TenantID: tenantID,
UserID: userID,
Type: typ,
Title: title,
Content: content,
}
// 测试环境下同步写入,避免异步任务未启动导致结果不确定。
if os.Getenv("JOB_INLINE") == "1" {
n := &models.Notification{
UserID: userID,
Type: typ,
Title: title,
Content: content,
IsRead: false,
TenantID: tenantID,
UserID: userID,
Type: typ,
Title: title,
Content: content,
IsRead: false,
}
if err := models.NotificationQuery.WithContext(ctx).Create(n); err != nil {
return errorx.ErrDatabaseError.WithCause(err)

View File

@@ -44,16 +44,17 @@ func (s *NotificationTestSuite) Test_CRUD() {
ctx := s.T().Context()
database.Truncate(ctx, s.DB, models.TableNameNotification)
tenantID := int64(1)
uID := int64(100)
ctx = context.WithValue(ctx, consts.CtxKeyUser, uID)
Convey("should send notification", func() {
err := Notification.Send(ctx, uID, "system", "Welcome", "Hello World")
err := Notification.Send(ctx, tenantID, uID, "system", "Welcome", "Hello World")
So(err, ShouldBeNil)
var list *requests.Pager
for i := 0; i < 5; i++ {
list, err = Notification.List(ctx, uID, 1, "")
list, err = Notification.List(ctx, tenantID, uID, 1, "")
So(err, ShouldBeNil)
if list.Total > 0 {
break
@@ -69,7 +70,7 @@ func (s *NotificationTestSuite) Test_CRUD() {
// Mark Read
// Need ID
n, _ := models.NotificationQuery.WithContext(ctx).Where(models.NotificationQuery.UserID.Eq(uID)).First()
err = Notification.MarkRead(ctx, uID, n.ID)
err = Notification.MarkRead(ctx, tenantID, uID, n.ID)
So(err, ShouldBeNil)
nReload, _ := models.NotificationQuery.WithContext(ctx).Where(models.NotificationQuery.ID.Eq(n.ID)).First()

View File

@@ -385,9 +385,9 @@ func (s *order) settleOrder(ctx context.Context, o *models.Order, method, extern
}
if Notification != nil {
_ = Notification.Send(ctx, o.UserID, "order", "支付成功", "订单已支付,您可以查看已购内容。")
_ = Notification.Send(ctx, o.TenantID, o.UserID, "order", "支付成功", "订单已支付,您可以查看已购内容。")
if tenantOwnerID > 0 {
_ = Notification.Send(ctx, tenantOwnerID, "order", "新的订单", "您的店铺有新的订单,收入已入账。")
_ = Notification.Send(ctx, o.TenantID, tenantOwnerID, "order", "新的订单", "您的店铺有新的订单,收入已入账。")
}
}
return nil

View File

@@ -140,7 +140,7 @@ func (s *tenant) Follow(ctx context.Context, tenantID, userID int64) error {
}
if Notification != nil {
_ = Notification.Send(ctx, t.UserID, "interaction", "新增粉丝", "有人关注了您的店铺: "+t.Name)
_ = Notification.Send(ctx, tenantID, t.UserID, "interaction", "新增粉丝", "有人关注了您的店铺: "+t.Name)
}
return nil
}