package middlewares import ( "errors" "strings" "quyun/v2/app/errorx" "quyun/v2/app/services" "quyun/v2/database/models" "quyun/v2/pkg/consts" "quyun/v2/providers/jwt" "github.com/gofiber/fiber/v3" log "github.com/sirupsen/logrus" "go.ipao.vip/gen/types" "gorm.io/gorm" ) // Middlewares provides reusable Fiber middlewares shared across modules. // // @provider type Middlewares struct { // log is the module logger injected by the framework. log *log.Entry `inject:"false"` // jwt is the JWT provider used by auth-related middlewares. jwt *jwt.JWT } func (f *Middlewares) Prepare() error { f.log = log.WithField("module", "middleware") return nil } func (m *Middlewares) Auth(ctx fiber.Ctx) error { if isPublicRoute(ctx) { return ctx.Next() } authHeader := ctx.Get("Authorization") if authHeader == "" { return errorx.ErrUnauthorized.WithMsg("Missing token") } claims, err := m.jwt.Parse(authHeader) if err != nil { return errorx.ErrUnauthorized.WithCause(err).WithMsg("Invalid token") } // get user model user, err := services.User.GetModelByID(ctx, claims.UserID) if err != nil { return errorx.ErrUnauthorized.WithCause(err).WithMsg("UserNotFound") } // Set Context ctx.Locals(consts.CtxKeyUser, user) if tenant := ctx.Locals(consts.CtxKeyTenant); tenant != nil { if model, ok := tenant.(*models.Tenant); ok && claims.TenantID > 0 && model.ID != claims.TenantID { return errorx.ErrForbidden.WithMsg("租户不匹配") } } else if claims.TenantID > 0 { tenantModel, err := services.Tenant.GetModelByID(ctx, claims.TenantID) if err != nil { return errorx.ErrUnauthorized.WithCause(err).WithMsg("TenantNotFound") } ctx.Locals(consts.CtxKeyTenant, tenantModel) } return ctx.Next() } func (m *Middlewares) SuperAuth(ctx fiber.Ctx) error { if isSuperPublicRoute(ctx) { return ctx.Next() } authHeader := ctx.Get("Authorization") if authHeader == "" { return errorx.ErrUnauthorized.WithMsg("Missing token") } claims, err := m.jwt.Parse(authHeader) if err != nil { return errorx.ErrUnauthorized.WithCause(err).WithMsg("Invalid token") } user, err := services.User.GetModelByID(ctx, claims.UserID) if err != nil { return errorx.ErrUnauthorized.WithCause(err).WithMsg("UserNotFound") } if !hasRole(user.Roles, consts.RoleSuperAdmin) { return errorx.ErrForbidden.WithMsg("无权限访问") } ctx.Locals(consts.CtxKeyUser, user) return ctx.Next() } func (m *Middlewares) TenantResolver(ctx fiber.Ctx) error { tenantCode := strings.TrimSpace(ctx.Params("tenantCode")) if tenantCode == "" { return errorx.ErrMissingParameter.WithMsg("缺少租户编码") } tbl, q := models.TenantQuery.QueryContext(ctx) tenant, err := q.Where(tbl.Code.Eq(tenantCode)).First() if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return errorx.ErrRecordNotFound.WithMsg("租户不存在") } return errorx.ErrDatabaseError.WithCause(err) } ctx.Locals(consts.CtxKeyTenant, tenant) return ctx.Next() } func hasRole(roles types.Array[consts.Role], role consts.Role) bool { for _, r := range roles { if r == role { return true } } return false } func isPublicRoute(ctx fiber.Ctx) bool { path := normalizeTenantPath(ctx.Path()) method := ctx.Method() if method == fiber.MethodGet { switch path { case "/v1/common/options", "/v1/contents", "/v1/topics", "/v1/tenants": return true } if strings.HasPrefix(path, "/v1/contents/") { return true } if strings.HasPrefix(path, "/v1/creators/") && strings.HasSuffix(path, "/contents") { return true } if strings.HasPrefix(path, "/v1/tenants/") { return true } if strings.HasPrefix(path, "/v1/orders/") && strings.HasSuffix(path, "/status") { return true } if strings.HasPrefix(path, "/v1/storage/") { return true } } if method == fiber.MethodPost && path == "/v1/webhook/payment/notify" { return true } if method == fiber.MethodPost { switch path { case "/v1/auth/otp", "/v1/auth/login": return true } } if method == fiber.MethodPut && strings.HasPrefix(path, "/v1/storage/") { return true } return false } func isSuperPublicRoute(ctx fiber.Ctx) bool { path := ctx.Path() method := ctx.Method() if method == fiber.MethodPost && path == "/super/v1/auth/login" { return true } return false } func normalizeTenantPath(path string) string { if !strings.HasPrefix(path, "/t/") { return path } rest := strings.TrimPrefix(path, "/t/") slash := strings.Index(rest, "/") if slash == -1 { return path } rest = rest[slash:] if strings.HasPrefix(rest, "/v1") { return rest } return path }