feat: complete login
This commit is contained in:
112
backend/app/http/auth.go
Normal file
112
backend/app/http/auth.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"quyun/app/models"
|
||||
"quyun/database/fields"
|
||||
"quyun/database/schemas/public/model"
|
||||
"quyun/providers/jwt"
|
||||
"quyun/providers/wechat"
|
||||
|
||||
"github.com/go-jet/jet/v2/qrm"
|
||||
"github.com/gofiber/fiber/v3"
|
||||
gonanoid "github.com/matoous/go-nanoid/v2"
|
||||
"github.com/pkg/errors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
StatePrefix = "sns_basic_auth"
|
||||
salt = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
)
|
||||
|
||||
// @provider
|
||||
type auth struct {
|
||||
wechat *wechat.Client
|
||||
jwt *jwt.JWT
|
||||
}
|
||||
|
||||
// @Router /auth/login [get]
|
||||
// @Bind code query
|
||||
// @Bind state query
|
||||
// @Bind redirect query
|
||||
func (ctl *auth) Login(ctx fiber.Ctx, code, state, redirect string) error {
|
||||
log.Debugf("code: %s, state: %s", code, state)
|
||||
|
||||
// get the openid
|
||||
token, err := ctl.wechat.AuthorizeCode2Token(code)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get openid")
|
||||
}
|
||||
log.Debugf("tokenInfo %+v", token)
|
||||
|
||||
authUserInfo, err := ctl.wechat.AuthorizeUserInfo(token.AccessToken, token.Openid)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get user info")
|
||||
}
|
||||
|
||||
log.Debugf("Auth User Info: %+v", authUserInfo)
|
||||
|
||||
user, err := models.Users.GetUserByOpenID(ctx.Context(), token.Openid)
|
||||
if err != nil {
|
||||
if errors.Is(err, qrm.ErrNoRows) {
|
||||
// Create User
|
||||
model := &model.Users{
|
||||
Status: fields.UserStatusOk,
|
||||
OpenID: token.GetOpenID(),
|
||||
Username: fmt.Sprintf("u_%s", gonanoid.MustGenerate(salt, 8)),
|
||||
Avatar: nil,
|
||||
}
|
||||
if err := models.Users.Create(ctx.Context(), model); err != nil {
|
||||
return errors.Wrap(err, "failed to create user")
|
||||
}
|
||||
} else {
|
||||
return errors.Wrap(err, "failed to get user")
|
||||
}
|
||||
}
|
||||
|
||||
jwtToken, err := ctl.jwt.CreateToken(ctl.jwt.CreateClaims(jwt.BaseClaims{UserID: user.ID}))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create token")
|
||||
}
|
||||
|
||||
ctx.Cookie(&fiber.Cookie{
|
||||
Name: "token",
|
||||
Value: jwtToken,
|
||||
Expires: time.Now().Add(6 * time.Hour),
|
||||
HTTPOnly: true,
|
||||
})
|
||||
|
||||
return ctx.Redirect().To(redirect)
|
||||
}
|
||||
|
||||
// @Router /auth/wechat [get]
|
||||
// @Bind redirect query
|
||||
func (ctl *auth) Wechat(ctx fiber.Ctx, redirect string) error {
|
||||
log.Debugf("%s, query: %v", ctx.OriginalURL(), ctx.Queries())
|
||||
|
||||
// 添加 redirect 参数
|
||||
u, err := url.Parse(string(ctx.Request().URI().FullURI()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
query := u.Query()
|
||||
query.Set("redirect", redirect)
|
||||
u.RawQuery = query.Encode()
|
||||
u.Path = "/auth/login"
|
||||
fullUrl := u.String()
|
||||
|
||||
log.Debug("redirect_uri: ", fullUrl)
|
||||
|
||||
to, err := ctl.wechat.ScopeAuthorizeURL(
|
||||
wechat.ScopeAuthorizeURLWithRedirectURI(fullUrl),
|
||||
)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get wechat auth url")
|
||||
}
|
||||
|
||||
return ctx.Redirect().To(to.String())
|
||||
}
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// @provider
|
||||
type pays struct {
|
||||
wepay *wepay.Client
|
||||
job *job.Job
|
||||
|
||||
@@ -27,7 +27,8 @@ type posts struct {
|
||||
// @Router /posts [get]
|
||||
// @Bind pagination query
|
||||
// @Bind query query
|
||||
func (ctl *posts) List(ctx fiber.Ctx, pagination *requests.Pagination, query *ListQuery) (*requests.Pager, error) {
|
||||
// @Bind user local
|
||||
func (ctl *posts) List(ctx fiber.Ctx, pagination *requests.Pagination, query *ListQuery, user *model.Users) (*requests.Pager, error) {
|
||||
cond := models.Posts.BuildConditionWithKey(query.Keyword)
|
||||
return models.Posts.List(ctx.Context(), pagination, cond)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"quyun/providers/job"
|
||||
"quyun/providers/jwt"
|
||||
"quyun/providers/wechat"
|
||||
"quyun/providers/wepay"
|
||||
|
||||
"go.ipao.vip/atom"
|
||||
@@ -10,6 +13,32 @@ import (
|
||||
)
|
||||
|
||||
func Provide(opts ...opt.Option) error {
|
||||
if err := container.Container.Provide(func(
|
||||
jwt *jwt.JWT,
|
||||
wechat *wechat.Client,
|
||||
) (*auth, error) {
|
||||
obj := &auth{
|
||||
jwt: jwt,
|
||||
wechat: wechat,
|
||||
}
|
||||
|
||||
return obj, nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := container.Container.Provide(func(
|
||||
job *job.Job,
|
||||
wepay *wepay.Client,
|
||||
) (*pays, error) {
|
||||
obj := &pays{
|
||||
job: job,
|
||||
wepay: wepay,
|
||||
}
|
||||
|
||||
return obj, nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := container.Container.Provide(func(
|
||||
wepay *wepay.Client,
|
||||
) (*posts, error) {
|
||||
@@ -22,10 +51,12 @@ func Provide(opts ...opt.Option) error {
|
||||
return err
|
||||
}
|
||||
if err := container.Container.Provide(func(
|
||||
auth *auth,
|
||||
pays *pays,
|
||||
posts *posts,
|
||||
) (contracts.HttpRoute, error) {
|
||||
obj := &Routes{
|
||||
auth: auth,
|
||||
pays: pays,
|
||||
posts: posts,
|
||||
}
|
||||
|
||||
@@ -9,11 +9,13 @@ import (
|
||||
_ "go.ipao.vip/atom/contracts"
|
||||
. "go.ipao.vip/atom/fen"
|
||||
"quyun/app/requests"
|
||||
"quyun/database/schemas/public/model"
|
||||
)
|
||||
|
||||
// @provider contracts.HttpRoute atom.GroupRoutes
|
||||
type Routes struct {
|
||||
log *log.Entry `inject:"false"`
|
||||
auth *auth
|
||||
pays *pays
|
||||
posts *posts
|
||||
}
|
||||
@@ -28,6 +30,19 @@ func (r *Routes) Name() string {
|
||||
}
|
||||
|
||||
func (r *Routes) Register(router fiber.Router) {
|
||||
// 注册路由组: auth
|
||||
router.Get("/auth/login", Func3(
|
||||
r.auth.Login,
|
||||
QueryParam[string]("code"),
|
||||
QueryParam[string]("state"),
|
||||
QueryParam[string]("redirect"),
|
||||
))
|
||||
|
||||
router.Get("/auth/wechat", Func1(
|
||||
r.auth.Wechat,
|
||||
QueryParam[string]("redirect"),
|
||||
))
|
||||
|
||||
// 注册路由组: pays
|
||||
router.Get("/pay/callback/:channel", Func1(
|
||||
r.pays.Callback,
|
||||
@@ -35,10 +50,11 @@ func (r *Routes) Register(router fiber.Router) {
|
||||
))
|
||||
|
||||
// 注册路由组: posts
|
||||
router.Get("/posts", DataFunc2(
|
||||
router.Get("/posts", DataFunc3(
|
||||
r.posts.List,
|
||||
Query[requests.Pagination]("pagination"),
|
||||
Query[ListQuery]("query"),
|
||||
Local[*model.Users]("user"),
|
||||
))
|
||||
|
||||
router.Get("/show/:id", DataFunc1(
|
||||
|
||||
Reference in New Issue
Block a user