package model import ( "context" "time" "quyun/app/requests" "quyun/database/fields" . "github.com/go-jet/jet/v2/postgres" "github.com/samber/lo" ) var tblPostsUpdateMutableColumns = tblPosts.MutableColumns.Except( tblPosts.CreatedAt, tblPosts.DeletedAt, tblPosts.Views, tblPosts.Likes, ) func (m *Posts) CondStatus(s fields.PostStatus) Cond { return func(cond BoolExpression) BoolExpression { return cond.AND(tblPosts.Status.EQ(Int(int64(s)))) } } func (m *Posts) CondLike(key *string) Cond { return func(cond BoolExpression) BoolExpression { tbl := tblPosts if key == nil || *key == "" { return cond } cond = cond.AND( tbl.Title.LIKE(String("%" + *key + "%")). OR( tbl.Content.LIKE(String("%" + *key + "%")), ). OR( tbl.Description.LIKE(String("%" + *key + "%")), ), ) return cond } } func (m *Posts) IncrViewCount(ctx context.Context) error { tbl := tblPosts stmt := tbl.UPDATE(tbl.Views). SET(tbl.Views.ADD(Int64(1))). WHERE(tbl.ID.EQ(Int64(m.ID))). RETURNING(tblPosts.AllColumns) m.log().Infof("sql: %s", stmt.DebugSql()) if err := stmt.QueryContext(ctx, db, m); err != nil { m.log().Errorf("error updating post view count: %v", err) return err } return nil } func (m *Posts) List(ctx context.Context, pagination *requests.Pagination, conds ...Cond) (*requests.Pager, error) { pagination.Format() cond := CondJoin(m.CondNotDeleted(), conds...) tbl := tblPosts stmt := tbl. SELECT(tbl.AllColumns). WHERE(CondTrue(cond...)). ORDER_BY(tbl.ID.DESC()). LIMIT(pagination.Limit). OFFSET(pagination.Offset) m.log().Infof("sql: %s", stmt.DebugSql()) var posts []Posts = make([]Posts, 0) err := stmt.QueryContext(ctx, db, &posts) if err != nil { m.log().Errorf("error querying post items: %v", err) return nil, err } count, err := m.Count(ctx, CondJoin(m.CondNotDeleted(), conds...)...) if err != nil { m.log().Errorf("error getting post count: %v", err) return nil, err } return &requests.Pager{ Items: posts, Total: count, Pagination: *pagination, }, nil } // SendTo func (m *Posts) SendTo(ctx context.Context, userId int64) error { // add record to user_posts tbl := tblUserPosts stmt := tbl.INSERT(tbl.MutableColumns).MODEL(UserPosts{ CreatedAt: time.Now(), UpdatedAt: time.Now(), UserID: userId, PostID: m.ID, Price: -1, }) m.log().Infof("sql: %s", stmt.DebugSql()) if _, err := stmt.ExecContext(ctx, db); err != nil { m.log().Errorf("error sending post to user: %v", err) return err } return nil } // PostBoughtStatistics 获取指定文件 ID 的购买次数 func (m *Posts) BoughtStatistics(ctx context.Context, postIds []int64) (map[int64]int64, error) { tbl := tblUserPosts // select count(user_id), post_id from user_posts up where post_id in (1, 2,3,4,5,6,7,8,9,10) group by post_id stmt := tbl. SELECT( COUNT(tbl.UserID).AS("cnt"), tbl.PostID.AS("post_id"), ). WHERE( tbl.PostID.IN(IntExprSlice(postIds)...), ). GROUP_BY( tbl.PostID, ) m.log().Infof("sql: %s", stmt.DebugSql()) var result []struct { Cnt int64 PostId int64 } if err := stmt.QueryContext(ctx, db, &result); err != nil { m.log().Errorf("error getting post bought statistics: %v", err) return nil, err } // convert to map resultMap := make(map[int64]int64) for _, item := range result { resultMap[item.PostId] = item.Cnt } return resultMap, nil } // Bought func (m *Posts) Bought(ctx context.Context, userId int64, pagination *requests.Pagination) (*requests.Pager, error) { pagination.Format() // select up.price,up.created_at,p.* from user_posts up left join posts p on up.post_id = p.id where up.user_id =1 tbl := tblUserPosts stmt := tbl. SELECT( tbl.Price.AS("price"), tbl.CreatedAt.AS("bought_at"), tblPosts.Title.AS("title"), ). FROM( tbl.INNER_JOIN(tblPosts, tblPosts.ID.EQ(tbl.PostID)), ). WHERE( tbl.UserID.EQ(Int64(userId)), ). ORDER_BY(tbl.ID.DESC()). LIMIT(pagination.Limit). OFFSET(pagination.Offset) m.log().Infof("sql: %s", stmt.DebugSql()) var items []struct { Title string `json:"title"` Price int64 `json:"price"` BoughtAt time.Time `json:"bought_at"` } if err := stmt.QueryContext(ctx, db, &items); err != nil { m.log().Errorf("error getting bought posts: %v", err) return nil, err } // convert to Posts var cnt struct { Cnt int64 } stmtCnt := tbl. SELECT(COUNT(tbl.ID).AS("cnt")). WHERE( tbl.UserID.EQ(Int64(userId)), ) if err := stmtCnt.QueryContext(ctx, db, &cnt); err != nil { m.log().Errorf("error getting bought posts count: %v", err) return nil, err } return &requests.Pager{ Items: items, Total: cnt.Cnt, Pagination: *pagination, }, nil } // GetPostsMapByIDs func (m *Posts) GetPostsMapByIDs(ctx context.Context, ids []int64) (map[int64]Posts, error) { if len(ids) == 0 { return nil, nil } tbl := tblPosts stmt := tbl. SELECT(tbl.AllColumns). WHERE( tbl.ID.IN(IntExprSlice(ids)...), ) m.log().Infof("sql: %s", stmt.DebugSql()) var posts []Posts = make([]Posts, 0) err := stmt.QueryContext(ctx, db, &posts) if err != nil { m.log().Errorf("error querying posts: %v", err) return nil, err } return lo.SliceToMap(posts, func(item Posts) (int64, Posts) { return item.ID, item }), nil } // GetMediaByIds func (m *Posts) GetMediaByIds(ctx context.Context, ids []int64) ([]Medias, error) { if len(ids) == 0 { return nil, nil } tbl := tblMedias stmt := tbl. SELECT(tbl.AllColumns). WHERE( tbl.ID.IN(IntExprSlice(ids)...), ) m.log().Infof("sql: %s", stmt.DebugSql()) var medias []Medias if err := stmt.QueryContext(ctx, db, &medias); err != nil { m.log().Errorf("error querying media: %v", err) return nil, err } return medias, nil } func (m *Posts) PayPrice() int64 { return m.Price * int64(m.Discount) / 100 }