tenant: admin orders sort whitelist

This commit is contained in:
2025-12-18 23:36:57 +08:00
parent 71bd15024e
commit 549339be74
8 changed files with 144 additions and 2 deletions

View File

@@ -14,6 +14,9 @@ type AdminOrderListFilter struct {
// Pagination 分页参数page/limit通用
requests.Pagination `json:",inline" query:",inline"`
// SortQueryFilter 排序参数asc/desc逗号分隔字段名字段白名单在 service 层统一校验。
requests.SortQueryFilter `json:",inline" query:",inline"`
// UserID 下单用户ID可选按买家用户ID精确过滤。
UserID *int64 `json:"user_id,omitempty" query:"user_id"`

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"quyun/v2/app/errorx"
@@ -18,6 +19,7 @@ import (
"github.com/samber/lo"
"github.com/sirupsen/logrus"
"go.ipao.vip/gen"
"go.ipao.vip/gen/field"
"gorm.io/gorm"
"gorm.io/gorm/clause"
@@ -428,7 +430,47 @@ func (s *order) AdminOrderPage(
query = query.Group(tbl.ID)
}
items, total, err := query.Where(conds...).Order(tbl.ID.Desc()).FindByPage(int(filter.Offset()), int(filter.Limit))
// 排序白名单:避免把任意字符串拼进 SQL 导致注入或慢查询。
// 约定:只允许按以下字段排序;未指定时默认按 id desc。
orderBys := make([]field.Expr, 0, 4)
allowedAsc := map[string]field.Expr{
"id": tbl.ID.Asc(),
"created_at": tbl.CreatedAt.Asc(),
"paid_at": tbl.PaidAt.Asc(),
"amount_paid": tbl.AmountPaid.Asc(),
}
allowedDesc := map[string]field.Expr{
"id": tbl.ID.Desc(),
"created_at": tbl.CreatedAt.Desc(),
"paid_at": tbl.PaidAt.Desc(),
"amount_paid": tbl.AmountPaid.Desc(),
}
for _, f := range filter.AscFields() {
f = strings.TrimSpace(f)
if f == "" {
continue
}
if ob, ok := allowedAsc[f]; ok {
orderBys = append(orderBys, ob)
}
}
for _, f := range filter.DescFields() {
f = strings.TrimSpace(f)
if f == "" {
continue
}
if ob, ok := allowedDesc[f]; ok {
orderBys = append(orderBys, ob)
}
}
// 默认加上 id desc 作为稳定排序(尤其是 join + group 的场景)。
if len(orderBys) == 0 {
orderBys = append(orderBys, tbl.ID.Desc())
} else {
orderBys = append(orderBys, tbl.ID.Desc())
}
items, total, err := query.Where(conds...).Order(orderBys...).FindByPage(int(filter.Offset()), int(filter.Limit))
if err != nil {
return nil, err
}

View File

@@ -12,6 +12,7 @@ import (
"quyun/v2/app/commands/testx"
"quyun/v2/app/errorx"
"quyun/v2/app/http/tenant/dto"
"quyun/v2/app/requests"
"quyun/v2/database"
"quyun/v2/database/models"
"quyun/v2/pkg/consts"
@@ -590,6 +591,58 @@ func (s *OrderTestSuite) Test_AdminOrderPage() {
So(pager.Total, ShouldEqual, 1)
})
Convey("按排序字段asc/desc排序白名单", func() {
s.truncate(ctx, models.TableNameOrderItem, models.TableNameOrder)
o1 := &models.Order{
TenantID: tenantID,
UserID: 2,
Type: consts.OrderTypeContentPurchase,
Status: consts.OrderStatusPaid,
Currency: consts.CurrencyCNY,
AmountPaid: 500,
Snapshot: types.JSON([]byte("{}")),
PaidAt: now,
CreatedAt: now.Add(-time.Hour),
UpdatedAt: now.Add(-time.Hour),
}
So(o1.Create(ctx), ShouldBeNil)
o2 := &models.Order{
TenantID: tenantID,
UserID: 3,
Type: consts.OrderTypeContentPurchase,
Status: consts.OrderStatusPaid,
Currency: consts.CurrencyCNY,
AmountPaid: 100,
Snapshot: types.JSON([]byte("{}")),
PaidAt: now,
CreatedAt: now,
UpdatedAt: now,
}
So(o2.Create(ctx), ShouldBeNil)
asc := "amount_paid"
pagerAsc, err := Order.AdminOrderPage(ctx, tenantID, &dto.AdminOrderListFilter{
SortQueryFilter: requests.SortQueryFilter{Asc: &asc},
})
So(err, ShouldBeNil)
So(pagerAsc.Total, ShouldEqual, 2)
itemsAsc, ok := pagerAsc.Items.([]*models.Order)
So(ok, ShouldBeTrue)
So(itemsAsc[0].AmountPaid, ShouldEqual, 100)
desc := "created_at"
pagerDesc, err := Order.AdminOrderPage(ctx, tenantID, &dto.AdminOrderListFilter{
SortQueryFilter: requests.SortQueryFilter{Desc: &desc},
})
So(err, ShouldBeNil)
So(pagerDesc.Total, ShouldEqual, 2)
itemsDesc, ok := pagerDesc.Items.([]*models.Order)
So(ok, ShouldBeTrue)
So(itemsDesc[0].CreatedAt.After(itemsDesc[1].CreatedAt), ShouldBeTrue)
})
Convey("按 type 过滤", func() {
s.truncate(ctx, models.TableNameOrderItem, models.TableNameOrder)