From 2b4cfb1a1e1da55ed4f97ec770592813389d5e5b Mon Sep 17 00:00:00 2001 From: Rogee Date: Mon, 2 Dec 2024 11:29:15 +0800 Subject: [PATCH] feat: uniq tenant_users --- .../migrations/20241128075611_init.sql | 35 +++++----- backend/fixtures/db.go | 26 +++++++ backend/modules/users/service.go | 58 ++++++++++++---- backend/modules/users/service_test.go | 67 +++++++++++++++++++ backend/pkg/db/db.go | 11 +++ backend/pkg/pg/users.go | 36 ++++++++-- backend/providers/app | 2 +- 7 files changed, 198 insertions(+), 37 deletions(-) create mode 100644 backend/fixtures/db.go create mode 100644 backend/modules/users/service_test.go diff --git a/backend/database/migrations/20241128075611_init.sql b/backend/database/migrations/20241128075611_init.sql index 583a260..6770d2c 100644 --- a/backend/database/migrations/20241128075611_init.sql +++ b/backend/database/migrations/20241128075611_init.sql @@ -4,16 +4,15 @@ CREATE TABLE users ( id SERIAL8 PRIMARY KEY, - open_id VARCHAR(128) NOT NULL, - union_id VARCHAR(128), - oauth jsonb , + open_id VARCHAR(128) NOT NULL UNIQUE, + union_id VARCHAR(128) , + oauth jsonb default '{}'::jsonb, expire_in timestamp NOT NULL, - created_at timestamp NOT NULL, - updated_at timestamp NOT NULL + created_at timestamp NOT NULL default now(), + updated_at timestamp NOT NULL default now() ); CREATE INDEX idx_users_open_id ON users (open_id); - CREATE INDEX idx_users_union_id ON users (union_id); -- table tenants @@ -24,8 +23,8 @@ CREATE TABLE slug VARCHAR(128) NOT NULL, description VARCHAR(128), expire_at timestamp NOT NULL, - created_at timestamp NOT NULL, - updated_at timestamp NOT NULL + created_at timestamp NOT NULL default now(), + updated_at timestamp NOT NULL default now() ); -- table users_tenants @@ -34,11 +33,13 @@ CREATE TABLE id SERIAL8 PRIMARY KEY, user_id INT8 NOT NULL, tenant_id INT8 NOT NULL, - created_at timestamp NOT NULL + created_at timestamp NOT NULL default now() ); CREATE INDEX idx_users_tenants_user_id ON users_tenants (user_id); CREATE INDEX idx_users_tenants_tenant_id ON users_tenants (tenant_id); +-- uniq user_id, tenant_id +CREATE UNIQUE INDEX idx_users_tenants_user_id_tenant_id ON users_tenants (user_id, tenant_id); CREATE TABLE tenant_user_balances ( id SERIAL8 PRIMARY KEY, @@ -56,9 +57,9 @@ CREATE TABLE user_balance_histories ( user_id INT8 NOT NULL, tenant_id INT8 NOT NULL, balance INT8 NOT NULL, - target jsonb , + target jsonb default '{}'::jsonb, type VARCHAR(128) NOT NULL, -- charge, consume, refund - created_at timestamp NOT NULL + created_at timestamp NOT NULL default now() ); CREATE INDEX idx_user_balance_histories_user_id ON user_balance_histories (user_id); CREATE INDEX idx_user_balance_histories_tenant_id ON user_balance_histories (tenant_id); @@ -73,8 +74,8 @@ CREATE TABLE price INT8 NOT NULL, discount INT8 NOT NULL, publish BOOL NOT NULL, - created_at timestamp NOT NULL, - updated_at timestamp NOT NULL + created_at timestamp NOT NULL default now(), + updated_at timestamp NOT NULL default now() ); CREATE INDEX idx_medias_tenant_id ON medias (tenant_id); @@ -86,11 +87,11 @@ CREATE TABLE id SERIAL8 PRIMARY KEY, media_id INT8 NOT NULL, type VARCHAR(128) NOT NULL, - source jsonb , + source jsonb default '{}'::jsonb, size INT8 NOT NULL, publish BOOL NOT NULL, - created_at timestamp NOT NULL, - updated_at timestamp NOT NULL + created_at timestamp NOT NULL default now(), + updated_at timestamp NOT NULL default now() ); CREATE INDEX idx_media_resources_media_id ON media_resources (media_id); @@ -101,7 +102,7 @@ CREATE TABLE user_medias ( tenant_id INT8 NOT NULL, media_id INT8 NOT NULL, price INT8 NOT NULL, - created_at timestamp NOT NULL + created_at timestamp NOT NULL default now() ); CREATE INDEX idx_user_medias_user_id ON user_medias (user_id); diff --git a/backend/fixtures/db.go b/backend/fixtures/db.go new file mode 100644 index 0000000..4c1c825 --- /dev/null +++ b/backend/fixtures/db.go @@ -0,0 +1,26 @@ +package fixtures + +import ( + "database/sql" + + // . "github.com/go-jet/jet/v2/postgres" + _ "github.com/lib/pq" + "github.com/sirupsen/logrus" +) + +func GetDB() (*sql.DB, error) { + logrus.SetLevel(logrus.DebugLevel) + + dsn := "postgres://postgres:xixi0202@10.1.1.3:5432/qvyun?sslmode=disable" + db, err := sql.Open("postgres", dsn) + if err != nil { + return nil, err + } + + err = db.Ping() + if err != nil { + return nil, err + } + + return db, nil +} diff --git a/backend/modules/users/service.go b/backend/modules/users/service.go index ca3fa44..3c7a209 100644 --- a/backend/modules/users/service.go +++ b/backend/modules/users/service.go @@ -12,6 +12,7 @@ import ( "backend/pkg/pg" . "github.com/go-jet/jet/v2/postgres" + "github.com/go-jet/jet/v2/qrm" "github.com/pkg/errors" "github.com/sirupsen/logrus" ) @@ -27,6 +28,23 @@ func (svc *Service) Prepare() error { return nil } +// GetByID +func (svc *Service) GetByID(ctx context.Context, id int64) (*model.Users, error) { + tbl := table.Users + stmt := tbl. + SELECT(tbl.AllColumns). + WHERE( + tbl.ID.EQ(Int64(id)), + ) + svc.log.WithField("method", "GetByID").Debug(stmt.DebugSql()) + + var item model.Users + if err := stmt.QueryContext(ctx, svc.db, &item); err != nil { + return nil, errors.Wrap(err, "failed to query user by id") + } + return &item, nil +} + // GetByOpenID func (svc *Service) GetByOpenID(ctx context.Context, openid string) (*model.Users, error) { tbl := table.Users @@ -51,14 +69,28 @@ func (svc *Service) GetOrNew(ctx context.Context, tenantID int64, openid string, user, err := svc.GetByOpenID(ctx, openid) if err == nil { // check: if tenant has user + hasUser, err := svc.TenantHasUser(ctx, user.ID, tenantID) + if err != nil { + return nil, errors.Wrap(err, "failed to check user-tenant relation") + } + + if !hasUser { + // create user-tenant relation + if err := svc.CreateTenantUser(ctx, user.ID, tenantID); err != nil { + return nil, errors.Wrap(err, "failed to create user-tenant relation") + } + } + return user, nil } - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, qrm.ErrNoRows) { user = &model.Users{ - OpenID: openid, - OAuth: authInfo, - ExpireIn: time.Now().Add(time.Minute * time.Duration(authInfo.ExpiresIn)), + OpenID: openid, + OAuth: authInfo, + ExpireIn: time.Now().Add(time.Minute * time.Duration(authInfo.ExpiresIn)), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), } tx, err := svc.db.BeginTx(ctx, nil) @@ -93,17 +125,16 @@ func (svc *Service) GetOrNew(ctx context.Context, tenantID int64, openid string, func (svc *Service) CreateFromModel(ctx context.Context, user *model.Users) (*model.Users, error) { log := svc.log.WithField("method", "CreateFromModel") - stmt := table.Users.INSERT().MODEL(user).RETURNING(table.Users.AllColumns) + tbl := table.Users + stmt := tbl.INSERT(tbl.AllColumns.Except(tbl.ID)).MODEL(user).RETURNING(tbl.AllColumns) log.Debug(stmt.DebugSql()) - // get tx from context - - var item model.Users - if err := stmt.QueryContext(ctx, db.FromContext(ctx, svc.db), &item); err != nil { + var userModel model.Users + err := stmt.QueryContext(ctx, db.FromContext(ctx, svc.db), &userModel) + if err != nil { return nil, errors.Wrap(err, "failed to create user") } - - return &item, nil + return &userModel, nil } // GetTenantByID @@ -152,7 +183,10 @@ func (svc *Service) CreateTenantUser(ctx context.Context, userID, tenantID int64 ).VALUES( Int64(userID), Int64(tenantID), - ) + ).ON_CONFLICT( + table.UsersTenants.UserID, + table.UsersTenants.TenantID, + ).DO_NOTHING() log.Debug(stmt.DebugSql()) if _, err := stmt.ExecContext(ctx, db.FromContext(ctx, svc.db)); err != nil { diff --git a/backend/modules/users/service_test.go b/backend/modules/users/service_test.go new file mode 100644 index 0000000..c2b7699 --- /dev/null +++ b/backend/modules/users/service_test.go @@ -0,0 +1,67 @@ +package users + +import ( + "context" + "testing" + + "backend/fixtures" + dbUtil "backend/pkg/db" + "backend/pkg/pg" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestService_GetOrNew(t *testing.T) { + FocusConvey("Test GetOrNew", t, func() { + // So(dbUtil.TruncateAllTables(context.TODO(), db, "users", "users_tenants"), ShouldBeNil) + db, err := fixtures.GetDB() + So(err, ShouldBeNil) + defer db.Close() + + Convey("Test GetOrNew", func() { + svc := &Service{db: db} + So(svc.Prepare(), ShouldBeNil) + + user, err := svc.GetByOpenID(context.Background(), "hello") + So(err, ShouldBeNil) + + So(user, ShouldNotBeNil) + So(user.OpenID, ShouldEqual, "hello") + }) + + FocusConvey("Test GetOrNew", func() { + svc := &Service{db: db} + So(svc.Prepare(), ShouldBeNil) + + openid := "test_openid" + authInfo := pg.UserOAuth{ + AccessToken: "test_access_token", + } + + user, err := svc.GetOrNew(context.Background(), 1, openid, authInfo) + So(err, ShouldBeNil) + So(user.OpenID, ShouldEqual, openid) + }) + }) +} + +func TestService_CreateTenantUser(t *testing.T) { + FocusConvey("Test CreateTenantUser", t, func() { + db, err := fixtures.GetDB() + So(err, ShouldBeNil) + defer db.Close() + + So(dbUtil.TruncateAllTables(context.TODO(), db, "users", "users_tenants"), ShouldBeNil) + + FocusConvey("Test Create", func() { + svc := &Service{db: db} + So(svc.Prepare(), ShouldBeNil) + + err := svc.CreateTenantUser(context.Background(), 1, 1) + So(err, ShouldBeNil) + + err = svc.CreateTenantUser(context.Background(), 1, 1) + So(err, ShouldBeNil) + }) + }) +} diff --git a/backend/pkg/db/db.go b/backend/pkg/db/db.go index 1565aa2..7774f2e 100644 --- a/backend/pkg/db/db.go +++ b/backend/pkg/db/db.go @@ -3,6 +3,7 @@ package db import ( "context" "database/sql" + "fmt" "backend/common/consts" @@ -15,3 +16,13 @@ func FromContext(ctx context.Context, db *sql.DB) qrm.DB { } return db } + +func TruncateAllTables(ctx context.Context, db *sql.DB, tableName ...string) error { + for _, name := range tableName { + sql := fmt.Sprintf("TRUNCATE TABLE %s CASCADE", name) + if _, err := db.ExecContext(ctx, sql); err != nil { + return err + } + } + return nil +} diff --git a/backend/pkg/pg/users.go b/backend/pkg/pg/users.go index 69aa98a..ffbb7f0 100644 --- a/backend/pkg/pg/users.go +++ b/backend/pkg/pg/users.go @@ -1,11 +1,33 @@ package pg +import ( + "database/sql/driver" + "encoding/json" + "errors" +) + type UserOAuth struct { - AccessToken string `json:"access_token"` - ExpiresIn int64 `json:"expires_in"` - IsSnapshotuser int64 `json:"is_snapshotuser"` - Openid string `json:"openid"` - RefreshToken string `json:"refresh_token"` - Scope string `json:"scope"` - Unionid string `json:"unionid"` + AccessToken string `json:"access_token,omitempty"` + ExpiresIn int64 `json:"expires_in,omitempty"` + IsSnapshotuser int64 `json:"is_snapshotuser,omitempty"` + Openid string `json:"openid,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + Scope string `json:"scope,omitempty"` + Unionid string `json:"unionid,omitempty"` +} + +func (x UserOAuth) Scan(value interface{}) (err error) { + switch v := value.(type) { + case string: + return json.Unmarshal([]byte(v), &x) + case []byte: + return json.Unmarshal(v, &x) + case *string: + return json.Unmarshal([]byte(*v), &x) + } + return errors.New("Unknown type for ") +} + +func (x UserOAuth) Value() (driver.Value, error) { + return json.Marshal(x) } diff --git a/backend/providers/app b/backend/providers/app index 6ea394c..9de1e7b 160000 --- a/backend/providers/app +++ b/backend/providers/app @@ -1 +1 @@ -Subproject commit 6ea394c47455a3e828bf138fed1f4cac5e39eebf +Subproject commit 9de1e7b7b0acc99a9833bc289f133d1488ed0730