feat: complete login

This commit is contained in:
yanghao05
2025-04-15 21:20:04 +08:00
parent 45a0b6848a
commit ca08568e1a
23 changed files with 842 additions and 28 deletions

112
backend/app/http/auth.go Normal file
View File

@@ -0,0 +1,112 @@
package http
import (
"fmt"
"net/url"
"time"
"quyun/app/models"
"quyun/database/fields"
"quyun/database/schemas/public/model"
"quyun/providers/jwt"
"quyun/providers/wechat"
"github.com/go-jet/jet/v2/qrm"
"github.com/gofiber/fiber/v3"
gonanoid "github.com/matoous/go-nanoid/v2"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
)
const (
StatePrefix = "sns_basic_auth"
salt = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
)
// @provider
type auth struct {
wechat *wechat.Client
jwt *jwt.JWT
}
// @Router /auth/login [get]
// @Bind code query
// @Bind state query
// @Bind redirect query
func (ctl *auth) Login(ctx fiber.Ctx, code, state, redirect string) error {
log.Debugf("code: %s, state: %s", code, state)
// get the openid
token, err := ctl.wechat.AuthorizeCode2Token(code)
if err != nil {
return errors.Wrap(err, "failed to get openid")
}
log.Debugf("tokenInfo %+v", token)
authUserInfo, err := ctl.wechat.AuthorizeUserInfo(token.AccessToken, token.Openid)
if err != nil {
return errors.Wrap(err, "failed to get user info")
}
log.Debugf("Auth User Info: %+v", authUserInfo)
user, err := models.Users.GetUserByOpenID(ctx.Context(), token.Openid)
if err != nil {
if errors.Is(err, qrm.ErrNoRows) {
// Create User
model := &model.Users{
Status: fields.UserStatusOk,
OpenID: token.GetOpenID(),
Username: fmt.Sprintf("u_%s", gonanoid.MustGenerate(salt, 8)),
Avatar: nil,
}
if err := models.Users.Create(ctx.Context(), model); err != nil {
return errors.Wrap(err, "failed to create user")
}
} else {
return errors.Wrap(err, "failed to get user")
}
}
jwtToken, err := ctl.jwt.CreateToken(ctl.jwt.CreateClaims(jwt.BaseClaims{UserID: user.ID}))
if err != nil {
return errors.Wrap(err, "failed to create token")
}
ctx.Cookie(&fiber.Cookie{
Name: "token",
Value: jwtToken,
Expires: time.Now().Add(6 * time.Hour),
HTTPOnly: true,
})
return ctx.Redirect().To(redirect)
}
// @Router /auth/wechat [get]
// @Bind redirect query
func (ctl *auth) Wechat(ctx fiber.Ctx, redirect string) error {
log.Debugf("%s, query: %v", ctx.OriginalURL(), ctx.Queries())
// 添加 redirect 参数
u, err := url.Parse(string(ctx.Request().URI().FullURI()))
if err != nil {
return err
}
query := u.Query()
query.Set("redirect", redirect)
u.RawQuery = query.Encode()
u.Path = "/auth/login"
fullUrl := u.String()
log.Debug("redirect_uri: ", fullUrl)
to, err := ctl.wechat.ScopeAuthorizeURL(
wechat.ScopeAuthorizeURLWithRedirectURI(fullUrl),
)
if err != nil {
return errors.Wrap(err, "failed to get wechat auth url")
}
return ctx.Redirect().To(to.String())
}

View File

@@ -15,6 +15,7 @@ import (
log "github.com/sirupsen/logrus"
)
// @provider
type pays struct {
wepay *wepay.Client
job *job.Job

View File

@@ -27,7 +27,8 @@ type posts struct {
// @Router /posts [get]
// @Bind pagination query
// @Bind query query
func (ctl *posts) List(ctx fiber.Ctx, pagination *requests.Pagination, query *ListQuery) (*requests.Pager, error) {
// @Bind user local
func (ctl *posts) List(ctx fiber.Ctx, pagination *requests.Pagination, query *ListQuery, user *model.Users) (*requests.Pager, error) {
cond := models.Posts.BuildConditionWithKey(query.Keyword)
return models.Posts.List(ctx.Context(), pagination, cond)
}

View File

@@ -1,6 +1,9 @@
package http
import (
"quyun/providers/job"
"quyun/providers/jwt"
"quyun/providers/wechat"
"quyun/providers/wepay"
"go.ipao.vip/atom"
@@ -10,6 +13,32 @@ import (
)
func Provide(opts ...opt.Option) error {
if err := container.Container.Provide(func(
jwt *jwt.JWT,
wechat *wechat.Client,
) (*auth, error) {
obj := &auth{
jwt: jwt,
wechat: wechat,
}
return obj, nil
}); err != nil {
return err
}
if err := container.Container.Provide(func(
job *job.Job,
wepay *wepay.Client,
) (*pays, error) {
obj := &pays{
job: job,
wepay: wepay,
}
return obj, nil
}); err != nil {
return err
}
if err := container.Container.Provide(func(
wepay *wepay.Client,
) (*posts, error) {
@@ -22,10 +51,12 @@ func Provide(opts ...opt.Option) error {
return err
}
if err := container.Container.Provide(func(
auth *auth,
pays *pays,
posts *posts,
) (contracts.HttpRoute, error) {
obj := &Routes{
auth: auth,
pays: pays,
posts: posts,
}

View File

@@ -9,11 +9,13 @@ import (
_ "go.ipao.vip/atom/contracts"
. "go.ipao.vip/atom/fen"
"quyun/app/requests"
"quyun/database/schemas/public/model"
)
// @provider contracts.HttpRoute atom.GroupRoutes
type Routes struct {
log *log.Entry `inject:"false"`
auth *auth
pays *pays
posts *posts
}
@@ -28,6 +30,19 @@ func (r *Routes) Name() string {
}
func (r *Routes) Register(router fiber.Router) {
// 注册路由组: auth
router.Get("/auth/login", Func3(
r.auth.Login,
QueryParam[string]("code"),
QueryParam[string]("state"),
QueryParam[string]("redirect"),
))
router.Get("/auth/wechat", Func1(
r.auth.Wechat,
QueryParam[string]("redirect"),
))
// 注册路由组: pays
router.Get("/pay/callback/:channel", Func1(
r.pays.Callback,
@@ -35,10 +50,11 @@ func (r *Routes) Register(router fiber.Router) {
))
// 注册路由组: posts
router.Get("/posts", DataFunc2(
router.Get("/posts", DataFunc3(
r.posts.List,
Query[requests.Pagination]("pagination"),
Query[ListQuery]("query"),
Local[*model.Users]("user"),
))
router.Get("/show/:id", DataFunc1(

View File

@@ -0,0 +1,57 @@
package middlewares
import (
"net/url"
"strings"
"quyun/app/models"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/log"
)
func (f *Middlewares) Auth(ctx fiber.Ctx) error {
if strings.HasPrefix(ctx.Path(), "/admin/") {
return ctx.Next()
}
if strings.HasPrefix(ctx.Path(), "/auth/") {
return ctx.Next()
}
fullUrl := string(ctx.Request().URI().FullURI())
u, err := url.Parse(fullUrl)
if err != nil {
return err
}
query := u.Query()
query.Set("redirect", fullUrl)
u.RawQuery = query.Encode()
u.Path = "/auth/wechat"
fullUrl = u.String()
// check cookie exists
cookie := ctx.Cookies("token")
log.Infof("cookie: %s", cookie)
if cookie == "" {
log.Infof("auth redirect_uri: %s", fullUrl)
return ctx.Redirect().To(fullUrl)
}
jwt, err := f.jwt.Parse(cookie)
if err != nil {
// remove cookie
ctx.ClearCookie("token")
return ctx.Redirect().To(fullUrl)
}
user, err := models.Users.GetByID(ctx.Context(), jwt.UserID)
if err != nil {
// remove cookie
ctx.ClearCookie("token")
return ctx.Redirect().To(fullUrl)
}
ctx.Locals("user", user)
return ctx.Next()
}

View File

@@ -2,8 +2,13 @@ package middlewares
import (
"github.com/gofiber/fiber/v3"
log "github.com/sirupsen/logrus"
)
func (f *Middlewares) DebugMode(c fiber.Ctx) error {
return c.Next()
func (f *Middlewares) DebugMode(ctx fiber.Ctx) error {
log.Infof("c.Path: %s", ctx.Path())
log.Infof("Request Method: %s", ctx.Method())
log.Infof("FullURL: %s", ctx.Request().URI().FullURI())
return ctx.Next()
}

View File

@@ -1,12 +1,15 @@
package middlewares
import (
"quyun/providers/jwt"
log "github.com/sirupsen/logrus"
)
// @provider
type Middlewares struct {
log *log.Entry `inject:"false"`
jwt *jwt.JWT
}
func (f *Middlewares) Prepare() error {

View File

@@ -1,13 +1,19 @@
package middlewares
import (
"quyun/providers/jwt"
"go.ipao.vip/atom/container"
"go.ipao.vip/atom/opt"
)
func Provide(opts ...opt.Option) error {
if err := container.Container.Provide(func() (*Middlewares, error) {
obj := &Middlewares{}
if err := container.Container.Provide(func(
jwt *jwt.JWT,
) (*Middlewares, error) {
obj := &Middlewares{
jwt: jwt,
}
if err := obj.Prepare(); err != nil {
return nil, err
}

View File

@@ -227,3 +227,23 @@ func (m *usersModel) PostList(ctx context.Context, userId int64, pagination *req
Pagination: *pagination,
}, nil
}
// GetUserIDByOpenID
func (m *usersModel) GetUserByOpenID(ctx context.Context, openID string) (*model.Users, error) {
tbl := table.Users
stmt := tbl.
SELECT(tbl.AllColumns).
WHERE(
tbl.OpenID.EQ(String(openID)),
)
m.log.Infof("sql: %s", stmt.DebugSql())
var user model.Users
if err := stmt.QueryContext(ctx, db, &user); err != nil {
m.log.Errorf("error querying user by OpenID: %v", err)
return nil, err
}
return &user, nil
}

View File

@@ -6,6 +6,7 @@ import (
"quyun/app/errorx"
appHttp "quyun/app/http"
"quyun/app/jobs"
"quyun/app/middlewares"
"quyun/app/models"
"quyun/app/service"
_ "quyun/docs"
@@ -32,8 +33,8 @@ import (
func defaultProviders() container.Providers {
return service.Default(container.Providers{
wechat.DefaultProvider(),
ali.DefaultProvider(),
wechat.DefaultProvider(),
wepay.DefaultProvider(),
http.DefaultProvider(),
postgres.DefaultProvider(),
@@ -53,6 +54,7 @@ func Command() atom.Option {
With(
jobs.Provide,
models.Provide,
middlewares.Provide,
).
WithProviders(
appHttp.Providers(),
@@ -66,10 +68,11 @@ type Service struct {
Initials []contracts.Initial `group:"initials"`
App *app.Config
Job *job.Job
Http *http.Service
Routes []contracts.HttpRoute `group:"routes"`
App *app.Config
Job *job.Job
Http *http.Service
Middlewares *middlewares.Middlewares
Routes []contracts.HttpRoute `group:"routes"`
}
func Serve(cmd *cobra.Command, args []string) error {
@@ -82,6 +85,9 @@ func Serve(cmd *cobra.Command, args []string) error {
svc.Http.Engine.Get("/swagger/*", swagger.HandlerDefault)
}
svc.Http.Engine.Use(errorx.Middleware)
svc.Http.Engine.Use(svc.Middlewares.DebugMode)
svc.Http.Engine.Use(svc.Middlewares.Auth)
svc.Http.Engine.Use(favicon.New(favicon.Config{
Data: []byte{},
}))