257 lines
6.5 KiB
Go
257 lines
6.5 KiB
Go
package users
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"time"
|
|
|
|
"backend/app/events/publishers"
|
|
"backend/database/fields"
|
|
"backend/database/models/qvyun_v2/public/model"
|
|
"backend/database/models/qvyun_v2/public/table"
|
|
"backend/pkg/oauth"
|
|
"backend/providers/event"
|
|
"backend/providers/otel"
|
|
|
|
. "github.com/go-jet/jet/v2/postgres"
|
|
"github.com/pkg/errors"
|
|
"github.com/samber/lo"
|
|
log "github.com/sirupsen/logrus"
|
|
"go.opentelemetry.io/otel/attribute"
|
|
semconv "go.opentelemetry.io/otel/semconv/v1.15.0"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
// @provider:except
|
|
type Service struct {
|
|
db *sql.DB
|
|
event *event.PubSub
|
|
log *log.Entry `inject:"false"`
|
|
}
|
|
|
|
func (svc *Service) Prepare() error {
|
|
svc.log = log.WithField("module", "users.service")
|
|
_ = Int(1)
|
|
return nil
|
|
}
|
|
|
|
// GetUsersByOpenID Get user by open id
|
|
func (svc *Service) GetUserByOpenIDOfChannel(ctx context.Context, channel fields.AuthChannel, openID string) (*model.Users, error) {
|
|
_, span := otel.Start(ctx, "users.service.GetUsersByOpenID")
|
|
defer span.End()
|
|
|
|
userId, err := svc.GetUserIDByOpenID(ctx, channel, openID)
|
|
if err != nil {
|
|
// span 添加用户不存在事件
|
|
span.AddEvent("user not found")
|
|
return nil, err
|
|
}
|
|
|
|
tbl := table.Users
|
|
stmt := tbl.SELECT(tbl.AllColumns).WHERE(tbl.ID.EQ(Int64(userId)))
|
|
span.SetAttributes(semconv.DBStatementKey.String(stmt.DebugSql()))
|
|
|
|
var user model.Users
|
|
if err := stmt.QueryContext(ctx, svc.db, &user); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
func (svc *Service) GetUserIDByOpenID(ctx context.Context, channel fields.AuthChannel, openID string) (int64, error) {
|
|
_, span := otel.Start(ctx, "users.service.GetUserIDByOpenID")
|
|
defer span.End()
|
|
|
|
span.SetAttributes(
|
|
attribute.String("channel", channel.String()),
|
|
attribute.String("openID", openID),
|
|
)
|
|
|
|
tbl := table.UserOauths
|
|
|
|
stmt := tbl.
|
|
SELECT(tbl.UserID.AS("user_id")).
|
|
WHERE(
|
|
tbl.Channel.EQ(Int16(int16(channel))).
|
|
AND(tbl.OpenID.EQ(String(openID))),
|
|
)
|
|
span.SetAttributes(semconv.DBStatementKey.String(stmt.DebugSql()))
|
|
|
|
var result struct {
|
|
UserID int64
|
|
}
|
|
|
|
if err := stmt.QueryContext(ctx, svc.db, &result); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return result.UserID, nil
|
|
}
|
|
|
|
// CreateUser
|
|
func (svc *Service) CreateUser(ctx context.Context, user *model.Users) (*model.Users, error) {
|
|
_, span := otel.Start(ctx, "users.service.CreateUser")
|
|
defer span.End()
|
|
span.SetAttributes(
|
|
attribute.String("user.username", user.Username),
|
|
attribute.String("user.email", user.Email),
|
|
attribute.String("user.phone", user.Phone),
|
|
)
|
|
|
|
if user.CreatedAt.IsZero() {
|
|
user.CreatedAt = time.Now()
|
|
}
|
|
|
|
if user.UpdatedAt.IsZero() {
|
|
user.UpdatedAt = time.Now()
|
|
}
|
|
|
|
user.Status = fields.UserStatusPending
|
|
|
|
// use bcrypt to hash password
|
|
pwd, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
user.Password = string(pwd)
|
|
|
|
tbl := table.Users
|
|
stmt := tbl.INSERT(tbl.MutableColumns).MODEL(user).ON_CONFLICT(tbl.Email, tbl.Phone, tbl.Username).DO_NOTHING().RETURNING(tbl.AllColumns)
|
|
span.SetAttributes(semconv.DBStatementKey.String(stmt.DebugSql()))
|
|
|
|
var m model.Users
|
|
if err = stmt.QueryContext(ctx, svc.db, &m); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// if user created successfully, trigger event
|
|
span.AddEvent("user created")
|
|
if err := svc.event.Publish(&publishers.UserRegister{ID: m.ID}); err != nil {
|
|
return nil, errors.Wrapf(err, "failed to publish user register event %d", m.ID)
|
|
}
|
|
|
|
return &m, nil
|
|
}
|
|
|
|
// GetUserByID
|
|
func (svc *Service) GetUserByID(ctx context.Context, userID int64) (*model.Users, error) {
|
|
_, span := otel.Start(ctx, "users.service.GetUserByID")
|
|
defer span.End()
|
|
span.SetAttributes(
|
|
attribute.Int64("user.id", userID),
|
|
)
|
|
|
|
tbl := table.Users
|
|
stmt := tbl.SELECT(tbl.AllColumns).WHERE(tbl.ID.EQ(Int64(userID)))
|
|
span.SetAttributes(semconv.DBStatementKey.String(stmt.DebugSql()))
|
|
|
|
var user model.Users
|
|
if err := stmt.QueryContext(ctx, svc.db, &user); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
// AttachUserOAuth
|
|
func (svc *Service) AttachUserOAuth(ctx context.Context, user *model.Users, channel fields.AuthChannel, oauthInfo oauth.OAuthInfo) error {
|
|
_, span := otel.Start(ctx, "users.service.AttachUserOAuth")
|
|
defer span.End()
|
|
span.SetAttributes(
|
|
attribute.Int64("user.id", user.ID),
|
|
attribute.String("channel", channel.String()),
|
|
attribute.String("openID", oauthInfo.GetOpenID()),
|
|
)
|
|
|
|
m := &model.UserOauths{
|
|
ID: 0,
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
DeletedAt: nil,
|
|
Channel: channel,
|
|
UserID: user.ID,
|
|
UnionID: lo.ToPtr(oauthInfo.GetUnionID()),
|
|
OpenID: oauthInfo.GetOpenID(),
|
|
AccessToken: oauthInfo.GetAccessToken(),
|
|
RefreshToken: oauthInfo.GetRefreshToken(),
|
|
ExpireAt: oauthInfo.GetExpiredAt(),
|
|
Meta: new(string),
|
|
}
|
|
|
|
tbl := table.UserOauths
|
|
stmt := tbl.
|
|
INSERT(tbl.MutableColumns).
|
|
MODEL(m).
|
|
ON_CONFLICT(tbl.Channel, tbl.UserID).
|
|
DO_UPDATE(
|
|
SET(
|
|
tbl.UnionID.SET(String(oauthInfo.GetUnionID())),
|
|
tbl.OpenID.SET(String(oauthInfo.GetOpenID())),
|
|
tbl.AccessToken.SET(String(oauthInfo.GetAccessToken())),
|
|
tbl.RefreshToken.SET(String(oauthInfo.GetRefreshToken())),
|
|
tbl.ExpireAt.SET(TimestampT(oauthInfo.GetExpiredAt())),
|
|
),
|
|
)
|
|
|
|
span.SetAttributes(semconv.DBStatementKey.String(stmt.DebugSql()))
|
|
|
|
if _, err := stmt.ExecContext(ctx, svc.db); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SetUserStatusByID
|
|
func (svc *Service) SetUserStatusByID(ctx context.Context, userID int64, status fields.UserStatus) error {
|
|
_, span := otel.Start(ctx, "users.service.SetUserStatusByID")
|
|
defer span.End()
|
|
span.SetAttributes(
|
|
attribute.Int64("user.id", userID),
|
|
attribute.String("user.status", status.String()),
|
|
)
|
|
|
|
tbl := table.Users
|
|
stmt := tbl.
|
|
UPDATE().
|
|
SET(
|
|
tbl.Status.SET(Int16(int16(status))),
|
|
).
|
|
WHERE(
|
|
tbl.ID.EQ(Int64(userID)),
|
|
)
|
|
span.SetAttributes(semconv.DBStatementKey.String(stmt.DebugSql()))
|
|
|
|
if _, err := stmt.ExecContext(ctx, svc.db); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetUserOAuthChannels
|
|
func (svc *Service) GetUserOAuthChannels(ctx context.Context, userID int64) ([]model.UserOauths, error) {
|
|
_, span := otel.Start(ctx, "users.service.GetUserOAuthChannels")
|
|
defer span.End()
|
|
span.SetAttributes(
|
|
attribute.Int64("user.id", userID),
|
|
)
|
|
|
|
tbl := table.UserOauths
|
|
stmt := tbl.
|
|
SELECT(
|
|
tbl.Channel,
|
|
tbl.ExpireAt,
|
|
).
|
|
WHERE(
|
|
tbl.UserID.EQ(Int64(userID)),
|
|
)
|
|
span.SetAttributes(semconv.DBStatementKey.String(stmt.DebugSql()))
|
|
|
|
var oauths []model.UserOauths
|
|
if err := stmt.QueryContext(ctx, svc.db, &oauths); err != nil {
|
|
return nil, err
|
|
}
|
|
return oauths, nil
|
|
}
|