diff --git a/backend/database/.transform.yaml b/backend/database/.transform.yaml index 0d987e6..27f3027 100755 --- a/backend/database/.transform.yaml +++ b/backend/database/.transform.yaml @@ -4,3 +4,6 @@ types: oauth: backend/pkg/pg.UserOAuth medias: resources: backend/pkg/pg.MediaResources + user_balance_histories: + target: backend/pkg/pg.BalanceTarget + type: backend/pkg/pg.BalanceType diff --git a/backend/database/migrations/20241128075611_init.sql b/backend/database/migrations/20241128075611_init.sql index f26884a..f874055 100644 --- a/backend/database/migrations/20241128075611_init.sql +++ b/backend/database/migrations/20241128075611_init.sql @@ -50,10 +50,12 @@ CREATE TABLE user_balance_histories ( balance INT8 NOT NULL, target jsonb default '{}'::jsonb, type VARCHAR(128) NOT NULL, -- charge, consume, refund + code VARCHAR(128) NOT NULL default '', created_at timestamp NOT NULL default now() ); CREATE INDEX idx_user_balance_histories_user_id ON user_balance_histories (user_id); CREATE INDEX idx_user_balance_histories_tenant_id ON user_balance_histories (tenant_id); +CREATE INDEX idx_user_balance_histories_code ON user_balance_histories (code); -- medias CREATE TABLE @@ -95,7 +97,6 @@ CREATE INDEX idx_user_medias_tenant_id ON user_medias (tenant_id); DROP TABLE users; DROP TABLE tenants; DROP TABLE users_tenants; -DROP TABLE tenant_user_balances; DROP TABLE user_balance_histories; DROP TABLE medias; DROP TABLE user_medias; diff --git a/backend/database/models/qvyun/public/model/user_balance_histories.go b/backend/database/models/qvyun/public/model/user_balance_histories.go index 1ed9529..024de59 100644 --- a/backend/database/models/qvyun/public/model/user_balance_histories.go +++ b/backend/database/models/qvyun/public/model/user_balance_histories.go @@ -8,15 +8,17 @@ package model import ( + "backend/pkg/pg" "time" ) type UserBalanceHistories struct { - ID int64 `sql:"primary_key" json:"id"` - UserID int64 `json:"user_id"` - TenantID int64 `json:"tenant_id"` - Balance int64 `json:"balance"` - Target *string `json:"target"` - Type string `json:"type"` - CreatedAt time.Time `json:"created_at"` + ID int64 `sql:"primary_key" json:"id"` + UserID int64 `json:"user_id"` + TenantID int64 `json:"tenant_id"` + Balance int64 `json:"balance"` + Target pg.BalanceTarget `json:"target"` + Type pg.BalanceType `json:"type"` + Code string `json:"code"` + CreatedAt time.Time `json:"created_at"` } diff --git a/backend/database/models/qvyun/public/table/user_balance_histories.go b/backend/database/models/qvyun/public/table/user_balance_histories.go index a77643f..f40ac5a 100644 --- a/backend/database/models/qvyun/public/table/user_balance_histories.go +++ b/backend/database/models/qvyun/public/table/user_balance_histories.go @@ -23,6 +23,7 @@ type userBalanceHistoriesTable struct { Balance postgres.ColumnInteger Target postgres.ColumnString Type postgres.ColumnString + Code postgres.ColumnString CreatedAt postgres.ColumnTimestamp AllColumns postgres.ColumnList @@ -70,9 +71,10 @@ func newUserBalanceHistoriesTableImpl(schemaName, tableName, alias string) userB BalanceColumn = postgres.IntegerColumn("balance") TargetColumn = postgres.StringColumn("target") TypeColumn = postgres.StringColumn("type") + CodeColumn = postgres.StringColumn("code") CreatedAtColumn = postgres.TimestampColumn("created_at") - allColumns = postgres.ColumnList{IDColumn, UserIDColumn, TenantIDColumn, BalanceColumn, TargetColumn, TypeColumn, CreatedAtColumn} - mutableColumns = postgres.ColumnList{UserIDColumn, TenantIDColumn, BalanceColumn, TargetColumn, TypeColumn, CreatedAtColumn} + allColumns = postgres.ColumnList{IDColumn, UserIDColumn, TenantIDColumn, BalanceColumn, TargetColumn, TypeColumn, CodeColumn, CreatedAtColumn} + mutableColumns = postgres.ColumnList{UserIDColumn, TenantIDColumn, BalanceColumn, TargetColumn, TypeColumn, CodeColumn, CreatedAtColumn} ) return userBalanceHistoriesTable{ @@ -85,6 +87,7 @@ func newUserBalanceHistoriesTableImpl(schemaName, tableName, alias string) userB Balance: BalanceColumn, Target: TargetColumn, Type: TypeColumn, + Code: CodeColumn, CreatedAt: CreatedAtColumn, AllColumns: allColumns, diff --git a/backend/modules/users/controller.go b/backend/modules/users/controller.go index 6a9ee9a..32b982d 100644 --- a/backend/modules/users/controller.go +++ b/backend/modules/users/controller.go @@ -1,14 +1,35 @@ package users -import "github.com/gofiber/fiber/v3" +import ( + "backend/pkg/consts" + "backend/providers/jwt" + + "github.com/gofiber/fiber/v3" + log "github.com/sirupsen/logrus" + hashids "github.com/speps/go-hashids/v2" +) // @provider type Controller struct { - svc *Service + svc *Service + hashIds *hashids.HashID } // List func (c *Controller) List(ctx fiber.Ctx) error { - return ctx.SendString(ctx.Params("tenant", "no user")) + return ctx.JSON(nil) +} + +// Charge +func (c *Controller) Charge(ctx fiber.Ctx) error { + claim := fiber.Locals[*jwt.Claims](ctx, consts.CtxKeyClaim) + log.Debug(claim) + + // [tenantId, chargeAmount, timestamp] + code := ctx.Params("code") + if err := c.svc.Charge(ctx.Context(), claim, code); err != nil { + return err + } + return ctx.JSON(nil) } diff --git a/backend/modules/users/provider.gen.go b/backend/modules/users/provider.gen.go index 12c86d0..23b8505 100755 --- a/backend/modules/users/provider.gen.go +++ b/backend/modules/users/provider.gen.go @@ -7,14 +7,17 @@ import ( "git.ipao.vip/rogeecn/atom/container" "git.ipao.vip/rogeecn/atom/contracts" "git.ipao.vip/rogeecn/atom/utils/opt" + hashids "github.com/speps/go-hashids/v2" ) func Provide(opts ...opt.Option) error { if err := container.Container.Provide(func( + hashIds *hashids.HashID, svc *Service, ) (*Controller, error) { obj := &Controller{ - svc: svc, + hashIds: hashIds, + svc: svc, } return obj, nil }); err != nil { @@ -37,9 +40,11 @@ func Provide(opts ...opt.Option) error { if err := container.Container.Provide(func( db *sql.DB, + hashIds *hashids.HashID, ) (*Service, error) { obj := &Service{ - db: db, + db: db, + hashIds: hashIds, } if err := obj.Prepare(); err != nil { return nil, err diff --git a/backend/modules/users/router.go b/backend/modules/users/router.go index 58c7821..4451767 100755 --- a/backend/modules/users/router.go +++ b/backend/modules/users/router.go @@ -27,4 +27,5 @@ func (r *Router) Name() string { func (r *Router) Register(router fiber.Router) { group := router.Group(r.Name()) group.Get("", r.controller.List) + group.Patch("charge/:code", r.controller.Charge) } diff --git a/backend/modules/users/service.go b/backend/modules/users/service.go index cf07f65..8c39d17 100644 --- a/backend/modules/users/service.go +++ b/backend/modules/users/service.go @@ -9,18 +9,22 @@ import ( "backend/database/models/qvyun/public/table" "backend/pkg/consts" "backend/pkg/db" + "backend/pkg/errorx" "backend/pkg/pg" + "backend/providers/jwt" . "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/qrm" "github.com/pkg/errors" "github.com/sirupsen/logrus" + hashids "github.com/speps/go-hashids/v2" ) // @provider:except type Service struct { - db *sql.DB - log *logrus.Entry `inject:"false"` + db *sql.DB + hashIds *hashids.HashID + log *logrus.Entry `inject:"false"` } func (svc *Service) Prepare() error { @@ -255,3 +259,123 @@ func (svc *Service) SetTenantExpireAtBySlug(ctx context.Context, slug string, ex return nil } + +func (svc *Service) GenerateChargeCode(ctx context.Context, tenantID, chargeAmount int64) (string, error) { + log := svc.log.WithField("method", "GenerateChargeCode") + + timestamp := time.Now().Unix() + code, err := svc.hashIds.EncodeInt64([]int64{tenantID, chargeAmount, timestamp}) + if err != nil { + return "", errors.Wrap(err, "failed to encode charge code") + } + log.Infof("generate charge code: %s", code) + + return code, nil +} + +// Charge +func (svc *Service) Charge(ctx context.Context, claim *jwt.Claims, code string) error { + log := svc.log.WithField("method", "Charge") + raw, err := svc.hashIds.DecodeInt64WithError(code) + if err != nil { + return errorx.InvalidChargeCode + } + + if len(raw) != 3 { + return errorx.InvalidChargeCode + } + + tenantId, chargeAmount, timestamp := raw[0], raw[1], raw[2] + if tenantId != claim.TenantID { + return errorx.InvalidChargeCode + } + generatedAt := time.Unix(timestamp, 0) + log.Infof("charge code %s generated at: %s", code, generatedAt) + + if chargeAmount <= 0 { + return errorx.InvalidChargeCode + } + + t := table.UserBalanceHistories + st := t.SELECT(COUNT(t.ID).AS("cnt")).WHERE(t.Code.EQ(String(code))) + log.Debug(st.DebugSql()) + + var result struct { + Cnt int64 + } + if err := st.QueryContext(ctx, db.FromContext(ctx, svc.db), &result); err != nil { + return errors.Wrap(err, "failed to query charge code") + } + + if result.Cnt > 0 { + return errorx.InvalidChargeCode + } + + has, err := svc.TenantHasUser(ctx, claim.UserID, tenantId) + if err != nil { + return errors.Wrap(err, "failed to check user-tenant relation") + } + + if !has { + return errorx.InvalidChargeCode + } + + log.Infof("charge tenant: %d, user: %d, amount: %d", claim.TenantID, claim.UserID, chargeAmount) + + tx, err := svc.db.BeginTx(ctx, nil) + if err != nil { + return errors.Wrap(err, "failed to begin transaction") + } + defer tx.Rollback() + + // update user balance in users_tenants + tbl := table.UsersTenants + stmt := tbl. + UPDATE(). + SET( + tbl.Balance.SET( + tbl.Balance.ADD(Int64(chargeAmount)), + ), + ). + WHERE( + tbl.UserID.EQ(Int64(claim.UserID)).AND( + tbl.TenantID.EQ(Int64(claim.TenantID)), + ), + ) + log.Debug(stmt.DebugSql()) + + if _, err := stmt.ExecContext(ctx, db.FromContext(ctx, svc.db)); err != nil { + return errors.Wrap(err, "failed to charge user balance") + } + + // insert charge record + chargeTbl := table.UserBalanceHistories + chargeStmt := chargeTbl. + INSERT( + chargeTbl.UserID, + chargeTbl.TenantID, + chargeTbl.Balance, + chargeTbl.Target, + chargeTbl.Type, + chargeTbl.Code, + ). + VALUES( + Int64(claim.UserID), + Int64(claim.TenantID), + Int64(chargeAmount), + Json(pg.BalanceTarget{}.MustValue()), + String(pg.BalanceTypeCharge.String()), + String(code), + ) + log.Debug(chargeStmt.DebugSql()) + + if _, err := chargeStmt.ExecContext(ctx, db.FromContext(ctx, svc.db)); err != nil { + return errors.Wrap(err, "failed to insert charge record") + } + + if err := tx.Commit(); err != nil { + return errors.Wrap(err, "failed to commit transaction") + } + + return nil +} diff --git a/backend/modules/users/service_test.go b/backend/modules/users/service_test.go index c2b7699..e64b7b4 100644 --- a/backend/modules/users/service_test.go +++ b/backend/modules/users/service_test.go @@ -4,64 +4,56 @@ import ( "context" "testing" - "backend/fixtures" - dbUtil "backend/pkg/db" - "backend/pkg/pg" + "backend/pkg/service/testx" + "backend/providers/hashids" + "backend/providers/jwt" + "backend/providers/postgres" + "backend/providers/storage" + log "github.com/sirupsen/logrus" . "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/suite" + "go.uber.org/dig" ) -func TestService_GetOrNew(t *testing.T) { - FocusConvey("Test GetOrNew", t, func() { - // So(dbUtil.TruncateAllTables(context.TODO(), db, "users", "users_tenants"), ShouldBeNil) - db, err := fixtures.GetDB() - So(err, ShouldBeNil) - defer db.Close() +type ServiceInjectParams struct { + dig.In + Svc *Service +} - Convey("Test GetOrNew", func() { - svc := &Service{db: db} - So(svc.Prepare(), ShouldBeNil) +type ServiceTestSuite struct { + suite.Suite + ServiceInjectParams +} - user, err := svc.GetByOpenID(context.Background(), "hello") - So(err, ShouldBeNil) +func Test_DiscoverMedias(t *testing.T) { + log.SetLevel(log.DebugLevel) - So(user, ShouldNotBeNil) - So(user.OpenID, ShouldEqual, "hello") - }) + providers := testx.Default( + postgres.DefaultProvider(), + storage.DefaultProvider(), + hashids.DefaultProvider(), + ).With( + Provide, + ) - FocusConvey("Test GetOrNew", func() { - svc := &Service{db: db} - So(svc.Prepare(), ShouldBeNil) - - openid := "test_openid" - authInfo := pg.UserOAuth{ - AccessToken: "test_access_token", - } - - user, err := svc.GetOrNew(context.Background(), 1, openid, authInfo) - So(err, ShouldBeNil) - So(user.OpenID, ShouldEqual, openid) - }) + testx.Serve(providers, t, func(params ServiceInjectParams) { + suite.Run(t, &ServiceTestSuite{ServiceInjectParams: params}) }) } -func TestService_CreateTenantUser(t *testing.T) { - FocusConvey("Test CreateTenantUser", t, func() { - db, err := fixtures.GetDB() +func (t *ServiceTestSuite) Test_Charge() { + Convey("Charge", t.T(), func() { + code, err := t.Svc.GenerateChargeCode(context.Background(), 1, 100) So(err, ShouldBeNil) - defer db.Close() + code = "b8TDWf59wvPw" - So(dbUtil.TruncateAllTables(context.TODO(), db, "users", "users_tenants"), ShouldBeNil) - - FocusConvey("Test Create", func() { - svc := &Service{db: db} - So(svc.Prepare(), ShouldBeNil) - - err := svc.CreateTenantUser(context.Background(), 1, 1) - So(err, ShouldBeNil) - - err = svc.CreateTenantUser(context.Background(), 1, 1) - So(err, ShouldBeNil) - }) + err = t.Svc.Charge(context.Background(), &jwt.Claims{ + BaseClaims: jwt.BaseClaims{ + TenantID: 1, + UserID: 1, + }, + }, code) + So(err, ShouldBeNil) }) } diff --git a/backend/pkg/errorx/error.go b/backend/pkg/errorx/error.go index 0dc9ecf..f0debd7 100644 --- a/backend/pkg/errorx/error.go +++ b/backend/pkg/errorx/error.go @@ -29,4 +29,5 @@ var ( RequestParseError = Response{http.StatusBadRequest, http.StatusBadRequest, "请求解析错误"} InternalError = Response{http.StatusInternalServerError, http.StatusInternalServerError, "内部错误"} UserBalanceNotEnough = Response{http.StatusPaymentRequired, 1001, "余额不足,请充值"} + InvalidChargeCode = Response{http.StatusPaymentRequired, 1002, "无效的充值码"} ) diff --git a/backend/pkg/pg/users.gen.go b/backend/pkg/pg/users.gen.go new file mode 100644 index 0000000..4443398 --- /dev/null +++ b/backend/pkg/pg/users.gen.go @@ -0,0 +1,179 @@ +// Code generated by go-enum DO NOT EDIT. +// Version: - +// Revision: - +// Build Date: - +// Built By: - + +package pg + +import ( + "database/sql/driver" + "errors" + "fmt" + "strings" +) + +const ( + // BalanceTypeCharge is a BalanceType of type Charge. + BalanceTypeCharge BalanceType = "charge" + // BalanceTypeConsume is a BalanceType of type Consume. + BalanceTypeConsume BalanceType = "consume" + // BalanceTypeRefund is a BalanceType of type Refund. + BalanceTypeRefund BalanceType = "refund" +) + +var ErrInvalidBalanceType = fmt.Errorf("not a valid BalanceType, try [%s]", strings.Join(_BalanceTypeNames, ", ")) + +var _BalanceTypeNames = []string{ + string(BalanceTypeCharge), + string(BalanceTypeConsume), + string(BalanceTypeRefund), +} + +// BalanceTypeNames returns a list of possible string values of BalanceType. +func BalanceTypeNames() []string { + tmp := make([]string, len(_BalanceTypeNames)) + copy(tmp, _BalanceTypeNames) + return tmp +} + +// BalanceTypeValues returns a list of the values for BalanceType +func BalanceTypeValues() []BalanceType { + return []BalanceType{ + BalanceTypeCharge, + BalanceTypeConsume, + BalanceTypeRefund, + } +} + +// String implements the Stringer interface. +func (x BalanceType) String() string { + return string(x) +} + +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x BalanceType) IsValid() bool { + _, err := ParseBalanceType(string(x)) + return err == nil +} + +var _BalanceTypeValue = map[string]BalanceType{ + "charge": BalanceTypeCharge, + "consume": BalanceTypeConsume, + "refund": BalanceTypeRefund, +} + +// ParseBalanceType attempts to convert a string to a BalanceType. +func ParseBalanceType(name string) (BalanceType, error) { + if x, ok := _BalanceTypeValue[name]; ok { + return x, nil + } + return BalanceType(""), fmt.Errorf("%s is %w", name, ErrInvalidBalanceType) +} + +var errBalanceTypeNilPtr = errors.New("value pointer is nil") // one per type for package clashes + +// Scan implements the Scanner interface. +func (x *BalanceType) Scan(value interface{}) (err error) { + if value == nil { + *x = BalanceType("") + return + } + + // A wider range of scannable types. + // driver.Value values at the top of the list for expediency + switch v := value.(type) { + case string: + *x, err = ParseBalanceType(v) + case []byte: + *x, err = ParseBalanceType(string(v)) + case BalanceType: + *x = v + case *BalanceType: + if v == nil { + return errBalanceTypeNilPtr + } + *x = *v + case *string: + if v == nil { + return errBalanceTypeNilPtr + } + *x, err = ParseBalanceType(*v) + default: + return errors.New("invalid type for BalanceType") + } + + return +} + +// Value implements the driver Valuer interface. +func (x BalanceType) Value() (driver.Value, error) { + return x.String(), nil +} + +// Set implements the Golang flag.Value interface func. +func (x *BalanceType) Set(val string) error { + v, err := ParseBalanceType(val) + *x = v + return err +} + +// Get implements the Golang flag.Getter interface func. +func (x *BalanceType) Get() interface{} { + return *x +} + +// Type implements the github.com/spf13/pFlag Value interface. +func (x *BalanceType) Type() string { + return "BalanceType" +} + +type NullBalanceType struct { + BalanceType BalanceType + Valid bool +} + +func NewNullBalanceType(val interface{}) (x NullBalanceType) { + err := x.Scan(val) // yes, we ignore this error, it will just be an invalid value. + _ = err // make any errcheck linters happy + return +} + +// Scan implements the Scanner interface. +func (x *NullBalanceType) Scan(value interface{}) (err error) { + if value == nil { + x.BalanceType, x.Valid = BalanceType(""), false + return + } + + err = x.BalanceType.Scan(value) + x.Valid = (err == nil) + return +} + +// Value implements the driver Valuer interface. +func (x NullBalanceType) Value() (driver.Value, error) { + if !x.Valid { + return nil, nil + } + // driver.Value accepts int64 for int values. + return string(x.BalanceType), nil +} + +type NullBalanceTypeStr struct { + NullBalanceType +} + +func NewNullBalanceTypeStr(val interface{}) (x NullBalanceTypeStr) { + x.Scan(val) // yes, we ignore this error, it will just be an invalid value. + return +} + +// Value implements the driver Valuer interface. +func (x NullBalanceTypeStr) Value() (driver.Value, error) { + if !x.Valid { + return nil, nil + } + return x.BalanceType.String(), nil +} diff --git a/backend/pkg/pg/users.go b/backend/pkg/pg/users.go index ffbb7f0..415bd02 100644 --- a/backend/pkg/pg/users.go +++ b/backend/pkg/pg/users.go @@ -4,6 +4,8 @@ import ( "database/sql/driver" "encoding/json" "errors" + + "github.com/samber/lo" ) type UserOAuth struct { @@ -16,7 +18,7 @@ type UserOAuth struct { Unionid string `json:"unionid,omitempty"` } -func (x UserOAuth) Scan(value interface{}) (err error) { +func (x *UserOAuth) Scan(value interface{}) (err error) { switch v := value.(type) { case string: return json.Unmarshal([]byte(v), &x) @@ -31,3 +33,36 @@ func (x UserOAuth) Scan(value interface{}) (err error) { func (x UserOAuth) Value() (driver.Value, error) { return json.Marshal(x) } + +type BalanceTarget struct { + ID int64 `json:"id,omitempty"` + Name string `json:"name,omitempty"` +} + +func (x BalanceTarget) MustValue() driver.Value { + return lo.Must(json.Marshal(x)) +} + +func (x *BalanceTarget) Scan(value interface{}) (err error) { + switch v := value.(type) { + case string: + return json.Unmarshal([]byte(v), &x) + case []byte: + return json.Unmarshal(v, &x) + case *string: + return json.Unmarshal([]byte(*v), &x) + } + return errors.New("Unknown type for ") +} + +func (x BalanceTarget) Value() (driver.Value, error) { + return json.Marshal(x) +} + +// swagger:enum BalanceType +// ENUM( +// Charge = "charge", +// Consume = "consume", +// Refund = "refund", +// ) +type BalanceType string