feat: tenant-scoped routing and portal navigation
This commit is contained in:
@@ -27,6 +27,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgconn"
|
||||
"go.ipao.vip/gen/types"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// @provider
|
||||
@@ -49,7 +50,7 @@ func (s *common) Options(ctx context.Context) (*common_dto.OptionsResponse, erro
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *common) CheckHash(ctx context.Context, userID int64, hash string) (*common_dto.UploadResult, error) {
|
||||
func (s *common) CheckHash(ctx context.Context, tenantID, userID int64, hash string) (*common_dto.UploadResult, error) {
|
||||
existing, err := models.MediaAssetQuery.WithContext(ctx).Where(models.MediaAssetQuery.Hash.Eq(hash)).First()
|
||||
if err != nil {
|
||||
return nil, nil // Not found, proceed to upload
|
||||
@@ -58,18 +59,25 @@ func (s *common) CheckHash(ctx context.Context, userID int64, hash string) (*com
|
||||
// Found existing file (Global deduplication hit)
|
||||
|
||||
// Check if user already has it (Logic deduplication hit)
|
||||
myExisting, err := models.MediaAssetQuery.WithContext(ctx).
|
||||
Where(models.MediaAssetQuery.Hash.Eq(hash), models.MediaAssetQuery.UserID.Eq(userID)).
|
||||
First()
|
||||
myQuery := models.MediaAssetQuery.WithContext(ctx).
|
||||
Where(models.MediaAssetQuery.Hash.Eq(hash), models.MediaAssetQuery.UserID.Eq(userID))
|
||||
if tenantID > 0 {
|
||||
myQuery = myQuery.Where(models.MediaAssetQuery.TenantID.Eq(tenantID))
|
||||
}
|
||||
myExisting, err := myQuery.First()
|
||||
if err == nil {
|
||||
return s.composeUploadResult(myExisting), nil
|
||||
}
|
||||
|
||||
// Create new record for this user reusing existing ObjectKey
|
||||
t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.UserID.Eq(userID)).First()
|
||||
var tid int64 = 0
|
||||
if err == nil {
|
||||
tid = t.ID
|
||||
// 优先使用路径租户,避免跨租户写入。
|
||||
tenant, err := s.resolveTenant(ctx, tenantID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var tid int64
|
||||
if tenant != nil {
|
||||
tid = tenant.ID
|
||||
}
|
||||
|
||||
asset := &models.MediaAsset{
|
||||
@@ -107,13 +115,47 @@ func (s *common) buildObjectKey(tenant *models.Tenant, hash, filename string) st
|
||||
return path.Join("quyun", tenantUUID, hash+ext)
|
||||
}
|
||||
|
||||
func (s *common) InitUpload(ctx context.Context, userID int64, form *common_dto.UploadInitForm) (*common_dto.UploadInitResponse, error) {
|
||||
func (s *common) resolveTenant(ctx context.Context, tenantID, userID int64) (*models.Tenant, error) {
|
||||
if tenantID > 0 {
|
||||
tbl, q := models.TenantQuery.QueryContext(ctx)
|
||||
tenant, err := q.Where(tbl.ID.Eq(tenantID)).First()
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errorx.ErrRecordNotFound.WithMsg("租户不存在")
|
||||
}
|
||||
return nil, errorx.ErrDatabaseError.WithCause(err)
|
||||
}
|
||||
return tenant, nil
|
||||
}
|
||||
if userID == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
tbl, q := models.TenantQuery.QueryContext(ctx)
|
||||
tenant, err := q.Where(tbl.UserID.Eq(userID)).First()
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errorx.ErrDatabaseError.WithCause(err)
|
||||
}
|
||||
return tenant, nil
|
||||
}
|
||||
|
||||
func (s *common) uploadTempDir(localPath string, tenantID int64, uploadID string) string {
|
||||
tenantKey := "public"
|
||||
if tenantID > 0 {
|
||||
tenantKey = strconv.FormatInt(tenantID, 10)
|
||||
}
|
||||
return filepath.Join(localPath, "temp", tenantKey, uploadID)
|
||||
}
|
||||
|
||||
func (s *common) InitUpload(ctx context.Context, tenantID, userID int64, form *common_dto.UploadInitForm) (*common_dto.UploadInitResponse, error) {
|
||||
uploadID := uuid.NewString()
|
||||
localPath := s.storage.Config.LocalPath
|
||||
if localPath == "" {
|
||||
localPath = "./storage"
|
||||
}
|
||||
tempDir := filepath.Join(localPath, "temp", uploadID)
|
||||
tempDir := s.uploadTempDir(localPath, tenantID, uploadID)
|
||||
if err := os.MkdirAll(tempDir, 0o755); err != nil {
|
||||
return nil, errorx.ErrInternalError.WithCause(err)
|
||||
}
|
||||
@@ -134,12 +176,12 @@ func (s *common) InitUpload(ctx context.Context, userID int64, form *common_dto.
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *common) UploadPart(ctx context.Context, userID int64, file *multipart.FileHeader, form *common_dto.UploadPartForm) error {
|
||||
func (s *common) UploadPart(ctx context.Context, tenantID, userID int64, file *multipart.FileHeader, form *common_dto.UploadPartForm) error {
|
||||
localPath := s.storage.Config.LocalPath
|
||||
if localPath == "" {
|
||||
localPath = "./storage"
|
||||
}
|
||||
partPath := filepath.Join(localPath, "temp", form.UploadID, strconv.Itoa(form.PartNumber))
|
||||
partPath := filepath.Join(s.uploadTempDir(localPath, tenantID, form.UploadID), strconv.Itoa(form.PartNumber))
|
||||
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
@@ -159,12 +201,12 @@ func (s *common) UploadPart(ctx context.Context, userID int64, file *multipart.F
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *common) CompleteUpload(ctx context.Context, userID int64, form *common_dto.UploadCompleteForm) (*common_dto.UploadResult, error) {
|
||||
func (s *common) CompleteUpload(ctx context.Context, tenantID, userID int64, form *common_dto.UploadCompleteForm) (*common_dto.UploadResult, error) {
|
||||
localPath := s.storage.Config.LocalPath
|
||||
if localPath == "" {
|
||||
localPath = "./storage"
|
||||
}
|
||||
tempDir := filepath.Join(localPath, "temp", form.UploadID)
|
||||
tempDir := s.uploadTempDir(localPath, tenantID, form.UploadID)
|
||||
|
||||
// Read Meta
|
||||
var meta UploadMeta
|
||||
@@ -220,21 +262,27 @@ func (s *common) CompleteUpload(ctx context.Context, userID int64, form *common_
|
||||
dst.Close() // Ensure flush before potential removal
|
||||
|
||||
// Deduplication Logic (Similar to Upload)
|
||||
t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.UserID.Eq(userID)).First()
|
||||
var tid int64 = 0
|
||||
if err == nil {
|
||||
tid = t.ID
|
||||
tenant, err := s.resolveTenant(ctx, tenantID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var tid int64
|
||||
if tenant != nil {
|
||||
tid = tenant.ID
|
||||
}
|
||||
|
||||
objectKey := s.buildObjectKey(t, hash, meta.Filename)
|
||||
objectKey := s.buildObjectKey(tenant, hash, meta.Filename)
|
||||
existing, err := models.MediaAssetQuery.WithContext(ctx).Where(models.MediaAssetQuery.Hash.Eq(hash)).First()
|
||||
var asset *models.MediaAsset
|
||||
|
||||
if err == nil {
|
||||
os.Remove(mergedPath) // Delete duplicate
|
||||
myExisting, err := models.MediaAssetQuery.WithContext(ctx).
|
||||
Where(models.MediaAssetQuery.Hash.Eq(hash), models.MediaAssetQuery.UserID.Eq(userID)).
|
||||
First()
|
||||
myQuery := models.MediaAssetQuery.WithContext(ctx).
|
||||
Where(models.MediaAssetQuery.Hash.Eq(hash), models.MediaAssetQuery.UserID.Eq(userID))
|
||||
if tenantID > 0 {
|
||||
myQuery = myQuery.Where(models.MediaAssetQuery.TenantID.Eq(tenantID))
|
||||
}
|
||||
myExisting, err := myQuery.First()
|
||||
if err == nil {
|
||||
os.RemoveAll(tempDir)
|
||||
return s.composeUploadResult(myExisting), nil
|
||||
@@ -282,10 +330,13 @@ func (s *common) CompleteUpload(ctx context.Context, userID int64, form *common_
|
||||
return s.composeUploadResult(asset), nil
|
||||
}
|
||||
|
||||
func (s *common) DeleteMediaAsset(ctx context.Context, userID, id int64) error {
|
||||
asset, err := models.MediaAssetQuery.WithContext(ctx).
|
||||
Where(models.MediaAssetQuery.ID.Eq(id), models.MediaAssetQuery.UserID.Eq(userID)).
|
||||
First()
|
||||
func (s *common) DeleteMediaAsset(ctx context.Context, tenantID, userID, id int64) error {
|
||||
query := models.MediaAssetQuery.WithContext(ctx).
|
||||
Where(models.MediaAssetQuery.ID.Eq(id), models.MediaAssetQuery.UserID.Eq(userID))
|
||||
if tenantID > 0 {
|
||||
query = query.Where(models.MediaAssetQuery.TenantID.Eq(tenantID))
|
||||
}
|
||||
asset, err := query.First()
|
||||
if err != nil {
|
||||
return errorx.ErrRecordNotFound
|
||||
}
|
||||
@@ -308,17 +359,18 @@ func (s *common) DeleteMediaAsset(ctx context.Context, userID, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *common) AbortUpload(ctx context.Context, userID int64, uploadId string) error {
|
||||
func (s *common) AbortUpload(ctx context.Context, tenantID, userID int64, uploadId string) error {
|
||||
localPath := s.storage.Config.LocalPath
|
||||
if localPath == "" {
|
||||
localPath = "./storage"
|
||||
}
|
||||
tempDir := filepath.Join(localPath, "temp", uploadId)
|
||||
tempDir := s.uploadTempDir(localPath, tenantID, uploadId)
|
||||
return os.RemoveAll(tempDir)
|
||||
}
|
||||
|
||||
func (s *common) Upload(
|
||||
ctx context.Context,
|
||||
tenantID int64,
|
||||
userID int64,
|
||||
file *multipart.FileHeader,
|
||||
typeArg string,
|
||||
@@ -357,13 +409,16 @@ func (s *common) Upload(
|
||||
|
||||
hash := hex.EncodeToString(hasher.Sum(nil))
|
||||
|
||||
t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.UserID.Eq(userID)).First()
|
||||
var tid int64 = 0
|
||||
if err == nil {
|
||||
tid = t.ID
|
||||
tenant, err := s.resolveTenant(ctx, tenantID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var tid int64
|
||||
if tenant != nil {
|
||||
tid = tenant.ID
|
||||
}
|
||||
|
||||
objectKey := s.buildObjectKey(t, hash, file.Filename)
|
||||
objectKey := s.buildObjectKey(tenant, hash, file.Filename)
|
||||
var asset *models.MediaAsset
|
||||
|
||||
// Deduplication Check
|
||||
@@ -374,9 +429,12 @@ func (s *common) Upload(
|
||||
os.RemoveAll(tmpDir)
|
||||
|
||||
// Check if user already has it (Logic Deduplication)
|
||||
myExisting, err := models.MediaAssetQuery.WithContext(ctx).
|
||||
Where(models.MediaAssetQuery.Hash.Eq(hash), models.MediaAssetQuery.UserID.Eq(userID)).
|
||||
First()
|
||||
myQuery := models.MediaAssetQuery.WithContext(ctx).
|
||||
Where(models.MediaAssetQuery.Hash.Eq(hash), models.MediaAssetQuery.UserID.Eq(userID))
|
||||
if tenantID > 0 {
|
||||
myQuery = myQuery.Where(models.MediaAssetQuery.TenantID.Eq(tenantID))
|
||||
}
|
||||
myExisting, err := myQuery.First()
|
||||
if err == nil {
|
||||
return s.composeUploadResult(myExisting), nil
|
||||
}
|
||||
|
||||
@@ -18,11 +18,14 @@ import (
|
||||
// @provider
|
||||
type content struct{}
|
||||
|
||||
func (s *content) List(ctx context.Context, filter *content_dto.ContentListFilter) (*requests.Pager, error) {
|
||||
func (s *content) List(ctx context.Context, tenantID int64, filter *content_dto.ContentListFilter) (*requests.Pager, error) {
|
||||
tbl, q := models.ContentQuery.QueryContext(ctx)
|
||||
|
||||
// Filters
|
||||
q = q.Where(tbl.Status.Eq(consts.ContentStatusPublished))
|
||||
if tenantID > 0 {
|
||||
q = q.Where(tbl.TenantID.Eq(tenantID))
|
||||
}
|
||||
if filter.Keyword != nil && *filter.Keyword != "" {
|
||||
keyword := "%" + *filter.Keyword + "%"
|
||||
q = q.Where(tbl.Title.Like(keyword)).Or(tbl.Description.Like(keyword))
|
||||
@@ -31,6 +34,9 @@ func (s *content) List(ctx context.Context, filter *content_dto.ContentListFilte
|
||||
q = q.Where(tbl.Genre.Eq(*filter.Genre))
|
||||
}
|
||||
if filter.TenantID != nil && *filter.TenantID > 0 {
|
||||
if tenantID > 0 && *filter.TenantID != tenantID {
|
||||
return nil, errorx.ErrForbidden.WithMsg("租户不匹配")
|
||||
}
|
||||
q = q.Where(tbl.TenantID.Eq(*filter.TenantID))
|
||||
}
|
||||
if filter.IsPinned != nil {
|
||||
@@ -128,16 +134,22 @@ func (s *content) List(ctx context.Context, filter *content_dto.ContentListFilte
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *content) Get(ctx context.Context, userID, id int64) (*content_dto.ContentDetail, error) {
|
||||
func (s *content) Get(ctx context.Context, tenantID, userID, id int64) (*content_dto.ContentDetail, error) {
|
||||
// Increment Views
|
||||
_, _ = models.ContentQuery.WithContext(ctx).
|
||||
Where(models.ContentQuery.ID.Eq(id)).
|
||||
UpdateSimple(models.ContentQuery.Views.Add(1))
|
||||
update := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.ID.Eq(id))
|
||||
if tenantID > 0 {
|
||||
update = update.Where(models.ContentQuery.TenantID.Eq(tenantID))
|
||||
}
|
||||
_, _ = update.UpdateSimple(models.ContentQuery.Views.Add(1))
|
||||
|
||||
_, q := models.ContentQuery.QueryContext(ctx)
|
||||
|
||||
var item models.Content
|
||||
err := q.UnderlyingDB().
|
||||
db := q.UnderlyingDB()
|
||||
if tenantID > 0 {
|
||||
db = db.Where("tenant_id = ?", tenantID)
|
||||
}
|
||||
err := db.
|
||||
Preload("Author").
|
||||
Preload("ContentAssets", func(db *gorm.DB) *gorm.DB {
|
||||
return db.Order("sort ASC")
|
||||
@@ -232,10 +244,25 @@ func (s *content) Get(ctx context.Context, userID, id int64) (*content_dto.Conte
|
||||
return detail, nil
|
||||
}
|
||||
|
||||
func (s *content) ListComments(ctx context.Context, userID, id int64, page int) (*requests.Pager, error) {
|
||||
func (s *content) ListComments(ctx context.Context, tenantID, userID, id int64, page int) (*requests.Pager, error) {
|
||||
if tenantID > 0 {
|
||||
_, err := models.ContentQuery.WithContext(ctx).
|
||||
Where(models.ContentQuery.ID.Eq(id), models.ContentQuery.TenantID.Eq(tenantID)).
|
||||
First()
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errorx.ErrRecordNotFound
|
||||
}
|
||||
return nil, errorx.ErrDatabaseError.WithCause(err)
|
||||
}
|
||||
}
|
||||
|
||||
tbl, q := models.CommentQuery.QueryContext(ctx)
|
||||
|
||||
q = q.Where(tbl.ContentID.Eq(id)).Preload(tbl.User)
|
||||
if tenantID > 0 {
|
||||
q = q.Where(tbl.TenantID.Eq(tenantID))
|
||||
}
|
||||
q = q.Order(tbl.CreatedAt.Desc())
|
||||
|
||||
p := requests.Pagination{Page: int64(page), Limit: 10}
|
||||
@@ -291,6 +318,7 @@ func (s *content) ListComments(ctx context.Context, userID, id int64, page int)
|
||||
|
||||
func (s *content) CreateComment(
|
||||
ctx context.Context,
|
||||
tenantID int64,
|
||||
userID int64,
|
||||
id int64,
|
||||
form *content_dto.CommentCreateForm,
|
||||
@@ -300,7 +328,11 @@ func (s *content) CreateComment(
|
||||
}
|
||||
uid := userID
|
||||
|
||||
c, err := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.ID.Eq(id)).First()
|
||||
query := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.ID.Eq(id))
|
||||
if tenantID > 0 {
|
||||
query = query.Where(models.ContentQuery.TenantID.Eq(tenantID))
|
||||
}
|
||||
c, err := query.First()
|
||||
if err != nil {
|
||||
return errorx.ErrRecordNotFound
|
||||
}
|
||||
@@ -319,14 +351,18 @@ func (s *content) CreateComment(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *content) LikeComment(ctx context.Context, userID, id int64) error {
|
||||
func (s *content) LikeComment(ctx context.Context, tenantID, userID, id int64) error {
|
||||
if userID == 0 {
|
||||
return errorx.ErrUnauthorized
|
||||
}
|
||||
uid := userID
|
||||
|
||||
// Fetch comment for author
|
||||
cm, err := models.CommentQuery.WithContext(ctx).Where(models.CommentQuery.ID.Eq(id)).First()
|
||||
query := models.CommentQuery.WithContext(ctx).Where(models.CommentQuery.ID.Eq(id))
|
||||
if tenantID > 0 {
|
||||
query = query.Where(models.CommentQuery.TenantID.Eq(tenantID))
|
||||
}
|
||||
cm, err := query.First()
|
||||
if err != nil {
|
||||
return errorx.ErrRecordNotFound
|
||||
}
|
||||
@@ -357,14 +393,18 @@ func (s *content) LikeComment(ctx context.Context, userID, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *content) GetLibrary(ctx context.Context, userID int64) ([]user_dto.ContentItem, error) {
|
||||
func (s *content) GetLibrary(ctx context.Context, tenantID, userID int64) ([]user_dto.ContentItem, error) {
|
||||
if userID == 0 {
|
||||
return nil, errorx.ErrUnauthorized
|
||||
}
|
||||
uid := userID
|
||||
|
||||
tbl, q := models.ContentAccessQuery.QueryContext(ctx)
|
||||
accessList, err := q.Where(tbl.UserID.Eq(uid), tbl.Status.Eq(consts.ContentAccessStatusActive)).Find()
|
||||
q = q.Where(tbl.UserID.Eq(uid), tbl.Status.Eq(consts.ContentAccessStatusActive))
|
||||
if tenantID > 0 {
|
||||
q = q.Where(tbl.TenantID.Eq(tenantID))
|
||||
}
|
||||
accessList, err := q.Find()
|
||||
if err != nil {
|
||||
return nil, errorx.ErrDatabaseError.WithCause(err)
|
||||
}
|
||||
@@ -380,7 +420,11 @@ func (s *content) GetLibrary(ctx context.Context, userID int64) ([]user_dto.Cont
|
||||
|
||||
ctbl, cq := models.ContentQuery.QueryContext(ctx)
|
||||
var list []*models.Content
|
||||
err = cq.Where(ctbl.ID.In(contentIDs...)).
|
||||
cq = cq.Where(ctbl.ID.In(contentIDs...))
|
||||
if tenantID > 0 {
|
||||
cq = cq.Where(ctbl.TenantID.Eq(tenantID))
|
||||
}
|
||||
err = cq.
|
||||
UnderlyingDB().
|
||||
Preload("Author").
|
||||
Preload("ContentAssets.Asset").
|
||||
@@ -398,36 +442,40 @@ func (s *content) GetLibrary(ctx context.Context, userID int64) ([]user_dto.Cont
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (s *content) GetFavorites(ctx context.Context, userID int64) ([]user_dto.ContentItem, error) {
|
||||
return s.getInteractList(ctx, userID, "favorite")
|
||||
func (s *content) GetFavorites(ctx context.Context, tenantID, userID int64) ([]user_dto.ContentItem, error) {
|
||||
return s.getInteractList(ctx, tenantID, userID, "favorite")
|
||||
}
|
||||
|
||||
func (s *content) AddFavorite(ctx context.Context, userID, contentId int64) error {
|
||||
return s.addInteract(ctx, userID, contentId, "favorite")
|
||||
func (s *content) AddFavorite(ctx context.Context, tenantID, userID, contentId int64) error {
|
||||
return s.addInteract(ctx, tenantID, userID, contentId, "favorite")
|
||||
}
|
||||
|
||||
func (s *content) RemoveFavorite(ctx context.Context, userID, contentId int64) error {
|
||||
return s.removeInteract(ctx, userID, contentId, "favorite")
|
||||
func (s *content) RemoveFavorite(ctx context.Context, tenantID, userID, contentId int64) error {
|
||||
return s.removeInteract(ctx, tenantID, userID, contentId, "favorite")
|
||||
}
|
||||
|
||||
func (s *content) GetLikes(ctx context.Context, userID int64) ([]user_dto.ContentItem, error) {
|
||||
return s.getInteractList(ctx, userID, "like")
|
||||
func (s *content) GetLikes(ctx context.Context, tenantID, userID int64) ([]user_dto.ContentItem, error) {
|
||||
return s.getInteractList(ctx, tenantID, userID, "like")
|
||||
}
|
||||
|
||||
func (s *content) AddLike(ctx context.Context, userID, contentId int64) error {
|
||||
return s.addInteract(ctx, userID, contentId, "like")
|
||||
func (s *content) AddLike(ctx context.Context, tenantID, userID, contentId int64) error {
|
||||
return s.addInteract(ctx, tenantID, userID, contentId, "like")
|
||||
}
|
||||
|
||||
func (s *content) RemoveLike(ctx context.Context, userID, contentId int64) error {
|
||||
return s.removeInteract(ctx, userID, contentId, "like")
|
||||
func (s *content) RemoveLike(ctx context.Context, tenantID, userID, contentId int64) error {
|
||||
return s.removeInteract(ctx, tenantID, userID, contentId, "like")
|
||||
}
|
||||
|
||||
func (s *content) ListTopics(ctx context.Context) ([]content_dto.Topic, error) {
|
||||
func (s *content) ListTopics(ctx context.Context, tenantID int64) ([]content_dto.Topic, error) {
|
||||
var results []struct {
|
||||
Genre string
|
||||
Count int
|
||||
}
|
||||
err := models.ContentQuery.WithContext(ctx).UnderlyingDB().
|
||||
db := models.ContentQuery.WithContext(ctx).UnderlyingDB()
|
||||
if tenantID > 0 {
|
||||
db = db.Where("tenant_id = ?", tenantID)
|
||||
}
|
||||
err := db.
|
||||
Model(&models.Content{}).
|
||||
Where("status = ?", consts.ContentStatusPublished).
|
||||
Select("genre, count(*) as count").
|
||||
@@ -445,8 +493,12 @@ func (s *content) ListTopics(ctx context.Context) ([]content_dto.Topic, error) {
|
||||
|
||||
// Fetch latest content in this genre to get a cover
|
||||
var c models.Content
|
||||
models.ContentQuery.WithContext(ctx).
|
||||
Where(models.ContentQuery.Genre.Eq(r.Genre), models.ContentQuery.Status.Eq(consts.ContentStatusPublished)).
|
||||
query := models.ContentQuery.WithContext(ctx).
|
||||
Where(models.ContentQuery.Genre.Eq(r.Genre), models.ContentQuery.Status.Eq(consts.ContentStatusPublished))
|
||||
if tenantID > 0 {
|
||||
query = query.Where(models.ContentQuery.TenantID.Eq(tenantID))
|
||||
}
|
||||
query.
|
||||
Order(models.ContentQuery.PublishedAt.Desc()).
|
||||
UnderlyingDB().
|
||||
Preload("ContentAssets").
|
||||
@@ -554,15 +606,19 @@ func (s *content) toMediaURLs(assets []*models.ContentAsset) []content_dto.Media
|
||||
return urls
|
||||
}
|
||||
|
||||
func (s *content) addInteract(ctx context.Context, userID, contentId int64, typ string) error {
|
||||
func (s *content) addInteract(ctx context.Context, tenantID, userID, contentId int64, typ string) error {
|
||||
if userID == 0 {
|
||||
return errorx.ErrUnauthorized
|
||||
}
|
||||
uid := userID
|
||||
|
||||
// Fetch content for author
|
||||
c, err := models.ContentQuery.WithContext(ctx).
|
||||
Where(models.ContentQuery.ID.Eq(contentId)).
|
||||
query := models.ContentQuery.WithContext(ctx).
|
||||
Where(models.ContentQuery.ID.Eq(contentId))
|
||||
if tenantID > 0 {
|
||||
query = query.Where(models.ContentQuery.TenantID.Eq(tenantID))
|
||||
}
|
||||
c, err := query.
|
||||
Select(models.ContentQuery.UserID, models.ContentQuery.Title).
|
||||
First()
|
||||
if err != nil {
|
||||
@@ -583,7 +639,11 @@ func (s *content) addInteract(ctx context.Context, userID, contentId int64, typ
|
||||
}
|
||||
|
||||
if typ == "like" {
|
||||
_, err := tx.Content.WithContext(ctx).Where(tx.Content.ID.Eq(contentId)).UpdateSimple(tx.Content.Likes.Add(1))
|
||||
contentQuery := tx.Content.WithContext(ctx).Where(tx.Content.ID.Eq(contentId))
|
||||
if tenantID > 0 {
|
||||
contentQuery = contentQuery.Where(tx.Content.TenantID.Eq(tenantID))
|
||||
}
|
||||
_, err := contentQuery.UpdateSimple(tx.Content.Likes.Add(1))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -605,7 +665,7 @@ func (s *content) addInteract(ctx context.Context, userID, contentId int64, typ
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *content) removeInteract(ctx context.Context, userID, contentId int64, typ string) error {
|
||||
func (s *content) removeInteract(ctx context.Context, tenantID, userID, contentId int64, typ string) error {
|
||||
if userID == 0 {
|
||||
return errorx.ErrUnauthorized
|
||||
}
|
||||
@@ -623,14 +683,18 @@ func (s *content) removeInteract(ctx context.Context, userID, contentId int64, t
|
||||
}
|
||||
|
||||
if typ == "like" {
|
||||
_, err := tx.Content.WithContext(ctx).Where(tx.Content.ID.Eq(contentId)).UpdateSimple(tx.Content.Likes.Sub(1))
|
||||
contentQuery := tx.Content.WithContext(ctx).Where(tx.Content.ID.Eq(contentId))
|
||||
if tenantID > 0 {
|
||||
contentQuery = contentQuery.Where(tx.Content.TenantID.Eq(tenantID))
|
||||
}
|
||||
_, err := contentQuery.UpdateSimple(tx.Content.Likes.Sub(1))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *content) getInteractList(ctx context.Context, userID int64, typ string) ([]user_dto.ContentItem, error) {
|
||||
func (s *content) getInteractList(ctx context.Context, tenantID, userID int64, typ string) ([]user_dto.ContentItem, error) {
|
||||
if userID == 0 {
|
||||
return nil, errorx.ErrUnauthorized
|
||||
}
|
||||
@@ -653,7 +717,11 @@ func (s *content) getInteractList(ctx context.Context, userID int64, typ string)
|
||||
|
||||
ctbl, cq := models.ContentQuery.QueryContext(ctx)
|
||||
var list []*models.Content
|
||||
err = cq.Where(ctbl.ID.In(contentIDs...)).
|
||||
cq = cq.Where(ctbl.ID.In(contentIDs...))
|
||||
if tenantID > 0 {
|
||||
cq = cq.Where(ctbl.TenantID.Eq(tenantID))
|
||||
}
|
||||
err = cq.
|
||||
UnderlyingDB().
|
||||
Preload("Author").
|
||||
Preload("ContentAssets.Asset").
|
||||
|
||||
@@ -41,6 +41,7 @@ func Test_Content(t *testing.T) {
|
||||
func (s *ContentTestSuite) Test_List() {
|
||||
Convey("List", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(1)
|
||||
database.Truncate(ctx, s.DB, models.TableNameContent, models.TableNameUser)
|
||||
|
||||
// Create Author
|
||||
@@ -73,7 +74,7 @@ func (s *ContentTestSuite) Test_List() {
|
||||
Limit: 10,
|
||||
},
|
||||
}
|
||||
res, err := Content.List(ctx, filter)
|
||||
res, err := Content.List(ctx, tenantID, filter)
|
||||
So(err, ShouldBeNil)
|
||||
So(res.Total, ShouldEqual, 1)
|
||||
items := res.Items.([]content_dto.ContentItem)
|
||||
@@ -86,6 +87,7 @@ func (s *ContentTestSuite) Test_List() {
|
||||
func (s *ContentTestSuite) Test_Get() {
|
||||
Convey("Get", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(1)
|
||||
database.Truncate(ctx, s.DB, models.TableNameContent, models.TableNameMediaAsset, models.TableNameContentAsset, models.TableNameUser)
|
||||
|
||||
// Author
|
||||
@@ -125,7 +127,7 @@ func (s *ContentTestSuite) Test_Get() {
|
||||
ctx = context.WithValue(ctx, consts.CtxKeyUser, author.ID)
|
||||
|
||||
Convey("should get detail with assets", func() {
|
||||
detail, err := Content.Get(ctx, author.ID, content.ID)
|
||||
detail, err := Content.Get(ctx, tenantID, author.ID, content.ID)
|
||||
So(err, ShouldBeNil)
|
||||
So(detail.Title, ShouldEqual, "Detail Content")
|
||||
So(detail.AuthorName, ShouldEqual, "Author1")
|
||||
@@ -138,6 +140,7 @@ func (s *ContentTestSuite) Test_Get() {
|
||||
func (s *ContentTestSuite) Test_CreateComment() {
|
||||
Convey("CreateComment", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(1)
|
||||
database.Truncate(ctx, s.DB, models.TableNameContent, models.TableNameComment, models.TableNameUser)
|
||||
|
||||
// User & Content
|
||||
@@ -153,7 +156,7 @@ func (s *ContentTestSuite) Test_CreateComment() {
|
||||
form := &content_dto.CommentCreateForm{
|
||||
Content: "Nice!",
|
||||
}
|
||||
err := Content.CreateComment(ctx, u.ID, c.ID, form)
|
||||
err := Content.CreateComment(ctx, tenantID, u.ID, c.ID, form)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
count, _ := models.CommentQuery.WithContext(ctx).Where(models.CommentQuery.ContentID.Eq(c.ID)).Count()
|
||||
@@ -165,6 +168,7 @@ func (s *ContentTestSuite) Test_CreateComment() {
|
||||
func (s *ContentTestSuite) Test_Library() {
|
||||
Convey("Library", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(1)
|
||||
database.Truncate(ctx, s.DB, models.TableNameContent, models.TableNameContentAccess, models.TableNameUser, models.TableNameContentAsset, models.TableNameMediaAsset)
|
||||
|
||||
// User
|
||||
@@ -192,7 +196,7 @@ func (s *ContentTestSuite) Test_Library() {
|
||||
})
|
||||
|
||||
Convey("should get library content with details", func() {
|
||||
list, err := Content.GetLibrary(ctx, u.ID)
|
||||
list, err := Content.GetLibrary(ctx, tenantID, u.ID)
|
||||
So(err, ShouldBeNil)
|
||||
So(len(list), ShouldEqual, 1)
|
||||
So(list[0].Title, ShouldEqual, "Paid Content")
|
||||
@@ -206,6 +210,7 @@ func (s *ContentTestSuite) Test_Library() {
|
||||
func (s *ContentTestSuite) Test_Interact() {
|
||||
Convey("Interact", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(1)
|
||||
database.Truncate(ctx, s.DB, models.TableNameContent, models.TableNameUserContentAction, models.TableNameUser)
|
||||
|
||||
// User & Content
|
||||
@@ -218,7 +223,7 @@ func (s *ContentTestSuite) Test_Interact() {
|
||||
|
||||
Convey("Like flow", func() {
|
||||
// Add Like
|
||||
err := Content.AddLike(ctx, u.ID, c.ID)
|
||||
err := Content.AddLike(ctx, tenantID, u.ID, c.ID)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Verify count
|
||||
@@ -226,13 +231,13 @@ func (s *ContentTestSuite) Test_Interact() {
|
||||
So(cReload.Likes, ShouldEqual, 1)
|
||||
|
||||
// Get Likes
|
||||
likes, err := Content.GetLikes(ctx, u.ID)
|
||||
likes, err := Content.GetLikes(ctx, tenantID, u.ID)
|
||||
So(err, ShouldBeNil)
|
||||
So(len(likes), ShouldEqual, 1)
|
||||
So(likes[0].ID, ShouldEqual, c.ID)
|
||||
|
||||
// Remove Like
|
||||
err = Content.RemoveLike(ctx, u.ID, c.ID)
|
||||
err = Content.RemoveLike(ctx, tenantID, u.ID, c.ID)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Verify count
|
||||
@@ -242,21 +247,21 @@ func (s *ContentTestSuite) Test_Interact() {
|
||||
|
||||
Convey("Favorite flow", func() {
|
||||
// Add Favorite
|
||||
err := Content.AddFavorite(ctx, u.ID, c.ID)
|
||||
err := Content.AddFavorite(ctx, tenantID, u.ID, c.ID)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Get Favorites
|
||||
favs, err := Content.GetFavorites(ctx, u.ID)
|
||||
favs, err := Content.GetFavorites(ctx, tenantID, u.ID)
|
||||
So(err, ShouldBeNil)
|
||||
So(len(favs), ShouldEqual, 1)
|
||||
So(favs[0].ID, ShouldEqual, c.ID)
|
||||
|
||||
// Remove Favorite
|
||||
err = Content.RemoveFavorite(ctx, u.ID, c.ID)
|
||||
err = Content.RemoveFavorite(ctx, tenantID, u.ID, c.ID)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Get Favorites
|
||||
favs, err = Content.GetFavorites(ctx, u.ID)
|
||||
favs, err = Content.GetFavorites(ctx, tenantID, u.ID)
|
||||
So(err, ShouldBeNil)
|
||||
So(len(favs), ShouldEqual, 0)
|
||||
})
|
||||
@@ -266,6 +271,7 @@ func (s *ContentTestSuite) Test_Interact() {
|
||||
func (s *ContentTestSuite) Test_ListTopics() {
|
||||
Convey("ListTopics", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(1)
|
||||
database.Truncate(ctx, s.DB, models.TableNameContent, models.TableNameUser)
|
||||
|
||||
u := &models.User{Username: "user_t", Phone: "13900000005"}
|
||||
@@ -280,7 +286,7 @@ func (s *ContentTestSuite) Test_ListTopics() {
|
||||
)
|
||||
|
||||
Convey("should aggregate topics", func() {
|
||||
topics, err := Content.ListTopics(ctx)
|
||||
topics, err := Content.ListTopics(ctx, tenantID)
|
||||
So(err, ShouldBeNil)
|
||||
So(len(topics), ShouldBeGreaterThanOrEqualTo, 2)
|
||||
|
||||
@@ -302,6 +308,7 @@ func (s *ContentTestSuite) Test_ListTopics() {
|
||||
func (s *ContentTestSuite) Test_PreviewLogic() {
|
||||
Convey("Preview Logic", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(1)
|
||||
database.Truncate(ctx, s.DB, models.TableNameContent, models.TableNameContentAsset, models.TableNameContentAccess, models.TableNameUser, models.TableNameMediaAsset)
|
||||
|
||||
author := &models.User{Username: "author_p", Phone: "13900000006"}
|
||||
@@ -324,7 +331,7 @@ func (s *ContentTestSuite) Test_PreviewLogic() {
|
||||
models.UserQuery.WithContext(ctx).Create(guest)
|
||||
guestCtx := context.WithValue(ctx, consts.CtxKeyUser, guest.ID)
|
||||
|
||||
detail, err := Content.Get(guestCtx, 0, c.ID)
|
||||
detail, err := Content.Get(guestCtx, tenantID, 0, c.ID)
|
||||
So(err, ShouldBeNil)
|
||||
So(len(detail.MediaUrls), ShouldEqual, 1)
|
||||
So(detail.MediaUrls[0].URL, ShouldContainSubstring, "preview.mp4")
|
||||
@@ -333,7 +340,7 @@ func (s *ContentTestSuite) Test_PreviewLogic() {
|
||||
|
||||
Convey("owner should see all", func() {
|
||||
ownerCtx := context.WithValue(ctx, consts.CtxKeyUser, author.ID)
|
||||
detail, err := Content.Get(ownerCtx, author.ID, c.ID)
|
||||
detail, err := Content.Get(ownerCtx, tenantID, author.ID, c.ID)
|
||||
So(err, ShouldBeNil)
|
||||
So(len(detail.MediaUrls), ShouldEqual, 2)
|
||||
So(detail.IsPurchased, ShouldBeTrue)
|
||||
@@ -348,7 +355,7 @@ func (s *ContentTestSuite) Test_PreviewLogic() {
|
||||
UserID: buyer.ID, ContentID: c.ID, Status: consts.ContentAccessStatusActive,
|
||||
})
|
||||
|
||||
detail, err := Content.Get(buyerCtx, buyer.ID, c.ID)
|
||||
detail, err := Content.Get(buyerCtx, tenantID, buyer.ID, c.ID)
|
||||
So(err, ShouldBeNil)
|
||||
So(len(detail.MediaUrls), ShouldEqual, 2)
|
||||
So(detail.IsPurchased, ShouldBeTrue)
|
||||
@@ -359,6 +366,7 @@ func (s *ContentTestSuite) Test_PreviewLogic() {
|
||||
func (s *ContentTestSuite) Test_ViewCounting() {
|
||||
Convey("ViewCounting", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(1)
|
||||
database.Truncate(ctx, s.DB, models.TableNameContent, models.TableNameUser)
|
||||
|
||||
author := &models.User{Username: "author_v", Phone: "13900000009"}
|
||||
@@ -368,7 +376,7 @@ func (s *ContentTestSuite) Test_ViewCounting() {
|
||||
models.ContentQuery.WithContext(ctx).Create(c)
|
||||
|
||||
Convey("should increment views", func() {
|
||||
_, err := Content.Get(ctx, 0, c.ID)
|
||||
_, err := Content.Get(ctx, tenantID, 0, c.ID)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
cReload, _ := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.ID.Eq(c.ID)).First()
|
||||
|
||||
@@ -39,6 +39,7 @@ func Test_Coupon(t *testing.T) {
|
||||
func (s *CouponTestSuite) Test_CouponFlow() {
|
||||
Convey("Coupon Flow", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(1)
|
||||
database.Truncate(
|
||||
ctx,
|
||||
s.DB,
|
||||
@@ -83,9 +84,10 @@ func (s *CouponTestSuite) Test_CouponFlow() {
|
||||
|
||||
Convey("should apply in Order.Create", func() {
|
||||
// Setup Content
|
||||
c := &models.Content{UserID: 99, Title: "Test", Status: consts.ContentStatusPublished}
|
||||
c := &models.Content{TenantID: tenantID, UserID: 99, Title: "Test", Status: consts.ContentStatusPublished}
|
||||
models.ContentQuery.WithContext(ctx).Create(c)
|
||||
models.ContentPriceQuery.WithContext(ctx).Create(&models.ContentPrice{
|
||||
TenantID: tenantID,
|
||||
ContentID: c.ID,
|
||||
PriceAmount: 2000, // 20.00 CNY
|
||||
Currency: "CNY",
|
||||
@@ -96,7 +98,7 @@ func (s *CouponTestSuite) Test_CouponFlow() {
|
||||
UserCouponID: uc.ID,
|
||||
}
|
||||
// Simulate Auth context for Order service
|
||||
res, err := Order.Create(ctx, user.ID, form)
|
||||
res, err := Order.Create(ctx, tenantID, user.ID, form)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Verify Order
|
||||
|
||||
@@ -31,7 +31,7 @@ var genreMap = map[string]string{
|
||||
"Qinqiang": "秦腔",
|
||||
}
|
||||
|
||||
func (s *creator) Apply(ctx context.Context, userID int64, form *creator_dto.ApplyForm) error {
|
||||
func (s *creator) Apply(ctx context.Context, tenantID, userID int64, form *creator_dto.ApplyForm) error {
|
||||
if userID == 0 {
|
||||
return errorx.ErrUnauthorized
|
||||
}
|
||||
@@ -72,8 +72,8 @@ func (s *creator) Apply(ctx context.Context, userID int64, form *creator_dto.App
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *creator) Dashboard(ctx context.Context, userID int64) (*creator_dto.DashboardStats, error) {
|
||||
tid, err := s.getTenantID(ctx, userID)
|
||||
func (s *creator) Dashboard(ctx context.Context, tenantID, userID int64) (*creator_dto.DashboardStats, error) {
|
||||
tid, err := s.getTenantID(ctx, tenantID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -107,10 +107,11 @@ func (s *creator) Dashboard(ctx context.Context, userID int64) (*creator_dto.Das
|
||||
|
||||
func (s *creator) ListContents(
|
||||
ctx context.Context,
|
||||
tenantID int64,
|
||||
userID int64,
|
||||
filter *creator_dto.CreatorContentListFilter,
|
||||
) (*requests.Pager, error) {
|
||||
tid, err := s.getTenantID(ctx, userID)
|
||||
tid, err := s.getTenantID(ctx, tenantID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -248,8 +249,8 @@ func (s *creator) ListContents(
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *creator) CreateContent(ctx context.Context, userID int64, form *creator_dto.ContentCreateForm) error {
|
||||
tid, err := s.getTenantID(ctx, userID)
|
||||
func (s *creator) CreateContent(ctx context.Context, tenantID, userID int64, form *creator_dto.ContentCreateForm) error {
|
||||
tid, err := s.getTenantID(ctx, tenantID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -321,11 +322,12 @@ func (s *creator) CreateContent(ctx context.Context, userID int64, form *creator
|
||||
|
||||
func (s *creator) UpdateContent(
|
||||
ctx context.Context,
|
||||
tenantID int64,
|
||||
userID int64,
|
||||
id int64,
|
||||
form *creator_dto.ContentUpdateForm,
|
||||
) error {
|
||||
tid, err := s.getTenantID(ctx, userID)
|
||||
tid, err := s.getTenantID(ctx, tenantID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -451,8 +453,8 @@ func (s *creator) UpdateContent(
|
||||
})
|
||||
}
|
||||
|
||||
func (s *creator) DeleteContent(ctx context.Context, userID, id int64) error {
|
||||
tid, err := s.getTenantID(ctx, userID)
|
||||
func (s *creator) DeleteContent(ctx context.Context, tenantID, userID, id int64) error {
|
||||
tid, err := s.getTenantID(ctx, tenantID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -472,8 +474,8 @@ func (s *creator) DeleteContent(ctx context.Context, userID, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *creator) GetContent(ctx context.Context, userID, id int64) (*creator_dto.ContentEditDTO, error) {
|
||||
tid, err := s.getTenantID(ctx, userID)
|
||||
func (s *creator) GetContent(ctx context.Context, tenantID, userID, id int64) (*creator_dto.ContentEditDTO, error) {
|
||||
tid, err := s.getTenantID(ctx, tenantID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -548,10 +550,11 @@ func (s *creator) GetContent(ctx context.Context, userID, id int64) (*creator_dt
|
||||
|
||||
func (s *creator) ListOrders(
|
||||
ctx context.Context,
|
||||
tenantID int64,
|
||||
userID int64,
|
||||
filter *creator_dto.CreatorOrderListFilter,
|
||||
) ([]creator_dto.Order, error) {
|
||||
tid, err := s.getTenantID(ctx, userID)
|
||||
tid, err := s.getTenantID(ctx, tenantID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -634,8 +637,8 @@ func (s *creator) ListOrders(
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (s *creator) ProcessRefund(ctx context.Context, userID, id int64, form *creator_dto.RefundForm) error {
|
||||
tid, err := s.getTenantID(ctx, userID)
|
||||
func (s *creator) ProcessRefund(ctx context.Context, tenantID, userID, id int64, form *creator_dto.RefundForm) error {
|
||||
tid, err := s.getTenantID(ctx, tenantID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -738,8 +741,8 @@ func (s *creator) ProcessRefund(ctx context.Context, userID, id int64, form *cre
|
||||
return errorx.ErrBadRequest.WithMsg("无效的操作")
|
||||
}
|
||||
|
||||
func (s *creator) GetSettings(ctx context.Context, userID int64) (*creator_dto.Settings, error) {
|
||||
tid, err := s.getTenantID(ctx, userID)
|
||||
func (s *creator) GetSettings(ctx context.Context, tenantID, userID int64) (*creator_dto.Settings, error) {
|
||||
tid, err := s.getTenantID(ctx, tenantID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -758,8 +761,8 @@ func (s *creator) GetSettings(ctx context.Context, userID int64) (*creator_dto.S
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *creator) UpdateSettings(ctx context.Context, userID int64, form *creator_dto.Settings) error {
|
||||
tid, err := s.getTenantID(ctx, userID)
|
||||
func (s *creator) UpdateSettings(ctx context.Context, tenantID, userID int64, form *creator_dto.Settings) error {
|
||||
tid, err := s.getTenantID(ctx, tenantID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -782,8 +785,8 @@ func (s *creator) UpdateSettings(ctx context.Context, userID int64, form *creato
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *creator) ListPayoutAccounts(ctx context.Context, userID int64) ([]creator_dto.PayoutAccount, error) {
|
||||
tid, err := s.getTenantID(ctx, userID)
|
||||
func (s *creator) ListPayoutAccounts(ctx context.Context, tenantID, userID int64) ([]creator_dto.PayoutAccount, error) {
|
||||
tid, err := s.getTenantID(ctx, tenantID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -806,8 +809,8 @@ func (s *creator) ListPayoutAccounts(ctx context.Context, userID int64) ([]creat
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (s *creator) AddPayoutAccount(ctx context.Context, userID int64, form *creator_dto.PayoutAccount) error {
|
||||
tid, err := s.getTenantID(ctx, userID)
|
||||
func (s *creator) AddPayoutAccount(ctx context.Context, tenantID, userID int64, form *creator_dto.PayoutAccount) error {
|
||||
tid, err := s.getTenantID(ctx, tenantID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -827,8 +830,8 @@ func (s *creator) AddPayoutAccount(ctx context.Context, userID int64, form *crea
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *creator) RemovePayoutAccount(ctx context.Context, userID, id int64) error {
|
||||
tid, err := s.getTenantID(ctx, userID)
|
||||
func (s *creator) RemovePayoutAccount(ctx context.Context, tenantID, userID, id int64) error {
|
||||
tid, err := s.getTenantID(ctx, tenantID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -842,8 +845,8 @@ func (s *creator) RemovePayoutAccount(ctx context.Context, userID, id int64) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *creator) Withdraw(ctx context.Context, userID int64, form *creator_dto.WithdrawForm) error {
|
||||
tid, err := s.getTenantID(ctx, userID)
|
||||
func (s *creator) Withdraw(ctx context.Context, tenantID, userID int64, form *creator_dto.WithdrawForm) error {
|
||||
tid, err := s.getTenantID(ctx, tenantID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -920,7 +923,7 @@ func (s *creator) Withdraw(ctx context.Context, userID int64, form *creator_dto.
|
||||
|
||||
// Helpers
|
||||
|
||||
func (s *creator) getTenantID(ctx context.Context, userID int64) (int64, error) {
|
||||
func (s *creator) getTenantID(ctx context.Context, tenantID, userID int64) (int64, error) {
|
||||
if userID == 0 {
|
||||
return 0, errorx.ErrUnauthorized
|
||||
}
|
||||
@@ -934,5 +937,8 @@ func (s *creator) getTenantID(ctx context.Context, userID int64) (int64, error)
|
||||
}
|
||||
return 0, errorx.ErrDatabaseError.WithCause(err)
|
||||
}
|
||||
if tenantID > 0 && t.ID != tenantID {
|
||||
return 0, errorx.ErrPermissionDenied.WithMsg("无权限访问该租户")
|
||||
}
|
||||
return t.ID, nil
|
||||
}
|
||||
|
||||
@@ -40,6 +40,7 @@ func Test_Creator(t *testing.T) {
|
||||
func (s *CreatorTestSuite) Test_Apply() {
|
||||
Convey("Apply", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(0)
|
||||
database.Truncate(ctx, s.DB, models.TableNameTenant, models.TableNameTenantUser, models.TableNameUser)
|
||||
|
||||
u := &models.User{Username: "creator1", Phone: "13700000001"}
|
||||
@@ -50,7 +51,7 @@ func (s *CreatorTestSuite) Test_Apply() {
|
||||
form := &creator_dto.ApplyForm{
|
||||
Name: "My Channel",
|
||||
}
|
||||
err := Creator.Apply(ctx, u.ID, form)
|
||||
err := Creator.Apply(ctx, tenantID, u.ID, form)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
t, _ := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.UserID.Eq(u.ID)).First()
|
||||
@@ -72,6 +73,7 @@ func (s *CreatorTestSuite) Test_Apply() {
|
||||
func (s *CreatorTestSuite) Test_CreateContent() {
|
||||
Convey("CreateContent", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(0)
|
||||
database.Truncate(
|
||||
ctx,
|
||||
s.DB,
|
||||
@@ -89,6 +91,7 @@ func (s *CreatorTestSuite) Test_CreateContent() {
|
||||
// Create Tenant manually
|
||||
t := &models.Tenant{UserID: u.ID, Name: "Channel 2", Code: "123", Status: consts.TenantStatusVerified}
|
||||
models.TenantQuery.WithContext(ctx).Create(t)
|
||||
tenantID = t.ID
|
||||
|
||||
Convey("should create content and assets", func() {
|
||||
form := &creator_dto.ContentCreateForm{
|
||||
@@ -97,7 +100,7 @@ func (s *CreatorTestSuite) Test_CreateContent() {
|
||||
Price: 9.99,
|
||||
// MediaIDs: ... need media asset
|
||||
}
|
||||
err := Creator.CreateContent(ctx, u.ID, form)
|
||||
err := Creator.CreateContent(ctx, tenantID, u.ID, form)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
c, _ := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.Title.Eq("New Song")).First()
|
||||
@@ -116,6 +119,7 @@ func (s *CreatorTestSuite) Test_CreateContent() {
|
||||
func (s *CreatorTestSuite) Test_UpdateContent() {
|
||||
Convey("UpdateContent", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(0)
|
||||
database.Truncate(
|
||||
ctx,
|
||||
s.DB,
|
||||
@@ -132,6 +136,7 @@ func (s *CreatorTestSuite) Test_UpdateContent() {
|
||||
|
||||
t := &models.Tenant{UserID: u.ID, Name: "Channel 3", Code: "124", Status: consts.TenantStatusVerified}
|
||||
models.TenantQuery.WithContext(ctx).Create(t)
|
||||
tenantID = t.ID
|
||||
|
||||
c := &models.Content{TenantID: t.ID, UserID: u.ID, Title: "Old Title", Genre: "audio"}
|
||||
models.ContentQuery.WithContext(ctx).Create(c)
|
||||
@@ -145,7 +150,7 @@ func (s *CreatorTestSuite) Test_UpdateContent() {
|
||||
Genre: "video",
|
||||
Price: &price,
|
||||
}
|
||||
err := Creator.UpdateContent(ctx, u.ID, c.ID, form)
|
||||
err := Creator.UpdateContent(ctx, tenantID, u.ID, c.ID, form)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Verify
|
||||
@@ -162,6 +167,7 @@ func (s *CreatorTestSuite) Test_UpdateContent() {
|
||||
func (s *CreatorTestSuite) Test_Dashboard() {
|
||||
Convey("Dashboard", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(0)
|
||||
database.Truncate(
|
||||
ctx,
|
||||
s.DB,
|
||||
@@ -178,6 +184,7 @@ func (s *CreatorTestSuite) Test_Dashboard() {
|
||||
|
||||
t := &models.Tenant{UserID: u.ID, Name: "Channel 4", Code: "125", Status: consts.TenantStatusVerified}
|
||||
models.TenantQuery.WithContext(ctx).Create(t)
|
||||
tenantID = t.ID
|
||||
|
||||
// Mock Data
|
||||
// 1. Followers
|
||||
@@ -198,7 +205,7 @@ func (s *CreatorTestSuite) Test_Dashboard() {
|
||||
)
|
||||
|
||||
Convey("should get stats", func() {
|
||||
stats, err := Creator.Dashboard(ctx, u.ID)
|
||||
stats, err := Creator.Dashboard(ctx, tenantID, u.ID)
|
||||
So(err, ShouldBeNil)
|
||||
So(stats.TotalFollowers.Value, ShouldEqual, 2)
|
||||
// Implementation sums 'debit_purchase' only based on my code
|
||||
@@ -210,6 +217,7 @@ func (s *CreatorTestSuite) Test_Dashboard() {
|
||||
func (s *CreatorTestSuite) Test_PayoutAccount() {
|
||||
Convey("PayoutAccount", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(0)
|
||||
database.Truncate(ctx, s.DB, models.TableNameTenant, models.TableNamePayoutAccount, models.TableNameUser)
|
||||
|
||||
u := &models.User{Username: "creator5", Phone: "13700000005"}
|
||||
@@ -218,6 +226,7 @@ func (s *CreatorTestSuite) Test_PayoutAccount() {
|
||||
|
||||
t := &models.Tenant{UserID: u.ID, Name: "Channel 5", Code: "126", Status: consts.TenantStatusVerified}
|
||||
models.TenantQuery.WithContext(ctx).Create(t)
|
||||
tenantID = t.ID
|
||||
|
||||
Convey("should CRUD payout account", func() {
|
||||
// Add
|
||||
@@ -227,21 +236,21 @@ func (s *CreatorTestSuite) Test_PayoutAccount() {
|
||||
Account: "user@example.com",
|
||||
Realname: "John Doe",
|
||||
}
|
||||
err := Creator.AddPayoutAccount(ctx, u.ID, form)
|
||||
err := Creator.AddPayoutAccount(ctx, tenantID, u.ID, form)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// List
|
||||
list, err := Creator.ListPayoutAccounts(ctx, u.ID)
|
||||
list, err := Creator.ListPayoutAccounts(ctx, tenantID, u.ID)
|
||||
So(err, ShouldBeNil)
|
||||
So(len(list), ShouldEqual, 1)
|
||||
So(list[0].Account, ShouldEqual, "user@example.com")
|
||||
|
||||
// Remove
|
||||
err = Creator.RemovePayoutAccount(ctx, u.ID, list[0].ID)
|
||||
err = Creator.RemovePayoutAccount(ctx, tenantID, u.ID, list[0].ID)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Verify Empty
|
||||
list, err = Creator.ListPayoutAccounts(ctx, u.ID)
|
||||
list, err = Creator.ListPayoutAccounts(ctx, tenantID, u.ID)
|
||||
So(err, ShouldBeNil)
|
||||
So(len(list), ShouldEqual, 0)
|
||||
})
|
||||
@@ -251,6 +260,7 @@ func (s *CreatorTestSuite) Test_PayoutAccount() {
|
||||
func (s *CreatorTestSuite) Test_Withdraw() {
|
||||
Convey("Withdraw", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(0)
|
||||
database.Truncate(
|
||||
ctx,
|
||||
s.DB,
|
||||
@@ -267,6 +277,7 @@ func (s *CreatorTestSuite) Test_Withdraw() {
|
||||
|
||||
t := &models.Tenant{UserID: u.ID, Name: "Channel 6", Code: "127", Status: consts.TenantStatusVerified}
|
||||
models.TenantQuery.WithContext(ctx).Create(t)
|
||||
tenantID = t.ID
|
||||
|
||||
pa := &models.PayoutAccount{
|
||||
TenantID: t.ID,
|
||||
@@ -283,7 +294,7 @@ func (s *CreatorTestSuite) Test_Withdraw() {
|
||||
Amount: 20.00,
|
||||
AccountID: pa.ID,
|
||||
}
|
||||
err := Creator.Withdraw(ctx, u.ID, form)
|
||||
err := Creator.Withdraw(ctx, tenantID, u.ID, form)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Verify Balance Deducted
|
||||
@@ -308,7 +319,7 @@ func (s *CreatorTestSuite) Test_Withdraw() {
|
||||
Amount: 100.00,
|
||||
AccountID: pa.ID,
|
||||
}
|
||||
err := Creator.Withdraw(ctx, u.ID, form)
|
||||
err := Creator.Withdraw(ctx, tenantID, u.ID, form)
|
||||
So(err, ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
@@ -317,6 +328,7 @@ func (s *CreatorTestSuite) Test_Withdraw() {
|
||||
func (s *CreatorTestSuite) Test_Refund() {
|
||||
Convey("Refund", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(0)
|
||||
database.Truncate(ctx, s.DB,
|
||||
models.TableNameTenant, models.TableNameUser, models.TableNameOrder,
|
||||
models.TableNameOrderItem, models.TableNameContentAccess, models.TableNameTenantLedger,
|
||||
@@ -330,6 +342,7 @@ func (s *CreatorTestSuite) Test_Refund() {
|
||||
// Tenant
|
||||
t := &models.Tenant{UserID: creator.ID, Name: "Channel 7", Code: "128", Status: consts.TenantStatusVerified}
|
||||
models.TenantQuery.WithContext(ctx).Create(t)
|
||||
tenantID = t.ID
|
||||
|
||||
// Buyer
|
||||
buyer := &models.User{Username: "buyer7", Phone: "13900000007", Balance: 0}
|
||||
@@ -349,7 +362,7 @@ func (s *CreatorTestSuite) Test_Refund() {
|
||||
|
||||
Convey("should accept refund", func() {
|
||||
form := &creator_dto.RefundForm{Action: "accept", Reason: "Defective"}
|
||||
err := Creator.ProcessRefund(ctx, creator.ID, o.ID, form)
|
||||
err := Creator.ProcessRefund(ctx, tenantID, creator.ID, o.ID, form)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Verify Order
|
||||
|
||||
@@ -21,14 +21,19 @@ import (
|
||||
// @provider
|
||||
type order struct{}
|
||||
|
||||
func (s *order) ListUserOrders(ctx context.Context, userID int64, status string) ([]user_dto.Order, error) {
|
||||
func (s *order) ListUserOrders(ctx context.Context, tenantID, userID int64, status string) ([]user_dto.Order, error) {
|
||||
if userID == 0 {
|
||||
return nil, errorx.ErrUnauthorized
|
||||
}
|
||||
uid := userID
|
||||
|
||||
tbl, q := models.OrderQuery.QueryContext(ctx)
|
||||
q = q.Where(tbl.UserID.Eq(uid))
|
||||
if tenantID > 0 {
|
||||
q = q.Where(tbl.UserID.Eq(uid), tbl.TenantID.Eq(tenantID)).
|
||||
Or(tbl.UserID.Eq(uid), tbl.Type.Eq(consts.OrderTypeRecharge))
|
||||
} else {
|
||||
q = q.Where(tbl.UserID.Eq(uid))
|
||||
}
|
||||
|
||||
if status != "" && status != "all" {
|
||||
q = q.Where(tbl.Status.Eq(consts.OrderStatus(status)))
|
||||
@@ -46,14 +51,21 @@ func (s *order) ListUserOrders(ctx context.Context, userID int64, status string)
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (s *order) GetUserOrder(ctx context.Context, userID, id int64) (*user_dto.Order, error) {
|
||||
func (s *order) GetUserOrder(ctx context.Context, tenantID, userID, id int64) (*user_dto.Order, error) {
|
||||
if userID == 0 {
|
||||
return nil, errorx.ErrUnauthorized
|
||||
}
|
||||
uid := userID
|
||||
|
||||
tbl, q := models.OrderQuery.QueryContext(ctx)
|
||||
item, err := q.Where(tbl.ID.Eq(id), tbl.UserID.Eq(uid)).First()
|
||||
itemQuery := q
|
||||
if tenantID > 0 {
|
||||
itemQuery = itemQuery.Where(tbl.ID.Eq(id), tbl.UserID.Eq(uid), tbl.TenantID.Eq(tenantID)).
|
||||
Or(tbl.ID.Eq(id), tbl.UserID.Eq(uid), tbl.Type.Eq(consts.OrderTypeRecharge))
|
||||
} else {
|
||||
itemQuery = itemQuery.Where(tbl.ID.Eq(id), tbl.UserID.Eq(uid))
|
||||
}
|
||||
item, err := itemQuery.First()
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errorx.ErrRecordNotFound
|
||||
@@ -70,6 +82,7 @@ func (s *order) GetUserOrder(ctx context.Context, userID, id int64) (*user_dto.O
|
||||
|
||||
func (s *order) Create(
|
||||
ctx context.Context,
|
||||
tenantID int64,
|
||||
userID int64,
|
||||
form *transaction_dto.OrderCreateForm,
|
||||
) (*transaction_dto.OrderCreateResponse, error) {
|
||||
@@ -86,7 +99,11 @@ func (s *order) Create(
|
||||
}
|
||||
if idempotencyKey != "" {
|
||||
tbl, q := models.OrderQuery.QueryContext(ctx)
|
||||
existing, err := q.Where(tbl.UserID.Eq(uid), tbl.IdempotencyKey.Eq(idempotencyKey)).First()
|
||||
q = q.Where(tbl.UserID.Eq(uid), tbl.IdempotencyKey.Eq(idempotencyKey))
|
||||
if tenantID > 0 {
|
||||
q = q.Where(tbl.TenantID.Eq(tenantID))
|
||||
}
|
||||
existing, err := q.First()
|
||||
if err == nil {
|
||||
return &transaction_dto.OrderCreateResponse{OrderID: existing.ID}, nil
|
||||
}
|
||||
@@ -96,7 +113,11 @@ func (s *order) Create(
|
||||
}
|
||||
|
||||
// 1. Fetch Content & Price
|
||||
content, err := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.ID.Eq(cid)).First()
|
||||
contentQuery := models.ContentQuery.WithContext(ctx).Where(models.ContentQuery.ID.Eq(cid))
|
||||
if tenantID > 0 {
|
||||
contentQuery = contentQuery.Where(models.ContentQuery.TenantID.Eq(tenantID))
|
||||
}
|
||||
content, err := contentQuery.First()
|
||||
if err != nil {
|
||||
return nil, errorx.ErrRecordNotFound.WithMsg("内容不存在")
|
||||
}
|
||||
@@ -188,6 +209,7 @@ func (s *order) Create(
|
||||
|
||||
func (s *order) Pay(
|
||||
ctx context.Context,
|
||||
tenantID int64,
|
||||
userID int64,
|
||||
id int64,
|
||||
form *transaction_dto.OrderPayForm,
|
||||
@@ -204,6 +226,9 @@ func (s *order) Pay(
|
||||
if err != nil {
|
||||
return nil, errorx.ErrRecordNotFound
|
||||
}
|
||||
if tenantID > 0 && o.TenantID > 0 && o.TenantID != tenantID {
|
||||
return nil, errorx.ErrForbidden.WithMsg("租户不匹配")
|
||||
}
|
||||
if o.Status != consts.OrderStatusCreated {
|
||||
return nil, errorx.ErrStatusConflict.WithMsg("订单状态不可支付")
|
||||
}
|
||||
@@ -219,11 +244,14 @@ func (s *order) Pay(
|
||||
}
|
||||
|
||||
// ProcessExternalPayment handles callback from payment gateway
|
||||
func (s *order) ProcessExternalPayment(ctx context.Context, orderID int64, externalID string) error {
|
||||
func (s *order) ProcessExternalPayment(ctx context.Context, tenantID, orderID int64, externalID string) error {
|
||||
o, err := models.OrderQuery.WithContext(ctx).Where(models.OrderQuery.ID.Eq(orderID)).First()
|
||||
if err != nil {
|
||||
return errorx.ErrRecordNotFound
|
||||
}
|
||||
if tenantID > 0 && o.TenantID > 0 && o.TenantID != tenantID {
|
||||
return errorx.ErrForbidden.WithMsg("租户不匹配")
|
||||
}
|
||||
if o.Status != consts.OrderStatusCreated {
|
||||
return nil // Already processed idempotency
|
||||
}
|
||||
@@ -365,7 +393,7 @@ func (s *order) settleOrder(ctx context.Context, o *models.Order, method, extern
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *order) Status(ctx context.Context, id int64) (*transaction_dto.OrderStatusResponse, error) {
|
||||
func (s *order) Status(ctx context.Context, tenantID, id int64) (*transaction_dto.OrderStatusResponse, error) {
|
||||
o, err := models.OrderQuery.WithContext(ctx).Where(models.OrderQuery.ID.Eq(id)).First()
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
@@ -373,6 +401,9 @@ func (s *order) Status(ctx context.Context, id int64) (*transaction_dto.OrderSta
|
||||
}
|
||||
return nil, errorx.ErrDatabaseError.WithCause(err)
|
||||
}
|
||||
if tenantID > 0 && o.TenantID > 0 && o.TenantID != tenantID {
|
||||
return nil, errorx.ErrForbidden.WithMsg("租户不匹配")
|
||||
}
|
||||
|
||||
return &transaction_dto.OrderStatusResponse{
|
||||
Status: string(o.Status),
|
||||
|
||||
@@ -39,6 +39,7 @@ func Test_Order(t *testing.T) {
|
||||
func (s *OrderTestSuite) Test_PurchaseFlow() {
|
||||
Convey("Purchase Flow", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(0)
|
||||
database.Truncate(ctx, s.DB,
|
||||
models.TableNameOrder, models.TableNameOrderItem, models.TableNameUser,
|
||||
models.TableNameContent, models.TableNameContentPrice, models.TableNameTenant,
|
||||
@@ -57,6 +58,7 @@ func (s *OrderTestSuite) Test_PurchaseFlow() {
|
||||
Status: consts.TenantStatusVerified,
|
||||
}
|
||||
models.TenantQuery.WithContext(ctx).Create(tenant)
|
||||
tenantID = tenant.ID
|
||||
// Content
|
||||
content := &models.Content{
|
||||
TenantID: tenant.ID,
|
||||
@@ -83,7 +85,7 @@ func (s *OrderTestSuite) Test_PurchaseFlow() {
|
||||
Convey("should create and pay order successfully", func() {
|
||||
// Step 1: Create Order
|
||||
form := &order_dto.OrderCreateForm{ContentID: content.ID}
|
||||
createRes, err := Order.Create(ctx, buyer.ID, form)
|
||||
createRes, err := Order.Create(ctx, tenantID, buyer.ID, form)
|
||||
So(err, ShouldBeNil)
|
||||
So(createRes.OrderID, ShouldNotBeEmpty)
|
||||
|
||||
@@ -95,7 +97,7 @@ func (s *OrderTestSuite) Test_PurchaseFlow() {
|
||||
|
||||
// Step 2: Pay Order
|
||||
payForm := &order_dto.OrderPayForm{Method: "balance"}
|
||||
_, err = Order.Pay(ctx, buyer.ID, createRes.OrderID, payForm)
|
||||
_, err = Order.Pay(ctx, tenantID, buyer.ID, createRes.OrderID, payForm)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Verify Order Paid
|
||||
@@ -130,11 +132,11 @@ func (s *OrderTestSuite) Test_PurchaseFlow() {
|
||||
Update(models.UserQuery.Balance, 500)
|
||||
|
||||
form := &order_dto.OrderCreateForm{ContentID: content.ID}
|
||||
createRes, err := Order.Create(ctx, buyer.ID, form)
|
||||
createRes, err := Order.Create(ctx, tenantID, buyer.ID, form)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
payForm := &order_dto.OrderPayForm{Method: "balance"}
|
||||
_, err = Order.Pay(ctx, buyer.ID, createRes.OrderID, payForm)
|
||||
_, err = Order.Pay(ctx, tenantID, buyer.ID, createRes.OrderID, payForm)
|
||||
So(err, ShouldNotBeNil)
|
||||
// Error should be QuotaExceeded or similar
|
||||
})
|
||||
@@ -144,6 +146,7 @@ func (s *OrderTestSuite) Test_PurchaseFlow() {
|
||||
func (s *OrderTestSuite) Test_OrderDetails() {
|
||||
Convey("Order Details", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(0)
|
||||
database.Truncate(
|
||||
ctx,
|
||||
s.DB,
|
||||
@@ -164,6 +167,7 @@ func (s *OrderTestSuite) Test_OrderDetails() {
|
||||
models.UserQuery.WithContext(ctx).Create(creator)
|
||||
tenant := &models.Tenant{UserID: creator.ID, Name: "Best Shop", Status: consts.TenantStatusVerified}
|
||||
models.TenantQuery.WithContext(ctx).Create(tenant)
|
||||
tenantID = tenant.ID
|
||||
content := &models.Content{
|
||||
TenantID: tenant.ID,
|
||||
UserID: creator.ID,
|
||||
@@ -199,13 +203,14 @@ func (s *OrderTestSuite) Test_OrderDetails() {
|
||||
// Create & Pay
|
||||
createRes, _ := Order.Create(
|
||||
ctx,
|
||||
tenantID,
|
||||
buyer.ID,
|
||||
&order_dto.OrderCreateForm{ContentID: content.ID},
|
||||
)
|
||||
Order.Pay(ctx, buyer.ID, createRes.OrderID, &order_dto.OrderPayForm{Method: "balance"})
|
||||
Order.Pay(ctx, tenantID, buyer.ID, createRes.OrderID, &order_dto.OrderPayForm{Method: "balance"})
|
||||
|
||||
// Get Detail
|
||||
detail, err := Order.GetUserOrder(ctx, buyer.ID, createRes.OrderID)
|
||||
detail, err := Order.GetUserOrder(ctx, tenantID, buyer.ID, createRes.OrderID)
|
||||
So(err, ShouldBeNil)
|
||||
So(detail.TenantName, ShouldEqual, "Best Shop")
|
||||
So(len(detail.Items), ShouldEqual, 1)
|
||||
@@ -219,6 +224,7 @@ func (s *OrderTestSuite) Test_OrderDetails() {
|
||||
func (s *OrderTestSuite) Test_PlatformCommission() {
|
||||
Convey("Platform Commission", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(0)
|
||||
database.Truncate(
|
||||
ctx,
|
||||
s.DB,
|
||||
@@ -236,6 +242,7 @@ func (s *OrderTestSuite) Test_PlatformCommission() {
|
||||
// Tenant
|
||||
t := &models.Tenant{UserID: creator.ID, Name: "Shop C", Status: consts.TenantStatusVerified}
|
||||
models.TenantQuery.WithContext(ctx).Create(t)
|
||||
tenantID = t.ID
|
||||
// Buyer
|
||||
buyer := &models.User{Username: "buyer_c", Balance: 2000}
|
||||
models.UserQuery.WithContext(ctx).Create(buyer)
|
||||
@@ -253,7 +260,7 @@ func (s *OrderTestSuite) Test_PlatformCommission() {
|
||||
|
||||
Convey("should deduct 10% fee", func() {
|
||||
payForm := &order_dto.OrderPayForm{Method: "balance"}
|
||||
_, err := Order.Pay(ctx, buyer.ID, o.ID, payForm)
|
||||
_, err := Order.Pay(ctx, tenantID, buyer.ID, o.ID, payForm)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Verify Creator Balance (1000 - 10% = 900)
|
||||
@@ -270,6 +277,7 @@ func (s *OrderTestSuite) Test_PlatformCommission() {
|
||||
func (s *OrderTestSuite) Test_ExternalPayment() {
|
||||
Convey("External Payment", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(0)
|
||||
database.Truncate(
|
||||
ctx,
|
||||
s.DB,
|
||||
@@ -287,6 +295,7 @@ func (s *OrderTestSuite) Test_ExternalPayment() {
|
||||
// Tenant
|
||||
t := &models.Tenant{UserID: creator.ID, Name: "Shop Ext", Status: consts.TenantStatusVerified}
|
||||
models.TenantQuery.WithContext(ctx).Create(t)
|
||||
tenantID = t.ID
|
||||
// Buyer (Balance 0)
|
||||
buyer := &models.User{Username: "buyer_ext", Balance: 0}
|
||||
models.UserQuery.WithContext(ctx).Create(buyer)
|
||||
@@ -302,7 +311,7 @@ func (s *OrderTestSuite) Test_ExternalPayment() {
|
||||
models.OrderItemQuery.WithContext(ctx).Create(&models.OrderItem{OrderID: o.ID, ContentID: 999})
|
||||
|
||||
Convey("should process external payment callback", func() {
|
||||
err := Order.ProcessExternalPayment(ctx, o.ID, "ext_tx_id_123")
|
||||
err := Order.ProcessExternalPayment(ctx, tenantID, o.ID, "ext_tx_id_123")
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Verify Status
|
||||
|
||||
@@ -704,7 +704,7 @@ func (s *super) RefundOrder(ctx context.Context, id int64, form *super_dto.Super
|
||||
return errorx.ErrRecordNotFound.WithMsg("租户不存在")
|
||||
}
|
||||
|
||||
return Creator.ProcessRefund(ctx, t.UserID, id, &v1_dto.RefundForm{
|
||||
return Creator.ProcessRefund(ctx, t.ID, t.UserID, id, &v1_dto.RefundForm{
|
||||
Action: "accept",
|
||||
Reason: form.Reason,
|
||||
})
|
||||
|
||||
@@ -17,9 +17,12 @@ import (
|
||||
// @provider
|
||||
type tenant struct{}
|
||||
|
||||
func (s *tenant) List(ctx context.Context, filter *dto.TenantListFilter) (*requests.Pager, error) {
|
||||
func (s *tenant) List(ctx context.Context, tenantID int64, filter *dto.TenantListFilter) (*requests.Pager, error) {
|
||||
tbl, q := models.TenantQuery.QueryContext(ctx)
|
||||
q = q.Where(tbl.Status.Eq(consts.TenantStatusVerified))
|
||||
if tenantID > 0 {
|
||||
q = q.Where(tbl.ID.Eq(tenantID))
|
||||
}
|
||||
|
||||
if filter.Keyword != nil && *filter.Keyword != "" {
|
||||
q = q.Where(tbl.Name.Like("%" + *filter.Keyword + "%"))
|
||||
@@ -73,8 +76,8 @@ func (s *tenant) List(ctx context.Context, filter *dto.TenantListFilter) (*reque
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *tenant) GetPublicProfile(ctx context.Context, userID, id int64) (*dto.TenantProfile, error) {
|
||||
t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.ID.Eq(id)).First()
|
||||
func (s *tenant) GetPublicProfile(ctx context.Context, tenantID, userID int64) (*dto.TenantProfile, error) {
|
||||
t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.ID.Eq(tenantID)).First()
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errorx.ErrRecordNotFound
|
||||
@@ -83,9 +86,9 @@ func (s *tenant) GetPublicProfile(ctx context.Context, userID, id int64) (*dto.T
|
||||
}
|
||||
|
||||
// Stats
|
||||
followers, _ := models.TenantUserQuery.WithContext(ctx).Where(models.TenantUserQuery.TenantID.Eq(id)).Count()
|
||||
followers, _ := models.TenantUserQuery.WithContext(ctx).Where(models.TenantUserQuery.TenantID.Eq(tenantID)).Count()
|
||||
contents, _ := models.ContentQuery.WithContext(ctx).
|
||||
Where(models.ContentQuery.TenantID.Eq(id), models.ContentQuery.Status.Eq(consts.ContentStatusPublished)).
|
||||
Where(models.ContentQuery.TenantID.Eq(tenantID), models.ContentQuery.Status.Eq(consts.ContentStatusPublished)).
|
||||
Count()
|
||||
|
||||
// Following status
|
||||
@@ -93,7 +96,7 @@ func (s *tenant) GetPublicProfile(ctx context.Context, userID, id int64) (*dto.T
|
||||
if userID > 0 {
|
||||
uid := userID
|
||||
isFollowing, _ = models.TenantUserQuery.WithContext(ctx).
|
||||
Where(models.TenantUserQuery.TenantID.Eq(id), models.TenantUserQuery.UserID.Eq(uid)).
|
||||
Where(models.TenantUserQuery.TenantID.Eq(tenantID), models.TenantUserQuery.UserID.Eq(uid)).
|
||||
Exists()
|
||||
}
|
||||
|
||||
@@ -113,20 +116,20 @@ func (s *tenant) GetPublicProfile(ctx context.Context, userID, id int64) (*dto.T
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *tenant) Follow(ctx context.Context, userID, id int64) error {
|
||||
func (s *tenant) Follow(ctx context.Context, tenantID, userID int64) error {
|
||||
if userID == 0 {
|
||||
return errorx.ErrUnauthorized
|
||||
}
|
||||
uid := userID
|
||||
|
||||
// Check if tenant exists
|
||||
t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.ID.Eq(id)).First()
|
||||
t, err := models.TenantQuery.WithContext(ctx).Where(models.TenantQuery.ID.Eq(tenantID)).First()
|
||||
if err != nil {
|
||||
return errorx.ErrRecordNotFound
|
||||
}
|
||||
|
||||
tu := &models.TenantUser{
|
||||
TenantID: id,
|
||||
TenantID: tenantID,
|
||||
UserID: uid,
|
||||
Role: types.Array[consts.TenantUserRole]{consts.TenantUserRoleMember},
|
||||
Status: consts.UserStatusVerified,
|
||||
@@ -142,14 +145,14 @@ func (s *tenant) Follow(ctx context.Context, userID, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *tenant) Unfollow(ctx context.Context, userID, id int64) error {
|
||||
func (s *tenant) Unfollow(ctx context.Context, tenantID, userID int64) error {
|
||||
if userID == 0 {
|
||||
return errorx.ErrUnauthorized
|
||||
}
|
||||
uid := userID
|
||||
|
||||
_, err := models.TenantUserQuery.WithContext(ctx).
|
||||
Where(models.TenantUserQuery.TenantID.Eq(id), models.TenantUserQuery.UserID.Eq(uid)).
|
||||
Where(models.TenantUserQuery.TenantID.Eq(tenantID), models.TenantUserQuery.UserID.Eq(uid)).
|
||||
Delete()
|
||||
if err != nil {
|
||||
return errorx.ErrDatabaseError.WithCause(err)
|
||||
@@ -157,14 +160,18 @@ func (s *tenant) Unfollow(ctx context.Context, userID, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *tenant) ListFollowed(ctx context.Context, userID int64) ([]dto.TenantProfile, error) {
|
||||
func (s *tenant) ListFollowed(ctx context.Context, tenantID, userID int64) ([]dto.TenantProfile, error) {
|
||||
if userID == 0 {
|
||||
return nil, errorx.ErrUnauthorized
|
||||
}
|
||||
uid := userID
|
||||
|
||||
tbl, q := models.TenantUserQuery.QueryContext(ctx)
|
||||
list, err := q.Where(tbl.UserID.Eq(uid)).Find()
|
||||
q = q.Where(tbl.UserID.Eq(uid))
|
||||
if tenantID > 0 {
|
||||
q = q.Where(tbl.TenantID.Eq(tenantID))
|
||||
}
|
||||
list, err := q.Find()
|
||||
if err != nil {
|
||||
return nil, errorx.ErrDatabaseError.WithCause(err)
|
||||
}
|
||||
|
||||
@@ -39,6 +39,7 @@ func Test_Tenant(t *testing.T) {
|
||||
func (s *TenantTestSuite) Test_Follow() {
|
||||
Convey("Follow Flow", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(0)
|
||||
database.Truncate(ctx, s.DB, models.TableNameTenant, models.TableNameTenantUser, models.TableNameUser)
|
||||
|
||||
// User
|
||||
@@ -49,29 +50,30 @@ func (s *TenantTestSuite) Test_Follow() {
|
||||
// Tenant
|
||||
t := &models.Tenant{Name: "Tenant A", Status: consts.TenantStatusVerified}
|
||||
models.TenantQuery.WithContext(ctx).Create(t)
|
||||
tenantID = t.ID
|
||||
|
||||
Convey("should follow tenant", func() {
|
||||
err := Tenant.Follow(ctx, u.ID, t.ID)
|
||||
err := Tenant.Follow(ctx, tenantID, u.ID)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Verify stats
|
||||
profile, err := Tenant.GetPublicProfile(ctx, u.ID, t.ID)
|
||||
profile, err := Tenant.GetPublicProfile(ctx, tenantID, u.ID)
|
||||
So(err, ShouldBeNil)
|
||||
So(profile.IsFollowing, ShouldBeTrue)
|
||||
So(profile.Stats.Followers, ShouldEqual, 1)
|
||||
|
||||
// List Followed
|
||||
list, err := Tenant.ListFollowed(ctx, u.ID)
|
||||
list, err := Tenant.ListFollowed(ctx, tenantID, u.ID)
|
||||
So(err, ShouldBeNil)
|
||||
So(len(list), ShouldEqual, 1)
|
||||
So(list[0].Name, ShouldEqual, "Tenant A")
|
||||
|
||||
// Unfollow
|
||||
err = Tenant.Unfollow(ctx, u.ID, t.ID)
|
||||
err = Tenant.Unfollow(ctx, tenantID, u.ID)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Verify
|
||||
profile, err = Tenant.GetPublicProfile(ctx, u.ID, t.ID)
|
||||
profile, err = Tenant.GetPublicProfile(ctx, tenantID, u.ID)
|
||||
So(err, ShouldBeNil)
|
||||
So(profile.IsFollowing, ShouldBeFalse)
|
||||
So(profile.Stats.Followers, ShouldEqual, 0)
|
||||
|
||||
@@ -31,7 +31,7 @@ func (s *user) SendOTP(ctx context.Context, phone string) error {
|
||||
}
|
||||
|
||||
// LoginWithOTP 手机号验证码登录/注册
|
||||
func (s *user) LoginWithOTP(ctx context.Context, phone, otp string) (*auth_dto.LoginResponse, error) {
|
||||
func (s *user) LoginWithOTP(ctx context.Context, tenantID int64, phone, otp string) (*auth_dto.LoginResponse, error) {
|
||||
// 1. 校验验证码 (模拟:固定 123456)
|
||||
if otp != "1234" {
|
||||
return nil, errorx.ErrInvalidCredentials.WithMsg("验证码错误")
|
||||
@@ -67,8 +67,8 @@ func (s *user) LoginWithOTP(ctx context.Context, phone, otp string) (*auth_dto.L
|
||||
|
||||
// 4. 生成 Token
|
||||
token, err := s.jwt.CreateToken(s.jwt.CreateClaims(jwt.BaseClaims{
|
||||
UserID: u.ID,
|
||||
// TenantID: 0, // 初始登录无租户上下文
|
||||
UserID: u.ID,
|
||||
TenantID: tenantID,
|
||||
}))
|
||||
if err != nil {
|
||||
return nil, errorx.ErrInternalError.WithMsg("生成令牌失败")
|
||||
|
||||
@@ -40,11 +40,12 @@ func Test_User(t *testing.T) {
|
||||
func (s *UserTestSuite) Test_LoginWithOTP() {
|
||||
Convey("LoginWithOTP", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(1)
|
||||
database.Truncate(ctx, s.DB, models.TableNameUser)
|
||||
|
||||
Convey("should create user and login success with correct OTP", func() {
|
||||
phone := "13800138000"
|
||||
resp, err := User.LoginWithOTP(ctx, phone, "1234")
|
||||
resp, err := User.LoginWithOTP(ctx, tenantID, phone, "1234")
|
||||
So(err, ShouldBeNil)
|
||||
So(resp, ShouldNotBeNil)
|
||||
So(resp.Token, ShouldNotBeEmpty)
|
||||
@@ -55,17 +56,17 @@ func (s *UserTestSuite) Test_LoginWithOTP() {
|
||||
Convey("should login existing user", func() {
|
||||
phone := "13800138001"
|
||||
// Pre-create user
|
||||
_, err := User.LoginWithOTP(ctx, phone, "1234")
|
||||
_, err := User.LoginWithOTP(ctx, tenantID, phone, "1234")
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
// Login again
|
||||
resp, err := User.LoginWithOTP(ctx, phone, "1234")
|
||||
resp, err := User.LoginWithOTP(ctx, tenantID, phone, "1234")
|
||||
So(err, ShouldBeNil)
|
||||
So(resp.User.Phone, ShouldEqual, phone)
|
||||
})
|
||||
|
||||
Convey("should fail with incorrect OTP", func() {
|
||||
resp, err := User.LoginWithOTP(ctx, "13800138002", "000000")
|
||||
resp, err := User.LoginWithOTP(ctx, tenantID, "13800138002", "000000")
|
||||
So(err, ShouldNotBeNil)
|
||||
So(resp, ShouldBeNil)
|
||||
})
|
||||
@@ -75,11 +76,12 @@ func (s *UserTestSuite) Test_LoginWithOTP() {
|
||||
func (s *UserTestSuite) Test_Me() {
|
||||
Convey("Me", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(1)
|
||||
database.Truncate(ctx, s.DB, models.TableNameUser)
|
||||
|
||||
// Create user
|
||||
phone := "13800138003"
|
||||
resp, _ := User.LoginWithOTP(ctx, phone, "1234")
|
||||
resp, _ := User.LoginWithOTP(ctx, tenantID, phone, "1234")
|
||||
userID := resp.User.ID
|
||||
|
||||
Convey("should return user profile", func() {
|
||||
@@ -104,10 +106,11 @@ func (s *UserTestSuite) Test_Me() {
|
||||
func (s *UserTestSuite) Test_Update() {
|
||||
Convey("Update", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(1)
|
||||
database.Truncate(ctx, s.DB, models.TableNameUser)
|
||||
|
||||
phone := "13800138004"
|
||||
resp, _ := User.LoginWithOTP(ctx, phone, "1234")
|
||||
resp, _ := User.LoginWithOTP(ctx, tenantID, phone, "1234")
|
||||
userID := resp.User.ID
|
||||
ctx = context.WithValue(ctx, consts.CtxKeyUser, userID)
|
||||
|
||||
@@ -132,10 +135,11 @@ func (s *UserTestSuite) Test_Update() {
|
||||
func (s *UserTestSuite) Test_RealName() {
|
||||
Convey("RealName", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(1)
|
||||
database.Truncate(ctx, s.DB, models.TableNameUser)
|
||||
|
||||
phone := "13800138005"
|
||||
resp, _ := User.LoginWithOTP(ctx, phone, "1234")
|
||||
resp, _ := User.LoginWithOTP(ctx, tenantID, phone, "1234")
|
||||
userID := resp.User.ID
|
||||
ctx = context.WithValue(ctx, consts.CtxKeyUser, userID)
|
||||
|
||||
@@ -157,10 +161,11 @@ func (s *UserTestSuite) Test_RealName() {
|
||||
func (s *UserTestSuite) Test_GetNotifications() {
|
||||
Convey("GetNotifications", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(1)
|
||||
database.Truncate(ctx, s.DB, models.TableNameUser, models.TableNameNotification)
|
||||
|
||||
phone := "13800138006"
|
||||
resp, _ := User.LoginWithOTP(ctx, phone, "1234")
|
||||
resp, _ := User.LoginWithOTP(ctx, tenantID, phone, "1234")
|
||||
userID := resp.User.ID
|
||||
ctx = context.WithValue(ctx, consts.CtxKeyUser, userID)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
// @provider
|
||||
type wallet struct{}
|
||||
|
||||
func (s *wallet) GetWallet(ctx context.Context, userID int64) (*user_dto.WalletResponse, error) {
|
||||
func (s *wallet) GetWallet(ctx context.Context, tenantID, userID int64) (*user_dto.WalletResponse, error) {
|
||||
// Get Balance
|
||||
u, err := models.UserQuery.WithContext(ctx).Where(models.UserQuery.ID.Eq(userID)).First()
|
||||
if err != nil {
|
||||
@@ -33,7 +33,13 @@ func (s *wallet) GetWallet(ctx context.Context, userID int64) (*user_dto.WalletR
|
||||
// Get Transactions (Orders)
|
||||
// Both purchase (expense) and recharge (income - if paid)
|
||||
tbl, q := models.OrderQuery.QueryContext(ctx)
|
||||
orders, err := q.Where(tbl.UserID.Eq(userID), tbl.Status.Eq(consts.OrderStatusPaid)).
|
||||
if tenantID > 0 {
|
||||
q = q.Where(tbl.UserID.Eq(userID), tbl.Status.Eq(consts.OrderStatusPaid), tbl.TenantID.Eq(tenantID)).
|
||||
Or(tbl.UserID.Eq(userID), tbl.Status.Eq(consts.OrderStatusPaid), tbl.Type.Eq(consts.OrderTypeRecharge))
|
||||
} else {
|
||||
q = q.Where(tbl.UserID.Eq(userID), tbl.Status.Eq(consts.OrderStatusPaid))
|
||||
}
|
||||
orders, err := q.
|
||||
Order(tbl.CreatedAt.Desc()).
|
||||
Limit(20). // Limit to recent 20
|
||||
Find()
|
||||
@@ -71,6 +77,7 @@ func (s *wallet) GetWallet(ctx context.Context, userID int64) (*user_dto.WalletR
|
||||
|
||||
func (s *wallet) Recharge(
|
||||
ctx context.Context,
|
||||
tenantID int64,
|
||||
userID int64,
|
||||
form *user_dto.RechargeForm,
|
||||
) (*user_dto.RechargeResponse, error) {
|
||||
@@ -98,7 +105,7 @@ func (s *wallet) Recharge(
|
||||
|
||||
// MOCK: Automatically pay for recharge order to close the loop
|
||||
// In production, this would be a callback from payment gateway
|
||||
if err := Order.ProcessExternalPayment(ctx, order.ID, "mock_auto_pay"); err != nil {
|
||||
if err := Order.ProcessExternalPayment(ctx, tenantID, order.ID, "mock_auto_pay"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ func Test_Wallet(t *testing.T) {
|
||||
func (s *WalletTestSuite) Test_GetWallet() {
|
||||
Convey("GetWallet", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(1)
|
||||
database.Truncate(ctx, s.DB, models.TableNameUser, models.TableNameOrder)
|
||||
|
||||
u := &models.User{Username: "wallet_user", Balance: 5000} // 50.00
|
||||
@@ -58,7 +59,7 @@ func (s *WalletTestSuite) Test_GetWallet() {
|
||||
models.OrderQuery.WithContext(ctx).Create(o1, o2)
|
||||
|
||||
Convey("should return balance and transactions", func() {
|
||||
res, err := Wallet.GetWallet(ctx, u.ID)
|
||||
res, err := Wallet.GetWallet(ctx, tenantID, u.ID)
|
||||
So(err, ShouldBeNil)
|
||||
So(res.Balance, ShouldEqual, 50.0)
|
||||
So(len(res.Transactions), ShouldEqual, 2)
|
||||
@@ -74,6 +75,7 @@ func (s *WalletTestSuite) Test_GetWallet() {
|
||||
func (s *WalletTestSuite) Test_Recharge() {
|
||||
Convey("Recharge", s.T(), func() {
|
||||
ctx := s.T().Context()
|
||||
tenantID := int64(1)
|
||||
database.Truncate(ctx, s.DB, models.TableNameUser, models.TableNameOrder)
|
||||
|
||||
u := &models.User{Username: "recharge_user"}
|
||||
@@ -82,7 +84,7 @@ func (s *WalletTestSuite) Test_Recharge() {
|
||||
|
||||
Convey("should create recharge order", func() {
|
||||
form := &user_dto.RechargeForm{Amount: 100.0}
|
||||
res, err := Wallet.Recharge(ctx, u.ID, form)
|
||||
res, err := Wallet.Recharge(ctx, tenantID, u.ID, form)
|
||||
So(err, ShouldBeNil)
|
||||
So(res.OrderID, ShouldNotBeEmpty)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user