Files
quyun/backend/app/models/posts.go
2025-05-06 16:14:36 +08:00

405 lines
9.3 KiB
Go

package models
import (
"context"
"errors"
"time"
"quyun/app/requests"
"quyun/database/conds"
"quyun/database/schemas/public/model"
"quyun/database/schemas/public/table"
. "github.com/go-jet/jet/v2/postgres"
"github.com/go-jet/jet/v2/qrm"
"github.com/samber/lo"
"github.com/sirupsen/logrus"
)
// @provider
type postsModel struct {
log *logrus.Entry `inject:"false"`
}
func (m *postsModel) Prepare() error {
m.log = logrus.WithField("model", "postsModel")
return nil
}
func (m *postsModel) IncrViewCount(ctx context.Context, id int64) error {
tbl := table.Posts
stmt := tbl.UPDATE(tbl.Views).SET(tbl.Views.ADD(Int64(1))).WHERE(tbl.ID.EQ(Int64(id)))
m.log.Infof("sql: %s", stmt.DebugSql())
var post model.Posts
err := stmt.QueryContext(ctx, db, &post)
if err != nil {
m.log.Errorf("error updating post view count: %v", err)
return err
}
return nil
}
// GetByID
func (m *postsModel) GetByID(ctx context.Context, id int64, cond ...conds.Cond) (*model.Posts, error) {
tbl := table.Posts
var combinedCond BoolExpression = tbl.ID.EQ(Int64(id))
for _, c := range cond {
combinedCond = c(combinedCond)
}
stmt := tbl.SELECT(tbl.AllColumns).WHERE(combinedCond)
m.log.Infof("sql: %s", stmt.DebugSql())
var post model.Posts
err := stmt.QueryContext(ctx, db, &post)
if err != nil {
m.log.Errorf("error getting post: %v", err)
return nil, err
}
return &post, nil
}
// Create
func (m *postsModel) Create(ctx context.Context, model *model.Posts) error {
model.CreatedAt = time.Now()
model.UpdatedAt = time.Now()
tbl := table.Posts
stmt := tbl.INSERT(tbl.MutableColumns).MODEL(model)
m.log.Infof("sql: %s", stmt.DebugSql())
_, err := stmt.ExecContext(ctx, db)
if err != nil {
m.log.Errorf("error creating post: %v", err)
return err
}
return nil
}
// Update
func (m *postsModel) Update(ctx context.Context, id int64, model *model.Posts) error {
model.UpdatedAt = time.Now()
tbl := table.Posts
stmt := tbl.UPDATE(tbl.MutableColumns.Except(tbl.CreatedAt, tbl.DeletedAt)).MODEL(model).WHERE(tbl.ID.EQ(Int64(id)))
m.log.Infof("sql: %s", stmt.DebugSql())
_, err := stmt.ExecContext(ctx, db)
if err != nil {
m.log.Errorf("error updating post: %v", err)
return err
}
return nil
}
// countByCond
func (m *postsModel) countByCondition(ctx context.Context, expr BoolExpression) (int64, error) {
var cnt struct {
Cnt int64
}
tbl := table.Posts
stmt := SELECT(COUNT(tbl.ID).AS("cnt")).FROM(tbl).WHERE(expr)
m.log.Infof("sql: %s", stmt.DebugSql())
err := stmt.QueryContext(ctx, db, &cnt)
if err != nil {
m.log.Errorf("error counting post items: %v", err)
return 0, err
}
return cnt.Cnt, nil
}
func (m *postsModel) List(ctx context.Context, pagination *requests.Pagination, cond ...conds.Cond) (*requests.Pager, error) {
pagination.Format()
combinedCond := table.Posts.DeletedAt.IS_NULL()
for _, c := range cond {
combinedCond = c(combinedCond)
}
tbl := table.Posts
stmt := tbl.
SELECT(tbl.AllColumns).
WHERE(combinedCond).
ORDER_BY(tbl.ID.DESC()).
LIMIT(pagination.Limit).
OFFSET(pagination.Offset)
m.log.Infof("sql: %s", stmt.DebugSql())
var posts []model.Posts = make([]model.Posts, 0)
err := stmt.QueryContext(ctx, db, &posts)
if err != nil {
m.log.Errorf("error querying post items: %v", err)
return nil, err
}
count, err := m.countByCondition(ctx, combinedCond)
if err != nil {
m.log.Errorf("error getting post count: %v", err)
return nil, err
}
return &requests.Pager{
Items: posts,
Total: count,
Pagination: *pagination,
}, nil
}
func (m *postsModel) IsUserBought(ctx context.Context, userId, postId int64) (bool, error) {
tbl := table.UserPosts
stmt := tbl.
SELECT(tbl.ID).
WHERE(
tbl.UserID.EQ(Int64(userId)).AND(
tbl.PostID.EQ(Int64(postId)),
),
)
m.log.Infof("sql: %s", stmt.DebugSql())
var userPost model.UserPosts
err := stmt.QueryContext(ctx, db, &userPost)
if err != nil {
if errors.Is(err, qrm.ErrNoRows) {
return false, nil
}
m.log.Errorf("error querying user post item: %v", err)
return false, err
}
return userPost.ID > 0, nil
}
func (m *postsModel) Buy(ctx context.Context, userId, postId int64) error {
tbl := table.UserPosts
post, err := m.GetByID(ctx, postId)
if err != nil {
m.log.Errorf("error getting post by ID: %v", err)
return err
}
user, err := Users.GetByID(ctx, userId)
if err != nil {
m.log.Errorf("error getting user by ID: %v", err)
return err
}
record := model.UserPosts{
UserID: user.ID,
PostID: post.ID,
Price: post.Price * int64(post.Discount) / 100,
}
stmt := tbl.INSERT(tbl.MutableColumns).MODEL(record)
m.log.Infof("sql: %s", stmt.DebugSql())
if _, err := stmt.ExecContext(ctx, db); err != nil {
m.log.Errorf("error buying post: %v", err)
return err
}
return nil
}
// DeleteByID soft delete item
func (m *postsModel) DeleteByID(ctx context.Context, id int64) error {
tbl := table.Posts
stmt := tbl.
UPDATE(tbl.DeletedAt).
SET(TimestampT(time.Now())).
WHERE(
tbl.ID.EQ(Int64(id)),
)
m.log.Infof("sql: %s", stmt.DebugSql())
if _, err := stmt.ExecContext(ctx, db); err != nil {
m.log.Errorf("error deleting post: %v", err)
return err
}
return nil
}
// SendTo
func (m *postsModel) SendTo(ctx context.Context, userId, postId int64) error {
// add record to user_posts
tbl := table.UserPosts
stmt := tbl.INSERT(tbl.MutableColumns).MODEL(model.UserPosts{
UserID: userId,
PostID: postId,
})
m.log.Infof("sql: %s", stmt.DebugSql())
if _, err := stmt.ExecContext(ctx, db); err != nil {
m.log.Errorf("error sending post to user: %v", err)
return err
}
return nil
}
// PostBoughtStatistics 获取指定文件 ID 的购买次数
func (m *postsModel) BoughtStatistics(ctx context.Context, postIds []int64) (map[int64]int64, error) {
tbl := table.UserPosts
// select count(user_id), post_id from user_posts up where post_id in (1, 2,3,4,5,6,7,8,9,10) group by post_id
stmt := tbl.
SELECT(
COUNT(tbl.UserID).AS("cnt"),
tbl.PostID.AS("post_id"),
).
WHERE(
tbl.PostID.IN(lo.Map(postIds, func(id int64, _ int) Expression { return Int64(id) })...),
).
GROUP_BY(
tbl.PostID,
)
m.log.Infof("sql: %s", stmt.DebugSql())
var result []struct {
Cnt int64
PostId int64
}
if err := stmt.QueryContext(ctx, db, &result); err != nil {
m.log.Errorf("error getting post bought statistics: %v", err)
return nil, err
}
// convert to map
resultMap := make(map[int64]int64)
for _, item := range result {
resultMap[item.PostId] = item.Cnt
}
return resultMap, nil
}
// Bought
func (m *postsModel) Bought(ctx context.Context, userId int64, pagination *requests.Pagination) (*requests.Pager, error) {
pagination.Format()
// select up.price,up.created_at,p.* from user_posts up left join posts p on up.post_id = p.id where up.user_id =1
tbl := table.UserPosts
stmt := tbl.
SELECT(
tbl.Price.AS("price"),
tbl.CreatedAt.AS("bought_at"),
table.Posts.Title.AS("title"),
).
FROM(
tbl.INNER_JOIN(table.Posts, table.Posts.ID.EQ(tbl.PostID)),
).
WHERE(
tbl.UserID.EQ(Int64(1)),
).
ORDER_BY(tbl.ID.DESC()).
LIMIT(pagination.Limit).
OFFSET(pagination.Offset)
m.log.Infof("sql: %s", stmt.DebugSql())
var items []struct {
Title string `json:"title"`
Price int64 `json:"price"`
BoughtAt time.Time `json:"bought_at"`
}
if err := stmt.QueryContext(ctx, db, &items); err != nil {
m.log.Errorf("error getting bought posts: %v", err)
return nil, err
}
// convert to model.Posts
var cnt struct {
Cnt int64
}
stmtCnt := tbl.
SELECT(COUNT(tbl.ID).AS("cnt")).
WHERE(
tbl.UserID.EQ(Int64(userId)),
)
if err := stmtCnt.QueryContext(ctx, db, &cnt); err != nil {
m.log.Errorf("error getting bought posts count: %v", err)
return nil, err
}
return &requests.Pager{
Items: items,
Total: cnt.Cnt,
Pagination: *pagination,
}, nil
}
// GetPostsMapByIDs
func (m *postsModel) GetPostsMapByIDs(ctx context.Context, ids []int64) (map[int64]model.Posts, error) {
if len(ids) == 0 {
return nil, nil
}
tbl := table.Posts
stmt := tbl.
SELECT(tbl.AllColumns).
WHERE(
tbl.ID.IN(lo.Map(ids, func(id int64, _ int) Expression { return Int64(id) })...),
)
m.log.Infof("sql: %s", stmt.DebugSql())
var posts []model.Posts = make([]model.Posts, 0)
err := stmt.QueryContext(ctx, db, &posts)
if err != nil {
m.log.Errorf("error querying posts: %v", err)
return nil, err
}
return lo.SliceToMap(posts, func(item model.Posts) (int64, model.Posts) {
return item.ID, item
}), nil
}
// GetMediaByIds
func (m *postsModel) GetMediaByIds(ctx context.Context, ids []int64) ([]model.Medias, error) {
if len(ids) == 0 {
return nil, nil
}
tbl := table.Medias
stmt := tbl.
SELECT(tbl.AllColumns).
WHERE(
tbl.ID.IN(lo.Map(ids, func(id int64, _ int) Expression { return Int64(id) })...),
)
m.log.Infof("sql: %s", stmt.DebugSql())
var medias []model.Medias
if err := stmt.QueryContext(ctx, db, &medias); err != nil {
m.log.Errorf("error querying media: %v", err)
return nil, err
}
return medias, nil
}
// Count
func (m *postsModel) Count(ctx context.Context, cond BoolExpression) (int64, error) {
tbl := table.Posts
stmt := tbl.
SELECT(COUNT(tbl.ID).AS("count")).
WHERE(cond)
var count struct {
Count int64
}
if err := stmt.QueryContext(ctx, db, &count); err != nil {
m.log.Errorf("error counting posts: %v", err)
return 0, err
}
return count.Count, nil
}