From 80ab62534d669219cc7607b84511521130fe58b6 Mon Sep 17 00:00:00 2001 From: Rogee Date: Thu, 19 Dec 2024 15:00:58 +0800 Subject: [PATCH] feat: add gen enum command --- cmd/gen.go | 1 + cmd/gen_enum.go | 110 +++ go.mod | 17 +- go.sum | 27 +- pkg/utils/fs.go | 23 + pkg/utils/generate.go | 1 + pkg/utils/generator/embedded_1.16.go | 16 + pkg/utils/generator/enum.tmpl | 367 ++++++++++ pkg/utils/generator/enum_string.tmpl | 384 +++++++++++ pkg/utils/generator/example_1.18_test.go | 22 + pkg/utils/generator/example_test.go | 151 +++++ pkg/utils/generator/generator.go | 745 +++++++++++++++++++++ pkg/utils/generator/generator_1.18_test.go | 321 +++++++++ pkg/utils/generator/generator_test.go | 368 ++++++++++ pkg/utils/generator/template_funcs.go | 126 ++++ 15 files changed, 2674 insertions(+), 5 deletions(-) create mode 100644 cmd/gen_enum.go create mode 100644 pkg/utils/fs.go create mode 100644 pkg/utils/generate.go create mode 100644 pkg/utils/generator/embedded_1.16.go create mode 100644 pkg/utils/generator/enum.tmpl create mode 100644 pkg/utils/generator/enum_string.tmpl create mode 100644 pkg/utils/generator/example_1.18_test.go create mode 100644 pkg/utils/generator/example_test.go create mode 100644 pkg/utils/generator/generator.go create mode 100644 pkg/utils/generator/generator_1.18_test.go create mode 100644 pkg/utils/generator/generator_test.go create mode 100644 pkg/utils/generator/template_funcs.go diff --git a/cmd/gen.go b/cmd/gen.go index e3d6031..b4d7c31 100644 --- a/cmd/gen.go +++ b/cmd/gen.go @@ -11,6 +11,7 @@ func CommandGen(root *cobra.Command) { cmds := []func(*cobra.Command){ CommandGenModel, + CommandGenEnum, } for _, c := range cmds { diff --git a/cmd/gen_enum.go b/cmd/gen_enum.go new file mode 100644 index 0000000..837bce1 --- /dev/null +++ b/cmd/gen_enum.go @@ -0,0 +1,110 @@ +package cmd + +import ( + "fmt" + "io/fs" + "log" + "os" + "path/filepath" + "strings" + + "git.ipao.vip/rogeecn/atomctl/pkg/utils" + "git.ipao.vip/rogeecn/atomctl/pkg/utils/generator" + _ "github.com/lib/pq" + "github.com/spf13/cobra" +) + +func CommandGenEnum(root *cobra.Command) { + cmd := &cobra.Command{ + Use: "enum", + Short: "Generate enums", + RunE: commandGenEnumE, + } + + cmd.Flags().BoolP("flag", "f", true, "Flag enum values") + cmd.Flags().BoolP("marshal", "m", false, "Marshal enum values") + cmd.Flags().BoolP("sql", "s", true, "SQL driver enum values") + + root.AddCommand(cmd) +} + +func commandGenEnumE(cmd *cobra.Command, args []string) error { + var filenames []string + + pwd, err := os.Getwd() + if err != nil { + return err + } + err = filepath.Walk(pwd, func(path string, info fs.FileInfo, err error) error { + if utils.IsDir(path) { + return nil + } + + if !strings.HasSuffix(path, ".go") { + return nil + } + + content, err := os.ReadFile(path) + if err != nil { + return err + } + + if strings.Contains(string(content), "ENUM(") && strings.Contains(string(content), "swagger:enum") { + filenames = append(filenames, path) + } + return nil + }) + if err != nil { + return err + } + + if len(filenames) == 0 { + return fmt.Errorf("no enum files found in %s", pwd) + } + + g := generator.NewGenerator() + + if marshal, _ := cmd.Flags().GetBool("marshal"); marshal { + g.WithMarshal() + } + + if flag, _ := cmd.Flags().GetBool("flag"); flag { + g.WithFlag() + } + + if sql, _ := cmd.Flags().GetBool("sql"); sql { + g.WithSQLDriver() + g.WithSQLInt() + g.WithSQLNullInt() + g.WithSQLNullStr() + } + + g.WithNames() + g.WithValues() + + for _, fileName := range filenames { + log.Printf("Generating enums for %s", fileName) + + fileName, _ = filepath.Abs(fileName) + outFilePath := fmt.Sprintf("%s.gen.go", strings.TrimSuffix(fileName, filepath.Ext(fileName))) + + // Parse the file given in arguments + raw, err := g.GenerateFromFile(fileName) + if err != nil { + return fmt.Errorf("failed generating enums\nInputFile=%s\nError=%s", fileName, err) + } + + // Nothing was generated, ignore the output and don't create a file. + if len(raw) < 1 { + continue + } + + mode := int(0o644) + err = os.WriteFile(outFilePath, raw, os.FileMode(mode)) + if err != nil { + return fmt.Errorf("failed writing to file %s: %s", outFilePath, err) + } + } + + return nil +} diff --git a/go.mod b/go.mod index d8137f5..6bf32f4 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,8 @@ module git.ipao.vip/rogeecn/atomctl go 1.23.2 require ( + github.com/Masterminds/sprig/v3 v3.3.0 + github.com/bradleyjkemp/cupaloy/v2 v2.8.0 github.com/go-jet/jet/v2 v2.12.0 github.com/lib/pq v1.10.9 github.com/pkg/errors v0.9.1 @@ -11,13 +13,20 @@ require ( github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.8.1 github.com/spf13/viper v1.19.0 + github.com/stretchr/testify v1.9.0 + golang.org/x/text v0.21.0 + golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d ) require ( + dario.cat/mergo v1.0.1 // indirect + github.com/Masterminds/goutils v1.1.1 // indirect + github.com/Masterminds/semver/v3 v3.3.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect + github.com/huandu/xstrings v1.5.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect github.com/jackc/pgconn v1.14.3 // indirect @@ -28,24 +37,26 @@ require ( github.com/jackc/pgtype v1.14.4 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mfridman/interpolate v0.0.2 // indirect + github.com/mitchellh/copystructure v1.2.0 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sethvargo/go-retry v0.3.0 // indirect + github.com/shopspring/decimal v1.4.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.11.0 // indirect - github.com/spf13/cast v1.6.0 // indirect + github.com/spf13/cast v1.7.0 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/stretchr/testify v1.9.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/crypto v0.31.0 // indirect golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect + golang.org/x/mod v0.17.0 // indirect golang.org/x/sync v0.10.0 // indirect golang.org/x/sys v0.28.0 // indirect - golang.org/x/text v0.21.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 32ce638..c48067d 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,15 @@ +dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= +dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI= +github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU= github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= +github.com/Masterminds/semver/v3 v3.3.0 h1:B8LGeaivUe71a5qox1ICM/JLl0NqZSW5CHyL+hmvYS0= +github.com/Masterminds/semver/v3 v3.3.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= +github.com/Masterminds/sprig/v3 v3.3.0 h1:mQh0Yrg1XPo6vjYXgtf5OtijNAKJRNcTdOOGZe3tPhs= +github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSCzdgBfDb35Lz0= +github.com/bradleyjkemp/cupaloy/v2 v2.8.0 h1:any4BmKE+jGIaMpnU8YgH/I2LPiLBufr6oMMlVBbn9M= +github.com/bradleyjkemp/cupaloy/v2 v2.8.0/go.mod h1:bm7JXdkRd4BHJk9HpwqAI8BoAY1lps46Enkdqw6aRX0= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= @@ -31,6 +41,8 @@ github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= +github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= @@ -112,8 +124,12 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY= github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg= +github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= +github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ= +github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= @@ -146,6 +162,8 @@ github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= @@ -154,8 +172,8 @@ github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9yS github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= -github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= -github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w= +github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= @@ -172,6 +190,7 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= @@ -217,6 +236,8 @@ golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKG golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -279,6 +300,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/pkg/utils/fs.go b/pkg/utils/fs.go new file mode 100644 index 0000000..6ddcf19 --- /dev/null +++ b/pkg/utils/fs.go @@ -0,0 +1,23 @@ +package utils + +import ( + "os" +) + +func IsDir(dir string) bool { + f, err := os.Stat(dir) + if err != nil { + return false + } + + return f.IsDir() +} + +func IsFile(path string) bool { + f, err := os.Stat(path) + if err != nil { + return false + } + + return f.Mode().IsRegular() +} diff --git a/pkg/utils/generate.go b/pkg/utils/generate.go new file mode 100644 index 0000000..d4b585b --- /dev/null +++ b/pkg/utils/generate.go @@ -0,0 +1 @@ +package utils diff --git a/pkg/utils/generator/embedded_1.16.go b/pkg/utils/generator/embedded_1.16.go new file mode 100644 index 0000000..ea64de9 --- /dev/null +++ b/pkg/utils/generator/embedded_1.16.go @@ -0,0 +1,16 @@ +//go:build go1.16 +// +build go1.16 + +package generator + +import ( + "embed" + "text/template" +) + +//go:embed enum.tmpl enum_string.tmpl +var content embed.FS + +func (g *Generator) addEmbeddedTemplates() { + g.t = template.Must(g.t.ParseFS(content, "*.tmpl")) +} diff --git a/pkg/utils/generator/enum.tmpl b/pkg/utils/generator/enum.tmpl new file mode 100644 index 0000000..f24f92e --- /dev/null +++ b/pkg/utils/generator/enum.tmpl @@ -0,0 +1,367 @@ +{{- define "header"}} +// Code generated by go-enum DO NOT EDIT. +// Version: {{ .version }} +// Revision: {{ .revision }} +// Build Date: {{ .buildDate }} +// Built By: {{ .builtBy }} +{{ range $idx, $tag := .buildTags }} +//go:build {{$tag}} +// +build {{$tag}} +{{- end }} + +package {{.package}} + +import ( + "fmt" +) +{{end -}} + +{{- define "enum"}} +const ( +{{- $enumName := .enum.Name -}} +{{- $enumType := .enum.Type -}} +{{- $noComments := .nocomments -}} +{{- $vars := dict "lastoffset" "0" -}} +{{ range $rIndex, $value := .enum.Values }} + {{- $lastOffset := pluck "lastoffset" $vars | first }}{{ $offset := offset $rIndex $enumType $value }} + {{- if $noComments }}{{else}} + {{ if eq $value.Name "_"}}// Skipped value.{{else}}// {{$value.PrefixedName}} is a {{$enumName}} of type {{$value.Name}}.{{end}}{{end}} + {{- if $value.Comment}} + // {{$value.Comment}} + {{- end}} + {{$value.PrefixedName}} {{ if eq $rIndex 0 }}{{$enumName}} = iota{{ if ne "0" $offset }} + {{ $offset }}{{end}}{{else if ne $lastOffset $offset }}{{$enumName}} = iota + {{ $offset }}{{end}}{{$_ := set $vars "lastoffset" $offset}} +{{- end}} +) +{{if .names -}} +var ErrInvalid{{.enum.Name}} = fmt.Errorf("not a valid {{.enum.Name}}, try [%s]", strings.Join(_{{.enum.Name}}Names, ", ")) +{{- else -}} +var ErrInvalid{{.enum.Name}} = errors.New("not a valid {{.enum.Name}}") +{{- end}} + +{{ template "stringer" . }} + +var _{{.enum.Name}}Map = {{ mapify .enum }} + +// String implements the Stringer interface. +func (x {{.enum.Name}}) String() string { + if str, ok := _{{.enum.Name}}Map[x]; ok { + return str + } + return fmt.Sprintf("{{.enum.Name}}(%d)", x) +} + +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x {{.enum.Name}}) IsValid() bool { + _, ok := _{{.enum.Name}}Map[x] + return ok +} + +var _{{.enum.Name}}Value = {{ unmapify .enum .lowercase }} + +// Parse{{.enum.Name}} attempts to convert a string to a {{.enum.Name}}. +func Parse{{.enum.Name}}(name string) ({{.enum.Name}}, error) { + if x, ok := _{{.enum.Name}}Value[name]; ok { + return x, nil + }{{if .nocase }} + // Case insensitive parse, do a separate lookup to prevent unnecessary cost of lowercasing a string if we don't need to. + if x, ok := _{{.enum.Name}}Value[strings.ToLower(name)]; ok { + return x, nil + }{{- end}} + return {{.enum.Name}}(0), fmt.Errorf("%s is %w", name, ErrInvalid{{.enum.Name}}) +} + +{{ if .mustparse }} +// MustParse{{.enum.Name}} converts a string to a {{.enum.Name}}, and panics if is not valid. +func MustParse{{.enum.Name}}(name string) {{.enum.Name}} { + val, err := Parse{{.enum.Name}}(name) + if err != nil { + panic(err) + } + return val +} +{{end}} + +{{ if .ptr }} +func (x {{.enum.Name}}) Ptr() *{{.enum.Name}} { + return &x +} +{{end}} + +{{ if .marshal }} +// MarshalText implements the text marshaller method. +func (x {{.enum.Name}}) MarshalText() ([]byte, error) { + return []byte(x.String()), nil +} + +// UnmarshalText implements the text unmarshaller method. +func (x *{{.enum.Name}}) UnmarshalText(text []byte) error { + name := string(text) + tmp, err := Parse{{.enum.Name}}(name) + if err != nil { + return err + } + *x = tmp + return nil +} +{{end}} + +{{ if or .sql .sqlnullint .sqlnullstr}} +var err{{.enum.Name}}NilPtr = errors.New("value pointer is nil") // one per type for package clashes + +// Scan implements the Scanner interface. +func (x *{{.enum.Name}}) Scan(value interface{}) (err error) { + if value == nil { + *x = {{.enum.Name}}(0) + return + } + + // A wider range of scannable types. + // driver.Value values at the top of the list for expediency + switch v := value.(type) { + case int64: + *x = {{.enum.Name}}(v) + case string: + *x, err = Parse{{.enum.Name}}(v){{if .sqlnullint }} + if err != nil { + // try parsing the integer value as a string + if val, verr := strconv.Atoi(v); verr == nil { + *x, err = {{.enum.Name}}(val), nil + } + }{{end}} + case []byte: + *x, err = Parse{{.enum.Name}}(string(v)){{if .sqlnullint }} + if err != nil { + // try parsing the integer value as a string + if val, verr := strconv.Atoi(string(v)); verr == nil { + *x, err = {{.enum.Name}}(val), nil + } + }{{end}} + case {{.enum.Name}}: + *x = v + case int: + *x = {{.enum.Name}}(v) + case *{{.enum.Name}}: + if v == nil{ + return err{{.enum.Name}}NilPtr + } + *x = *v + case uint: + *x = {{.enum.Name}}(v) + case uint64: + *x = {{.enum.Name}}(v) + case *int: + if v == nil{ + return err{{.enum.Name}}NilPtr + } + *x = {{.enum.Name}}(*v) + case *int64: + if v == nil{ + return err{{.enum.Name}}NilPtr + } + *x = {{.enum.Name}}(*v) + case float64: // json marshals everything as a float64 if it's a number + *x = {{.enum.Name}}(v) + case *float64: // json marshals everything as a float64 if it's a number + if v == nil{ + return err{{.enum.Name}}NilPtr + } + *x = {{.enum.Name}}(*v) + case *uint: + if v == nil{ + return err{{.enum.Name}}NilPtr + } + *x = {{.enum.Name}}(*v) + case *uint64: + if v == nil{ + return err{{.enum.Name}}NilPtr + } + *x = {{.enum.Name}}(*v) + case *string: + if v == nil{ + return err{{.enum.Name}}NilPtr + } + *x, err = Parse{{.enum.Name}}(*v){{if .sqlnullint }} + if err != nil { + // try parsing the integer value as a string + if val, verr := strconv.Atoi(*v); verr == nil { + *x, err = {{.enum.Name}}(val), nil + } + }{{end}} + } + + return +} + +{{ if or .sql .sqlnullstr }} +// Value implements the driver Valuer interface. +func (x {{.enum.Name}}) Value() (driver.Value, error) { + return {{.enum.Type }}(x), nil +} +{{ else }} +// Value implements the driver Valuer interface. +func (x {{.enum.Name}}) Value() (driver.Value, error) { + return int64(x), nil +} +{{end}} + +{{end}} + + +{{ if .flag }} +// Set implements the Golang flag.Value interface func. +func (x *{{.enum.Name}}) Set(val string) error { + v, err := Parse{{.enum.Name}}(val) + *x = v + return err +} + +// Get implements the Golang flag.Getter interface func. +func (x *{{.enum.Name}}) Get() interface{} { + return *x +} + +// Type implements the github.com/spf13/pFlag Value interface. +func (x *{{.enum.Name}}) Type() string { + return "{{.enum.Name}}" +} +{{end}} + +{{ if or .sqlnullint .sqlnullstr }} +type Null{{.enum.Name}} struct{ + {{.enum.Name}} {{.enum.Name}} + Valid bool{{/* Add some info as to whether this value was set during unmarshalling or not */}}{{if .marshal }} + Set bool{{ end }} +} + +func NewNull{{.enum.Name}}(val interface{}) (x Null{{.enum.Name}}) { + x.Scan(val) // yes, we ignore this error, it will just be an invalid value. + return +} + +// Scan implements the Scanner interface. +func (x *Null{{.enum.Name}}) Scan(value interface{}) (err error) { + {{- if .marshal }}x.Set = true{{ end }} + if value == nil { + x.{{.enum.Name}}, x.Valid = {{.enum.Name}}(0), false + return + } + + err = x.{{.enum.Name}}.Scan(value) + x.Valid = (err == nil) + return +} + +{{ if .sqlnullint }} +// Value implements the driver Valuer interface. +func (x Null{{.enum.Name}}) Value() (driver.Value, error) { + if !x.Valid{ + return nil, nil + } + // driver.Value accepts int64 for int values. + return int64(x.{{.enum.Name}}), nil +} +{{ else }} +// Value implements the driver Valuer interface. +func (x Null{{.enum.Name}}) Value() (driver.Value, error) { + if !x.Valid{ + return nil, nil + } + return x.{{.enum.Name}}.String(), nil +} +{{ end }} + +{{ if .marshal }} +// MarshalJSON correctly serializes a Null{{.enum.Name}} to JSON. +func (n Null{{.enum.Name}}) MarshalJSON() ([]byte, error) { + const nullStr = "null" + if n.Valid { + return json.Marshal(n.{{.enum.Name}}) + } + return []byte(nullStr), nil +} + +// UnmarshalJSON correctly deserializes a Null{{.enum.Name}} from JSON. +func (n *Null{{.enum.Name}}) UnmarshalJSON(b []byte) error { + n.Set = true + var x interface{} + err := json.Unmarshal(b, &x) + if err != nil{ + return err + } + err = n.Scan(x) + return err +} +{{ end }} + +{{ end }} + +{{ if and .sqlnullint .sqlnullstr }} +type Null{{.enum.Name}}Str struct { + Null{{.enum.Name}} +} + +func NewNull{{.enum.Name}}Str(val interface{}) (x Null{{.enum.Name}}Str) { + x.Scan(val) // yes, we ignore this error, it will just be an invalid value. + return +} + +// Value implements the driver Valuer interface. +func (x Null{{.enum.Name}}Str) Value() (driver.Value, error) { + if !x.Valid{ + return nil, nil + } + return x.{{.enum.Name}}.String(), nil +} +{{ if .marshal }} +// MarshalJSON correctly serializes a Null{{.enum.Name}} to JSON. +func (n Null{{.enum.Name}}Str) MarshalJSON() ([]byte, error) { + const nullStr = "null" + if n.Valid { + return json.Marshal(n.{{.enum.Name}}) + } + return []byte(nullStr), nil +} + +// UnmarshalJSON correctly deserializes a Null{{.enum.Name}} from JSON. +func (n *Null{{.enum.Name}}Str) UnmarshalJSON(b []byte) error { + n.Set = true + var x interface{} + err := json.Unmarshal(b, &x) + if err != nil{ + return err + } + err = n.Scan(x) + return err +} +{{ end }} +{{ end }} + +{{end}} + + +{{- define "stringer"}} +const _{{.enum.Name}}Name = "{{ stringify .enum .forcelower }}" + +{{ if .names }}var _{{.enum.Name}}Names = {{namify .enum}} + +// {{.enum.Name}}Names returns a list of possible string values of {{.enum.Name}}. +func {{.enum.Name}}Names() []string { + tmp := make([]string, len(_{{.enum.Name}}Names)) + copy(tmp, _{{.enum.Name}}Names) + return tmp +} +{{ end -}} + +{{ if .values }} + +// {{.enum.Name}}Values returns a list of the values for {{.enum.Name}} +func {{.enum.Name}}Values() []{{.enum.Name}} { + return []{{.enum.Name}}{ {{ range $rIndex, $value := .enum.Values }}{{ if ne $value.Name "_"}} + {{$value.PrefixedName}},{{ end }} +{{- end}} + } +} +{{ end -}} + +{{end}} diff --git a/pkg/utils/generator/enum_string.tmpl b/pkg/utils/generator/enum_string.tmpl new file mode 100644 index 0000000..21ccf2d --- /dev/null +++ b/pkg/utils/generator/enum_string.tmpl @@ -0,0 +1,384 @@ +{{- define "enum_string"}} +const ( +{{- $enumName := .enum.Name -}} +{{- $enumType := .enum.Type -}} +{{- $noComments := .nocomments -}} +{{- $vars := dict "lastoffset" "0" -}} +{{ range $rIndex, $value := .enum.Values }} + {{- if $noComments }}{{else}} + {{ if eq $value.Name "_"}}// Skipped value.{{else}}// {{$value.PrefixedName}} is a {{$enumName}} of type {{$value.RawName}}.{{end}}{{end}} + {{- if $value.Comment}} + // {{$value.Comment}} + {{- end}} + {{$value.PrefixedName}} {{$enumName}} = "{{$value.ValueStr}}" +{{- end}} +) +{{if .names -}} +var ErrInvalid{{.enum.Name}} = fmt.Errorf("not a valid {{.enum.Name}}, try [%s]", strings.Join(_{{.enum.Name}}Names, ", ")) +{{- else -}} +var ErrInvalid{{.enum.Name}} = errors.New("not a valid {{.enum.Name}}") +{{- end}} + +{{ if .names }}var _{{.enum.Name}}Names = {{namify .enum}} + +// {{.enum.Name}}Names returns a list of possible string values of {{.enum.Name}}. +func {{.enum.Name}}Names() []string { + tmp := make([]string, len(_{{.enum.Name}}Names)) + copy(tmp, _{{.enum.Name}}Names) + return tmp +} +{{ end -}} + + +{{ if .values }} + +// {{.enum.Name}}Values returns a list of the values for {{.enum.Name}} +func {{.enum.Name}}Values() []{{.enum.Name}} { + return []{{.enum.Name}}{ {{ range $rIndex, $value := .enum.Values }}{{ if ne $value.Name "_"}} + {{$value.PrefixedName}},{{ end }} +{{- end}} + } +} +{{ end -}} + +// String implements the Stringer interface. +func (x {{.enum.Name}}) String() string { + return string(x) +} + +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x {{.enum.Name}}) IsValid() bool { + _, err := Parse{{.enum.Name}}(string(x)) + return err == nil +} + +var _{{.enum.Name}}Value = {{ unmapify .enum .lowercase }} + +// Parse{{.enum.Name}} attempts to convert a string to a {{.enum.Name}}. +func Parse{{.enum.Name}}(name string) ({{.enum.Name}}, error) { + if x, ok := _{{.enum.Name}}Value[name]; ok { + return x, nil + }{{if .nocase }} + // Case insensitive parse, do a separate lookup to prevent unnecessary cost of lowercasing a string if we don't need to. + if x, ok := _{{.enum.Name}}Value[strings.ToLower(name)]; ok { + return x, nil + }{{- end}} + return {{.enum.Name}}(""), fmt.Errorf("%s is %w", name, ErrInvalid{{.enum.Name}}) +} + +{{ if .mustparse }} +// MustParse{{.enum.Name}} converts a string to a {{.enum.Name}}, and panics if is not valid. +func MustParse{{.enum.Name}}(name string) {{.enum.Name}} { + val, err := Parse{{.enum.Name}}(name) + if err != nil { + panic(err) + } + return val +} +{{end}} + +{{ if .ptr }} +func (x {{.enum.Name}}) Ptr() *{{.enum.Name}} { + return &x +} +{{end}} + +{{ if .marshal }} +// MarshalText implements the text marshaller method. +func (x {{.enum.Name}}) MarshalText() ([]byte, error) { + return []byte(string(x)), nil +} + +// UnmarshalText implements the text unmarshaller method. +func (x *{{.enum.Name}}) UnmarshalText(text []byte) error { + tmp, err := Parse{{.enum.Name}}(string(text)) + if err != nil { + return err + } + *x = tmp + return nil +} +{{end}} + +{{ if .anySQLEnabled }} +var err{{.enum.Name}}NilPtr = errors.New("value pointer is nil") // one per type for package clashes +{{ end }} + +{{/* SQL stored as a string value */}} +{{ if .sql }} +{{ if eq .enum.Type "string" }} +// Scan implements the Scanner interface. +func (x *{{.enum.Name}}) Scan(value interface{}) (err error) { + if value == nil { + *x = {{.enum.Name}}("") + return + } + + // A wider range of scannable types. + // driver.Value values at the top of the list for expediency + switch v := value.(type) { + case string: + *x, err = Parse{{.enum.Name}}(v) + case []byte: + *x, err = Parse{{.enum.Name}}(string(v)) + case {{.enum.Name}}: + *x = v + case *{{.enum.Name}}: + if v == nil{ + return err{{.enum.Name}}NilPtr + } + *x = *v + case *string: + if v == nil{ + return err{{.enum.Name}}NilPtr + } + *x, err = Parse{{.enum.Name}}(*v) + default: + return errors.New("invalid type for {{.enum.Name}}") + } + + return +} + +// Value implements the driver Valuer interface. +func (x {{.enum.Name}}) Value() (driver.Value, error) { + return x.String(), nil +} + +{{ else if or .sqlint .sqlnullint }} +{{/* SQL stored as an integer value */}} +var sqlInt{{.enum.Name}}Map = map[int64]{{.enum.Name}}{ {{ range $rIndex, $value := .enum.Values }}{{ if ne $value.Name "_"}} +{{ $value.ValueInt }}: {{ $value.PrefixedName }},{{end}} +{{- end}} +} + +var sqlInt{{.enum.Name}}Value = map[{{.enum.Name}}]int64{ {{ range $rIndex, $value := .enum.Values }}{{ if ne $value.Name "_"}} + {{ $value.PrefixedName }}: {{ $value.ValueInt }},{{end}} +{{- end}} +} + +func lookupSqlInt{{.enum.Name}}(val int64) ({{.enum.Name}}, error){ + x, ok := sqlInt{{.enum.Name}}Map[val] + if !ok{ + return x, fmt.Errorf("%v is not %w", val, ErrInvalid{{.enum.Name}}) + } + return x, nil +} + +// Scan implements the Scanner interface. +func (x *{{.enum.Name}}) Scan(value interface{}) (err error) { + if value == nil { + *x = {{.enum.Name}}("") + return + } + + // A wider range of scannable types. + // driver.Value values at the top of the list for expediency + switch v := value.(type) { + case int64: + *x, err = lookupSqlInt{{.enum.Name}}(v) + case string: + *x, err = Parse{{.enum.Name}}(v) + case []byte: + if val, verr := strconv.ParseInt(string(v), 10, 64); verr == nil { + *x, err = lookupSqlInt{{.enum.Name}}(val) + } else { + // try parsing the value as a string + *x, err = Parse{{.enum.Name}}(string(v)) + } + case {{.enum.Name}}: + *x = v + case int: + *x, err = lookupSqlInt{{.enum.Name}}(int64(v)) + case *{{.enum.Name}}: + if v == nil{ + return err{{.enum.Name}}NilPtr + } + *x = *v + case uint: + *x, err = lookupSqlInt{{.enum.Name}}(int64(v)) + case uint64: + *x, err = lookupSqlInt{{.enum.Name}}(int64(v)) + case *int: + if v == nil{ + return err{{.enum.Name}}NilPtr + } + *x, err = lookupSqlInt{{.enum.Name}}(int64(*v)) + case *int64: + if v == nil{ + return err{{.enum.Name}}NilPtr + } + *x, err = lookupSqlInt{{.enum.Name}}(int64(*v)) + case float64: // json marshals everything as a float64 if it's a number + *x, err = lookupSqlInt{{.enum.Name}}(int64(v)) + case *float64: // json marshals everything as a float64 if it's a number + if v == nil{ + return err{{.enum.Name}}NilPtr + } + *x, err = lookupSqlInt{{.enum.Name}}(int64(*v)) + case *uint: + if v == nil{ + return err{{.enum.Name}}NilPtr + } + *x, err = lookupSqlInt{{.enum.Name}}(int64(*v)) + case *uint64: + if v == nil{ + return err{{.enum.Name}}NilPtr + } + *x, err = lookupSqlInt{{.enum.Name}}(int64(*v)) + case *string: + if v == nil{ + return err{{.enum.Name}}NilPtr + } + *x, err = Parse{{.enum.Name}}(*v) + default: + return errors.New("invalid type for {{.enum.Name}}") + } + + return +} + +// Value implements the driver Valuer interface. +func (x {{.enum.Name}}) Value() (driver.Value, error) { + val, ok := sqlInt{{.enum.Name}}Value[x] + if !ok{ + return nil, ErrInvalid{{.enum.Name}} + } + return int64(val), nil +} + +{{end}} +{{end}} + + +{{ if .flag }} +// Set implements the Golang flag.Value interface func. +func (x *{{.enum.Name}}) Set(val string) error { + v, err := Parse{{.enum.Name}}(val) + *x = v + return err +} + +// Get implements the Golang flag.Getter interface func. +func (x *{{.enum.Name}}) Get() interface{} { + return *x +} + +// Type implements the github.com/spf13/pFlag Value interface. +func (x *{{.enum.Name}}) Type() string { + return "{{.enum.Name}}" +} +{{end}} + +{{ if or .sqlnullint .sqlnullstr }} +type Null{{.enum.Name}} struct{ + {{.enum.Name}} {{.enum.Name}} + Valid bool{{/* Add some info as to whether this value was set during unmarshalling or not */}}{{if .marshal }} + Set bool{{ end }} +} + +func NewNull{{.enum.Name}}(val interface{}) (x Null{{.enum.Name}}) { + err := x.Scan(val) // yes, we ignore this error, it will just be an invalid value. + _ = err // make any errcheck linters happy + return +} + +// Scan implements the Scanner interface. +func (x *Null{{.enum.Name}}) Scan(value interface{}) (err error) { + if value == nil { + x.{{.enum.Name}}, x.Valid = {{.enum.Name}}(""), false + return + } + + err = x.{{.enum.Name}}.Scan(value) + x.Valid = (err == nil) + return +} + +{{ if .sqlnullint }} +// Value implements the driver Valuer interface. +func (x Null{{.enum.Name}}) Value() (driver.Value, error) { + if !x.Valid{ + return nil, nil + } + // driver.Value accepts int64 for int values. + return string(x.{{.enum.Name}}), nil +} +{{ else }} +// Value implements the driver Valuer interface. +func (x Null{{.enum.Name}}) Value() (driver.Value, error) { + if !x.Valid{ + return nil, nil + } + return x.{{.enum.Name}}.String(), nil +} +{{ end }} + +{{ if .marshal }} +// MarshalJSON correctly serializes a Null{{.enum.Name}} to JSON. +func (n Null{{.enum.Name}}) MarshalJSON() ([]byte, error) { + const nullStr = "null" + if n.Valid { + return json.Marshal(n.{{.enum.Name}}) + } + return []byte(nullStr), nil +} + +// UnmarshalJSON correctly deserializes a Null{{.enum.Name}} from JSON. +func (n *Null{{.enum.Name}}) UnmarshalJSON(b []byte) error { + n.Set = true + var x interface{} + err := json.Unmarshal(b, &x) + if err != nil{ + return err + } + err = n.Scan(x) + return err +} +{{ end }} + +{{ end }} + +{{ if and .sqlnullint .sqlnullstr }} +type Null{{.enum.Name}}Str struct { + Null{{.enum.Name}} +} + +func NewNull{{.enum.Name}}Str(val interface{}) (x Null{{.enum.Name}}Str) { + x.Scan(val) // yes, we ignore this error, it will just be an invalid value. + return +} + +// Value implements the driver Valuer interface. +func (x Null{{.enum.Name}}Str) Value() (driver.Value, error) { + if !x.Valid{ + return nil, nil + } + return x.{{.enum.Name}}.String(), nil +} +{{ if .marshal }} +// MarshalJSON correctly serializes a Null{{.enum.Name}} to JSON. +func (n Null{{.enum.Name}}Str) MarshalJSON() ([]byte, error) { + const nullStr = "null" + if n.Valid { + return json.Marshal(n.{{.enum.Name}}) + } + return []byte(nullStr), nil +} + +// UnmarshalJSON correctly deserializes a Null{{.enum.Name}} from JSON. +func (n *Null{{.enum.Name}}Str) UnmarshalJSON(b []byte) error { + n.Set = true + var x interface{} + err := json.Unmarshal(b, &x) + if err != nil{ + return err + } + err = n.Scan(x) + return err +} +{{ end }} +{{ end }} + +{{end}} diff --git a/pkg/utils/generator/example_1.18_test.go b/pkg/utils/generator/example_1.18_test.go new file mode 100644 index 0000000..dbf6658 --- /dev/null +++ b/pkg/utils/generator/example_1.18_test.go @@ -0,0 +1,22 @@ +//go:build go1.18 +// +build go1.18 + +package generator + +// SumIntsOrFloats sums the values of map m. It supports both int64 and float64 +// as types for map values. +func SumIntsOrFloats[K comparable, V int64 | float64](m map[K]V) V { + var s V + for _, v := range m { + s += v + } + return s +} + +// ChangeType is a type of change detected. +/* ENUM( + Create + Update + Delete +) */ +type ChangeType int diff --git a/pkg/utils/generator/example_test.go b/pkg/utils/generator/example_test.go new file mode 100644 index 0000000..f59ced8 --- /dev/null +++ b/pkg/utils/generator/example_test.go @@ -0,0 +1,151 @@ +package generator + +// X is doc'ed +type X struct{} + +// Color is an enumeration of colors that are allowed. +// ENUM( +// Black, White, Red +// Green +// Blue=33 +// grey= +// yellow +// ). +type Color int + +// Animal x ENUM( +// Cat, +// Dog, +// Fish +// ) Some other line of info +type Animal int32 + +// Model x ENUM(Toyota,_,Chevy,_,Ford). +type Model int32 + +/* + ENUM( + Coke + Pepsi + MtnDew + +). +*/ +type Soda int64 + +/* + ENUM( + test_lower + Test_capital + anotherLowerCaseStart + +) +*/ +type Cases int64 + +/* + ENUM( + test-Hyphen + -hyphenStart + _underscoreFirst + 0numberFirst + 123456789a + 123123-asdf + ending-hyphen- + +) +*/ +type Sanitizing int64 + +/* + ENUM( + startWithNum=23 + nextNum + +) +*/ +type StartNotZero int64 + +// ENUM( +// Black, White, Red +// Green +// Blue=33 // Blue starts with 33. +// grey= +// yellow +// ) +type ColorWithComment int + +/* +ENUM( +Black, White, Red +Green +Blue=33 // Blue starts with 33 +grey= +yellow +) +*/ +type ColorWithComment2 int + +/* ENUM( +Black, White, Red +Green = 33 // Green starts with 33 +*/ +// Blue +// grey= +// yellow +// blue-green // blue-green comment +// red-orange +// red-orange-blue +// ) +type ColorWithComment3 int + +/* ENUM( + _, // Placeholder +Black, White, Red +Green = 33 // Green starts with 33 +*/ +// Blue +// grey= +// yellow // Where did all the (somewhat) bad fish go? (something else that goes in parentheses at the end of the line) +// blue-green // blue-green comment +// red-orange // has a , in it!?! +// ) +type ColorWithComment4 int + +/* + ENUM( + +Unknown= 0 +E2P15 = 32768 +E2P16 = 65536 +E2P17 = 131072 +E2P18 = 262144 +E2P19 = 524288 +E2P20 = 1048576 +E2P21 = 2097152 +E2P22 = 33554432 +E2P23 = 67108864 +E2P28 = 536870912 +E2P30 = 1073741824 +E2P31 = 2147483648 +E2P32 = 4294967296 +E2P33 = 8454967296 +) +*/ +type Enum64bit uint64 + +// NonASCII +// ENUM( +// Продам = 1114 +// 車庫 = 300 +// էժան = 1 +// ) +type NonASCII int + +// StringEnum. +// ENUM( +// random = 1114 +// values = 300 +// here = 1 +// ) +type StringEnum string diff --git a/pkg/utils/generator/generator.go b/pkg/utils/generator/generator.go new file mode 100644 index 0000000..8466875 --- /dev/null +++ b/pkg/utils/generator/generator.go @@ -0,0 +1,745 @@ +package generator + +import ( + "bytes" + "errors" + "fmt" + "go/ast" + "go/parser" + "go/token" + "net/url" + "sort" + "strconv" + "strings" + "text/template" + "unicode" + + "github.com/Masterminds/sprig/v3" + "golang.org/x/text/cases" + "golang.org/x/text/language" + "golang.org/x/tools/imports" +) + +const ( + skipHolder = `_` + parseCommentPrefix = `//` +) + +var replacementNames = map[string]string{} + +// Generator is responsible for generating validation files for the given in a go source file. +type Generator struct { + Version string + Revision string + BuildDate string + BuiltBy string + t *template.Template + knownTemplates map[string]*template.Template + userTemplateNames []string + fileSet *token.FileSet + noPrefix bool + lowercaseLookup bool + caseInsensitive bool + marshal bool + sql bool + sqlint bool + flag bool + names bool + values bool + leaveSnakeCase bool + prefix string + sqlNullInt bool + sqlNullStr bool + ptr bool + mustParse bool + forceLower bool + noComments bool + buildTags []string +} + +// Enum holds data for a discovered enum in the parsed source +type Enum struct { + Name string + Prefix string + Type string + Values []EnumValue +} + +// EnumValue holds the individual data for each enum value within the found enum. +type EnumValue struct { + RawName string + Name string + PrefixedName string + ValueStr string + ValueInt interface{} + Comment string +} + +// NewGenerator is a constructor method for creating a new Generator with default +// templates loaded. +func NewGenerator() *Generator { + g := &Generator{ + Version: "-", + Revision: "-", + BuildDate: "-", + BuiltBy: "-", + knownTemplates: make(map[string]*template.Template), + userTemplateNames: make([]string, 0), + t: template.New("generator"), + fileSet: token.NewFileSet(), + noPrefix: false, + } + + funcs := sprig.TxtFuncMap() + + funcs["stringify"] = Stringify + funcs["mapify"] = Mapify + funcs["unmapify"] = Unmapify + funcs["namify"] = Namify + funcs["offset"] = Offset + + g.t.Funcs(funcs) + + g.addEmbeddedTemplates() + + g.updateTemplates() + + return g +} + +// WithNoPrefix is used to change the enum const values generated to not have the enum on them. +func (g *Generator) WithNoPrefix() *Generator { + g.noPrefix = true + return g +} + +// WithLowercaseVariant is used to change the enum const values generated to not have the enum on them. +func (g *Generator) WithLowercaseVariant() *Generator { + g.lowercaseLookup = true + return g +} + +// WithLowercaseVariant is used to change the enum const values generated to not have the enum on them. +func (g *Generator) WithCaseInsensitiveParse() *Generator { + g.lowercaseLookup = true + g.caseInsensitive = true + return g +} + +// WithMarshal is used to add marshalling to the enum +func (g *Generator) WithMarshal() *Generator { + g.marshal = true + return g +} + +// WithSQLDriver is used to add marshalling to the enum +func (g *Generator) WithSQLDriver() *Generator { + g.sql = true + return g +} + +// WithSQLInt is used to signal a string to be stored as an int. +func (g *Generator) WithSQLInt() *Generator { + g.sqlint = true + return g +} + +// WithFlag is used to add flag methods to the enum +func (g *Generator) WithFlag() *Generator { + g.flag = true + return g +} + +// WithNames is used to add Names methods to the enum +func (g *Generator) WithNames() *Generator { + g.names = true + return g +} + +// WithValues is used to add Values methods to the enum +func (g *Generator) WithValues() *Generator { + g.values = true + return g +} + +// WithoutSnakeToCamel is used to add flag methods to the enum +func (g *Generator) WithoutSnakeToCamel() *Generator { + g.leaveSnakeCase = true + return g +} + +// WithPrefix is used to add a custom prefix to the enum constants +func (g *Generator) WithPrefix(prefix string) *Generator { + g.prefix = prefix + return g +} + +// WithPtr adds a way to get a pointer value straight from the const value. +func (g *Generator) WithPtr() *Generator { + g.ptr = true + return g +} + +// WithSQLNullInt is used to add a null int option for SQL interactions. +func (g *Generator) WithSQLNullInt() *Generator { + g.sqlNullInt = true + return g +} + +// WithSQLNullStr is used to add a null string option for SQL interactions. +func (g *Generator) WithSQLNullStr() *Generator { + g.sqlNullStr = true + return g +} + +// WithMustParse is used to add a method `MustParse` that will panic on failure. +func (g *Generator) WithMustParse() *Generator { + g.mustParse = true + return g +} + +// WithForceLower is used to force enums names to lower case while keeping variable names the same. +func (g *Generator) WithForceLower() *Generator { + g.forceLower = true + return g +} + +// WithNoComments is used to remove auto generated comments from the enum. +func (g *Generator) WithNoComments() *Generator { + g.noComments = true + return g +} + +// WithBuildTags will add build tags to the generated file. +func (g *Generator) WithBuildTags(tags ...string) *Generator { + g.buildTags = append(g.buildTags, tags...) + return g +} + +func (g *Generator) anySQLEnabled() bool { + return g.sql || g.sqlNullStr || g.sqlint || g.sqlNullInt +} + +// ParseAliases is used to add aliases to replace during name sanitization. +func ParseAliases(aliases []string) error { + aliasMap := map[string]string{} + + for _, str := range aliases { + kvps := strings.Split(str, ",") + for _, kvp := range kvps { + parts := strings.Split(kvp, ":") + if len(parts) != 2 { + return fmt.Errorf("invalid formatted alias entry %q, must be in the format \"key:value\"", kvp) + } + aliasMap[parts[0]] = parts[1] + } + } + + for k, v := range aliasMap { + replacementNames[k] = v + } + + return nil +} + +// WithTemplates is used to provide the filenames of additional templates. +func (g *Generator) WithTemplates(filenames ...string) *Generator { + for _, ut := range template.Must(g.t.ParseFiles(filenames...)).Templates() { + if _, ok := g.knownTemplates[ut.Name()]; !ok { + g.userTemplateNames = append(g.userTemplateNames, ut.Name()) + } + } + g.updateTemplates() + sort.Strings(g.userTemplateNames) + return g +} + +// GenerateFromFile is responsible for orchestrating the Code generation. It results in a byte array +// that can be written to any file desired. It has already had goimports run on the code before being returned. +func (g *Generator) GenerateFromFile(inputFile string) ([]byte, error) { + f, err := g.parseFile(inputFile) + if err != nil { + return nil, fmt.Errorf("generate: error parsing input file '%s': %s", inputFile, err) + } + return g.Generate(f) +} + +// Generate does the heavy lifting for the code generation starting from the parsed AST file. +func (g *Generator) Generate(f *ast.File) ([]byte, error) { + enums := g.inspect(f) + if len(enums) <= 0 { + return nil, nil + } + + pkg := f.Name.Name + + vBuff := bytes.NewBuffer([]byte{}) + err := g.t.ExecuteTemplate(vBuff, "header", map[string]interface{}{ + "package": pkg, + "version": g.Version, + "revision": g.Revision, + "buildDate": g.BuildDate, + "builtBy": g.BuiltBy, + "buildTags": g.buildTags, + }) + if err != nil { + return nil, fmt.Errorf("failed writing header: %w", err) + } + + // Make the output more consistent by iterating over sorted keys of map + var keys []string + for key := range enums { + keys = append(keys, key) + } + sort.Strings(keys) + + var created int + for _, name := range keys { + ts := enums[name] + + // Parse the enum doc statement + enum, pErr := g.parseEnum(ts) + if pErr != nil { + continue + } + + created++ + data := map[string]interface{}{ + "enum": enum, + "name": name, + "lowercase": g.lowercaseLookup, + "nocase": g.caseInsensitive, + "nocomments": g.noComments, + "marshal": g.marshal, + "sql": g.sql, + "sqlint": g.sqlint, + "flag": g.flag, + "names": g.names, + "ptr": g.ptr, + "values": g.values, + "anySQLEnabled": g.anySQLEnabled(), + "sqlnullint": g.sqlNullInt, + "sqlnullstr": g.sqlNullStr, + "mustparse": g.mustParse, + "forcelower": g.forceLower, + } + + templateName := "enum" + if enum.Type == "string" { + templateName = "enum_string" + } + + err = g.t.ExecuteTemplate(vBuff, templateName, data) + if err != nil { + return vBuff.Bytes(), fmt.Errorf("failed writing enum data for enum: %q: %w", name, err) + } + + for _, userTemplateName := range g.userTemplateNames { + err = g.t.ExecuteTemplate(vBuff, userTemplateName, data) + if err != nil { + return vBuff.Bytes(), fmt.Errorf("failed writing enum data for enum: %q, template: %v: %w", name, userTemplateName, err) + } + } + } + + if created < 1 { + // Don't save anything if we didn't actually generate any successful enums. + return nil, nil + } + + formatted, err := imports.Process(pkg, vBuff.Bytes(), nil) + if err != nil { + err = fmt.Errorf("generate: error formatting code %s\n\n%s", err, vBuff.String()) + } + return formatted, err +} + +// updateTemplates will update the lookup map for validation checks that are +// allowed within the template engine. +func (g *Generator) updateTemplates() { + for _, template := range g.t.Templates() { + g.knownTemplates[template.Name()] = template + } +} + +// parseFile simply calls the go/parser ParseFile function with an empty token.FileSet +func (g *Generator) parseFile(fileName string) (*ast.File, error) { + // Parse the file given in arguments + return parser.ParseFile(g.fileSet, fileName, nil, parser.ParseComments) +} + +// parseEnum looks for the ENUM(x,y,z) formatted documentation from the type definition +func (g *Generator) parseEnum(ts *ast.TypeSpec) (*Enum, error) { + if ts.Doc == nil { + return nil, errors.New("no doc on enum") + } + + enum := &Enum{} + + enum.Name = ts.Name.Name + enum.Type = fmt.Sprintf("%s", ts.Type) + if !g.noPrefix { + enum.Prefix = ts.Name.Name + } + if g.prefix != "" { + enum.Prefix = g.prefix + enum.Prefix + } + + enumDecl := getEnumDeclFromComments(ts.Doc.List) + if enumDecl == "" { + return nil, errors.New("failed parsing enum") + } + + values := strings.Split(strings.TrimSuffix(strings.TrimPrefix(enumDecl, `ENUM(`), `)`), `,`) + var ( + data interface{} + unsigned bool + ) + if strings.HasPrefix(enum.Type, "u") { + data = uint64(0) + unsigned = true + } else { + data = int64(0) + } + for _, value := range values { + var comment string + + // Trim and store comments + if strings.Contains(value, parseCommentPrefix) { + commentStartIndex := strings.Index(value, parseCommentPrefix) + comment = value[commentStartIndex+len(parseCommentPrefix):] + comment = strings.TrimSpace(unescapeComment(comment)) + // value without comment + value = value[:commentStartIndex] + } + + // Make sure to leave out any empty parts + if value != "" { + rawName := value + valueStr := value + + if strings.Contains(value, `=`) { + // Get the value specified and set the data to that value. + equalIndex := strings.Index(value, `=`) + dataVal := strings.TrimSpace(value[equalIndex+1:]) + if dataVal != "" { + valueStr = dataVal + rawName = value[:equalIndex] + if enum.Type == "string" { + if parsed, err := strconv.ParseInt(dataVal, 10, 64); err == nil { + data = parsed + valueStr = rawName + } + if isQuoted(dataVal) { + valueStr = trimQuotes(dataVal) + } + } else if unsigned { + newData, err := strconv.ParseUint(dataVal, 10, 64) + if err != nil { + err = fmt.Errorf("failed parsing the data part of enum value '%s': %w", value, err) + fmt.Println(err) + return nil, err + } + data = newData + } else { + newData, err := strconv.ParseInt(dataVal, 10, 64) + if err != nil { + err = fmt.Errorf("failed parsing the data part of enum value '%s': %w", value, err) + fmt.Println(err) + return nil, err + } + data = newData + } + } else { + rawName = strings.TrimSuffix(rawName, `=`) + fmt.Printf("Ignoring enum with '=' but no value after: %s\n", rawName) + } + } + rawName = strings.TrimSpace(rawName) + valueStr = strings.TrimSpace(valueStr) + name := cases.Title(language.Und, cases.NoLower).String(rawName) + prefixedName := name + if name != skipHolder { + prefixedName = enum.Prefix + name + prefixedName = sanitizeValue(prefixedName) + if !g.leaveSnakeCase { + prefixedName = snakeToCamelCase(prefixedName) + } + } + + ev := EnumValue{Name: name, RawName: rawName, PrefixedName: prefixedName, ValueStr: valueStr, ValueInt: data, Comment: comment} + enum.Values = append(enum.Values, ev) + data = increment(data) + } + } + + // fmt.Printf("###\nENUM: %+v\n###\n", enum) + + return enum, nil +} + +func isQuoted(s string) bool { + s = strings.TrimSpace(s) + return (strings.HasPrefix(s, `"`) && strings.HasSuffix(s, `"`)) || (strings.HasPrefix(s, `'`) && strings.HasSuffix(s, `'`)) +} + +func trimQuotes(s string) string { + s = strings.TrimSpace(s) + for _, quote := range []string{`"`, `'`} { + s = strings.TrimPrefix(s, quote) + s = strings.TrimSuffix(s, quote) + } + return s +} + +func increment(d interface{}) interface{} { + switch v := d.(type) { + case uint64: + return v + 1 + case int64: + return v + 1 + } + return d +} + +func unescapeComment(comment string) string { + val, err := url.QueryUnescape(comment) + if err != nil { + return comment + } + return val +} + +// sanitizeValue will ensure the value name generated adheres to golang's +// identifier syntax as described here: https://golang.org/ref/spec#Identifiers +// identifier = letter { letter | unicode_digit } +// where letter can be unicode_letter or '_' +func sanitizeValue(value string) string { + // Keep skip value holders + if value == skipHolder { + return skipHolder + } + + replacedValue := value + for k, v := range replacementNames { + replacedValue = strings.ReplaceAll(replacedValue, k, v) + } + + nameBuilder := strings.Builder{} + nameBuilder.Grow(len(replacedValue)) + + for i, r := range replacedValue { + // If the start character is not a unicode letter (this check includes the case of '_') + // then we need to add an exported prefix, so tack on a 'X' at the beginning + if i == 0 && !unicode.IsLetter(r) { + nameBuilder.WriteRune('X') + } + + if unicode.IsLetter(r) || unicode.IsNumber(r) || r == '_' { + nameBuilder.WriteRune(r) + } + } + + return nameBuilder.String() +} + +func snakeToCamelCase(value string) string { + parts := strings.Split(value, "_") + title := cases.Title(language.Und, cases.NoLower) + + for i, part := range parts { + parts[i] = title.String(part) + } + value = strings.Join(parts, "") + + return value +} + +// getEnumDeclFromComments parses the array of comment strings and creates a single Enum Declaration statement +// that is easier to deal with for the remainder of parsing. It turns multi line declarations and makes a single +// string declaration. +func getEnumDeclFromComments(comments []*ast.Comment) string { + const EnumPrefix = "ENUM(" + var ( + parts []string + lines []string + store bool + enumParamLevel int + filteredLines []string + ) + + for _, comment := range comments { + lines = append(lines, breakCommentIntoLines(comment)...) + } + + filteredLines = make([]string, 0, len(lines)) + for idx := range lines { + line := lines[idx] + // If we're not in the enum, and this line doesn't contain the + // start string, then move along + if !store && !strings.Contains(line, EnumPrefix) { + continue + } + if !store { + // We must have had the start value in here + store = true + enumParamLevel = 1 + start := strings.Index(line, EnumPrefix) + line = line[start+len(EnumPrefix):] + } + lineParamLevel := strings.Count(line, "(") + lineParamLevel = lineParamLevel - strings.Count(line, ")") + + if enumParamLevel+lineParamLevel < 1 { + // We've ended, either with more than we need, or with just enough. Now we need to find the end. + for lineIdx, ch := range line { + if ch == '(' { + enumParamLevel = enumParamLevel + 1 + continue + } + if ch == ')' { + enumParamLevel = enumParamLevel - 1 + if enumParamLevel == 0 { + // We've found the end of the ENUM() definition, + // Cut off the suffix and break out of the loop + line = line[:lineIdx] + store = false + break + } + } + } + } + + filteredLines = append(filteredLines, line) + } + + if enumParamLevel > 0 { + fmt.Println("ENUM Parse error, there is a dangling '(' in your comment.") + return "" + } + + // Go over all the lines in this comment block + for _, line := range filteredLines { + _, trimmed := parseLinePart(line) + if trimmed != "" { + parts = append(parts, trimmed) + } + } + + joined := fmt.Sprintf("ENUM(%s)", strings.Join(parts, `,`)) + return joined +} + +func parseLinePart(line string) (paramLevel int, trimmed string) { + trimmed = line + comment := "" + if idx := strings.Index(line, parseCommentPrefix); idx >= 0 { + trimmed = line[:idx] + comment = "//" + url.QueryEscape(strings.TrimSpace(line[idx+2:])) + } + trimmed = trimAllTheThings(trimmed) + trimmed += comment + opens := strings.Count(line, `(`) + closes := strings.Count(line, `)`) + if opens > 0 { + paramLevel += opens + } + if closes > 0 { + paramLevel -= closes + } + return +} + +// breakCommentIntoLines takes the comment and since single line comments are already broken into lines +// we break multiline comments into separate lines for processing. +func breakCommentIntoLines(comment *ast.Comment) []string { + lines := []string{} + text := comment.Text + if strings.HasPrefix(text, `/*`) { + // deal with multi line comment + multiline := strings.TrimSuffix(strings.TrimPrefix(text, `/*`), `*/`) + lines = append(lines, strings.Split(multiline, "\n")...) + } else { + lines = append(lines, strings.TrimPrefix(text, `//`)) + } + return lines +} + +// trimAllTheThings takes off all the cruft of a line that we don't need. +// These lines should be pre-filtered so that we don't have to worry about +// the `ENUM(` prefix and the `)` suffix... those should already be removed. +func trimAllTheThings(thing string) string { + preTrimmed := strings.TrimSuffix(strings.TrimSpace(thing), `,`) + return strings.TrimSpace(preTrimmed) +} + +// inspect will walk the ast and fill a map of names and their struct information +// for use in the generation template. +func (g *Generator) inspect(f ast.Node) map[string]*ast.TypeSpec { + enums := make(map[string]*ast.TypeSpec) + // Inspect the AST and find all structs. + ast.Inspect(f, func(n ast.Node) bool { + switch x := n.(type) { + case *ast.GenDecl: + copyGenDeclCommentsToSpecs(x) + case *ast.Ident: + if x.Obj != nil { + // fmt.Printf("Node: %#v\n", x.Obj) + // Make sure it's a Type Identifier + if x.Obj.Kind == ast.Typ { + // Make sure it's a spec (Type Identifiers can be throughout the code) + if ts, ok := x.Obj.Decl.(*ast.TypeSpec); ok { + // fmt.Printf("Type: %+v\n", ts) + isEnum := isTypeSpecEnum(ts) + // Only store documented enums + if isEnum { + // fmt.Printf("EnumType: %T\n", ts.Type) + enums[x.Name] = ts + } + } + } + } + } + // Return true to continue through the tree + return true + }) + + return enums +} + +// copyDocsToSpecs will take the GenDecl level documents and copy them +// to the children Type and Value specs. I think this is actually working +// around a bug in the AST, but it works for now. +func copyGenDeclCommentsToSpecs(x *ast.GenDecl) { + // Copy the doc spec to the type or value spec + // cause they missed this... whoops + if x.Doc != nil { + for _, spec := range x.Specs { + switch s := spec.(type) { + case *ast.TypeSpec: + if s.Doc == nil { + s.Doc = x.Doc + } + case *ast.ValueSpec: + if s.Doc == nil { + s.Doc = x.Doc + } + } + } + } +} + +// isTypeSpecEnum checks the comments on the type spec to determine if there is an enum +// declaration for the type. +func isTypeSpecEnum(ts *ast.TypeSpec) bool { + isEnum := false + if ts.Doc != nil { + for _, comment := range ts.Doc.List { + if strings.Contains(comment.Text, `ENUM(`) { + isEnum = true + } + } + } + + return isEnum +} diff --git a/pkg/utils/generator/generator_1.18_test.go b/pkg/utils/generator/generator_1.18_test.go new file mode 100644 index 0000000..fd536b3 --- /dev/null +++ b/pkg/utils/generator/generator_1.18_test.go @@ -0,0 +1,321 @@ +//go:build go1.18 +// +build go1.18 + +package generator + +import ( + "fmt" + "go/parser" + "strings" + "testing" + + "github.com/bradleyjkemp/cupaloy/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var testExampleFiles = map[string]string{ + "og": `example_test.go`, + "1.18": `example_1.18_test.go`, +} + +// TestNoStructInputFile +func Test118NoStructFile(t *testing.T) { + input := `package test + // Behavior + type SomeInterface interface{ + + } + ` + g := NewGenerator(). + WithoutSnakeToCamel() + f, err := parser.ParseFile(g.fileSet, "TestRequiredErrors", input, parser.ParseComments) + assert.Nil(t, err, "Error parsing no struct input") + + output, err := g.Generate(f) + assert.Nil(t, err, "Error generating formatted code") + if false { // Debugging statement + fmt.Println(string(output)) + } +} + +// TestNoFile +func Test118NoFile(t *testing.T) { + g := NewGenerator(). + WithoutSnakeToCamel() + // Parse the file given in arguments + _, err := g.GenerateFromFile("") + assert.NotNil(t, err, "Error generating formatted code") +} + +// TestExampleFile +func Test118ExampleFile(t *testing.T) { + g := NewGenerator(). + WithMarshal(). + WithSQLDriver(). + WithCaseInsensitiveParse(). + WithNames(). + WithoutSnakeToCamel() + + for name, testExample := range testExampleFiles { + t.Run(name, func(t *testing.T) { + // Parse the file given in arguments + imported, err := g.GenerateFromFile(testExample) + require.Nil(t, err, "Error generating formatted code") + + outputLines := strings.Split(string(imported), "\n") + cupaloy.SnapshotT(t, outputLines) + + if false { + fmt.Println(string(imported)) + } + }) + } +} + +// TestExampleFileMoreOptions +func Test118ExampleFileMoreOptions(t *testing.T) { + g := NewGenerator(). + WithMarshal(). + WithSQLDriver(). + WithCaseInsensitiveParse(). + WithNames(). + WithoutSnakeToCamel(). + WithMustParse(). + WithForceLower(). + WithTemplates(`../example/user_template.tmpl`) + for name, testExample := range testExampleFiles { + t.Run(name, func(t *testing.T) { + // Parse the file given in arguments + imported, err := g.GenerateFromFile(testExample) + require.Nil(t, err, "Error generating formatted code") + + outputLines := strings.Split(string(imported), "\n") + cupaloy.SnapshotT(t, outputLines) + + if false { + fmt.Println(string(imported)) + } + }) + } +} + +// TestExampleFile +func Test118NoPrefixExampleFile(t *testing.T) { + g := NewGenerator(). + WithMarshal(). + WithLowercaseVariant(). + WithNoPrefix(). + WithFlag(). + WithoutSnakeToCamel() + for name, testExample := range testExampleFiles { + t.Run(name, func(t *testing.T) { + // Parse the file given in arguments + imported, err := g.GenerateFromFile(testExample) + require.Nil(t, err, "Error generating formatted code") + + outputLines := strings.Split(string(imported), "\n") + cupaloy.SnapshotT(t, outputLines) + + if false { + fmt.Println(string(imported)) + } + }) + } +} + +// TestExampleFile +func Test118NoPrefixExampleFileWithSnakeToCamel(t *testing.T) { + g := NewGenerator(). + WithMarshal(). + WithLowercaseVariant(). + WithNoPrefix(). + WithFlag() + + for name, testExample := range testExampleFiles { + t.Run(name, func(t *testing.T) { + // Parse the file given in arguments + imported, err := g.GenerateFromFile(testExample) + require.Nil(t, err, "Error generating formatted code") + + outputLines := strings.Split(string(imported), "\n") + cupaloy.SnapshotT(t, outputLines) + + if false { + fmt.Println(string(imported)) + } + }) + } +} + +// TestCustomPrefixExampleFile +func Test118CustomPrefixExampleFile(t *testing.T) { + g := NewGenerator(). + WithMarshal(). + WithLowercaseVariant(). + WithNoPrefix(). + WithFlag(). + WithoutSnakeToCamel(). + WithPtr(). + WithSQLNullInt(). + WithSQLNullStr(). + WithPrefix("Custom_prefix_") + for name, testExample := range testExampleFiles { + t.Run(name, func(t *testing.T) { + // Parse the file given in arguments + imported, err := g.GenerateFromFile(testExample) + require.Nil(t, err, "Error generating formatted code") + + outputLines := strings.Split(string(imported), "\n") + cupaloy.SnapshotT(t, outputLines) + + if false { + fmt.Println(string(imported)) + } + }) + } +} + +func Test118AliasParsing(t *testing.T) { + tests := map[string]struct { + input []string + resultingMap map[string]string + err error + }{ + "no aliases": { + resultingMap: map[string]string{}, + }, + "multiple arrays": { + input: []string{ + `!:Bang,a:a`, + `@:AT`, + `&:AND,|:OR`, + }, + resultingMap: map[string]string{ + "a": "a", + "!": "Bang", + "@": "AT", + "&": "AND", + "|": "OR", + }, + }, + "more types": { + input: []string{ + `*:star,+:PLUS`, + `-:less`, + `#:HASH,!:Bang`, + }, + resultingMap: map[string]string{ + "*": "star", + "+": "PLUS", + "-": "less", + "#": "HASH", + "!": "Bang", + }, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + defer func() { + replacementNames = map[string]string{} + }() + err := ParseAliases(tc.input) + if tc.err != nil { + require.Error(t, err) + require.EqualError(t, err, tc.err.Error()) + } else { + require.NoError(t, err) + require.Equal(t, tc.resultingMap, replacementNames) + } + }) + } +} + +// TestEnumParseFailure +func Test118EnumParseFailure(t *testing.T) { + input := `package test + // Behavior + type SomeInterface interface{ + + } + + // ENUM( + // a, + //} + type Animal int + ` + g := NewGenerator(). + WithoutSnakeToCamel() + f, err := parser.ParseFile(g.fileSet, "TestRequiredErrors", input, parser.ParseComments) + assert.Nil(t, err, "Error parsing no struct input") + + output, err := g.Generate(f) + assert.Nil(t, err, "Error generating formatted code") + assert.Empty(t, string(output)) + if false { // Debugging statement + fmt.Println(string(output)) + } +} + +// TestUintInvalidParsing +func Test118UintInvalidParsing(t *testing.T) { + input := `package test + // ENUM( + // a=-1, + //) + type Animal uint + ` + g := NewGenerator(). + WithoutSnakeToCamel() + f, err := parser.ParseFile(g.fileSet, "TestRequiredErrors", input, parser.ParseComments) + assert.Nil(t, err, "Error parsing no struct input") + + output, err := g.Generate(f) + assert.Nil(t, err, "Error generating formatted code") + assert.Empty(t, string(output)) + if false { // Debugging statement + fmt.Println(string(output)) + } +} + +// TestIntInvalidParsing +func Test118IntInvalidParsing(t *testing.T) { + input := `package test + // ENUM( + // a=c, + //) + type Animal int + ` + g := NewGenerator(). + WithoutSnakeToCamel() + f, err := parser.ParseFile(g.fileSet, "TestRequiredErrors", input, parser.ParseComments) + assert.Nil(t, err, "Error parsing no struct input") + + output, err := g.Generate(f) + assert.Nil(t, err, "Error generating formatted code") + assert.Empty(t, string(output)) + if false { // Debugging statement + fmt.Println(string(output)) + } +} + +// TestAliasing +func Test118Aliasing(t *testing.T) { + input := `package test + // ENUM(a,b,CDEF) with some extra text + type Animal int + ` + g := NewGenerator(). + WithoutSnakeToCamel() + _ = ParseAliases([]string{"CDEF:C"}) + f, err := parser.ParseFile(g.fileSet, "TestRequiredErrors", input, parser.ParseComments) + assert.Nil(t, err, "Error parsing no struct input") + + output, err := g.Generate(f) + assert.Nil(t, err, "Error generating formatted code") + assert.Contains(t, string(output), "// AnimalC is a Animal of type CDEF.") + if false { // Debugging statement + fmt.Println(string(output)) + } +} diff --git a/pkg/utils/generator/generator_test.go b/pkg/utils/generator/generator_test.go new file mode 100644 index 0000000..5255559 --- /dev/null +++ b/pkg/utils/generator/generator_test.go @@ -0,0 +1,368 @@ +package generator + +import ( + "errors" + "fmt" + "go/parser" + "strings" + "testing" + + "github.com/bradleyjkemp/cupaloy/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + testExample = `example_test.go` +) + +// TestNoStructInputFile +func TestNoStructFile(t *testing.T) { + input := `package test + // Behavior + type SomeInterface interface{ + + } + ` + g := NewGenerator(). + WithoutSnakeToCamel() + f, err := parser.ParseFile(g.fileSet, "TestRequiredErrors", input, parser.ParseComments) + assert.Nil(t, err, "Error parsing no struct input") + + output, err := g.Generate(f) + assert.Nil(t, err, "Error generating formatted code") + if false { // Debugging statement + fmt.Println(string(output)) + } +} + +// TestNoFile +func TestNoFile(t *testing.T) { + g := NewGenerator(). + WithoutSnakeToCamel() + // Parse the file given in arguments + _, err := g.GenerateFromFile("") + assert.NotNil(t, err, "Error generating formatted code") +} + +// TestExampleFile +func TestExampleFile(t *testing.T) { + g := NewGenerator(). + WithMarshal(). + WithSQLDriver(). + WithCaseInsensitiveParse(). + WithNames(). + WithoutSnakeToCamel() + // Parse the file given in arguments + imported, err := g.GenerateFromFile(testExample) + require.Nil(t, err, "Error generating formatted code") + + outputLines := strings.Split(string(imported), "\n") + cupaloy.SnapshotT(t, outputLines) + + if false { + fmt.Println(string(imported)) + } +} + +// TestExampleFileMoreOptions +func TestExampleFileMoreOptions(t *testing.T) { + g := NewGenerator(). + WithMarshal(). + WithSQLDriver(). + WithCaseInsensitiveParse(). + WithNames(). + WithoutSnakeToCamel(). + WithMustParse(). + WithForceLower(). + WithTemplates(`../example/user_template.tmpl`) + // Parse the file given in arguments + imported, err := g.GenerateFromFile(testExample) + require.Nil(t, err, "Error generating formatted code") + + outputLines := strings.Split(string(imported), "\n") + cupaloy.SnapshotT(t, outputLines) + + if false { + fmt.Println(string(imported)) + } +} + +// TestExampleFile +func TestNoPrefixExampleFile(t *testing.T) { + g := NewGenerator(). + WithMarshal(). + WithLowercaseVariant(). + WithNoPrefix(). + WithFlag(). + WithoutSnakeToCamel() + // Parse the file given in arguments + imported, err := g.GenerateFromFile(testExample) + require.Nil(t, err, "Error generating formatted code") + + outputLines := strings.Split(string(imported), "\n") + cupaloy.SnapshotT(t, outputLines) + + if false { + fmt.Println(string(imported)) + } +} + +// TestExampleFile +func TestReplacePrefixExampleFile(t *testing.T) { + g := NewGenerator(). + WithMarshal(). + WithLowercaseVariant(). + WithNoPrefix(). + WithPrefix("MyPrefix_"). + WithFlag(). + WithoutSnakeToCamel() + // Parse the file given in arguments + imported, err := g.GenerateFromFile(testExample) + require.Nil(t, err, "Error generating formatted code") + + outputLines := strings.Split(string(imported), "\n") + cupaloy.SnapshotT(t, outputLines) + + if false { + fmt.Println(string(imported)) + } +} + +// TestExampleFile +func TestNoPrefixExampleFileWithSnakeToCamel(t *testing.T) { + g := NewGenerator(). + WithMarshal(). + WithLowercaseVariant(). + WithNoPrefix(). + WithFlag() + + // Parse the file given in arguments + imported, err := g.GenerateFromFile(testExample) + require.Nil(t, err, "Error generating formatted code") + + outputLines := strings.Split(string(imported), "\n") + cupaloy.SnapshotT(t, outputLines) + + if false { + fmt.Println(string(imported)) + } +} + +// TestCustomPrefixExampleFile +func TestCustomPrefixExampleFile(t *testing.T) { + g := NewGenerator(). + WithMarshal(). + WithLowercaseVariant(). + WithNoPrefix(). + WithFlag(). + WithoutSnakeToCamel(). + WithPtr(). + WithSQLNullInt(). + WithSQLNullStr(). + WithPrefix("Custom_prefix_") + // Parse the file given in arguments + imported, err := g.GenerateFromFile(testExample) + require.Nil(t, err, "Error generating formatted code") + + outputLines := strings.Split(string(imported), "\n") + cupaloy.SnapshotT(t, outputLines) + + if false { + fmt.Println(string(imported)) + } +} + +func TestAliasParsing(t *testing.T) { + tests := map[string]struct { + input []string + resultingMap map[string]string + err error + }{ + "no aliases": { + resultingMap: map[string]string{}, + }, + "bad input": { + input: []string{"a:b,c"}, + err: errors.New(`invalid formatted alias entry "c", must be in the format "key:value"`), + }, + "multiple arrays": { + input: []string{ + `!:Bang,a:a`, + `@:AT`, + `&:AND,|:OR`, + }, + resultingMap: map[string]string{ + "a": "a", + "!": "Bang", + "@": "AT", + "&": "AND", + "|": "OR", + }, + }, + "more types": { + input: []string{ + `*:star,+:PLUS`, + `-:less`, + `#:HASH,!:Bang`, + }, + resultingMap: map[string]string{ + "*": "star", + "+": "PLUS", + "-": "less", + "#": "HASH", + "!": "Bang", + }, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + replacementNames = map[string]string{} + err := ParseAliases(tc.input) + if tc.err != nil { + require.Error(t, err) + require.EqualError(t, err, tc.err.Error()) + } else { + require.NoError(t, err) + require.Equal(t, tc.resultingMap, replacementNames) + } + }) + } +} + +// TestEnumParseFailure +func TestEnumParseFailure(t *testing.T) { + input := `package test + // Behavior + type SomeInterface interface{ + + } + + // ENUM( + // a, + //} + type Animal int + ` + g := NewGenerator(). + WithoutSnakeToCamel() + f, err := parser.ParseFile(g.fileSet, "TestRequiredErrors", input, parser.ParseComments) + assert.Nil(t, err, "Error parsing no struct input") + + output, err := g.Generate(f) + assert.Nil(t, err, "Error generating formatted code") + assert.Empty(t, string(output)) + if false { // Debugging statement + fmt.Println(string(output)) + } +} + +// TestUintInvalidParsing +func TestUintInvalidParsing(t *testing.T) { + input := `package test + // ENUM( + // a=-1, + //) + type Animal uint + ` + g := NewGenerator(). + WithoutSnakeToCamel() + f, err := parser.ParseFile(g.fileSet, "TestRequiredErrors", input, parser.ParseComments) + assert.Nil(t, err, "Error parsing no struct input") + + output, err := g.Generate(f) + assert.Nil(t, err, "Error generating formatted code") + assert.Empty(t, string(output)) + if false { // Debugging statement + fmt.Println(string(output)) + } +} + +// TestIntInvalidParsing +func TestIntInvalidParsing(t *testing.T) { + input := `package test + // ENUM( + // a=c, + //) + type Animal int + ` + g := NewGenerator(). + WithoutSnakeToCamel() + f, err := parser.ParseFile(g.fileSet, "TestRequiredErrors", input, parser.ParseComments) + assert.Nil(t, err, "Error parsing no struct input") + + output, err := g.Generate(f) + assert.Nil(t, err, "Error generating formatted code") + assert.Empty(t, string(output)) + if false { // Debugging statement + fmt.Println(string(output)) + } +} + +// TestAliasing +func TestAliasing(t *testing.T) { + input := `package test + // ENUM(a,b,CDEF) with some extra text + type Animal int + ` + g := NewGenerator(). + WithoutSnakeToCamel() + _ = ParseAliases([]string{"CDEF:C"}) + f, err := parser.ParseFile(g.fileSet, "TestRequiredErrors", input, parser.ParseComments) + assert.Nil(t, err, "Error parsing no struct input") + + output, err := g.Generate(f) + assert.Nil(t, err, "Error generating formatted code") + assert.Contains(t, string(output), "// AnimalC is a Animal of type CDEF.") + if false { // Debugging statement + fmt.Println(string(output)) + } +} + +// TestParanthesesParsing +func TestParenthesesParsing(t *testing.T) { + input := `package test + // This is a pre-enum comment that needs (to be handled properly) + // ENUM( + // abc (x), + //). This is an extra string comment (With parentheses of it's own) + // And (another line) with Parentheses + type Animal string + ` + g := NewGenerator() + f, err := parser.ParseFile(g.fileSet, "TestRequiredErrors", input, parser.ParseComments) + assert.Nil(t, err, "Error parsing no struct input") + + output, err := g.Generate(f) + assert.Nil(t, err, "Error generating formatted code") + assert.Contains(t, string(output), "// AnimalAbcX is a Animal of type abc (x).") + assert.NotContains(t, string(output), "// AnimalAnd") + if false { // Debugging statement + fmt.Println(string(output)) + } +} + +// TestQuotedStrings +func TestQuotedStrings(t *testing.T) { + input := `package test + // This is a pre-enum comment that needs (to be handled properly) + // ENUM( + // abc (x), + // ghi = "20", + //). This is an extra string comment (With parentheses of it's own) + // And (another line) with Parentheses + type Animal string + ` + g := NewGenerator() + f, err := parser.ParseFile(g.fileSet, "TestRequiredErrors", input, parser.ParseComments) + assert.Nil(t, err, "Error parsing no struct input") + + output, err := g.Generate(f) + assert.Nil(t, err, "Error generating formatted code") + assert.Contains(t, string(output), "// AnimalAbcX is a Animal of type abc (x).") + assert.Contains(t, string(output), "AnimalGhi Animal = \"20\"") + assert.NotContains(t, string(output), "// AnimalAnd") + if false { // Debugging statement + fmt.Println(string(output)) + } +} diff --git a/pkg/utils/generator/template_funcs.go b/pkg/utils/generator/template_funcs.go new file mode 100644 index 0000000..ebacd2b --- /dev/null +++ b/pkg/utils/generator/template_funcs.go @@ -0,0 +1,126 @@ +package generator + +import ( + "fmt" + "strconv" + "strings" +) + +// Stringify returns a string that is all of the enum value names concatenated without a separator +func Stringify(e Enum, forceLower bool) (ret string, err error) { + for _, val := range e.Values { + if val.Name != skipHolder { + next := val.RawName + if forceLower { + next = strings.ToLower(next) + } + ret = ret + next + } + } + return +} + +// Mapify returns a map that is all of the indexes for a string value lookup +func Mapify(e Enum) (ret string, err error) { + strName := fmt.Sprintf(`_%sName`, e.Name) + ret = fmt.Sprintf("map[%s]string{\n", e.Name) + index := 0 + for _, val := range e.Values { + if val.Name != skipHolder { + nextIndex := index + len(val.Name) + ret = fmt.Sprintf("%s%s: %s[%d:%d],\n", ret, val.PrefixedName, strName, index, nextIndex) + index = nextIndex + } + } + ret = ret + `}` + return +} + +// Unmapify returns a map that is all of the indexes for a string value lookup +func Unmapify(e Enum, lowercase bool) (ret string, err error) { + if e.Type == "string" { + return UnmapifyStringEnum(e, lowercase) + } + strName := fmt.Sprintf(`_%sName`, e.Name) + ret = fmt.Sprintf("map[string]%s{\n", e.Name) + index := 0 + for _, val := range e.Values { + if val.Name != skipHolder { + nextIndex := index + len(val.Name) + ret = fmt.Sprintf("%s%s[%d:%d]: %s,\n", ret, strName, index, nextIndex, val.PrefixedName) + if lowercase { + ret = fmt.Sprintf("%sstrings.ToLower(%s[%d:%d]): %s,\n", ret, strName, index, nextIndex, val.PrefixedName) + } + index = nextIndex + } + } + ret = ret + `}` + return +} + +// Unmapify returns a map that is all of the indexes for a string value lookup +func UnmapifyStringEnum(e Enum, lowercase bool) (ret string, err error) { + var builder strings.Builder + _, err = builder.WriteString("map[string]" + e.Name + "{\n") + if err != nil { + return + } + for _, val := range e.Values { + if val.Name != skipHolder { + _, err = builder.WriteString(fmt.Sprintf("%q:%s,\n", val.ValueStr, val.PrefixedName)) + if err != nil { + return + } + if lowercase && strings.ToLower(val.ValueStr) != val.ValueStr { + _, err = builder.WriteString(fmt.Sprintf("%q:%s,\n", strings.ToLower(val.ValueStr), val.PrefixedName)) + if err != nil { + return + } + } + } + } + builder.WriteByte('}') + ret = builder.String() + return +} + +// Namify returns a slice that is all of the possible names for an enum in a slice +func Namify(e Enum) (ret string, err error) { + if e.Type == "string" { + return namifyStringEnum(e) + } + strName := fmt.Sprintf(`_%sName`, e.Name) + ret = "[]string{\n" + index := 0 + for _, val := range e.Values { + if val.Name != skipHolder { + nextIndex := index + len(val.Name) + ret = fmt.Sprintf("%s%s[%d:%d],\n", ret, strName, index, nextIndex) + index = nextIndex + } + } + ret = ret + "}" + return +} + +// Namify returns a slice that is all of the possible names for an enum in a slice +func namifyStringEnum(e Enum) (ret string, err error) { + ret = "[]string{\n" + for _, val := range e.Values { + if val.Name != skipHolder { + ret = fmt.Sprintf("%sstring(%s),\n", ret, val.PrefixedName) + } + } + ret = ret + "}" + return +} + +func Offset(index int, enumType string, val EnumValue) (strResult string) { + if strings.HasPrefix(enumType, "u") { + // Unsigned + return strconv.FormatUint(val.ValueInt.(uint64)-uint64(index), 10) + } else { + // Signed + return strconv.FormatInt(val.ValueInt.(int64)-int64(index), 10) + } +}