211 lines
4.9 KiB
Go
211 lines
4.9 KiB
Go
package middlewares
|
|
|
|
import (
|
|
"errors"
|
|
"strings"
|
|
|
|
"quyun/v2/app/errorx"
|
|
"quyun/v2/app/services"
|
|
"quyun/v2/database/models"
|
|
"quyun/v2/pkg/consts"
|
|
"quyun/v2/providers/jwt"
|
|
|
|
"github.com/gofiber/fiber/v3"
|
|
log "github.com/sirupsen/logrus"
|
|
"go.ipao.vip/gen/types"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// Middlewares provides reusable Fiber middlewares shared across modules.
|
|
//
|
|
// @provider
|
|
type Middlewares struct {
|
|
// log is the module logger injected by the framework.
|
|
log *log.Entry `inject:"false"`
|
|
// jwt is the JWT provider used by auth-related middlewares.
|
|
jwt *jwt.JWT
|
|
}
|
|
|
|
func (f *Middlewares) Prepare() error {
|
|
f.log = log.WithField("module", "middleware")
|
|
return nil
|
|
}
|
|
|
|
func (m *Middlewares) AuthOptional(ctx fiber.Ctx) error {
|
|
return m.authenticate(ctx, false)
|
|
}
|
|
|
|
func (m *Middlewares) AuthRequired(ctx fiber.Ctx) error {
|
|
if isPublicRoute(ctx) {
|
|
return m.AuthOptional(ctx)
|
|
}
|
|
return m.authenticate(ctx, true)
|
|
}
|
|
|
|
func (m *Middlewares) authenticate(ctx fiber.Ctx, requireToken bool) error {
|
|
authHeader := ctx.Get("Authorization")
|
|
if authHeader == "" {
|
|
if requireToken {
|
|
return errorx.ErrUnauthorized.WithMsg("Missing token")
|
|
}
|
|
return ctx.Next()
|
|
}
|
|
|
|
claims, err := m.jwt.Parse(authHeader)
|
|
if err != nil {
|
|
return errorx.ErrUnauthorized.WithCause(err).WithMsg("Invalid token")
|
|
}
|
|
|
|
// 获取用户信息,确保 token 与账号状态一致。
|
|
user, err := services.User.GetModelByID(ctx, claims.UserID)
|
|
if err != nil {
|
|
return errorx.ErrUnauthorized.WithCause(err).WithMsg("UserNotFound")
|
|
}
|
|
if user.Status == consts.UserStatusBanned {
|
|
return errorx.ErrAccountDisabled
|
|
}
|
|
|
|
// Set Context
|
|
ctx.Locals(consts.CtxKeyUser, user)
|
|
|
|
if tenant := ctx.Locals(consts.CtxKeyTenant); tenant != nil {
|
|
if model, ok := tenant.(*models.Tenant); ok && claims.TenantID > 0 && model.ID != claims.TenantID {
|
|
return errorx.ErrForbidden.WithMsg("租户不匹配")
|
|
}
|
|
} else if claims.TenantID > 0 {
|
|
tenantModel, err := services.Tenant.GetModelByID(ctx, claims.TenantID)
|
|
if err != nil {
|
|
return errorx.ErrUnauthorized.WithCause(err).WithMsg("TenantNotFound")
|
|
}
|
|
ctx.Locals(consts.CtxKeyTenant, tenantModel)
|
|
}
|
|
|
|
return ctx.Next()
|
|
}
|
|
|
|
func (m *Middlewares) SuperAuth(ctx fiber.Ctx) error {
|
|
if isSuperPublicRoute(ctx) {
|
|
return ctx.Next()
|
|
}
|
|
authHeader := ctx.Get("Authorization")
|
|
if authHeader == "" {
|
|
return errorx.ErrUnauthorized.WithMsg("Missing token")
|
|
}
|
|
|
|
claims, err := m.jwt.Parse(authHeader)
|
|
if err != nil {
|
|
return errorx.ErrUnauthorized.WithCause(err).WithMsg("Invalid token")
|
|
}
|
|
|
|
user, err := services.User.GetModelByID(ctx, claims.UserID)
|
|
if err != nil {
|
|
return errorx.ErrUnauthorized.WithCause(err).WithMsg("UserNotFound")
|
|
}
|
|
if user.Status == consts.UserStatusBanned {
|
|
return errorx.ErrAccountDisabled
|
|
}
|
|
|
|
if !hasRole(user.Roles, consts.RoleSuperAdmin) {
|
|
return errorx.ErrForbidden.WithMsg("无权限访问")
|
|
}
|
|
|
|
ctx.Locals(consts.CtxKeyUser, user)
|
|
return ctx.Next()
|
|
}
|
|
|
|
func (m *Middlewares) TenantResolver(ctx fiber.Ctx) error {
|
|
tenantCode := strings.TrimSpace(ctx.Params("tenantCode"))
|
|
if tenantCode == "" {
|
|
return errorx.ErrMissingParameter.WithMsg("缺少租户编码")
|
|
}
|
|
|
|
tbl, q := models.TenantQuery.QueryContext(ctx)
|
|
tenant, err := q.Where(tbl.Code.Eq(tenantCode)).First()
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return errorx.ErrRecordNotFound.WithMsg("租户不存在")
|
|
}
|
|
return errorx.ErrDatabaseError.WithCause(err)
|
|
}
|
|
|
|
ctx.Locals(consts.CtxKeyTenant, tenant)
|
|
return ctx.Next()
|
|
}
|
|
|
|
func hasRole(roles types.Array[consts.Role], role consts.Role) bool {
|
|
for _, r := range roles {
|
|
if r == role {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func isPublicRoute(ctx fiber.Ctx) bool {
|
|
path := normalizeTenantPath(ctx.Path())
|
|
method := ctx.Method()
|
|
|
|
if method == fiber.MethodGet {
|
|
switch path {
|
|
case "/v1/common/options", "/v1/contents", "/v1/topics", "/v1/tenants":
|
|
return true
|
|
}
|
|
if strings.HasPrefix(path, "/v1/contents/") {
|
|
return true
|
|
}
|
|
if strings.HasPrefix(path, "/v1/creators/") && strings.HasSuffix(path, "/contents") {
|
|
return true
|
|
}
|
|
if strings.HasPrefix(path, "/v1/tenants/") {
|
|
return true
|
|
}
|
|
if strings.HasPrefix(path, "/v1/storage/") {
|
|
return true
|
|
}
|
|
}
|
|
|
|
if method == fiber.MethodPost && path == "/v1/webhook/payment/notify" {
|
|
return true
|
|
}
|
|
|
|
if method == fiber.MethodPost {
|
|
switch path {
|
|
case "/v1/auth/otp", "/v1/auth/login":
|
|
return true
|
|
}
|
|
}
|
|
|
|
if method == fiber.MethodPut && strings.HasPrefix(path, "/v1/storage/") {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func isSuperPublicRoute(ctx fiber.Ctx) bool {
|
|
path := ctx.Path()
|
|
method := ctx.Method()
|
|
|
|
if method == fiber.MethodPost && path == "/super/v1/auth/login" {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func normalizeTenantPath(path string) string {
|
|
if !strings.HasPrefix(path, "/t/") {
|
|
return path
|
|
}
|
|
rest := strings.TrimPrefix(path, "/t/")
|
|
slash := strings.Index(rest, "/")
|
|
if slash == -1 {
|
|
return path
|
|
}
|
|
rest = rest[slash:]
|
|
if strings.HasPrefix(rest, "/v1") {
|
|
return rest
|
|
}
|
|
return path
|
|
}
|