fix: validate creator assets by tenant
This commit is contained in:
@@ -261,6 +261,11 @@ func (s *creator) CreateContent(ctx context.Context, tenantID, userID int64, for
|
|||||||
if form.Status != "" {
|
if form.Status != "" {
|
||||||
status = consts.ContentStatus(form.Status)
|
status = consts.ContentStatus(form.Status)
|
||||||
}
|
}
|
||||||
|
// 校验素材归属,避免跨租户引用。
|
||||||
|
if err := s.validateContentAssets(ctx, tx, tid, uid, form.CoverIDs, form.MediaIDs); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// 1. Create Content
|
// 1. Create Content
|
||||||
content := &models.Content{
|
content := &models.Content{
|
||||||
TenantID: tid,
|
TenantID: tid,
|
||||||
@@ -415,6 +420,10 @@ func (s *creator) UpdateContent(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 4. Update Assets (Full replacement strategy)
|
// 4. Update Assets (Full replacement strategy)
|
||||||
|
if err := s.validateContentAssets(ctx, tx, tid, uid, form.CoverIDs, form.MediaIDs); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
_, err = tx.ContentAsset.WithContext(ctx).Where(tx.ContentAsset.ContentID.Eq(id)).Delete()
|
_, err = tx.ContentAsset.WithContext(ctx).Where(tx.ContentAsset.ContentID.Eq(id)).Delete()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -1005,3 +1014,52 @@ func (s *creator) getTenantID(ctx context.Context, tenantID, userID int64) (int6
|
|||||||
}
|
}
|
||||||
return t.ID, nil
|
return t.ID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *creator) validateContentAssets(
|
||||||
|
ctx context.Context,
|
||||||
|
tx *models.Query,
|
||||||
|
tenantID int64,
|
||||||
|
userID int64,
|
||||||
|
coverIDs []int64,
|
||||||
|
mediaIDs []int64,
|
||||||
|
) error {
|
||||||
|
if len(coverIDs) == 0 && len(mediaIDs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ids := make(map[int64]struct{}, len(coverIDs)+len(mediaIDs))
|
||||||
|
for _, id := range coverIDs {
|
||||||
|
ids[id] = struct{}{}
|
||||||
|
}
|
||||||
|
for _, id := range mediaIDs {
|
||||||
|
ids[id] = struct{}{}
|
||||||
|
}
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
assetIDs := make([]int64, 0, len(ids))
|
||||||
|
for id := range ids {
|
||||||
|
assetIDs = append(assetIDs, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
list, err := tx.MediaAsset.WithContext(ctx).Where(tx.MediaAsset.ID.In(assetIDs...)).Find()
|
||||||
|
if err != nil {
|
||||||
|
return errorx.ErrDatabaseError.WithCause(err)
|
||||||
|
}
|
||||||
|
if len(list) != len(assetIDs) {
|
||||||
|
return errorx.ErrRecordNotFound.WithMsg("素材不存在")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, asset := range list {
|
||||||
|
if asset.TenantID == tenantID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if asset.TenantID == 0 && asset.UserID == userID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return errorx.ErrForbidden.WithMsg("素材不属于当前租户")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user