- Add middlewares_test.go with tests for AuthOptional, AuthRequired, SuperAuth, and TenantResolver - Update todo_list.md with test specifications, coverage status, and pending test cases (T1-T4)
319 lines
9.0 KiB
Go
319 lines
9.0 KiB
Go
package middlewares
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"quyun/v2/app/commands/testx"
|
|
"quyun/v2/app/errorx"
|
|
"quyun/v2/app/services"
|
|
"quyun/v2/database"
|
|
"quyun/v2/database/models"
|
|
"quyun/v2/pkg/consts"
|
|
"quyun/v2/providers/jwt"
|
|
|
|
"github.com/gofiber/fiber/v3"
|
|
jwtv4 "github.com/golang-jwt/jwt/v4"
|
|
. "github.com/smartystreets/goconvey/convey"
|
|
"github.com/stretchr/testify/suite"
|
|
"go.ipao.vip/atom/contracts"
|
|
"go.ipao.vip/gen/types"
|
|
"go.uber.org/dig"
|
|
)
|
|
|
|
type MiddlewaresTestSuiteInjectParams struct {
|
|
dig.In
|
|
|
|
DB *sql.DB
|
|
Initials []contracts.Initial `group:"initials"`
|
|
JWT *jwt.JWT
|
|
Middlewares *Middlewares
|
|
}
|
|
|
|
type MiddlewaresTestSuite struct {
|
|
suite.Suite
|
|
MiddlewaresTestSuiteInjectParams
|
|
}
|
|
|
|
func Test_Middlewares(t *testing.T) {
|
|
providers := testx.Default().With(Provide)
|
|
|
|
testx.Serve(providers, t, func(p MiddlewaresTestSuiteInjectParams) {
|
|
suite.Run(t, &MiddlewaresTestSuite{MiddlewaresTestSuiteInjectParams: p})
|
|
})
|
|
}
|
|
|
|
func (s *MiddlewaresTestSuite) newTestApp() *fiber.App {
|
|
handler := errorx.NewErrorHandler()
|
|
return fiber.New(fiber.Config{
|
|
ErrorHandler: func(c fiber.Ctx, err error) error {
|
|
appErr := handler.Handle(err)
|
|
return c.Status(appErr.StatusCode).JSON(fiber.Map{
|
|
"code": appErr.Code,
|
|
"message": appErr.Message,
|
|
})
|
|
},
|
|
})
|
|
}
|
|
|
|
func (s *MiddlewaresTestSuite) createTestUser(ctx context.Context, phone string, roles types.Array[consts.Role]) *models.User {
|
|
user := &models.User{
|
|
Phone: phone,
|
|
Roles: roles,
|
|
Status: consts.UserStatusVerified,
|
|
}
|
|
_ = models.UserQuery.WithContext(ctx).Create(user)
|
|
return user
|
|
}
|
|
|
|
func (s *MiddlewaresTestSuite) createToken(userID, tenantID int64) string {
|
|
claims := s.JWT.CreateClaims(jwt.BaseClaims{
|
|
UserID: userID,
|
|
TenantID: tenantID,
|
|
})
|
|
token, _ := s.JWT.CreateToken(claims)
|
|
return "Bearer " + token
|
|
}
|
|
|
|
func (s *MiddlewaresTestSuite) createExpiredToken(userID int64) string {
|
|
claims := &jwt.Claims{
|
|
BaseClaims: jwt.BaseClaims{
|
|
UserID: userID,
|
|
},
|
|
}
|
|
claims.ExpiresAt = jwtv4.NewNumericDate(time.Now().Add(-time.Hour))
|
|
claims.NotBefore = jwtv4.NewNumericDate(time.Now().Add(-2 * time.Hour))
|
|
|
|
token := jwtv4.NewWithClaims(jwtv4.SigningMethodHS256, claims)
|
|
tokenString, _ := token.SignedString(s.JWT.SigningKey)
|
|
return "Bearer " + tokenString
|
|
}
|
|
|
|
func (s *MiddlewaresTestSuite) Test_AuthOptional() {
|
|
Convey("AuthOptional", s.T(), func() {
|
|
ctx := s.T().Context()
|
|
database.Truncate(ctx, s.DB, models.TableNameUser)
|
|
|
|
Convey("should pass without token and ctx.Locals has no user", func() {
|
|
app := s.newTestApp()
|
|
app.Use(s.Middlewares.AuthOptional)
|
|
app.Get("/test", func(c fiber.Ctx) error {
|
|
user := c.Locals(consts.CtxKeyUser)
|
|
if user == nil {
|
|
return c.SendString("no_user")
|
|
}
|
|
return c.SendString("has_user")
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
resp, err := app.Test(req)
|
|
So(err, ShouldBeNil)
|
|
So(resp.StatusCode, ShouldEqual, http.StatusOK)
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
So(string(body), ShouldEqual, "no_user")
|
|
})
|
|
|
|
Convey("should pass with valid token and ctx.Locals has user", func() {
|
|
app := s.newTestApp()
|
|
app.Use(s.Middlewares.AuthOptional)
|
|
app.Get("/test", func(c fiber.Ctx) error {
|
|
user := c.Locals(consts.CtxKeyUser)
|
|
if user == nil {
|
|
return c.SendString("no_user")
|
|
}
|
|
return c.SendString("has_user")
|
|
})
|
|
|
|
user := s.createTestUser(ctx, "13800000001", types.Array[consts.Role]{consts.RoleUser})
|
|
token := s.createToken(user.ID, 0)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
req.Header.Set("Authorization", token)
|
|
resp, err := app.Test(req)
|
|
So(err, ShouldBeNil)
|
|
So(resp.StatusCode, ShouldEqual, http.StatusOK)
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
So(string(body), ShouldEqual, "has_user")
|
|
})
|
|
})
|
|
}
|
|
|
|
func (s *MiddlewaresTestSuite) Test_AuthRequired() {
|
|
Convey("AuthRequired", s.T(), func() {
|
|
ctx := s.T().Context()
|
|
database.Truncate(ctx, s.DB, models.TableNameUser)
|
|
|
|
Convey("should return 401 without token", func() {
|
|
app := s.newTestApp()
|
|
app.Use(s.Middlewares.AuthRequired)
|
|
app.Get("/protected", func(c fiber.Ctx) error {
|
|
return c.SendString("ok")
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
|
resp, err := app.Test(req)
|
|
So(err, ShouldBeNil)
|
|
So(resp.StatusCode, ShouldEqual, http.StatusUnauthorized)
|
|
})
|
|
|
|
Convey("should return 401 with invalid token", func() {
|
|
app := s.newTestApp()
|
|
app.Use(s.Middlewares.AuthRequired)
|
|
app.Get("/protected", func(c fiber.Ctx) error {
|
|
return c.SendString("ok")
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
|
req.Header.Set("Authorization", "Bearer invalid_token")
|
|
resp, err := app.Test(req)
|
|
So(err, ShouldBeNil)
|
|
So(resp.StatusCode, ShouldEqual, http.StatusUnauthorized)
|
|
})
|
|
|
|
Convey("should pass with valid token", func() {
|
|
app := s.newTestApp()
|
|
app.Use(s.Middlewares.AuthRequired)
|
|
app.Get("/protected", func(c fiber.Ctx) error {
|
|
return c.SendString("ok")
|
|
})
|
|
|
|
user := s.createTestUser(ctx, "13800000002", types.Array[consts.Role]{consts.RoleUser})
|
|
token := s.createToken(user.ID, 0)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
|
req.Header.Set("Authorization", token)
|
|
resp, err := app.Test(req)
|
|
So(err, ShouldBeNil)
|
|
So(resp.StatusCode, ShouldEqual, http.StatusOK)
|
|
})
|
|
})
|
|
}
|
|
|
|
func (s *MiddlewaresTestSuite) Test_SuperAuth() {
|
|
Convey("SuperAuth", s.T(), func() {
|
|
ctx := s.T().Context()
|
|
database.Truncate(ctx, s.DB, models.TableNameUser)
|
|
|
|
Convey("should return 401 without token", func() {
|
|
app := s.newTestApp()
|
|
app.Use(s.Middlewares.SuperAuth)
|
|
app.Get("/super/v1/tenants", func(c fiber.Ctx) error {
|
|
return c.SendString("ok")
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/super/v1/tenants", nil)
|
|
resp, err := app.Test(req)
|
|
So(err, ShouldBeNil)
|
|
So(resp.StatusCode, ShouldEqual, http.StatusUnauthorized)
|
|
})
|
|
|
|
Convey("should return 403 when user is not super_admin", func() {
|
|
app := s.newTestApp()
|
|
app.Use(s.Middlewares.SuperAuth)
|
|
app.Get("/super/v1/tenants", func(c fiber.Ctx) error {
|
|
return c.SendString("ok")
|
|
})
|
|
|
|
user := s.createTestUser(ctx, "13800000003", types.Array[consts.Role]{consts.RoleUser})
|
|
token := s.createToken(user.ID, 0)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/super/v1/tenants", nil)
|
|
req.Header.Set("Authorization", token)
|
|
resp, err := app.Test(req)
|
|
So(err, ShouldBeNil)
|
|
So(resp.StatusCode, ShouldEqual, http.StatusForbidden)
|
|
})
|
|
|
|
Convey("should pass when user is super_admin", func() {
|
|
app := s.newTestApp()
|
|
app.Use(s.Middlewares.SuperAuth)
|
|
app.Get("/super/v1/tenants", func(c fiber.Ctx) error {
|
|
return c.SendString("ok")
|
|
})
|
|
|
|
user := s.createTestUser(ctx, "13800000004", types.Array[consts.Role]{consts.RoleSuperAdmin})
|
|
token := s.createToken(user.ID, 0)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/super/v1/tenants", nil)
|
|
req.Header.Set("Authorization", token)
|
|
resp, err := app.Test(req)
|
|
So(err, ShouldBeNil)
|
|
So(resp.StatusCode, ShouldEqual, http.StatusOK)
|
|
})
|
|
|
|
Convey("should allow public routes without auth", func() {
|
|
app := s.newTestApp()
|
|
app.Use(s.Middlewares.SuperAuth)
|
|
app.Post("/super/v1/auth/login", func(c fiber.Ctx) error {
|
|
return c.SendString("login")
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/super/v1/auth/login", nil)
|
|
resp, err := app.Test(req)
|
|
So(err, ShouldBeNil)
|
|
So(resp.StatusCode, ShouldEqual, http.StatusOK)
|
|
})
|
|
})
|
|
}
|
|
|
|
func (s *MiddlewaresTestSuite) Test_TenantResolver() {
|
|
Convey("TenantResolver", s.T(), func() {
|
|
ctx := s.T().Context()
|
|
database.Truncate(ctx, s.DB, models.TableNameTenant, models.TableNameUser)
|
|
|
|
Convey("should return 404 when tenant not found", func() {
|
|
app := s.newTestApp()
|
|
app.Get("/t/:tenantCode/v1/test", s.Middlewares.TenantResolver, func(c fiber.Ctx) error {
|
|
return c.SendString("ok")
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/t/nonexistent/v1/test", nil)
|
|
resp, err := app.Test(req)
|
|
So(err, ShouldBeNil)
|
|
So(resp.StatusCode, ShouldEqual, http.StatusNotFound)
|
|
})
|
|
|
|
Convey("should set tenant in ctx.Locals when found", func() {
|
|
owner := &models.User{Phone: "13800000005", Status: consts.UserStatusVerified}
|
|
_ = models.UserQuery.WithContext(ctx).Create(owner)
|
|
|
|
tenant := &models.Tenant{
|
|
Name: "Test Tenant",
|
|
Code: "test_tenant",
|
|
UserID: owner.ID,
|
|
Status: consts.TenantStatusVerified,
|
|
}
|
|
_ = models.TenantQuery.WithContext(ctx).Create(tenant)
|
|
|
|
app := s.newTestApp()
|
|
app.Get("/t/:tenantCode/v1/test", s.Middlewares.TenantResolver, func(c fiber.Ctx) error {
|
|
t := c.Locals(consts.CtxKeyTenant)
|
|
if t == nil {
|
|
return c.SendString("no_tenant")
|
|
}
|
|
if model, ok := t.(*models.Tenant); ok {
|
|
return c.SendString(model.Code)
|
|
}
|
|
return c.SendString("invalid_tenant")
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/t/test_tenant/v1/test", nil)
|
|
resp, err := app.Test(req)
|
|
So(err, ShouldBeNil)
|
|
So(resp.StatusCode, ShouldEqual, http.StatusOK)
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
So(string(body), ShouldEqual, "test_tenant")
|
|
})
|
|
})
|
|
}
|
|
|
|
var _ = services.User
|