package middlewares import ( "context" "database/sql" "io" "net/http" "net/http/httptest" "testing" "time" "quyun/v2/app/commands/testx" "quyun/v2/app/errorx" "quyun/v2/app/services" "quyun/v2/database" "quyun/v2/database/models" "quyun/v2/pkg/consts" "quyun/v2/providers/jwt" "github.com/gofiber/fiber/v3" jwtv4 "github.com/golang-jwt/jwt/v4" . "github.com/smartystreets/goconvey/convey" "github.com/stretchr/testify/suite" "go.ipao.vip/atom/contracts" "go.ipao.vip/gen/types" "go.uber.org/dig" ) type MiddlewaresTestSuiteInjectParams struct { dig.In DB *sql.DB Initials []contracts.Initial `group:"initials"` JWT *jwt.JWT Middlewares *Middlewares } type MiddlewaresTestSuite struct { suite.Suite MiddlewaresTestSuiteInjectParams } func Test_Middlewares(t *testing.T) { providers := testx.Default().With(Provide) testx.Serve(providers, t, func(p MiddlewaresTestSuiteInjectParams) { suite.Run(t, &MiddlewaresTestSuite{MiddlewaresTestSuiteInjectParams: p}) }) } func (s *MiddlewaresTestSuite) newTestApp() *fiber.App { handler := errorx.NewErrorHandler() return fiber.New(fiber.Config{ ErrorHandler: func(c fiber.Ctx, err error) error { appErr := handler.Handle(err) return c.Status(appErr.StatusCode).JSON(fiber.Map{ "code": appErr.Code, "message": appErr.Message, }) }, }) } func (s *MiddlewaresTestSuite) createTestUser(ctx context.Context, phone string, roles types.Array[consts.Role]) *models.User { user := &models.User{ Phone: phone, Roles: roles, Status: consts.UserStatusVerified, } _ = models.UserQuery.WithContext(ctx).Create(user) return user } func (s *MiddlewaresTestSuite) createToken(userID, tenantID int64) string { claims := s.JWT.CreateClaims(jwt.BaseClaims{ UserID: userID, TenantID: tenantID, }) token, _ := s.JWT.CreateToken(claims) return "Bearer " + token } func (s *MiddlewaresTestSuite) createExpiredToken(userID int64) string { claims := &jwt.Claims{ BaseClaims: jwt.BaseClaims{ UserID: userID, }, } claims.ExpiresAt = jwtv4.NewNumericDate(time.Now().Add(-time.Hour)) claims.NotBefore = jwtv4.NewNumericDate(time.Now().Add(-2 * time.Hour)) token := jwtv4.NewWithClaims(jwtv4.SigningMethodHS256, claims) tokenString, _ := token.SignedString(s.JWT.SigningKey) return "Bearer " + tokenString } func (s *MiddlewaresTestSuite) Test_AuthOptional() { Convey("AuthOptional", s.T(), func() { ctx := s.T().Context() database.Truncate(ctx, s.DB, models.TableNameUser) Convey("should pass without token and ctx.Locals has no user", func() { app := s.newTestApp() app.Use(s.Middlewares.AuthOptional) app.Get("/test", func(c fiber.Ctx) error { user := c.Locals(consts.CtxKeyUser) if user == nil { return c.SendString("no_user") } return c.SendString("has_user") }) req := httptest.NewRequest(http.MethodGet, "/test", nil) resp, err := app.Test(req) So(err, ShouldBeNil) So(resp.StatusCode, ShouldEqual, http.StatusOK) body, _ := io.ReadAll(resp.Body) So(string(body), ShouldEqual, "no_user") }) Convey("should pass with valid token and ctx.Locals has user", func() { app := s.newTestApp() app.Use(s.Middlewares.AuthOptional) app.Get("/test", func(c fiber.Ctx) error { user := c.Locals(consts.CtxKeyUser) if user == nil { return c.SendString("no_user") } return c.SendString("has_user") }) user := s.createTestUser(ctx, "13800000001", types.Array[consts.Role]{consts.RoleUser}) token := s.createToken(user.ID, 0) req := httptest.NewRequest(http.MethodGet, "/test", nil) req.Header.Set("Authorization", token) resp, err := app.Test(req) So(err, ShouldBeNil) So(resp.StatusCode, ShouldEqual, http.StatusOK) body, _ := io.ReadAll(resp.Body) So(string(body), ShouldEqual, "has_user") }) }) } func (s *MiddlewaresTestSuite) Test_AuthRequired() { Convey("AuthRequired", s.T(), func() { ctx := s.T().Context() database.Truncate(ctx, s.DB, models.TableNameUser) Convey("should return 401 without token", func() { app := s.newTestApp() app.Use(s.Middlewares.AuthRequired) app.Get("/protected", func(c fiber.Ctx) error { return c.SendString("ok") }) req := httptest.NewRequest(http.MethodGet, "/protected", nil) resp, err := app.Test(req) So(err, ShouldBeNil) So(resp.StatusCode, ShouldEqual, http.StatusUnauthorized) }) Convey("should return 401 with invalid token", func() { app := s.newTestApp() app.Use(s.Middlewares.AuthRequired) app.Get("/protected", func(c fiber.Ctx) error { return c.SendString("ok") }) req := httptest.NewRequest(http.MethodGet, "/protected", nil) req.Header.Set("Authorization", "Bearer invalid_token") resp, err := app.Test(req) So(err, ShouldBeNil) So(resp.StatusCode, ShouldEqual, http.StatusUnauthorized) }) Convey("should pass with valid token", func() { app := s.newTestApp() app.Use(s.Middlewares.AuthRequired) app.Get("/protected", func(c fiber.Ctx) error { return c.SendString("ok") }) user := s.createTestUser(ctx, "13800000002", types.Array[consts.Role]{consts.RoleUser}) token := s.createToken(user.ID, 0) req := httptest.NewRequest(http.MethodGet, "/protected", nil) req.Header.Set("Authorization", token) resp, err := app.Test(req) So(err, ShouldBeNil) So(resp.StatusCode, ShouldEqual, http.StatusOK) }) }) } func (s *MiddlewaresTestSuite) Test_SuperAuth() { Convey("SuperAuth", s.T(), func() { ctx := s.T().Context() database.Truncate(ctx, s.DB, models.TableNameUser) Convey("should return 401 without token", func() { app := s.newTestApp() app.Use(s.Middlewares.SuperAuth) app.Get("/super/v1/tenants", func(c fiber.Ctx) error { return c.SendString("ok") }) req := httptest.NewRequest(http.MethodGet, "/super/v1/tenants", nil) resp, err := app.Test(req) So(err, ShouldBeNil) So(resp.StatusCode, ShouldEqual, http.StatusUnauthorized) }) Convey("should return 403 when user is not super_admin", func() { app := s.newTestApp() app.Use(s.Middlewares.SuperAuth) app.Get("/super/v1/tenants", func(c fiber.Ctx) error { return c.SendString("ok") }) user := s.createTestUser(ctx, "13800000003", types.Array[consts.Role]{consts.RoleUser}) token := s.createToken(user.ID, 0) req := httptest.NewRequest(http.MethodGet, "/super/v1/tenants", nil) req.Header.Set("Authorization", token) resp, err := app.Test(req) So(err, ShouldBeNil) So(resp.StatusCode, ShouldEqual, http.StatusForbidden) }) Convey("should pass when user is super_admin", func() { app := s.newTestApp() app.Use(s.Middlewares.SuperAuth) app.Get("/super/v1/tenants", func(c fiber.Ctx) error { return c.SendString("ok") }) user := s.createTestUser(ctx, "13800000004", types.Array[consts.Role]{consts.RoleSuperAdmin}) token := s.createToken(user.ID, 0) req := httptest.NewRequest(http.MethodGet, "/super/v1/tenants", nil) req.Header.Set("Authorization", token) resp, err := app.Test(req) So(err, ShouldBeNil) So(resp.StatusCode, ShouldEqual, http.StatusOK) }) Convey("should allow public routes without auth", func() { app := s.newTestApp() app.Use(s.Middlewares.SuperAuth) app.Post("/super/v1/auth/login", func(c fiber.Ctx) error { return c.SendString("login") }) req := httptest.NewRequest(http.MethodPost, "/super/v1/auth/login", nil) resp, err := app.Test(req) So(err, ShouldBeNil) So(resp.StatusCode, ShouldEqual, http.StatusOK) }) }) } func (s *MiddlewaresTestSuite) Test_TenantResolver() { Convey("TenantResolver", s.T(), func() { ctx := s.T().Context() database.Truncate(ctx, s.DB, models.TableNameTenant, models.TableNameUser) Convey("should return 404 when tenant not found", func() { app := s.newTestApp() app.Get("/t/:tenantCode/v1/test", s.Middlewares.TenantResolver, func(c fiber.Ctx) error { return c.SendString("ok") }) req := httptest.NewRequest(http.MethodGet, "/t/nonexistent/v1/test", nil) resp, err := app.Test(req) So(err, ShouldBeNil) So(resp.StatusCode, ShouldEqual, http.StatusNotFound) }) Convey("should set tenant in ctx.Locals when found", func() { owner := &models.User{Phone: "13800000005", Status: consts.UserStatusVerified} _ = models.UserQuery.WithContext(ctx).Create(owner) tenant := &models.Tenant{ Name: "Test Tenant", Code: "test_tenant", UserID: owner.ID, Status: consts.TenantStatusVerified, } _ = models.TenantQuery.WithContext(ctx).Create(tenant) app := s.newTestApp() app.Get("/t/:tenantCode/v1/test", s.Middlewares.TenantResolver, func(c fiber.Ctx) error { t := c.Locals(consts.CtxKeyTenant) if t == nil { return c.SendString("no_tenant") } if model, ok := t.(*models.Tenant); ok { return c.SendString(model.Code) } return c.SendString("invalid_tenant") }) req := httptest.NewRequest(http.MethodGet, "/t/test_tenant/v1/test", nil) resp, err := app.Test(req) So(err, ShouldBeNil) So(resp.StatusCode, ShouldEqual, http.StatusOK) body, _ := io.ReadAll(resp.Body) So(string(body), ShouldEqual, "test_tenant") }) }) } var _ = services.User