feat: uniq tenant_users
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
||||
"backend/pkg/pg"
|
||||
|
||||
. "github.com/go-jet/jet/v2/postgres"
|
||||
"github.com/go-jet/jet/v2/qrm"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
@@ -27,6 +28,23 @@ func (svc *Service) Prepare() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID
|
||||
func (svc *Service) GetByID(ctx context.Context, id int64) (*model.Users, error) {
|
||||
tbl := table.Users
|
||||
stmt := tbl.
|
||||
SELECT(tbl.AllColumns).
|
||||
WHERE(
|
||||
tbl.ID.EQ(Int64(id)),
|
||||
)
|
||||
svc.log.WithField("method", "GetByID").Debug(stmt.DebugSql())
|
||||
|
||||
var item model.Users
|
||||
if err := stmt.QueryContext(ctx, svc.db, &item); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to query user by id")
|
||||
}
|
||||
return &item, nil
|
||||
}
|
||||
|
||||
// GetByOpenID
|
||||
func (svc *Service) GetByOpenID(ctx context.Context, openid string) (*model.Users, error) {
|
||||
tbl := table.Users
|
||||
@@ -51,14 +69,28 @@ func (svc *Service) GetOrNew(ctx context.Context, tenantID int64, openid string,
|
||||
user, err := svc.GetByOpenID(ctx, openid)
|
||||
if err == nil {
|
||||
// check: if tenant has user
|
||||
hasUser, err := svc.TenantHasUser(ctx, user.ID, tenantID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to check user-tenant relation")
|
||||
}
|
||||
|
||||
if !hasUser {
|
||||
// create user-tenant relation
|
||||
if err := svc.CreateTenantUser(ctx, user.ID, tenantID); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create user-tenant relation")
|
||||
}
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if errors.Is(err, qrm.ErrNoRows) {
|
||||
user = &model.Users{
|
||||
OpenID: openid,
|
||||
OAuth: authInfo,
|
||||
ExpireIn: time.Now().Add(time.Minute * time.Duration(authInfo.ExpiresIn)),
|
||||
OpenID: openid,
|
||||
OAuth: authInfo,
|
||||
ExpireIn: time.Now().Add(time.Minute * time.Duration(authInfo.ExpiresIn)),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
tx, err := svc.db.BeginTx(ctx, nil)
|
||||
@@ -93,17 +125,16 @@ func (svc *Service) GetOrNew(ctx context.Context, tenantID int64, openid string,
|
||||
func (svc *Service) CreateFromModel(ctx context.Context, user *model.Users) (*model.Users, error) {
|
||||
log := svc.log.WithField("method", "CreateFromModel")
|
||||
|
||||
stmt := table.Users.INSERT().MODEL(user).RETURNING(table.Users.AllColumns)
|
||||
tbl := table.Users
|
||||
stmt := tbl.INSERT(tbl.AllColumns.Except(tbl.ID)).MODEL(user).RETURNING(tbl.AllColumns)
|
||||
log.Debug(stmt.DebugSql())
|
||||
|
||||
// get tx from context
|
||||
|
||||
var item model.Users
|
||||
if err := stmt.QueryContext(ctx, db.FromContext(ctx, svc.db), &item); err != nil {
|
||||
var userModel model.Users
|
||||
err := stmt.QueryContext(ctx, db.FromContext(ctx, svc.db), &userModel)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create user")
|
||||
}
|
||||
|
||||
return &item, nil
|
||||
return &userModel, nil
|
||||
}
|
||||
|
||||
// GetTenantByID
|
||||
@@ -152,7 +183,10 @@ func (svc *Service) CreateTenantUser(ctx context.Context, userID, tenantID int64
|
||||
).VALUES(
|
||||
Int64(userID),
|
||||
Int64(tenantID),
|
||||
)
|
||||
).ON_CONFLICT(
|
||||
table.UsersTenants.UserID,
|
||||
table.UsersTenants.TenantID,
|
||||
).DO_NOTHING()
|
||||
log.Debug(stmt.DebugSql())
|
||||
|
||||
if _, err := stmt.ExecContext(ctx, db.FromContext(ctx, svc.db)); err != nil {
|
||||
|
||||
67
backend/modules/users/service_test.go
Normal file
67
backend/modules/users/service_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package users
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"backend/fixtures"
|
||||
dbUtil "backend/pkg/db"
|
||||
"backend/pkg/pg"
|
||||
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
)
|
||||
|
||||
func TestService_GetOrNew(t *testing.T) {
|
||||
FocusConvey("Test GetOrNew", t, func() {
|
||||
// So(dbUtil.TruncateAllTables(context.TODO(), db, "users", "users_tenants"), ShouldBeNil)
|
||||
db, err := fixtures.GetDB()
|
||||
So(err, ShouldBeNil)
|
||||
defer db.Close()
|
||||
|
||||
Convey("Test GetOrNew", func() {
|
||||
svc := &Service{db: db}
|
||||
So(svc.Prepare(), ShouldBeNil)
|
||||
|
||||
user, err := svc.GetByOpenID(context.Background(), "hello")
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
So(user, ShouldNotBeNil)
|
||||
So(user.OpenID, ShouldEqual, "hello")
|
||||
})
|
||||
|
||||
FocusConvey("Test GetOrNew", func() {
|
||||
svc := &Service{db: db}
|
||||
So(svc.Prepare(), ShouldBeNil)
|
||||
|
||||
openid := "test_openid"
|
||||
authInfo := pg.UserOAuth{
|
||||
AccessToken: "test_access_token",
|
||||
}
|
||||
|
||||
user, err := svc.GetOrNew(context.Background(), 1, openid, authInfo)
|
||||
So(err, ShouldBeNil)
|
||||
So(user.OpenID, ShouldEqual, openid)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestService_CreateTenantUser(t *testing.T) {
|
||||
FocusConvey("Test CreateTenantUser", t, func() {
|
||||
db, err := fixtures.GetDB()
|
||||
So(err, ShouldBeNil)
|
||||
defer db.Close()
|
||||
|
||||
So(dbUtil.TruncateAllTables(context.TODO(), db, "users", "users_tenants"), ShouldBeNil)
|
||||
|
||||
FocusConvey("Test Create", func() {
|
||||
svc := &Service{db: db}
|
||||
So(svc.Prepare(), ShouldBeNil)
|
||||
|
||||
err := svc.CreateTenantUser(context.Background(), 1, 1)
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
err = svc.CreateTenantUser(context.Background(), 1, 1)
|
||||
So(err, ShouldBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user