From 2d52491536997ca5b1faab1015bd630c2a9c3620 Mon Sep 17 00:00:00 2001 From: yanghao05 Date: Tue, 7 Feb 2023 11:45:23 +0800 Subject: [PATCH] fix issues --- database/seeders/migration.go | 20 +++++++++--------- go.mod | 2 +- modules/auth/controller/permission.go | 29 ++++++++++++++++++++------- modules/auth/routes/routes.go | 2 +- providers/jwt/jwt.go | 13 ++++++++---- providers/rbac/casbin.go | 4 ++++ providers/rbac/rbac.go | 1 + 7 files changed, 48 insertions(+), 23 deletions(-) diff --git a/database/seeders/migration.go b/database/seeders/migration.go index 0bb1611..e490436 100755 --- a/database/seeders/migration.go +++ b/database/seeders/migration.go @@ -25,16 +25,16 @@ func NewMigrationSeeder() contracts.Seeder { } func (s *MigrationSeeder) Run(faker *gofakeit.Faker, db *gorm.DB) { - times := 10 - for i := 0; i < times; i++ { - data := s.Generate(faker, i) - if i == 0 { - stmt := &gorm.Statement{DB: db} - _ = stmt.Parse(&data) - log.Printf("seeding %s for %d times", stmt.Schema.Table, times) - } - db.Create(&data) - } + // times := 10 + // for i := 0; i < times; i++ { + // data := s.Generate(faker, i) + // if i == 0 { + // stmt := &gorm.Statement{DB: db} + // _ = stmt.Parse(&data) + // log.Printf("seeding %s for %d times", stmt.Schema.Table, times) + // } + // db.Create(&data) + // } } func (s *MigrationSeeder) Generate(faker *gofakeit.Faker, idx int) models.Migration { diff --git a/go.mod b/go.mod index c6b9b8f..dfbe5f8 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( go.uber.org/zap v1.21.0 golang.org/x/crypto v0.5.0 golang.org/x/sync v0.1.0 + google.golang.org/protobuf v1.28.1 gorm.io/driver/mysql v1.4.1 gorm.io/gen v0.3.19 gorm.io/gorm v1.24.0 @@ -85,7 +86,6 @@ require ( golang.org/x/sys v0.4.0 // indirect golang.org/x/text v0.6.0 // indirect golang.org/x/tools v0.1.12 // indirect - google.golang.org/protobuf v1.28.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/modules/auth/controller/permission.go b/modules/auth/controller/permission.go index 558e645..b9c8413 100755 --- a/modules/auth/controller/permission.go +++ b/modules/auth/controller/permission.go @@ -1,23 +1,38 @@ package controller import ( - "atom/providers/config" + "atom/providers/jwt" + "atom/providers/rbac" "github.com/gin-gonic/gin" ) type PermissionController interface { - GetName(*gin.Context) (string, error) + Get(ctx *gin.Context) (string, error) } type permissionControllerImpl struct { - conf *config.Config + jwt *jwt.JWT + rbac rbac.IRbac } -func NewPermissionController(conf *config.Config) PermissionController { - return &permissionControllerImpl{conf: conf} +func NewPermissionController( + jwt *jwt.JWT, + rbac rbac.IRbac, +) PermissionController { + return &permissionControllerImpl{rbac: rbac, jwt: jwt} } -func (c *permissionControllerImpl) GetName(ctx *gin.Context) (string, error) { - return "Permission",nil +func (c *permissionControllerImpl) Get(ctx *gin.Context) (string, error) { + claims, err := c.jwt.GetClaims(ctx) + if err != nil { + return "", err + } + + perm, err := c.rbac.JsonPermissionsForUser(claims.Username) + if err != nil { + return "", err + } + + return perm, nil } diff --git a/modules/auth/routes/routes.go b/modules/auth/routes/routes.go index 2d03852..d3bd8d1 100755 --- a/modules/auth/routes/routes.go +++ b/modules/auth/routes/routes.go @@ -70,7 +70,7 @@ func (r *Route) Register() { permissionGroup := group.Group("permission") { - permissionGroup.GET("/permissions", gen.DataFunc(r.permission.GetName)) + permissionGroup.GET("/permissions", gen.DataFunc(r.permission.Get)) } } diff --git a/providers/jwt/jwt.go b/providers/jwt/jwt.go index fc0f57d..746995c 100644 --- a/providers/jwt/jwt.go +++ b/providers/jwt/jwt.go @@ -14,6 +14,11 @@ import ( "golang.org/x/sync/singleflight" ) +const ( + CtxKey = "claims" + HttpHeader = "Authorization" +) + func init() { if err := container.Container.Provide(NewJWT); err != nil { log.Fatal(err) @@ -118,7 +123,7 @@ func (j *JWT) ParseToken(tokenString string) (*CustomClaims, error) { } func (j *JWT) GetClaims(c *gin.Context) (*CustomClaims, error) { - token := c.Request.Header.Get("Authorization") + token := c.Request.Header.Get(HttpHeader) claims, err := j.ParseToken(token) if err != nil { log.Error("从Gin的Context中获取从jwt解析信息失败, 请检查请求头是否存在 Authorization 且 Claims 为规定结构") @@ -128,7 +133,7 @@ func (j *JWT) GetClaims(c *gin.Context) (*CustomClaims, error) { // GetUserID 从Gin的Context中获取从jwt解析出来的用户ID func (j *JWT) GetUserID(c *gin.Context) uint64 { - if claims, exists := c.Get("claims"); !exists { + if claims, exists := c.Get(CtxKey); !exists { if cl, err := j.GetClaims(c); err != nil { return 0 } else { @@ -142,7 +147,7 @@ func (j *JWT) GetUserID(c *gin.Context) uint64 { // GetUserUuid 从Gin的Context中获取从jwt解析出来的用户UUID func (j *JWT) GetUserUuid(c *gin.Context) string { - if claims, exists := c.Get("claims"); !exists { + if claims, exists := c.Get(CtxKey); !exists { if cl, err := j.GetClaims(c); err != nil { return uuid.UUID{}.String() } else { @@ -156,7 +161,7 @@ func (j *JWT) GetUserUuid(c *gin.Context) string { // GetUserAuthorityId 从Gin的Context中获取从jwt解析出来的用户角色id func (j *JWT) GetRoleId(c *gin.Context) uint64 { - if claims, exists := c.Get("claims"); !exists { + if claims, exists := c.Get(CtxKey); !exists { if cl, err := j.GetClaims(c); err != nil { return 0 } else { diff --git a/providers/rbac/casbin.go b/providers/rbac/casbin.go index 11f9249..d627ecc 100644 --- a/providers/rbac/casbin.go +++ b/providers/rbac/casbin.go @@ -71,6 +71,10 @@ func (cb *Casbin) Reload() error { return nil } +func (cb *Casbin) JsonPermissionsForUser(username string) (string, error) { + return casbin.CasbinJsGetPermissionForUser(cb.enforcer, username) +} + func (cb *Casbin) Update(roleID uint, infos []CasbinInfo) error { roleIdStr := strconv.Itoa(int(roleID)) cb.Clear(0, roleIdStr) diff --git a/providers/rbac/rbac.go b/providers/rbac/rbac.go index 1866fa6..d9e7cb0 100644 --- a/providers/rbac/rbac.go +++ b/providers/rbac/rbac.go @@ -2,5 +2,6 @@ package rbac type IRbac interface { Can(role, method, path string) bool + JsonPermissionsForUser(string) (string, error) Reload() error }