feat: tenant-scoped routing and portal navigation

This commit is contained in:
2026-01-08 21:30:46 +08:00
parent f3aa92078a
commit 3e095c57f3
52 changed files with 1111 additions and 670 deletions

View File

@@ -27,6 +27,7 @@ import (
"github.com/google/uuid"
"github.com/jackc/pgconn"
"go.ipao.vip/gen/types"
"gorm.io/gorm"
)
// @provider
@@ -49,7 +50,7 @@ func (s *common) Options(ctx context.Context) (*common_dto.OptionsResponse, erro
}, nil
}
func (s *common) CheckHash(ctx context.Context, userID int64, hash string) (*common_dto.UploadResult, error) {
func (s *common) CheckHash(ctx context.Context, tenantID, userID int64, hash string) (*common_dto.UploadResult, error) {
existing, err := models.MediaAssetQuery.WithContext(ctx).Where(models.MediaAssetQuery.Hash.Eq(hash)).First()
if err != nil {
return nil, nil // Not found, proceed to upload
@@ -58,18 +59,25 @@ func (s *common) CheckHash(ctx context.Context, userID int64, hash string) (*com
// Found existing file (Global deduplication hit)
// Check if user already has it (Logic deduplication hit)
myExisting, err := models.MediaAssetQuery.WithContext(ctx).
Where(models.MediaAssetQuery.Hash.Eq(hash), models.MediaAssetQuery.UserID.Eq(userID)).
First()
myQuery := models.MediaAssetQuery.WithContext(ctx).
Where(models.MediaAssetQuery.Hash.Eq(hash), models.MediaAssetQuery.UserID.Eq(userID))
if tenantID > 0 {
myQuery = myQuery.Where(models.MediaAssetQuery.TenantID.Eq(tenantID))
}
myExisting, err := myQuery.First()
if err == nil {
return s.composeUploadResult(myExisting), nil
}
// Create new record for this user reusing existing ObjectKey
t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.UserID.Eq(userID)).First()
var tid int64 = 0
if err == nil {
tid = t.ID
// 优先使用路径租户,避免跨租户写入。
tenant, err := s.resolveTenant(ctx, tenantID, userID)
if err != nil {
return nil, err
}
var tid int64
if tenant != nil {
tid = tenant.ID
}
asset := &models.MediaAsset{
@@ -107,13 +115,47 @@ func (s *common) buildObjectKey(tenant *models.Tenant, hash, filename string) st
return path.Join("quyun", tenantUUID, hash+ext)
}
func (s *common) InitUpload(ctx context.Context, userID int64, form *common_dto.UploadInitForm) (*common_dto.UploadInitResponse, error) {
func (s *common) resolveTenant(ctx context.Context, tenantID, userID int64) (*models.Tenant, error) {
if tenantID > 0 {
tbl, q := models.TenantQuery.QueryContext(ctx)
tenant, err := q.Where(tbl.ID.Eq(tenantID)).First()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errorx.ErrRecordNotFound.WithMsg("租户不存在")
}
return nil, errorx.ErrDatabaseError.WithCause(err)
}
return tenant, nil
}
if userID == 0 {
return nil, nil
}
tbl, q := models.TenantQuery.QueryContext(ctx)
tenant, err := q.Where(tbl.UserID.Eq(userID)).First()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, errorx.ErrDatabaseError.WithCause(err)
}
return tenant, nil
}
func (s *common) uploadTempDir(localPath string, tenantID int64, uploadID string) string {
tenantKey := "public"
if tenantID > 0 {
tenantKey = strconv.FormatInt(tenantID, 10)
}
return filepath.Join(localPath, "temp", tenantKey, uploadID)
}
func (s *common) InitUpload(ctx context.Context, tenantID, userID int64, form *common_dto.UploadInitForm) (*common_dto.UploadInitResponse, error) {
uploadID := uuid.NewString()
localPath := s.storage.Config.LocalPath
if localPath == "" {
localPath = "./storage"
}
tempDir := filepath.Join(localPath, "temp", uploadID)
tempDir := s.uploadTempDir(localPath, tenantID, uploadID)
if err := os.MkdirAll(tempDir, 0o755); err != nil {
return nil, errorx.ErrInternalError.WithCause(err)
}
@@ -134,12 +176,12 @@ func (s *common) InitUpload(ctx context.Context, userID int64, form *common_dto.
}, nil
}
func (s *common) UploadPart(ctx context.Context, userID int64, file *multipart.FileHeader, form *common_dto.UploadPartForm) error {
func (s *common) UploadPart(ctx context.Context, tenantID, userID int64, file *multipart.FileHeader, form *common_dto.UploadPartForm) error {
localPath := s.storage.Config.LocalPath
if localPath == "" {
localPath = "./storage"
}
partPath := filepath.Join(localPath, "temp", form.UploadID, strconv.Itoa(form.PartNumber))
partPath := filepath.Join(s.uploadTempDir(localPath, tenantID, form.UploadID), strconv.Itoa(form.PartNumber))
src, err := file.Open()
if err != nil {
@@ -159,12 +201,12 @@ func (s *common) UploadPart(ctx context.Context, userID int64, file *multipart.F
return nil
}
func (s *common) CompleteUpload(ctx context.Context, userID int64, form *common_dto.UploadCompleteForm) (*common_dto.UploadResult, error) {
func (s *common) CompleteUpload(ctx context.Context, tenantID, userID int64, form *common_dto.UploadCompleteForm) (*common_dto.UploadResult, error) {
localPath := s.storage.Config.LocalPath
if localPath == "" {
localPath = "./storage"
}
tempDir := filepath.Join(localPath, "temp", form.UploadID)
tempDir := s.uploadTempDir(localPath, tenantID, form.UploadID)
// Read Meta
var meta UploadMeta
@@ -220,21 +262,27 @@ func (s *common) CompleteUpload(ctx context.Context, userID int64, form *common_
dst.Close() // Ensure flush before potential removal
// Deduplication Logic (Similar to Upload)
t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.UserID.Eq(userID)).First()
var tid int64 = 0
if err == nil {
tid = t.ID
tenant, err := s.resolveTenant(ctx, tenantID, userID)
if err != nil {
return nil, err
}
var tid int64
if tenant != nil {
tid = tenant.ID
}
objectKey := s.buildObjectKey(t, hash, meta.Filename)
objectKey := s.buildObjectKey(tenant, hash, meta.Filename)
existing, err := models.MediaAssetQuery.WithContext(ctx).Where(models.MediaAssetQuery.Hash.Eq(hash)).First()
var asset *models.MediaAsset
if err == nil {
os.Remove(mergedPath) // Delete duplicate
myExisting, err := models.MediaAssetQuery.WithContext(ctx).
Where(models.MediaAssetQuery.Hash.Eq(hash), models.MediaAssetQuery.UserID.Eq(userID)).
First()
myQuery := models.MediaAssetQuery.WithContext(ctx).
Where(models.MediaAssetQuery.Hash.Eq(hash), models.MediaAssetQuery.UserID.Eq(userID))
if tenantID > 0 {
myQuery = myQuery.Where(models.MediaAssetQuery.TenantID.Eq(tenantID))
}
myExisting, err := myQuery.First()
if err == nil {
os.RemoveAll(tempDir)
return s.composeUploadResult(myExisting), nil
@@ -282,10 +330,13 @@ func (s *common) CompleteUpload(ctx context.Context, userID int64, form *common_
return s.composeUploadResult(asset), nil
}
func (s *common) DeleteMediaAsset(ctx context.Context, userID, id int64) error {
asset, err := models.MediaAssetQuery.WithContext(ctx).
Where(models.MediaAssetQuery.ID.Eq(id), models.MediaAssetQuery.UserID.Eq(userID)).
First()
func (s *common) DeleteMediaAsset(ctx context.Context, tenantID, userID, id int64) error {
query := models.MediaAssetQuery.WithContext(ctx).
Where(models.MediaAssetQuery.ID.Eq(id), models.MediaAssetQuery.UserID.Eq(userID))
if tenantID > 0 {
query = query.Where(models.MediaAssetQuery.TenantID.Eq(tenantID))
}
asset, err := query.First()
if err != nil {
return errorx.ErrRecordNotFound
}
@@ -308,17 +359,18 @@ func (s *common) DeleteMediaAsset(ctx context.Context, userID, id int64) error {
return nil
}
func (s *common) AbortUpload(ctx context.Context, userID int64, uploadId string) error {
func (s *common) AbortUpload(ctx context.Context, tenantID, userID int64, uploadId string) error {
localPath := s.storage.Config.LocalPath
if localPath == "" {
localPath = "./storage"
}
tempDir := filepath.Join(localPath, "temp", uploadId)
tempDir := s.uploadTempDir(localPath, tenantID, uploadId)
return os.RemoveAll(tempDir)
}
func (s *common) Upload(
ctx context.Context,
tenantID int64,
userID int64,
file *multipart.FileHeader,
typeArg string,
@@ -357,13 +409,16 @@ func (s *common) Upload(
hash := hex.EncodeToString(hasher.Sum(nil))
t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.UserID.Eq(userID)).First()
var tid int64 = 0
if err == nil {
tid = t.ID
tenant, err := s.resolveTenant(ctx, tenantID, userID)
if err != nil {
return nil, err
}
var tid int64
if tenant != nil {
tid = tenant.ID
}
objectKey := s.buildObjectKey(t, hash, file.Filename)
objectKey := s.buildObjectKey(tenant, hash, file.Filename)
var asset *models.MediaAsset
// Deduplication Check
@@ -374,9 +429,12 @@ func (s *common) Upload(
os.RemoveAll(tmpDir)
// Check if user already has it (Logic Deduplication)
myExisting, err := models.MediaAssetQuery.WithContext(ctx).
Where(models.MediaAssetQuery.Hash.Eq(hash), models.MediaAssetQuery.UserID.Eq(userID)).
First()
myQuery := models.MediaAssetQuery.WithContext(ctx).
Where(models.MediaAssetQuery.Hash.Eq(hash), models.MediaAssetQuery.UserID.Eq(userID))
if tenantID > 0 {
myQuery = myQuery.Where(models.MediaAssetQuery.TenantID.Eq(tenantID))
}
myExisting, err := myQuery.First()
if err == nil {
return s.composeUploadResult(myExisting), nil
}

View File

@@ -18,11 +18,14 @@ import (
// @provider
type content struct{}
func (s *content) List(ctx context.Context, filter *content_dto.ContentListFilter) (*requests.Pager, error) {
func (s *content) List(ctx context.Context, tenantID int64, filter *content_dto.ContentListFilter) (*requests.Pager, error) {
tbl, q := models.ContentQuery.QueryContext(ctx)
// Filters
q = q.Where(tbl.Status.Eq(consts.ContentStatusPublished))
if tenantID > 0 {
q = q.Where(tbl.TenantID.Eq(tenantID))
}
if filter.Keyword != nil && *filter.Keyword != "" {
keyword := "%" + *filter.Keyword + "%"
q = q.Where(tbl.Title.Like(keyword)).Or(tbl.Description.Like(keyword))
@@ -31,6 +34,9 @@ func (s *content) List(ctx context.Context, filter *content_dto.ContentListFilte
q = q.Where(tbl.Genre.Eq(*filter.Genre))
}
if filter.TenantID != nil && *filter.TenantID > 0 {
if tenantID > 0 && *filter.TenantID != tenantID {
return nil, errorx.ErrForbidden.WithMsg("租户不匹配")
}
q = q.Where(tbl.TenantID.Eq(*filter.TenantID))
}
if filter.IsPinned != nil {
@@ -128,16 +134,22 @@ func (s *content) List(ctx context.Context, filter *content_dto.ContentListFilte
}, nil
}
func (s *content) Get(ctx context.Context, userID, id int64) (*content_dto.ContentDetail, error) {
func (s *content) Get(ctx context.Context, tenantID, userID, id int64) (*content_dto.ContentDetail, error) {
// Increment Views
_, _ = models.ContentQuery.WithContext(ctx).
Where(models.ContentQuery.ID.Eq(id)).
UpdateSimple(models.ContentQuery.Views.Add(1))
update := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.ID.Eq(id))
if tenantID > 0 {
update = update.Where(models.ContentQuery.TenantID.Eq(tenantID))
}
_, _ = update.UpdateSimple(models.ContentQuery.Views.Add(1))
_, q := models.ContentQuery.QueryContext(ctx)
var item models.Content
err := q.UnderlyingDB().
db := q.UnderlyingDB()
if tenantID > 0 {
db = db.Where("tenant_id = ?", tenantID)
}
err := db.
Preload("Author").
Preload("ContentAssets", func(db *gorm.DB) *gorm.DB {
return db.Order("sort ASC")
@@ -232,10 +244,25 @@ func (s *content) Get(ctx context.Context, userID, id int64) (*content_dto.Conte
return detail, nil
}
func (s *content) ListComments(ctx context.Context, userID, id int64, page int) (*requests.Pager, error) {
func (s *content) ListComments(ctx context.Context, tenantID, userID, id int64, page int) (*requests.Pager, error) {
if tenantID > 0 {
_, err := models.ContentQuery.WithContext(ctx).
Where(models.ContentQuery.ID.Eq(id), models.ContentQuery.TenantID.Eq(tenantID)).
First()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errorx.ErrRecordNotFound
}
return nil, errorx.ErrDatabaseError.WithCause(err)
}
}
tbl, q := models.CommentQuery.QueryContext(ctx)
q = q.Where(tbl.ContentID.Eq(id)).Preload(tbl.User)
if tenantID > 0 {
q = q.Where(tbl.TenantID.Eq(tenantID))
}
q = q.Order(tbl.CreatedAt.Desc())
p := requests.Pagination{Page: int64(page), Limit: 10}
@@ -291,6 +318,7 @@ func (s *content) ListComments(ctx context.Context, userID, id int64, page int)
func (s *content) CreateComment(
ctx context.Context,
tenantID int64,
userID int64,
id int64,
form *content_dto.CommentCreateForm,
@@ -300,7 +328,11 @@ func (s *content) CreateComment(
}
uid := userID
c, err := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.ID.Eq(id)).First()
query := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.ID.Eq(id))
if tenantID > 0 {
query = query.Where(models.ContentQuery.TenantID.Eq(tenantID))
}
c, err := query.First()
if err != nil {
return errorx.ErrRecordNotFound
}
@@ -319,14 +351,18 @@ func (s *content) CreateComment(
return nil
}
func (s *content) LikeComment(ctx context.Context, userID, id int64) error {
func (s *content) LikeComment(ctx context.Context, tenantID, userID, id int64) error {
if userID == 0 {
return errorx.ErrUnauthorized
}
uid := userID
// Fetch comment for author
cm, err := models.CommentQuery.WithContext(ctx).Where(models.CommentQuery.ID.Eq(id)).First()
query := models.CommentQuery.WithContext(ctx).Where(models.CommentQuery.ID.Eq(id))
if tenantID > 0 {
query = query.Where(models.CommentQuery.TenantID.Eq(tenantID))
}
cm, err := query.First()
if err != nil {
return errorx.ErrRecordNotFound
}
@@ -357,14 +393,18 @@ func (s *content) LikeComment(ctx context.Context, userID, id int64) error {
return nil
}
func (s *content) GetLibrary(ctx context.Context, userID int64) ([]user_dto.ContentItem, error) {
func (s *content) GetLibrary(ctx context.Context, tenantID, userID int64) ([]user_dto.ContentItem, error) {
if userID == 0 {
return nil, errorx.ErrUnauthorized
}
uid := userID
tbl, q := models.ContentAccessQuery.QueryContext(ctx)
accessList, err := q.Where(tbl.UserID.Eq(uid), tbl.Status.Eq(consts.ContentAccessStatusActive)).Find()
q = q.Where(tbl.UserID.Eq(uid), tbl.Status.Eq(consts.ContentAccessStatusActive))
if tenantID > 0 {
q = q.Where(tbl.TenantID.Eq(tenantID))
}
accessList, err := q.Find()
if err != nil {
return nil, errorx.ErrDatabaseError.WithCause(err)
}
@@ -380,7 +420,11 @@ func (s *content) GetLibrary(ctx context.Context, userID int64) ([]user_dto.Cont
ctbl, cq := models.ContentQuery.QueryContext(ctx)
var list []*models.Content
err = cq.Where(ctbl.ID.In(contentIDs...)).
cq = cq.Where(ctbl.ID.In(contentIDs...))
if tenantID > 0 {
cq = cq.Where(ctbl.TenantID.Eq(tenantID))
}
err = cq.
UnderlyingDB().
Preload("Author").
Preload("ContentAssets.Asset").
@@ -398,36 +442,40 @@ func (s *content) GetLibrary(ctx context.Context, userID int64) ([]user_dto.Cont
return data, nil
}
func (s *content) GetFavorites(ctx context.Context, userID int64) ([]user_dto.ContentItem, error) {
return s.getInteractList(ctx, userID, "favorite")
func (s *content) GetFavorites(ctx context.Context, tenantID, userID int64) ([]user_dto.ContentItem, error) {
return s.getInteractList(ctx, tenantID, userID, "favorite")
}
func (s *content) AddFavorite(ctx context.Context, userID, contentId int64) error {
return s.addInteract(ctx, userID, contentId, "favorite")
func (s *content) AddFavorite(ctx context.Context, tenantID, userID, contentId int64) error {
return s.addInteract(ctx, tenantID, userID, contentId, "favorite")
}
func (s *content) RemoveFavorite(ctx context.Context, userID, contentId int64) error {
return s.removeInteract(ctx, userID, contentId, "favorite")
func (s *content) RemoveFavorite(ctx context.Context, tenantID, userID, contentId int64) error {
return s.removeInteract(ctx, tenantID, userID, contentId, "favorite")
}
func (s *content) GetLikes(ctx context.Context, userID int64) ([]user_dto.ContentItem, error) {
return s.getInteractList(ctx, userID, "like")
func (s *content) GetLikes(ctx context.Context, tenantID, userID int64) ([]user_dto.ContentItem, error) {
return s.getInteractList(ctx, tenantID, userID, "like")
}
func (s *content) AddLike(ctx context.Context, userID, contentId int64) error {
return s.addInteract(ctx, userID, contentId, "like")
func (s *content) AddLike(ctx context.Context, tenantID, userID, contentId int64) error {
return s.addInteract(ctx, tenantID, userID, contentId, "like")
}
func (s *content) RemoveLike(ctx context.Context, userID, contentId int64) error {
return s.removeInteract(ctx, userID, contentId, "like")
func (s *content) RemoveLike(ctx context.Context, tenantID, userID, contentId int64) error {
return s.removeInteract(ctx, tenantID, userID, contentId, "like")
}
func (s *content) ListTopics(ctx context.Context) ([]content_dto.Topic, error) {
func (s *content) ListTopics(ctx context.Context, tenantID int64) ([]content_dto.Topic, error) {
var results []struct {
Genre string
Count int
}
err := models.ContentQuery.WithContext(ctx).UnderlyingDB().
db := models.ContentQuery.WithContext(ctx).UnderlyingDB()
if tenantID > 0 {
db = db.Where("tenant_id = ?", tenantID)
}
err := db.
Model(&models.Content{}).
Where("status = ?", consts.ContentStatusPublished).
Select("genre, count(*) as count").
@@ -445,8 +493,12 @@ func (s *content) ListTopics(ctx context.Context) ([]content_dto.Topic, error) {
// Fetch latest content in this genre to get a cover
var c models.Content
models.ContentQuery.WithContext(ctx).
Where(models.ContentQuery.Genre.Eq(r.Genre), models.ContentQuery.Status.Eq(consts.ContentStatusPublished)).
query := models.ContentQuery.WithContext(ctx).
Where(models.ContentQuery.Genre.Eq(r.Genre), models.ContentQuery.Status.Eq(consts.ContentStatusPublished))
if tenantID > 0 {
query = query.Where(models.ContentQuery.TenantID.Eq(tenantID))
}
query.
Order(models.ContentQuery.PublishedAt.Desc()).
UnderlyingDB().
Preload("ContentAssets").
@@ -554,15 +606,19 @@ func (s *content) toMediaURLs(assets []*models.ContentAsset) []content_dto.Media
return urls
}
func (s *content) addInteract(ctx context.Context, userID, contentId int64, typ string) error {
func (s *content) addInteract(ctx context.Context, tenantID, userID, contentId int64, typ string) error {
if userID == 0 {
return errorx.ErrUnauthorized
}
uid := userID
// Fetch content for author
c, err := models.ContentQuery.WithContext(ctx).
Where(models.ContentQuery.ID.Eq(contentId)).
query := models.ContentQuery.WithContext(ctx).
Where(models.ContentQuery.ID.Eq(contentId))
if tenantID > 0 {
query = query.Where(models.ContentQuery.TenantID.Eq(tenantID))
}
c, err := query.
Select(models.ContentQuery.UserID, models.ContentQuery.Title).
First()
if err != nil {
@@ -583,7 +639,11 @@ func (s *content) addInteract(ctx context.Context, userID, contentId int64, typ
}
if typ == "like" {
_, err := tx.Content.WithContext(ctx).Where(tx.Content.ID.Eq(contentId)).UpdateSimple(tx.Content.Likes.Add(1))
contentQuery := tx.Content.WithContext(ctx).Where(tx.Content.ID.Eq(contentId))
if tenantID > 0 {
contentQuery = contentQuery.Where(tx.Content.TenantID.Eq(tenantID))
}
_, err := contentQuery.UpdateSimple(tx.Content.Likes.Add(1))
return err
}
return nil
@@ -605,7 +665,7 @@ func (s *content) addInteract(ctx context.Context, userID, contentId int64, typ
return nil
}
func (s *content) removeInteract(ctx context.Context, userID, contentId int64, typ string) error {
func (s *content) removeInteract(ctx context.Context, tenantID, userID, contentId int64, typ string) error {
if userID == 0 {
return errorx.ErrUnauthorized
}
@@ -623,14 +683,18 @@ func (s *content) removeInteract(ctx context.Context, userID, contentId int64, t
}
if typ == "like" {
_, err := tx.Content.WithContext(ctx).Where(tx.Content.ID.Eq(contentId)).UpdateSimple(tx.Content.Likes.Sub(1))
contentQuery := tx.Content.WithContext(ctx).Where(tx.Content.ID.Eq(contentId))
if tenantID > 0 {
contentQuery = contentQuery.Where(tx.Content.TenantID.Eq(tenantID))
}
_, err := contentQuery.UpdateSimple(tx.Content.Likes.Sub(1))
return err
}
return nil
})
}
func (s *content) getInteractList(ctx context.Context, userID int64, typ string) ([]user_dto.ContentItem, error) {
func (s *content) getInteractList(ctx context.Context, tenantID, userID int64, typ string) ([]user_dto.ContentItem, error) {
if userID == 0 {
return nil, errorx.ErrUnauthorized
}
@@ -653,7 +717,11 @@ func (s *content) getInteractList(ctx context.Context, userID int64, typ string)
ctbl, cq := models.ContentQuery.QueryContext(ctx)
var list []*models.Content
err = cq.Where(ctbl.ID.In(contentIDs...)).
cq = cq.Where(ctbl.ID.In(contentIDs...))
if tenantID > 0 {
cq = cq.Where(ctbl.TenantID.Eq(tenantID))
}
err = cq.
UnderlyingDB().
Preload("Author").
Preload("ContentAssets.Asset").

View File

@@ -41,6 +41,7 @@ func Test_Content(t *testing.T) {
func (s *ContentTestSuite) Test_List() {
Convey("List", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(1)
database.Truncate(ctx, s.DB, models.TableNameContent, models.TableNameUser)
// Create Author
@@ -73,7 +74,7 @@ func (s *ContentTestSuite) Test_List() {
Limit: 10,
},
}
res, err := Content.List(ctx, filter)
res, err := Content.List(ctx, tenantID, filter)
So(err, ShouldBeNil)
So(res.Total, ShouldEqual, 1)
items := res.Items.([]content_dto.ContentItem)
@@ -86,6 +87,7 @@ func (s *ContentTestSuite) Test_List() {
func (s *ContentTestSuite) Test_Get() {
Convey("Get", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(1)
database.Truncate(ctx, s.DB, models.TableNameContent, models.TableNameMediaAsset, models.TableNameContentAsset, models.TableNameUser)
// Author
@@ -125,7 +127,7 @@ func (s *ContentTestSuite) Test_Get() {
ctx = context.WithValue(ctx, consts.CtxKeyUser, author.ID)
Convey("should get detail with assets", func() {
detail, err := Content.Get(ctx, author.ID, content.ID)
detail, err := Content.Get(ctx, tenantID, author.ID, content.ID)
So(err, ShouldBeNil)
So(detail.Title, ShouldEqual, "Detail Content")
So(detail.AuthorName, ShouldEqual, "Author1")
@@ -138,6 +140,7 @@ func (s *ContentTestSuite) Test_Get() {
func (s *ContentTestSuite) Test_CreateComment() {
Convey("CreateComment", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(1)
database.Truncate(ctx, s.DB, models.TableNameContent, models.TableNameComment, models.TableNameUser)
// User & Content
@@ -153,7 +156,7 @@ func (s *ContentTestSuite) Test_CreateComment() {
form := &content_dto.CommentCreateForm{
Content: "Nice!",
}
err := Content.CreateComment(ctx, u.ID, c.ID, form)
err := Content.CreateComment(ctx, tenantID, u.ID, c.ID, form)
So(err, ShouldBeNil)
count, _ := models.CommentQuery.WithContext(ctx).Where(models.CommentQuery.ContentID.Eq(c.ID)).Count()
@@ -165,6 +168,7 @@ func (s *ContentTestSuite) Test_CreateComment() {
func (s *ContentTestSuite) Test_Library() {
Convey("Library", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(1)
database.Truncate(ctx, s.DB, models.TableNameContent, models.TableNameContentAccess, models.TableNameUser, models.TableNameContentAsset, models.TableNameMediaAsset)
// User
@@ -192,7 +196,7 @@ func (s *ContentTestSuite) Test_Library() {
})
Convey("should get library content with details", func() {
list, err := Content.GetLibrary(ctx, u.ID)
list, err := Content.GetLibrary(ctx, tenantID, u.ID)
So(err, ShouldBeNil)
So(len(list), ShouldEqual, 1)
So(list[0].Title, ShouldEqual, "Paid Content")
@@ -206,6 +210,7 @@ func (s *ContentTestSuite) Test_Library() {
func (s *ContentTestSuite) Test_Interact() {
Convey("Interact", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(1)
database.Truncate(ctx, s.DB, models.TableNameContent, models.TableNameUserContentAction, models.TableNameUser)
// User & Content
@@ -218,7 +223,7 @@ func (s *ContentTestSuite) Test_Interact() {
Convey("Like flow", func() {
// Add Like
err := Content.AddLike(ctx, u.ID, c.ID)
err := Content.AddLike(ctx, tenantID, u.ID, c.ID)
So(err, ShouldBeNil)
// Verify count
@@ -226,13 +231,13 @@ func (s *ContentTestSuite) Test_Interact() {
So(cReload.Likes, ShouldEqual, 1)
// Get Likes
likes, err := Content.GetLikes(ctx, u.ID)
likes, err := Content.GetLikes(ctx, tenantID, u.ID)
So(err, ShouldBeNil)
So(len(likes), ShouldEqual, 1)
So(likes[0].ID, ShouldEqual, c.ID)
// Remove Like
err = Content.RemoveLike(ctx, u.ID, c.ID)
err = Content.RemoveLike(ctx, tenantID, u.ID, c.ID)
So(err, ShouldBeNil)
// Verify count
@@ -242,21 +247,21 @@ func (s *ContentTestSuite) Test_Interact() {
Convey("Favorite flow", func() {
// Add Favorite
err := Content.AddFavorite(ctx, u.ID, c.ID)
err := Content.AddFavorite(ctx, tenantID, u.ID, c.ID)
So(err, ShouldBeNil)
// Get Favorites
favs, err := Content.GetFavorites(ctx, u.ID)
favs, err := Content.GetFavorites(ctx, tenantID, u.ID)
So(err, ShouldBeNil)
So(len(favs), ShouldEqual, 1)
So(favs[0].ID, ShouldEqual, c.ID)
// Remove Favorite
err = Content.RemoveFavorite(ctx, u.ID, c.ID)
err = Content.RemoveFavorite(ctx, tenantID, u.ID, c.ID)
So(err, ShouldBeNil)
// Get Favorites
favs, err = Content.GetFavorites(ctx, u.ID)
favs, err = Content.GetFavorites(ctx, tenantID, u.ID)
So(err, ShouldBeNil)
So(len(favs), ShouldEqual, 0)
})
@@ -266,6 +271,7 @@ func (s *ContentTestSuite) Test_Interact() {
func (s *ContentTestSuite) Test_ListTopics() {
Convey("ListTopics", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(1)
database.Truncate(ctx, s.DB, models.TableNameContent, models.TableNameUser)
u := &models.User{Username: "user_t", Phone: "13900000005"}
@@ -280,7 +286,7 @@ func (s *ContentTestSuite) Test_ListTopics() {
)
Convey("should aggregate topics", func() {
topics, err := Content.ListTopics(ctx)
topics, err := Content.ListTopics(ctx, tenantID)
So(err, ShouldBeNil)
So(len(topics), ShouldBeGreaterThanOrEqualTo, 2)
@@ -302,6 +308,7 @@ func (s *ContentTestSuite) Test_ListTopics() {
func (s *ContentTestSuite) Test_PreviewLogic() {
Convey("Preview Logic", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(1)
database.Truncate(ctx, s.DB, models.TableNameContent, models.TableNameContentAsset, models.TableNameContentAccess, models.TableNameUser, models.TableNameMediaAsset)
author := &models.User{Username: "author_p", Phone: "13900000006"}
@@ -324,7 +331,7 @@ func (s *ContentTestSuite) Test_PreviewLogic() {
models.UserQuery.WithContext(ctx).Create(guest)
guestCtx := context.WithValue(ctx, consts.CtxKeyUser, guest.ID)
detail, err := Content.Get(guestCtx, 0, c.ID)
detail, err := Content.Get(guestCtx, tenantID, 0, c.ID)
So(err, ShouldBeNil)
So(len(detail.MediaUrls), ShouldEqual, 1)
So(detail.MediaUrls[0].URL, ShouldContainSubstring, "preview.mp4")
@@ -333,7 +340,7 @@ func (s *ContentTestSuite) Test_PreviewLogic() {
Convey("owner should see all", func() {
ownerCtx := context.WithValue(ctx, consts.CtxKeyUser, author.ID)
detail, err := Content.Get(ownerCtx, author.ID, c.ID)
detail, err := Content.Get(ownerCtx, tenantID, author.ID, c.ID)
So(err, ShouldBeNil)
So(len(detail.MediaUrls), ShouldEqual, 2)
So(detail.IsPurchased, ShouldBeTrue)
@@ -348,7 +355,7 @@ func (s *ContentTestSuite) Test_PreviewLogic() {
UserID: buyer.ID, ContentID: c.ID, Status: consts.ContentAccessStatusActive,
})
detail, err := Content.Get(buyerCtx, buyer.ID, c.ID)
detail, err := Content.Get(buyerCtx, tenantID, buyer.ID, c.ID)
So(err, ShouldBeNil)
So(len(detail.MediaUrls), ShouldEqual, 2)
So(detail.IsPurchased, ShouldBeTrue)
@@ -359,6 +366,7 @@ func (s *ContentTestSuite) Test_PreviewLogic() {
func (s *ContentTestSuite) Test_ViewCounting() {
Convey("ViewCounting", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(1)
database.Truncate(ctx, s.DB, models.TableNameContent, models.TableNameUser)
author := &models.User{Username: "author_v", Phone: "13900000009"}
@@ -368,7 +376,7 @@ func (s *ContentTestSuite) Test_ViewCounting() {
models.ContentQuery.WithContext(ctx).Create(c)
Convey("should increment views", func() {
_, err := Content.Get(ctx, 0, c.ID)
_, err := Content.Get(ctx, tenantID, 0, c.ID)
So(err, ShouldBeNil)
cReload, _ := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.ID.Eq(c.ID)).First()

View File

@@ -39,6 +39,7 @@ func Test_Coupon(t *testing.T) {
func (s *CouponTestSuite) Test_CouponFlow() {
Convey("Coupon Flow", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(1)
database.Truncate(
ctx,
s.DB,
@@ -83,9 +84,10 @@ func (s *CouponTestSuite) Test_CouponFlow() {
Convey("should apply in Order.Create", func() {
// Setup Content
c := &models.Content{UserID: 99, Title: "Test", Status: consts.ContentStatusPublished}
c := &models.Content{TenantID: tenantID, UserID: 99, Title: "Test", Status: consts.ContentStatusPublished}
models.ContentQuery.WithContext(ctx).Create(c)
models.ContentPriceQuery.WithContext(ctx).Create(&models.ContentPrice{
TenantID: tenantID,
ContentID: c.ID,
PriceAmount: 2000, // 20.00 CNY
Currency: "CNY",
@@ -96,7 +98,7 @@ func (s *CouponTestSuite) Test_CouponFlow() {
UserCouponID: uc.ID,
}
// Simulate Auth context for Order service
res, err := Order.Create(ctx, user.ID, form)
res, err := Order.Create(ctx, tenantID, user.ID, form)
So(err, ShouldBeNil)
// Verify Order

View File

@@ -31,7 +31,7 @@ var genreMap = map[string]string{
"Qinqiang": "秦腔",
}
func (s *creator) Apply(ctx context.Context, userID int64, form *creator_dto.ApplyForm) error {
func (s *creator) Apply(ctx context.Context, tenantID, userID int64, form *creator_dto.ApplyForm) error {
if userID == 0 {
return errorx.ErrUnauthorized
}
@@ -72,8 +72,8 @@ func (s *creator) Apply(ctx context.Context, userID int64, form *creator_dto.App
return nil
}
func (s *creator) Dashboard(ctx context.Context, userID int64) (*creator_dto.DashboardStats, error) {
tid, err := s.getTenantID(ctx, userID)
func (s *creator) Dashboard(ctx context.Context, tenantID, userID int64) (*creator_dto.DashboardStats, error) {
tid, err := s.getTenantID(ctx, tenantID, userID)
if err != nil {
return nil, err
}
@@ -107,10 +107,11 @@ func (s *creator) Dashboard(ctx context.Context, userID int64) (*creator_dto.Das
func (s *creator) ListContents(
ctx context.Context,
tenantID int64,
userID int64,
filter *creator_dto.CreatorContentListFilter,
) (*requests.Pager, error) {
tid, err := s.getTenantID(ctx, userID)
tid, err := s.getTenantID(ctx, tenantID, userID)
if err != nil {
return nil, err
}
@@ -248,8 +249,8 @@ func (s *creator) ListContents(
}, nil
}
func (s *creator) CreateContent(ctx context.Context, userID int64, form *creator_dto.ContentCreateForm) error {
tid, err := s.getTenantID(ctx, userID)
func (s *creator) CreateContent(ctx context.Context, tenantID, userID int64, form *creator_dto.ContentCreateForm) error {
tid, err := s.getTenantID(ctx, tenantID, userID)
if err != nil {
return err
}
@@ -321,11 +322,12 @@ func (s *creator) CreateContent(ctx context.Context, userID int64, form *creator
func (s *creator) UpdateContent(
ctx context.Context,
tenantID int64,
userID int64,
id int64,
form *creator_dto.ContentUpdateForm,
) error {
tid, err := s.getTenantID(ctx, userID)
tid, err := s.getTenantID(ctx, tenantID, userID)
if err != nil {
return err
}
@@ -451,8 +453,8 @@ func (s *creator) UpdateContent(
})
}
func (s *creator) DeleteContent(ctx context.Context, userID, id int64) error {
tid, err := s.getTenantID(ctx, userID)
func (s *creator) DeleteContent(ctx context.Context, tenantID, userID, id int64) error {
tid, err := s.getTenantID(ctx, tenantID, userID)
if err != nil {
return err
}
@@ -472,8 +474,8 @@ func (s *creator) DeleteContent(ctx context.Context, userID, id int64) error {
return nil
}
func (s *creator) GetContent(ctx context.Context, userID, id int64) (*creator_dto.ContentEditDTO, error) {
tid, err := s.getTenantID(ctx, userID)
func (s *creator) GetContent(ctx context.Context, tenantID, userID, id int64) (*creator_dto.ContentEditDTO, error) {
tid, err := s.getTenantID(ctx, tenantID, userID)
if err != nil {
return nil, err
}
@@ -548,10 +550,11 @@ func (s *creator) GetContent(ctx context.Context, userID, id int64) (*creator_dt
func (s *creator) ListOrders(
ctx context.Context,
tenantID int64,
userID int64,
filter *creator_dto.CreatorOrderListFilter,
) ([]creator_dto.Order, error) {
tid, err := s.getTenantID(ctx, userID)
tid, err := s.getTenantID(ctx, tenantID, userID)
if err != nil {
return nil, err
}
@@ -634,8 +637,8 @@ func (s *creator) ListOrders(
return data, nil
}
func (s *creator) ProcessRefund(ctx context.Context, userID, id int64, form *creator_dto.RefundForm) error {
tid, err := s.getTenantID(ctx, userID)
func (s *creator) ProcessRefund(ctx context.Context, tenantID, userID, id int64, form *creator_dto.RefundForm) error {
tid, err := s.getTenantID(ctx, tenantID, userID)
if err != nil {
return err
}
@@ -738,8 +741,8 @@ func (s *creator) ProcessRefund(ctx context.Context, userID, id int64, form *cre
return errorx.ErrBadRequest.WithMsg("无效的操作")
}
func (s *creator) GetSettings(ctx context.Context, userID int64) (*creator_dto.Settings, error) {
tid, err := s.getTenantID(ctx, userID)
func (s *creator) GetSettings(ctx context.Context, tenantID, userID int64) (*creator_dto.Settings, error) {
tid, err := s.getTenantID(ctx, tenantID, userID)
if err != nil {
return nil, err
}
@@ -758,8 +761,8 @@ func (s *creator) GetSettings(ctx context.Context, userID int64) (*creator_dto.S
}, nil
}
func (s *creator) UpdateSettings(ctx context.Context, userID int64, form *creator_dto.Settings) error {
tid, err := s.getTenantID(ctx, userID)
func (s *creator) UpdateSettings(ctx context.Context, tenantID, userID int64, form *creator_dto.Settings) error {
tid, err := s.getTenantID(ctx, tenantID, userID)
if err != nil {
return err
}
@@ -782,8 +785,8 @@ func (s *creator) UpdateSettings(ctx context.Context, userID int64, form *creato
return err
}
func (s *creator) ListPayoutAccounts(ctx context.Context, userID int64) ([]creator_dto.PayoutAccount, error) {
tid, err := s.getTenantID(ctx, userID)
func (s *creator) ListPayoutAccounts(ctx context.Context, tenantID, userID int64) ([]creator_dto.PayoutAccount, error) {
tid, err := s.getTenantID(ctx, tenantID, userID)
if err != nil {
return nil, err
}
@@ -806,8 +809,8 @@ func (s *creator) ListPayoutAccounts(ctx context.Context, userID int64) ([]creat
return data, nil
}
func (s *creator) AddPayoutAccount(ctx context.Context, userID int64, form *creator_dto.PayoutAccount) error {
tid, err := s.getTenantID(ctx, userID)
func (s *creator) AddPayoutAccount(ctx context.Context, tenantID, userID int64, form *creator_dto.PayoutAccount) error {
tid, err := s.getTenantID(ctx, tenantID, userID)
if err != nil {
return err
}
@@ -827,8 +830,8 @@ func (s *creator) AddPayoutAccount(ctx context.Context, userID int64, form *crea
return nil
}
func (s *creator) RemovePayoutAccount(ctx context.Context, userID, id int64) error {
tid, err := s.getTenantID(ctx, userID)
func (s *creator) RemovePayoutAccount(ctx context.Context, tenantID, userID, id int64) error {
tid, err := s.getTenantID(ctx, tenantID, userID)
if err != nil {
return err
}
@@ -842,8 +845,8 @@ func (s *creator) RemovePayoutAccount(ctx context.Context, userID, id int64) err
return nil
}
func (s *creator) Withdraw(ctx context.Context, userID int64, form *creator_dto.WithdrawForm) error {
tid, err := s.getTenantID(ctx, userID)
func (s *creator) Withdraw(ctx context.Context, tenantID, userID int64, form *creator_dto.WithdrawForm) error {
tid, err := s.getTenantID(ctx, tenantID, userID)
if err != nil {
return err
}
@@ -920,7 +923,7 @@ func (s *creator) Withdraw(ctx context.Context, userID int64, form *creator_dto.
// Helpers
func (s *creator) getTenantID(ctx context.Context, userID int64) (int64, error) {
func (s *creator) getTenantID(ctx context.Context, tenantID, userID int64) (int64, error) {
if userID == 0 {
return 0, errorx.ErrUnauthorized
}
@@ -934,5 +937,8 @@ func (s *creator) getTenantID(ctx context.Context, userID int64) (int64, error)
}
return 0, errorx.ErrDatabaseError.WithCause(err)
}
if tenantID > 0 && t.ID != tenantID {
return 0, errorx.ErrPermissionDenied.WithMsg("无权限访问该租户")
}
return t.ID, nil
}

View File

@@ -40,6 +40,7 @@ func Test_Creator(t *testing.T) {
func (s *CreatorTestSuite) Test_Apply() {
Convey("Apply", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(0)
database.Truncate(ctx, s.DB, models.TableNameTenant, models.TableNameTenantUser, models.TableNameUser)
u := &models.User{Username: "creator1", Phone: "13700000001"}
@@ -50,7 +51,7 @@ func (s *CreatorTestSuite) Test_Apply() {
form := &creator_dto.ApplyForm{
Name: "My Channel",
}
err := Creator.Apply(ctx, u.ID, form)
err := Creator.Apply(ctx, tenantID, u.ID, form)
So(err, ShouldBeNil)
t, _ := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.UserID.Eq(u.ID)).First()
@@ -72,6 +73,7 @@ func (s *CreatorTestSuite) Test_Apply() {
func (s *CreatorTestSuite) Test_CreateContent() {
Convey("CreateContent", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(0)
database.Truncate(
ctx,
s.DB,
@@ -89,6 +91,7 @@ func (s *CreatorTestSuite) Test_CreateContent() {
// Create Tenant manually
t := &models.Tenant{UserID: u.ID, Name: "Channel 2", Code: "123", Status: consts.TenantStatusVerified}
models.TenantQuery.WithContext(ctx).Create(t)
tenantID = t.ID
Convey("should create content and assets", func() {
form := &creator_dto.ContentCreateForm{
@@ -97,7 +100,7 @@ func (s *CreatorTestSuite) Test_CreateContent() {
Price: 9.99,
// MediaIDs: ... need media asset
}
err := Creator.CreateContent(ctx, u.ID, form)
err := Creator.CreateContent(ctx, tenantID, u.ID, form)
So(err, ShouldBeNil)
c, _ := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.Title.Eq("New Song")).First()
@@ -116,6 +119,7 @@ func (s *CreatorTestSuite) Test_CreateContent() {
func (s *CreatorTestSuite) Test_UpdateContent() {
Convey("UpdateContent", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(0)
database.Truncate(
ctx,
s.DB,
@@ -132,6 +136,7 @@ func (s *CreatorTestSuite) Test_UpdateContent() {
t := &models.Tenant{UserID: u.ID, Name: "Channel 3", Code: "124", Status: consts.TenantStatusVerified}
models.TenantQuery.WithContext(ctx).Create(t)
tenantID = t.ID
c := &models.Content{TenantID: t.ID, UserID: u.ID, Title: "Old Title", Genre: "audio"}
models.ContentQuery.WithContext(ctx).Create(c)
@@ -145,7 +150,7 @@ func (s *CreatorTestSuite) Test_UpdateContent() {
Genre: "video",
Price: &price,
}
err := Creator.UpdateContent(ctx, u.ID, c.ID, form)
err := Creator.UpdateContent(ctx, tenantID, u.ID, c.ID, form)
So(err, ShouldBeNil)
// Verify
@@ -162,6 +167,7 @@ func (s *CreatorTestSuite) Test_UpdateContent() {
func (s *CreatorTestSuite) Test_Dashboard() {
Convey("Dashboard", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(0)
database.Truncate(
ctx,
s.DB,
@@ -178,6 +184,7 @@ func (s *CreatorTestSuite) Test_Dashboard() {
t := &models.Tenant{UserID: u.ID, Name: "Channel 4", Code: "125", Status: consts.TenantStatusVerified}
models.TenantQuery.WithContext(ctx).Create(t)
tenantID = t.ID
// Mock Data
// 1. Followers
@@ -198,7 +205,7 @@ func (s *CreatorTestSuite) Test_Dashboard() {
)
Convey("should get stats", func() {
stats, err := Creator.Dashboard(ctx, u.ID)
stats, err := Creator.Dashboard(ctx, tenantID, u.ID)
So(err, ShouldBeNil)
So(stats.TotalFollowers.Value, ShouldEqual, 2)
// Implementation sums 'debit_purchase' only based on my code
@@ -210,6 +217,7 @@ func (s *CreatorTestSuite) Test_Dashboard() {
func (s *CreatorTestSuite) Test_PayoutAccount() {
Convey("PayoutAccount", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(0)
database.Truncate(ctx, s.DB, models.TableNameTenant, models.TableNamePayoutAccount, models.TableNameUser)
u := &models.User{Username: "creator5", Phone: "13700000005"}
@@ -218,6 +226,7 @@ func (s *CreatorTestSuite) Test_PayoutAccount() {
t := &models.Tenant{UserID: u.ID, Name: "Channel 5", Code: "126", Status: consts.TenantStatusVerified}
models.TenantQuery.WithContext(ctx).Create(t)
tenantID = t.ID
Convey("should CRUD payout account", func() {
// Add
@@ -227,21 +236,21 @@ func (s *CreatorTestSuite) Test_PayoutAccount() {
Account: "user@example.com",
Realname: "John Doe",
}
err := Creator.AddPayoutAccount(ctx, u.ID, form)
err := Creator.AddPayoutAccount(ctx, tenantID, u.ID, form)
So(err, ShouldBeNil)
// List
list, err := Creator.ListPayoutAccounts(ctx, u.ID)
list, err := Creator.ListPayoutAccounts(ctx, tenantID, u.ID)
So(err, ShouldBeNil)
So(len(list), ShouldEqual, 1)
So(list[0].Account, ShouldEqual, "user@example.com")
// Remove
err = Creator.RemovePayoutAccount(ctx, u.ID, list[0].ID)
err = Creator.RemovePayoutAccount(ctx, tenantID, u.ID, list[0].ID)
So(err, ShouldBeNil)
// Verify Empty
list, err = Creator.ListPayoutAccounts(ctx, u.ID)
list, err = Creator.ListPayoutAccounts(ctx, tenantID, u.ID)
So(err, ShouldBeNil)
So(len(list), ShouldEqual, 0)
})
@@ -251,6 +260,7 @@ func (s *CreatorTestSuite) Test_PayoutAccount() {
func (s *CreatorTestSuite) Test_Withdraw() {
Convey("Withdraw", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(0)
database.Truncate(
ctx,
s.DB,
@@ -267,6 +277,7 @@ func (s *CreatorTestSuite) Test_Withdraw() {
t := &models.Tenant{UserID: u.ID, Name: "Channel 6", Code: "127", Status: consts.TenantStatusVerified}
models.TenantQuery.WithContext(ctx).Create(t)
tenantID = t.ID
pa := &models.PayoutAccount{
TenantID: t.ID,
@@ -283,7 +294,7 @@ func (s *CreatorTestSuite) Test_Withdraw() {
Amount: 20.00,
AccountID: pa.ID,
}
err := Creator.Withdraw(ctx, u.ID, form)
err := Creator.Withdraw(ctx, tenantID, u.ID, form)
So(err, ShouldBeNil)
// Verify Balance Deducted
@@ -308,7 +319,7 @@ func (s *CreatorTestSuite) Test_Withdraw() {
Amount: 100.00,
AccountID: pa.ID,
}
err := Creator.Withdraw(ctx, u.ID, form)
err := Creator.Withdraw(ctx, tenantID, u.ID, form)
So(err, ShouldNotBeNil)
})
})
@@ -317,6 +328,7 @@ func (s *CreatorTestSuite) Test_Withdraw() {
func (s *CreatorTestSuite) Test_Refund() {
Convey("Refund", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(0)
database.Truncate(ctx, s.DB,
models.TableNameTenant, models.TableNameUser, models.TableNameOrder,
models.TableNameOrderItem, models.TableNameContentAccess, models.TableNameTenantLedger,
@@ -330,6 +342,7 @@ func (s *CreatorTestSuite) Test_Refund() {
// Tenant
t := &models.Tenant{UserID: creator.ID, Name: "Channel 7", Code: "128", Status: consts.TenantStatusVerified}
models.TenantQuery.WithContext(ctx).Create(t)
tenantID = t.ID
// Buyer
buyer := &models.User{Username: "buyer7", Phone: "13900000007", Balance: 0}
@@ -349,7 +362,7 @@ func (s *CreatorTestSuite) Test_Refund() {
Convey("should accept refund", func() {
form := &creator_dto.RefundForm{Action: "accept", Reason: "Defective"}
err := Creator.ProcessRefund(ctx, creator.ID, o.ID, form)
err := Creator.ProcessRefund(ctx, tenantID, creator.ID, o.ID, form)
So(err, ShouldBeNil)
// Verify Order

View File

@@ -21,14 +21,19 @@ import (
// @provider
type order struct{}
func (s *order) ListUserOrders(ctx context.Context, userID int64, status string) ([]user_dto.Order, error) {
func (s *order) ListUserOrders(ctx context.Context, tenantID, userID int64, status string) ([]user_dto.Order, error) {
if userID == 0 {
return nil, errorx.ErrUnauthorized
}
uid := userID
tbl, q := models.OrderQuery.QueryContext(ctx)
q = q.Where(tbl.UserID.Eq(uid))
if tenantID > 0 {
q = q.Where(tbl.UserID.Eq(uid), tbl.TenantID.Eq(tenantID)).
Or(tbl.UserID.Eq(uid), tbl.Type.Eq(consts.OrderTypeRecharge))
} else {
q = q.Where(tbl.UserID.Eq(uid))
}
if status != "" && status != "all" {
q = q.Where(tbl.Status.Eq(consts.OrderStatus(status)))
@@ -46,14 +51,21 @@ func (s *order) ListUserOrders(ctx context.Context, userID int64, status string)
return data, nil
}
func (s *order) GetUserOrder(ctx context.Context, userID, id int64) (*user_dto.Order, error) {
func (s *order) GetUserOrder(ctx context.Context, tenantID, userID, id int64) (*user_dto.Order, error) {
if userID == 0 {
return nil, errorx.ErrUnauthorized
}
uid := userID
tbl, q := models.OrderQuery.QueryContext(ctx)
item, err := q.Where(tbl.ID.Eq(id), tbl.UserID.Eq(uid)).First()
itemQuery := q
if tenantID > 0 {
itemQuery = itemQuery.Where(tbl.ID.Eq(id), tbl.UserID.Eq(uid), tbl.TenantID.Eq(tenantID)).
Or(tbl.ID.Eq(id), tbl.UserID.Eq(uid), tbl.Type.Eq(consts.OrderTypeRecharge))
} else {
itemQuery = itemQuery.Where(tbl.ID.Eq(id), tbl.UserID.Eq(uid))
}
item, err := itemQuery.First()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errorx.ErrRecordNotFound
@@ -70,6 +82,7 @@ func (s *order) GetUserOrder(ctx context.Context, userID, id int64) (*user_dto.O
func (s *order) Create(
ctx context.Context,
tenantID int64,
userID int64,
form *transaction_dto.OrderCreateForm,
) (*transaction_dto.OrderCreateResponse, error) {
@@ -86,7 +99,11 @@ func (s *order) Create(
}
if idempotencyKey != "" {
tbl, q := models.OrderQuery.QueryContext(ctx)
existing, err := q.Where(tbl.UserID.Eq(uid), tbl.IdempotencyKey.Eq(idempotencyKey)).First()
q = q.Where(tbl.UserID.Eq(uid), tbl.IdempotencyKey.Eq(idempotencyKey))
if tenantID > 0 {
q = q.Where(tbl.TenantID.Eq(tenantID))
}
existing, err := q.First()
if err == nil {
return &transaction_dto.OrderCreateResponse{OrderID: existing.ID}, nil
}
@@ -96,7 +113,11 @@ func (s *order) Create(
}
// 1. Fetch Content & Price
content, err := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.ID.Eq(cid)).First()
contentQuery := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.ID.Eq(cid))
if tenantID > 0 {
contentQuery = contentQuery.Where(models.ContentQuery.TenantID.Eq(tenantID))
}
content, err := contentQuery.First()
if err != nil {
return nil, errorx.ErrRecordNotFound.WithMsg("内容不存在")
}
@@ -188,6 +209,7 @@ func (s *order) Create(
func (s *order) Pay(
ctx context.Context,
tenantID int64,
userID int64,
id int64,
form *transaction_dto.OrderPayForm,
@@ -204,6 +226,9 @@ func (s *order) Pay(
if err != nil {
return nil, errorx.ErrRecordNotFound
}
if tenantID > 0 && o.TenantID > 0 && o.TenantID != tenantID {
return nil, errorx.ErrForbidden.WithMsg("租户不匹配")
}
if o.Status != consts.OrderStatusCreated {
return nil, errorx.ErrStatusConflict.WithMsg("订单状态不可支付")
}
@@ -219,11 +244,14 @@ func (s *order) Pay(
}
// ProcessExternalPayment handles callback from payment gateway
func (s *order) ProcessExternalPayment(ctx context.Context, orderID int64, externalID string) error {
func (s *order) ProcessExternalPayment(ctx context.Context, tenantID, orderID int64, externalID string) error {
o, err := models.OrderQuery.WithContext(ctx).Where(models.OrderQuery.ID.Eq(orderID)).First()
if err != nil {
return errorx.ErrRecordNotFound
}
if tenantID > 0 && o.TenantID > 0 && o.TenantID != tenantID {
return errorx.ErrForbidden.WithMsg("租户不匹配")
}
if o.Status != consts.OrderStatusCreated {
return nil // Already processed idempotency
}
@@ -365,7 +393,7 @@ func (s *order) settleOrder(ctx context.Context, o *models.Order, method, extern
return nil
}
func (s *order) Status(ctx context.Context, id int64) (*transaction_dto.OrderStatusResponse, error) {
func (s *order) Status(ctx context.Context, tenantID, id int64) (*transaction_dto.OrderStatusResponse, error) {
o, err := models.OrderQuery.WithContext(ctx).Where(models.OrderQuery.ID.Eq(id)).First()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -373,6 +401,9 @@ func (s *order) Status(ctx context.Context, id int64) (*transaction_dto.OrderSta
}
return nil, errorx.ErrDatabaseError.WithCause(err)
}
if tenantID > 0 && o.TenantID > 0 && o.TenantID != tenantID {
return nil, errorx.ErrForbidden.WithMsg("租户不匹配")
}
return &transaction_dto.OrderStatusResponse{
Status: string(o.Status),

View File

@@ -39,6 +39,7 @@ func Test_Order(t *testing.T) {
func (s *OrderTestSuite) Test_PurchaseFlow() {
Convey("Purchase Flow", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(0)
database.Truncate(ctx, s.DB,
models.TableNameOrder, models.TableNameOrderItem, models.TableNameUser,
models.TableNameContent, models.TableNameContentPrice, models.TableNameTenant,
@@ -57,6 +58,7 @@ func (s *OrderTestSuite) Test_PurchaseFlow() {
Status: consts.TenantStatusVerified,
}
models.TenantQuery.WithContext(ctx).Create(tenant)
tenantID = tenant.ID
// Content
content := &models.Content{
TenantID: tenant.ID,
@@ -83,7 +85,7 @@ func (s *OrderTestSuite) Test_PurchaseFlow() {
Convey("should create and pay order successfully", func() {
// Step 1: Create Order
form := &order_dto.OrderCreateForm{ContentID: content.ID}
createRes, err := Order.Create(ctx, buyer.ID, form)
createRes, err := Order.Create(ctx, tenantID, buyer.ID, form)
So(err, ShouldBeNil)
So(createRes.OrderID, ShouldNotBeEmpty)
@@ -95,7 +97,7 @@ func (s *OrderTestSuite) Test_PurchaseFlow() {
// Step 2: Pay Order
payForm := &order_dto.OrderPayForm{Method: "balance"}
_, err = Order.Pay(ctx, buyer.ID, createRes.OrderID, payForm)
_, err = Order.Pay(ctx, tenantID, buyer.ID, createRes.OrderID, payForm)
So(err, ShouldBeNil)
// Verify Order Paid
@@ -130,11 +132,11 @@ func (s *OrderTestSuite) Test_PurchaseFlow() {
Update(models.UserQuery.Balance, 500)
form := &order_dto.OrderCreateForm{ContentID: content.ID}
createRes, err := Order.Create(ctx, buyer.ID, form)
createRes, err := Order.Create(ctx, tenantID, buyer.ID, form)
So(err, ShouldBeNil)
payForm := &order_dto.OrderPayForm{Method: "balance"}
_, err = Order.Pay(ctx, buyer.ID, createRes.OrderID, payForm)
_, err = Order.Pay(ctx, tenantID, buyer.ID, createRes.OrderID, payForm)
So(err, ShouldNotBeNil)
// Error should be QuotaExceeded or similar
})
@@ -144,6 +146,7 @@ func (s *OrderTestSuite) Test_PurchaseFlow() {
func (s *OrderTestSuite) Test_OrderDetails() {
Convey("Order Details", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(0)
database.Truncate(
ctx,
s.DB,
@@ -164,6 +167,7 @@ func (s *OrderTestSuite) Test_OrderDetails() {
models.UserQuery.WithContext(ctx).Create(creator)
tenant := &models.Tenant{UserID: creator.ID, Name: "Best Shop", Status: consts.TenantStatusVerified}
models.TenantQuery.WithContext(ctx).Create(tenant)
tenantID = tenant.ID
content := &models.Content{
TenantID: tenant.ID,
UserID: creator.ID,
@@ -199,13 +203,14 @@ func (s *OrderTestSuite) Test_OrderDetails() {
// Create & Pay
createRes, _ := Order.Create(
ctx,
tenantID,
buyer.ID,
&order_dto.OrderCreateForm{ContentID: content.ID},
)
Order.Pay(ctx, buyer.ID, createRes.OrderID, &order_dto.OrderPayForm{Method: "balance"})
Order.Pay(ctx, tenantID, buyer.ID, createRes.OrderID, &order_dto.OrderPayForm{Method: "balance"})
// Get Detail
detail, err := Order.GetUserOrder(ctx, buyer.ID, createRes.OrderID)
detail, err := Order.GetUserOrder(ctx, tenantID, buyer.ID, createRes.OrderID)
So(err, ShouldBeNil)
So(detail.TenantName, ShouldEqual, "Best Shop")
So(len(detail.Items), ShouldEqual, 1)
@@ -219,6 +224,7 @@ func (s *OrderTestSuite) Test_OrderDetails() {
func (s *OrderTestSuite) Test_PlatformCommission() {
Convey("Platform Commission", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(0)
database.Truncate(
ctx,
s.DB,
@@ -236,6 +242,7 @@ func (s *OrderTestSuite) Test_PlatformCommission() {
// Tenant
t := &models.Tenant{UserID: creator.ID, Name: "Shop C", Status: consts.TenantStatusVerified}
models.TenantQuery.WithContext(ctx).Create(t)
tenantID = t.ID
// Buyer
buyer := &models.User{Username: "buyer_c", Balance: 2000}
models.UserQuery.WithContext(ctx).Create(buyer)
@@ -253,7 +260,7 @@ func (s *OrderTestSuite) Test_PlatformCommission() {
Convey("should deduct 10% fee", func() {
payForm := &order_dto.OrderPayForm{Method: "balance"}
_, err := Order.Pay(ctx, buyer.ID, o.ID, payForm)
_, err := Order.Pay(ctx, tenantID, buyer.ID, o.ID, payForm)
So(err, ShouldBeNil)
// Verify Creator Balance (1000 - 10% = 900)
@@ -270,6 +277,7 @@ func (s *OrderTestSuite) Test_PlatformCommission() {
func (s *OrderTestSuite) Test_ExternalPayment() {
Convey("External Payment", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(0)
database.Truncate(
ctx,
s.DB,
@@ -287,6 +295,7 @@ func (s *OrderTestSuite) Test_ExternalPayment() {
// Tenant
t := &models.Tenant{UserID: creator.ID, Name: "Shop Ext", Status: consts.TenantStatusVerified}
models.TenantQuery.WithContext(ctx).Create(t)
tenantID = t.ID
// Buyer (Balance 0)
buyer := &models.User{Username: "buyer_ext", Balance: 0}
models.UserQuery.WithContext(ctx).Create(buyer)
@@ -302,7 +311,7 @@ func (s *OrderTestSuite) Test_ExternalPayment() {
models.OrderItemQuery.WithContext(ctx).Create(&models.OrderItem{OrderID: o.ID, ContentID: 999})
Convey("should process external payment callback", func() {
err := Order.ProcessExternalPayment(ctx, o.ID, "ext_tx_id_123")
err := Order.ProcessExternalPayment(ctx, tenantID, o.ID, "ext_tx_id_123")
So(err, ShouldBeNil)
// Verify Status

View File

@@ -704,7 +704,7 @@ func (s *super) RefundOrder(ctx context.Context, id int64, form *super_dto.Super
return errorx.ErrRecordNotFound.WithMsg("租户不存在")
}
return Creator.ProcessRefund(ctx, t.UserID, id, &v1_dto.RefundForm{
return Creator.ProcessRefund(ctx, t.ID, t.UserID, id, &v1_dto.RefundForm{
Action: "accept",
Reason: form.Reason,
})

View File

@@ -17,9 +17,12 @@ import (
// @provider
type tenant struct{}
func (s *tenant) List(ctx context.Context, filter *dto.TenantListFilter) (*requests.Pager, error) {
func (s *tenant) List(ctx context.Context, tenantID int64, filter *dto.TenantListFilter) (*requests.Pager, error) {
tbl, q := models.TenantQuery.QueryContext(ctx)
q = q.Where(tbl.Status.Eq(consts.TenantStatusVerified))
if tenantID > 0 {
q = q.Where(tbl.ID.Eq(tenantID))
}
if filter.Keyword != nil && *filter.Keyword != "" {
q = q.Where(tbl.Name.Like("%" + *filter.Keyword + "%"))
@@ -73,8 +76,8 @@ func (s *tenant) List(ctx context.Context, filter *dto.TenantListFilter) (*reque
}, nil
}
func (s *tenant) GetPublicProfile(ctx context.Context, userID, id int64) (*dto.TenantProfile, error) {
t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.ID.Eq(id)).First()
func (s *tenant) GetPublicProfile(ctx context.Context, tenantID, userID int64) (*dto.TenantProfile, error) {
t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.ID.Eq(tenantID)).First()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errorx.ErrRecordNotFound
@@ -83,9 +86,9 @@ func (s *tenant) GetPublicProfile(ctx context.Context, userID, id int64) (*dto.T
}
// Stats
followers, _ := models.TenantUserQuery.WithContext(ctx).Where(models.TenantUserQuery.TenantID.Eq(id)).Count()
followers, _ := models.TenantUserQuery.WithContext(ctx).Where(models.TenantUserQuery.TenantID.Eq(tenantID)).Count()
contents, _ := models.ContentQuery.WithContext(ctx).
Where(models.ContentQuery.TenantID.Eq(id), models.ContentQuery.Status.Eq(consts.ContentStatusPublished)).
Where(models.ContentQuery.TenantID.Eq(tenantID), models.ContentQuery.Status.Eq(consts.ContentStatusPublished)).
Count()
// Following status
@@ -93,7 +96,7 @@ func (s *tenant) GetPublicProfile(ctx context.Context, userID, id int64) (*dto.T
if userID > 0 {
uid := userID
isFollowing, _ = models.TenantUserQuery.WithContext(ctx).
Where(models.TenantUserQuery.TenantID.Eq(id), models.TenantUserQuery.UserID.Eq(uid)).
Where(models.TenantUserQuery.TenantID.Eq(tenantID), models.TenantUserQuery.UserID.Eq(uid)).
Exists()
}
@@ -113,20 +116,20 @@ func (s *tenant) GetPublicProfile(ctx context.Context, userID, id int64) (*dto.T
}, nil
}
func (s *tenant) Follow(ctx context.Context, userID, id int64) error {
func (s *tenant) Follow(ctx context.Context, tenantID, userID int64) error {
if userID == 0 {
return errorx.ErrUnauthorized
}
uid := userID
// Check if tenant exists
t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.ID.Eq(id)).First()
t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.ID.Eq(tenantID)).First()
if err != nil {
return errorx.ErrRecordNotFound
}
tu := &models.TenantUser{
TenantID: id,
TenantID: tenantID,
UserID: uid,
Role: types.Array[consts.TenantUserRole]{consts.TenantUserRoleMember},
Status: consts.UserStatusVerified,
@@ -142,14 +145,14 @@ func (s *tenant) Follow(ctx context.Context, userID, id int64) error {
return nil
}
func (s *tenant) Unfollow(ctx context.Context, userID, id int64) error {
func (s *tenant) Unfollow(ctx context.Context, tenantID, userID int64) error {
if userID == 0 {
return errorx.ErrUnauthorized
}
uid := userID
_, err := models.TenantUserQuery.WithContext(ctx).
Where(models.TenantUserQuery.TenantID.Eq(id), models.TenantUserQuery.UserID.Eq(uid)).
Where(models.TenantUserQuery.TenantID.Eq(tenantID), models.TenantUserQuery.UserID.Eq(uid)).
Delete()
if err != nil {
return errorx.ErrDatabaseError.WithCause(err)
@@ -157,14 +160,18 @@ func (s *tenant) Unfollow(ctx context.Context, userID, id int64) error {
return nil
}
func (s *tenant) ListFollowed(ctx context.Context, userID int64) ([]dto.TenantProfile, error) {
func (s *tenant) ListFollowed(ctx context.Context, tenantID, userID int64) ([]dto.TenantProfile, error) {
if userID == 0 {
return nil, errorx.ErrUnauthorized
}
uid := userID
tbl, q := models.TenantUserQuery.QueryContext(ctx)
list, err := q.Where(tbl.UserID.Eq(uid)).Find()
q = q.Where(tbl.UserID.Eq(uid))
if tenantID > 0 {
q = q.Where(tbl.TenantID.Eq(tenantID))
}
list, err := q.Find()
if err != nil {
return nil, errorx.ErrDatabaseError.WithCause(err)
}

View File

@@ -39,6 +39,7 @@ func Test_Tenant(t *testing.T) {
func (s *TenantTestSuite) Test_Follow() {
Convey("Follow Flow", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(0)
database.Truncate(ctx, s.DB, models.TableNameTenant, models.TableNameTenantUser, models.TableNameUser)
// User
@@ -49,29 +50,30 @@ func (s *TenantTestSuite) Test_Follow() {
// Tenant
t := &models.Tenant{Name: "Tenant A", Status: consts.TenantStatusVerified}
models.TenantQuery.WithContext(ctx).Create(t)
tenantID = t.ID
Convey("should follow tenant", func() {
err := Tenant.Follow(ctx, u.ID, t.ID)
err := Tenant.Follow(ctx, tenantID, u.ID)
So(err, ShouldBeNil)
// Verify stats
profile, err := Tenant.GetPublicProfile(ctx, u.ID, t.ID)
profile, err := Tenant.GetPublicProfile(ctx, tenantID, u.ID)
So(err, ShouldBeNil)
So(profile.IsFollowing, ShouldBeTrue)
So(profile.Stats.Followers, ShouldEqual, 1)
// List Followed
list, err := Tenant.ListFollowed(ctx, u.ID)
list, err := Tenant.ListFollowed(ctx, tenantID, u.ID)
So(err, ShouldBeNil)
So(len(list), ShouldEqual, 1)
So(list[0].Name, ShouldEqual, "Tenant A")
// Unfollow
err = Tenant.Unfollow(ctx, u.ID, t.ID)
err = Tenant.Unfollow(ctx, tenantID, u.ID)
So(err, ShouldBeNil)
// Verify
profile, err = Tenant.GetPublicProfile(ctx, u.ID, t.ID)
profile, err = Tenant.GetPublicProfile(ctx, tenantID, u.ID)
So(err, ShouldBeNil)
So(profile.IsFollowing, ShouldBeFalse)
So(profile.Stats.Followers, ShouldEqual, 0)

View File

@@ -31,7 +31,7 @@ func (s *user) SendOTP(ctx context.Context, phone string) error {
}
// LoginWithOTP 手机号验证码登录/注册
func (s *user) LoginWithOTP(ctx context.Context, phone, otp string) (*auth_dto.LoginResponse, error) {
func (s *user) LoginWithOTP(ctx context.Context, tenantID int64, phone, otp string) (*auth_dto.LoginResponse, error) {
// 1. 校验验证码 (模拟:固定 123456)
if otp != "1234" {
return nil, errorx.ErrInvalidCredentials.WithMsg("验证码错误")
@@ -67,8 +67,8 @@ func (s *user) LoginWithOTP(ctx context.Context, phone, otp string) (*auth_dto.L
// 4. 生成 Token
token, err := s.jwt.CreateToken(s.jwt.CreateClaims(jwt.BaseClaims{
UserID: u.ID,
// TenantID: 0, // 初始登录无租户上下文
UserID: u.ID,
TenantID: tenantID,
}))
if err != nil {
return nil, errorx.ErrInternalError.WithMsg("生成令牌失败")

View File

@@ -40,11 +40,12 @@ func Test_User(t *testing.T) {
func (s *UserTestSuite) Test_LoginWithOTP() {
Convey("LoginWithOTP", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(1)
database.Truncate(ctx, s.DB, models.TableNameUser)
Convey("should create user and login success with correct OTP", func() {
phone := "13800138000"
resp, err := User.LoginWithOTP(ctx, phone, "1234")
resp, err := User.LoginWithOTP(ctx, tenantID, phone, "1234")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.Token, ShouldNotBeEmpty)
@@ -55,17 +56,17 @@ func (s *UserTestSuite) Test_LoginWithOTP() {
Convey("should login existing user", func() {
phone := "13800138001"
// Pre-create user
_, err := User.LoginWithOTP(ctx, phone, "1234")
_, err := User.LoginWithOTP(ctx, tenantID, phone, "1234")
So(err, ShouldBeNil)
// Login again
resp, err := User.LoginWithOTP(ctx, phone, "1234")
resp, err := User.LoginWithOTP(ctx, tenantID, phone, "1234")
So(err, ShouldBeNil)
So(resp.User.Phone, ShouldEqual, phone)
})
Convey("should fail with incorrect OTP", func() {
resp, err := User.LoginWithOTP(ctx, "13800138002", "000000")
resp, err := User.LoginWithOTP(ctx, tenantID, "13800138002", "000000")
So(err, ShouldNotBeNil)
So(resp, ShouldBeNil)
})
@@ -75,11 +76,12 @@ func (s *UserTestSuite) Test_LoginWithOTP() {
func (s *UserTestSuite) Test_Me() {
Convey("Me", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(1)
database.Truncate(ctx, s.DB, models.TableNameUser)
// Create user
phone := "13800138003"
resp, _ := User.LoginWithOTP(ctx, phone, "1234")
resp, _ := User.LoginWithOTP(ctx, tenantID, phone, "1234")
userID := resp.User.ID
Convey("should return user profile", func() {
@@ -104,10 +106,11 @@ func (s *UserTestSuite) Test_Me() {
func (s *UserTestSuite) Test_Update() {
Convey("Update", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(1)
database.Truncate(ctx, s.DB, models.TableNameUser)
phone := "13800138004"
resp, _ := User.LoginWithOTP(ctx, phone, "1234")
resp, _ := User.LoginWithOTP(ctx, tenantID, phone, "1234")
userID := resp.User.ID
ctx = context.WithValue(ctx, consts.CtxKeyUser, userID)
@@ -132,10 +135,11 @@ func (s *UserTestSuite) Test_Update() {
func (s *UserTestSuite) Test_RealName() {
Convey("RealName", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(1)
database.Truncate(ctx, s.DB, models.TableNameUser)
phone := "13800138005"
resp, _ := User.LoginWithOTP(ctx, phone, "1234")
resp, _ := User.LoginWithOTP(ctx, tenantID, phone, "1234")
userID := resp.User.ID
ctx = context.WithValue(ctx, consts.CtxKeyUser, userID)
@@ -157,10 +161,11 @@ func (s *UserTestSuite) Test_RealName() {
func (s *UserTestSuite) Test_GetNotifications() {
Convey("GetNotifications", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(1)
database.Truncate(ctx, s.DB, models.TableNameUser, models.TableNameNotification)
phone := "13800138006"
resp, _ := User.LoginWithOTP(ctx, phone, "1234")
resp, _ := User.LoginWithOTP(ctx, tenantID, phone, "1234")
userID := resp.User.ID
ctx = context.WithValue(ctx, consts.CtxKeyUser, userID)

View File

@@ -20,7 +20,7 @@ import (
// @provider
type wallet struct{}
func (s *wallet) GetWallet(ctx context.Context, userID int64) (*user_dto.WalletResponse, error) {
func (s *wallet) GetWallet(ctx context.Context, tenantID, userID int64) (*user_dto.WalletResponse, error) {
// Get Balance
u, err := models.UserQuery.WithContext(ctx).Where(models.UserQuery.ID.Eq(userID)).First()
if err != nil {
@@ -33,7 +33,13 @@ func (s *wallet) GetWallet(ctx context.Context, userID int64) (*user_dto.WalletR
// Get Transactions (Orders)
// Both purchase (expense) and recharge (income - if paid)
tbl, q := models.OrderQuery.QueryContext(ctx)
orders, err := q.Where(tbl.UserID.Eq(userID), tbl.Status.Eq(consts.OrderStatusPaid)).
if tenantID > 0 {
q = q.Where(tbl.UserID.Eq(userID), tbl.Status.Eq(consts.OrderStatusPaid), tbl.TenantID.Eq(tenantID)).
Or(tbl.UserID.Eq(userID), tbl.Status.Eq(consts.OrderStatusPaid), tbl.Type.Eq(consts.OrderTypeRecharge))
} else {
q = q.Where(tbl.UserID.Eq(userID), tbl.Status.Eq(consts.OrderStatusPaid))
}
orders, err := q.
Order(tbl.CreatedAt.Desc()).
Limit(20). // Limit to recent 20
Find()
@@ -71,6 +77,7 @@ func (s *wallet) GetWallet(ctx context.Context, userID int64) (*user_dto.WalletR
func (s *wallet) Recharge(
ctx context.Context,
tenantID int64,
userID int64,
form *user_dto.RechargeForm,
) (*user_dto.RechargeResponse, error) {
@@ -98,7 +105,7 @@ func (s *wallet) Recharge(
// MOCK: Automatically pay for recharge order to close the loop
// In production, this would be a callback from payment gateway
if err := Order.ProcessExternalPayment(ctx, order.ID, "mock_auto_pay"); err != nil {
if err := Order.ProcessExternalPayment(ctx, tenantID, order.ID, "mock_auto_pay"); err != nil {
return nil, err
}

View File

@@ -40,6 +40,7 @@ func Test_Wallet(t *testing.T) {
func (s *WalletTestSuite) Test_GetWallet() {
Convey("GetWallet", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(1)
database.Truncate(ctx, s.DB, models.TableNameUser, models.TableNameOrder)
u := &models.User{Username: "wallet_user", Balance: 5000} // 50.00
@@ -58,7 +59,7 @@ func (s *WalletTestSuite) Test_GetWallet() {
models.OrderQuery.WithContext(ctx).Create(o1, o2)
Convey("should return balance and transactions", func() {
res, err := Wallet.GetWallet(ctx, u.ID)
res, err := Wallet.GetWallet(ctx, tenantID, u.ID)
So(err, ShouldBeNil)
So(res.Balance, ShouldEqual, 50.0)
So(len(res.Transactions), ShouldEqual, 2)
@@ -74,6 +75,7 @@ func (s *WalletTestSuite) Test_GetWallet() {
func (s *WalletTestSuite) Test_Recharge() {
Convey("Recharge", s.T(), func() {
ctx := s.T().Context()
tenantID := int64(1)
database.Truncate(ctx, s.DB, models.TableNameUser, models.TableNameOrder)
u := &models.User{Username: "recharge_user"}
@@ -82,7 +84,7 @@ func (s *WalletTestSuite) Test_Recharge() {
Convey("should create recharge order", func() {
form := &user_dto.RechargeForm{Amount: 100.0}
res, err := Wallet.Recharge(ctx, u.ID, form)
res, err := Wallet.Recharge(ctx, tenantID, u.ID, form)
So(err, ShouldBeNil)
So(res.OrderID, ShouldNotBeEmpty)