From 16f80cc297dad7d71f06f90e76ce0f0f08ecbdb4 Mon Sep 17 00:00:00 2001 From: yanghao05 Date: Mon, 30 Jan 2023 11:44:34 +0800 Subject: [PATCH] add cors middleware --- config.toml | 19 +++++++++ middleware/cors.go | 73 ++++++++++++++++++++++++++++++++ modules/system/dao/dao_test.go | 2 +- providers/config/section_http.go | 12 ++++++ 4 files changed, 105 insertions(+), 1 deletion(-) create mode 100644 middleware/cors.go diff --git a/config.toml b/config.toml index 2908490..782ad35 100644 --- a/config.toml +++ b/config.toml @@ -8,6 +8,25 @@ HttpsCert = "" HttpKey = "" Port = 8088 +[Http.Cors] +# 跨域配置 +# 需要配合 server/initialize/router.go#L32 使用 +# 放行模式: Allow-all, 放行全部; whitelist, 白名单模式, 来自白名单内域名的请求添加 cors 头; strict-whitelist 严格白名单模式, 白名单外的请求一律拒绝 +Mode="strict-whitelist" +[[Http.Cors.Whitelist]] +AllowOrigin = "example1.com" +AllowHeaders = "Content-Type,AccessToken,X-CSRF-Token, Authorization, Token,X-Token,X-User-Id" +AllowMethods = "POST, GET" +ExposeHeaders = "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers, Content-Type" +AllowCredentials = true + +[[Http.Cors.Whitelist]] +AllowOrigin = "example2.com" +AllowHeaders = "content-type" +AllowMethods = "GET, POST" +ExposeHeaders = "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers, Content-Type" +AllowCredentials = true + [Log] Level = "debug" diff --git a/middleware/cors.go b/middleware/cors.go new file mode 100644 index 0000000..043f84f --- /dev/null +++ b/middleware/cors.go @@ -0,0 +1,73 @@ +package middleware + +import ( + "atom/providers/config" + "net/http" + + "github.com/gin-gonic/gin" +) + +// Cors 直接放行所有跨域请求并放行所有 OPTIONS 方法 +func Cors() gin.HandlerFunc { + return func(c *gin.Context) { + method := c.Request.Method + origin := c.Request.Header.Get("Origin") + c.Header("Access-Control-Allow-Origin", origin) + c.Header("Access-Control-Allow-Headers", "Content-Type,AccessToken,X-CSRF-Token, Authorization, Token,X-Token,X-User-Id") + c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS,DELETE,PUT") + c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers, Content-Type, New-Token, New-Expires-At") + c.Header("Access-Control-Allow-Credentials", "true") + + // 放行所有OPTIONS方法 + if method == "OPTIONS" { + c.AbortWithStatus(http.StatusNoContent) + } + // 处理请求 + c.Next() + } +} + +// CorsByRules 按照配置处理跨域请求 +func CorsByRules(config *config.Config) gin.HandlerFunc { + // 放行全部 + if config.Http.Cors.Mode == "allow-all" { + return Cors() + } + return func(c *gin.Context) { + whitelist := checkCors(config, c.GetHeader("origin")) + + // 通过检查, 添加请求头 + if whitelist != nil { + c.Header("Access-Control-Allow-Origin", whitelist.AllowOrigin) + c.Header("Access-Control-Allow-Headers", whitelist.AllowHeaders) + c.Header("Access-Control-Allow-Methods", whitelist.AllowMethods) + c.Header("Access-Control-Expose-Headers", whitelist.ExposeHeaders) + if whitelist.AllowCredentials { + c.Header("Access-Control-Allow-Credentials", "true") + } + } + + // 严格白名单模式且未通过检查,直接拒绝处理请求 + if whitelist == nil && config.Http.Cors.Mode == "strict-whitelist" && !(c.Request.Method == "GET" && c.Request.URL.Path == "/health") { + c.AbortWithStatus(http.StatusForbidden) + } else { + // 非严格白名单模式,无论是否通过检查均放行所有 OPTIONS 方法 + if c.Request.Method == http.MethodOptions { + c.AbortWithStatus(http.StatusNoContent) + } + } + + // 处理请求 + c.Next() + } +} + +func checkCors(conf *config.Config, currentOrigin string) *config.Whitelist { + for _, whitelist := range conf.Http.Cors.Whitelist { + // 遍历配置中的跨域头,寻找匹配项 + if currentOrigin == whitelist.AllowOrigin { + return &whitelist + } + } + return nil +} diff --git a/modules/system/dao/dao_test.go b/modules/system/dao/dao_test.go index b48ec5e..e52861d 100644 --- a/modules/system/dao/dao_test.go +++ b/modules/system/dao/dao_test.go @@ -8,7 +8,7 @@ import ( "atom/providers/config" _ "atom/providers/database" _ "atom/providers/http" - _ "atom/providers/logger" + _ "atom/providers/log" "go.uber.org/dig" "gorm.io/gorm" diff --git a/providers/config/section_http.go b/providers/config/section_http.go index d1bcd84..95b0c5b 100644 --- a/providers/config/section_http.go +++ b/providers/config/section_http.go @@ -9,6 +9,18 @@ type Http struct { Https bool HttpsCert string HttpKey string + Cors struct { + Mode string + Whitelist []Whitelist + } +} + +type Whitelist struct { + AllowOrigin string + AllowHeaders string + AllowMethods string + ExposeHeaders string + AllowCredentials bool } func (h Http) Address() string {