fix: provider

This commit is contained in:
Rogee
2024-12-19 17:57:23 +08:00
parent 4702873975
commit e007535972
4 changed files with 88 additions and 18 deletions

20
.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,20 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Launch Package",
"type": "go",
"request": "launch",
"mode": "auto",
"program": "${workspaceFolder}",
"args": [
"gen",
"provider",
"/projects/mp-qvyun/backend",
]
}
]
}

View File

@@ -370,8 +370,15 @@ func renderFile(filename string, conf []Provider) error {
for k, v := range item.Imports { for k, v := range item.Imports {
// 如果是当前包的引用,直接使用包名 // 如果是当前包的引用,直接使用包名
if strings.HasSuffix(k, "/"+v) { if strings.HasSuffix(k, "/"+v) {
v = "" imports[k] = ""
continue
} }
if gomod.GetPackageModuleName(k) == v {
imports[k] = ""
continue
}
imports[k] = v imports[k] = v
} }
}) })

View File

@@ -1,6 +1,8 @@
package gomod package gomod
import ( import (
"bufio"
"errors"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@@ -40,7 +42,7 @@ func Parse(modPath string) error {
goMod = &GoMod{file: f, modules: make(map[string]ModuleInfo)} goMod = &GoMod{file: f, modules: make(map[string]ModuleInfo)}
for _, require := range f.Require { for _, require := range f.Require {
if !require.Indirect { if require.Indirect {
continue continue
} }
@@ -94,22 +96,33 @@ func getPackageName(pkg, version string) (string, error) {
return "", err return "", err
} }
packagePattern := regexp.MustCompile(`package\s+(\w+)`) packagePattern := regexp.MustCompile(`^package\s+(\w+)$`)
getFilePackageName := func(file string) (string, error) {
// 读取文件内容
f, err := os.Open(file)
if err != nil {
return "", err
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
if matches := packagePattern.FindStringSubmatch(line); matches != nil {
return matches[1], nil
}
}
return "", errors.New("no match")
}
if len(files) > 0 { if len(files) > 0 {
for _, file := range files { for _, file := range files {
if strings.HasSuffix(file, "_test.go") { if strings.HasSuffix(file, "_test.go") {
continue continue
} }
// 读取文件内容
content, err := os.ReadFile(file) if name, err := getFilePackageName(file); err == nil {
if err != nil { return name, nil
return "", err
}
packageName := packagePattern.FindStringSubmatch(string(content))
if len(packageName) == 2 {
return packageName[1], nil
} }
} }
} }

View File

@@ -1,6 +1,9 @@
package gomod package gomod
import ( import (
"bufio"
"os"
"regexp"
"testing" "testing"
"github.com/rogeecn/fabfile" "github.com/rogeecn/fabfile"
@@ -21,12 +24,39 @@ func Test_ParseGoMod(t *testing.T) {
func Test_getPackageName(t *testing.T) { func Test_getPackageName(t *testing.T) {
Convey("Test getPackageName", t, func() { Convey("Test getPackageName", t, func() {
Convey("", func() {
Convey("github.com/redis/go-redis/v9@v9.7.0", func() { Convey("github.com/redis/go-redis/v9@v9.7.0", func() {
name, err := getPackageName("github.com/redis/go-redis/v9", "v9.7.0") name, err := getPackageName("github.com/redis/go-redis/v9", "v9.7.0")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(name, ShouldEqual, "redis") So(name, ShouldEqual, "redis")
}) })
Convey("github.com/pkg/errors@v0.9.1", func() {
name, err := getPackageName("github.com/pkg/errors", "v0.9.1")
So(err, ShouldBeNil)
So(name, ShouldEqual, "errors")
})
})
}
func Test_file(t *testing.T) {
Convey("Test file", t, func() {
Convey("Test file", func() {
packagePattern := regexp.MustCompile(`^package\s+(\w+)$`)
file := "/root/go/pkg/mod/github.com/redis/go-redis/v9@v9.7.0/acl_commands.go"
// read file line by line
f, err := os.Open(file)
So(err, ShouldBeNil)
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
if matches := packagePattern.FindStringSubmatch(line); matches != nil {
t.Logf("Matched package name: %s", matches[1])
}
}
So(scanner.Err(), ShouldBeNil)
}) })
}) })
} }