package tenancy import ( "database/sql" "regexp" "strings" "quyun/v2/app/errorx" "github.com/gofiber/fiber/v3" "github.com/google/uuid" ) const ( LocalTenantCode = "tenant_code" LocalTenantID = "tenant_id" LocalTenantUUID = "tenant_uuid" ) var tenantCodeRe = regexp.MustCompile(`^[a-z0-9_-]+$`) type Tenant struct { ID int64 Code string UUID uuid.UUID } func ResolveTenant(c fiber.Ctx, db *sql.DB) (*Tenant, error) { raw := strings.TrimSpace(c.Params("tenant_code")) code := strings.ToLower(raw) if code == "" || !tenantCodeRe.MatchString(code) { return nil, errorx.ErrInvalidParameter.WithMsg("invalid tenant_code") } var ( id int64 tenantUUID uuid.UUID status int16 ) err := db.QueryRowContext( c.Context(), `SELECT id, tenant_uuid, status FROM tenants WHERE lower(tenant_code) = $1 LIMIT 1`, code, ).Scan(&id, &tenantUUID, &status) if err != nil { if err == sql.ErrNoRows { return nil, fiber.ErrNotFound } return nil, errorx.ErrDatabaseError.WithMsg("database error").WithParams(err.Error()) } // status: 0 enabled (by default) if status != 0 { return nil, fiber.ErrNotFound } return &Tenant{ID: id, Code: code, UUID: tenantUUID}, nil } func Middleware(db *sql.DB) fiber.Handler { return func(c fiber.Ctx) error { tenant, err := ResolveTenant(c, db) if err != nil { return err } c.Locals(LocalTenantCode, tenant.Code) c.Locals(LocalTenantID, tenant.ID) c.Locals(LocalTenantUUID, tenant.UUID.String()) return c.Next() } }