package models import ( "context" "errors" "time" "quyun/app/requests" "quyun/database/conds" "quyun/database/schemas/public/model" "quyun/database/schemas/public/table" . "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/qrm" "github.com/samber/lo" "github.com/sirupsen/logrus" ) // @provider type postsModel struct { log *logrus.Entry `inject:"false"` } func (m *postsModel) Prepare() error { m.log = logrus.WithField("model", "postsModel") return nil } func (m *postsModel) IncrViewCount(ctx context.Context, id int64) error { tbl := table.Posts stmt := tbl.UPDATE(tbl.Views).SET(tbl.Views.ADD(Int64(1))).WHERE(tbl.ID.EQ(Int64(id))) m.log.Infof("sql: %s", stmt.DebugSql()) var post model.Posts err := stmt.QueryContext(ctx, db, &post) if err != nil { m.log.Errorf("error updating post view count: %v", err) return err } return nil } // GetByID func (m *postsModel) GetByID(ctx context.Context, id int64, cond ...conds.Cond) (*model.Posts, error) { tbl := table.Posts var combinedCond BoolExpression = tbl.ID.EQ(Int64(id)) for _, c := range cond { combinedCond = c(combinedCond) } stmt := tbl.SELECT(tbl.AllColumns).WHERE(combinedCond) m.log.Infof("sql: %s", stmt.DebugSql()) var post model.Posts err := stmt.QueryContext(ctx, db, &post) if err != nil { m.log.Errorf("error getting post: %v", err) return nil, err } return &post, nil } // Create func (m *postsModel) Create(ctx context.Context, model *model.Posts) error { model.CreatedAt = time.Now() model.UpdatedAt = time.Now() tbl := table.Posts stmt := tbl.INSERT(tbl.MutableColumns).MODEL(model) m.log.Infof("sql: %s", stmt.DebugSql()) _, err := stmt.ExecContext(ctx, db) if err != nil { m.log.Errorf("error creating post: %v", err) return err } return nil } // Update func (m *postsModel) Update(ctx context.Context, id int64, model *model.Posts) error { model.UpdatedAt = time.Now() tbl := table.Posts stmt := tbl.UPDATE(tbl.MutableColumns.Except(tbl.CreatedAt, tbl.DeletedAt)).MODEL(model).WHERE(tbl.ID.EQ(Int64(id))) m.log.Infof("sql: %s", stmt.DebugSql()) _, err := stmt.ExecContext(ctx, db) if err != nil { m.log.Errorf("error updating post: %v", err) return err } return nil } // countByCond func (m *postsModel) countByCondition(ctx context.Context, expr BoolExpression) (int64, error) { var cnt struct { Cnt int64 } tbl := table.Posts stmt := SELECT(COUNT(tbl.ID).AS("cnt")).FROM(tbl).WHERE(expr) m.log.Infof("sql: %s", stmt.DebugSql()) err := stmt.QueryContext(ctx, db, &cnt) if err != nil { m.log.Errorf("error counting post items: %v", err) return 0, err } return cnt.Cnt, nil } func (m *postsModel) List(ctx context.Context, pagination *requests.Pagination, cond ...conds.Cond) (*requests.Pager, error) { pagination.Format() combinedCond := table.Posts.DeletedAt.IS_NULL() for _, c := range cond { combinedCond = c(combinedCond) } tbl := table.Posts stmt := tbl. SELECT(tbl.AllColumns). WHERE(combinedCond). ORDER_BY(tbl.ID.DESC()). LIMIT(pagination.Limit). OFFSET(pagination.Offset) m.log.Infof("sql: %s", stmt.DebugSql()) var posts []model.Posts = make([]model.Posts, 0) err := stmt.QueryContext(ctx, db, &posts) if err != nil { m.log.Errorf("error querying post items: %v", err) return nil, err } count, err := m.countByCondition(ctx, combinedCond) if err != nil { m.log.Errorf("error getting post count: %v", err) return nil, err } return &requests.Pager{ Items: posts, Total: count, Pagination: *pagination, }, nil } func (m *postsModel) IsUserBought(ctx context.Context, userId, postId int64) (bool, error) { tbl := table.UserPosts stmt := tbl. SELECT(tbl.ID). WHERE( tbl.UserID.EQ(Int64(userId)).AND( tbl.PostID.EQ(Int64(postId)), ), ) m.log.Infof("sql: %s", stmt.DebugSql()) var userPost model.UserPosts err := stmt.QueryContext(ctx, db, &userPost) if err != nil { if errors.Is(err, qrm.ErrNoRows) { return false, nil } m.log.Errorf("error querying user post item: %v", err) return false, err } return userPost.ID > 0, nil } func (m *postsModel) Buy(ctx context.Context, userId, postId int64) error { tbl := table.UserPosts post, err := m.GetByID(ctx, postId) if err != nil { m.log.Errorf("error getting post by ID: %v", err) return err } user, err := Users.GetByID(ctx, userId) if err != nil { m.log.Errorf("error getting user by ID: %v", err) return err } record := model.UserPosts{ UserID: user.ID, PostID: post.ID, Price: post.Price * int64(post.Discount) / 100, } stmt := tbl.INSERT(tbl.MutableColumns).MODEL(record) m.log.Infof("sql: %s", stmt.DebugSql()) if _, err := stmt.ExecContext(ctx, db); err != nil { m.log.Errorf("error buying post: %v", err) return err } return nil } // DeleteByID soft delete item func (m *postsModel) DeleteByID(ctx context.Context, id int64) error { tbl := table.Posts stmt := tbl. UPDATE(tbl.DeletedAt). SET(TimestampT(time.Now())). WHERE( tbl.ID.EQ(Int64(id)), ) m.log.Infof("sql: %s", stmt.DebugSql()) if _, err := stmt.ExecContext(ctx, db); err != nil { m.log.Errorf("error deleting post: %v", err) return err } return nil } // SendTo func (m *postsModel) SendTo(ctx context.Context, postId, userId int64) error { // add record to user_posts tbl := table.UserPosts stmt := tbl.INSERT(tbl.MutableColumns).MODEL(model.UserPosts{ UserID: userId, PostID: postId, }) m.log.Infof("sql: %s", stmt.DebugSql()) if _, err := stmt.ExecContext(ctx, db); err != nil { m.log.Errorf("error sending post to user: %v", err) return err } return nil } // PostBoughtStatistics 获取指定文件 ID 的购买次数 func (m *postsModel) BoughtStatistics(ctx context.Context, postIds []int64) (map[int64]int64, error) { tbl := table.UserPosts // select count(user_id), post_id from user_posts up where post_id in (1, 2,3,4,5,6,7,8,9,10) group by post_id stmt := tbl. SELECT( COUNT(tbl.UserID).AS("cnt"), tbl.PostID.AS("post_id"), ). WHERE( tbl.PostID.IN(lo.Map(postIds, func(id int64, _ int) Expression { return Int64(id) })...), ). GROUP_BY( tbl.PostID, ) m.log.Infof("sql: %s", stmt.DebugSql()) var result []struct { Cnt int64 PostId int64 } if err := stmt.QueryContext(ctx, db, &result); err != nil { m.log.Errorf("error getting post bought statistics: %v", err) return nil, err } // convert to map resultMap := make(map[int64]int64) for _, item := range result { resultMap[item.PostId] = item.Cnt } return resultMap, nil } // Bought func (m *postsModel) Bought(ctx context.Context, userId int64, pagination *requests.Pagination) (*requests.Pager, error) { pagination.Format() // select up.price,up.created_at,p.* from user_posts up left join posts p on up.post_id = p.id where up.user_id =1 tbl := table.UserPosts stmt := tbl. SELECT( tbl.Price.AS("price"), tbl.CreatedAt.AS("bought_at"), table.Posts.Title.AS("title"), ). FROM( tbl.INNER_JOIN(table.Posts, table.Posts.ID.EQ(tbl.PostID)), ). WHERE( tbl.UserID.EQ(Int64(userId)), ). ORDER_BY(tbl.ID.DESC()). LIMIT(pagination.Limit). OFFSET(pagination.Offset) m.log.Infof("sql: %s", stmt.DebugSql()) var items []struct { Title string `json:"title"` Price int64 `json:"price"` BoughtAt time.Time `json:"bought_at"` } if err := stmt.QueryContext(ctx, db, &items); err != nil { m.log.Errorf("error getting bought posts: %v", err) return nil, err } // convert to model.Posts var cnt struct { Cnt int64 } stmtCnt := tbl. SELECT(COUNT(tbl.ID).AS("cnt")). WHERE( tbl.UserID.EQ(Int64(userId)), ) if err := stmtCnt.QueryContext(ctx, db, &cnt); err != nil { m.log.Errorf("error getting bought posts count: %v", err) return nil, err } return &requests.Pager{ Items: items, Total: cnt.Cnt, Pagination: *pagination, }, nil } // GetPostsMapByIDs func (m *postsModel) GetPostsMapByIDs(ctx context.Context, ids []int64) (map[int64]model.Posts, error) { if len(ids) == 0 { return nil, nil } tbl := table.Posts stmt := tbl. SELECT(tbl.AllColumns). WHERE( tbl.ID.IN(lo.Map(ids, func(id int64, _ int) Expression { return Int64(id) })...), ) m.log.Infof("sql: %s", stmt.DebugSql()) var posts []model.Posts = make([]model.Posts, 0) err := stmt.QueryContext(ctx, db, &posts) if err != nil { m.log.Errorf("error querying posts: %v", err) return nil, err } return lo.SliceToMap(posts, func(item model.Posts) (int64, model.Posts) { return item.ID, item }), nil } // GetMediaByIds func (m *postsModel) GetMediaByIds(ctx context.Context, ids []int64) ([]model.Medias, error) { if len(ids) == 0 { return nil, nil } tbl := table.Medias stmt := tbl. SELECT(tbl.AllColumns). WHERE( tbl.ID.IN(lo.Map(ids, func(id int64, _ int) Expression { return Int64(id) })...), ) m.log.Infof("sql: %s", stmt.DebugSql()) var medias []model.Medias if err := stmt.QueryContext(ctx, db, &medias); err != nil { m.log.Errorf("error querying media: %v", err) return nil, err } return medias, nil } // Count func (m *postsModel) Count(ctx context.Context, cond BoolExpression) (int64, error) { tbl := table.Posts stmt := tbl. SELECT(COUNT(tbl.ID).AS("count")). WHERE(cond) var count struct { Count int64 } if err := stmt.QueryContext(ctx, db, &count); err != nil { m.log.Errorf("error counting posts: %v", err) return 0, err } return count.Count, nil }