diff --git a/cmd/gen_model.go b/cmd/gen_model.go index e308d28..cbe03bb 100644 --- a/cmd/gen_model.go +++ b/cmd/gen_model.go @@ -93,7 +93,7 @@ func commandGenModelE(cmd *cobra.Command, args []string) error { return defaultTableModelField } - splits := strings.Split(toType, ".") + splits := strings.SplitN(toType, ".", 2) typeName := splits[len(splits)-1] pkg := splits[0] diff --git a/templates/project/app/errorx/error.go.tpl b/templates/project/app/errorx/error.go.tpl index e7489cc..4388fe5 100644 --- a/templates/project/app/errorx/error.go.tpl +++ b/templates/project/app/errorx/error.go.tpl @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "runtime" + "strings" "github.com/go-jet/jet/v2/qrm" "github.com/gofiber/fiber/v3" @@ -31,6 +32,7 @@ type Response struct { StatusCode int `json:"-" xml:"-"` Code int `json:"code" xml:"code"` Message string `json:"message" xml:"message"` + Data any `json:"data,omitempty" xml:"data"` } func New(code, statusCode int, message string) *Response { @@ -47,6 +49,13 @@ func (r *Response) Sql(sql string) *Response { return r } +func (r *Response) from(err *Response) *Response { + r.Code = err.Code + r.Message = err.Message + r.StatusCode = err.StatusCode + return r +} + func (r *Response) Params(params ...any) *Response { r.params = params if _, file, line, ok := runtime.Caller(1); ok { @@ -62,12 +71,15 @@ func Wrap(err error) *Response { return &Response{err: err} } +func (r *Response) Wrap(err error) *Response { + r.err = err + return r +} + func (r *Response) format() { r.isFormat = true if errors.Is(r.err, qrm.ErrNoRows) { - r.Code = RecordNotExists.Code - r.Message = RecordNotExists.Message - r.StatusCode = RecordNotExists.StatusCode + r.from(RecordNotExists) return } @@ -77,6 +89,19 @@ func (r *Response) format() { r.StatusCode = e.Code return } + + if r.err != nil { + msg := r.err.Error() + if strings.Contains(msg, "duplicate key value") || strings.Contains(msg, "unique constraint") { + r.from(RecordDuplicated) + return + } + + r.Code = http.StatusInternalServerError + r.StatusCode = http.StatusInternalServerError + r.Message = msg + } + return } func (r *Response) Error() string { @@ -95,7 +120,12 @@ func (r *Response) Response(ctx fiber.Ctx) error { contentType := utils.ToLower(utils.UnsafeString(ctx.Context().Request.Header.ContentType())) contentType = binder.FilterFlags(utils.ParseVendorSpecificContentType(contentType)) - log.WithError(r.err).WithField("file", r.file).WithField("params", r.params).Errorf("response error: %+v", r) + log. + WithError(r.err). + WithField("file", r.file). + WithField("sql", r.sql). + WithField("params", r.params). + Errorf("response error: %+v", r) // Parse body accordingly switch contentType { @@ -104,13 +134,14 @@ func (r *Response) Response(ctx fiber.Ctx) error { case fiber.MIMETextHTML, fiber.MIMETextPlain: return ctx.Status(r.StatusCode).SendString(r.Message) default: - return ctx.Status(r.StatusCode).JSON(r.Message) + return ctx.Status(r.StatusCode).JSON(r) } } var ( - RecordNotExists = New(http.StatusNotFound, http.StatusNotFound, "记录不存在") - BadRequest = New(http.StatusBadRequest, http.StatusBadRequest, "请求错误") - Unauthorized = New(http.StatusUnauthorized, http.StatusUnauthorized, "未授权") - InternalErr = New(http.StatusInternalServerError, http.StatusInternalServerError, "内部错误") + RecordDuplicated = New(1001, http.StatusBadRequest, "记录重复") + RecordNotExists = New(http.StatusNotFound, http.StatusNotFound, "记录不存在") + BadRequest = New(http.StatusBadRequest, http.StatusBadRequest, "请求错误") + Unauthorized = New(http.StatusUnauthorized, http.StatusUnauthorized, "未授权") + InternalErr = New(http.StatusInternalServerError, http.StatusInternalServerError, "内部错误") ) diff --git a/templates/project/database/database.go.tpl b/templates/project/database/database.go.tpl index b42098e..39c3f69 100644 --- a/templates/project/database/database.go.tpl +++ b/templates/project/database/database.go.tpl @@ -21,7 +21,7 @@ func FromContext(ctx context.Context, db *sql.DB) qrm.DB { return db } -func TruncateAllTables(ctx context.Context, db *sql.DB, tableName ...string) error { +func Truncate(ctx context.Context, db *sql.DB, tableName ...string) error { for _, name := range tableName { sql := fmt.Sprintf("TRUNCATE TABLE %s RESTART IDENTITY", name) if _, err := db.ExecContext(ctx, sql); err != nil { diff --git a/templates/project/database/fields/common.go.tpl b/templates/project/database/fields/common.go.tpl index e790807..a078b0f 100644 --- a/templates/project/database/fields/common.go.tpl +++ b/templates/project/database/fields/common.go.tpl @@ -4,14 +4,18 @@ import ( "database/sql/driver" "encoding/json" "errors" - - "github.com/samber/lo" ) // implement sql.Scanner interface -type field struct{} +type Json[T any] struct { + Data T `json:",inline"` +} -func (x *field) Scan(value interface{}) (err error) { +func ToJson[T any](data T) Json[T] { + return Json[T]{Data: data} +} + +func (x *Json[T]) Scan(value interface{}) (err error) { switch v := value.(type) { case string: return json.Unmarshal([]byte(v), &x) @@ -23,10 +27,19 @@ func (x *field) Scan(value interface{}) (err error) { return errors.New("Unknown type for ") } -func (x field) Value() (driver.Value, error) { - return json.Marshal(x) +func (x Json[T]) Value() (driver.Value, error) { + return json.Marshal(x.Data) } -func (x field) MustValue() driver.Value { - return lo.Must(json.Marshal(x)) +func (x Json[T]) MarshalJSON() ([]byte, error) { + return json.Marshal(x.Data) +} + +func (x *Json[T]) UnmarshalJSON(data []byte) error { + var value T + if err := json.Unmarshal(data, &value); err != nil { + return err + } + x.Data = value + return nil }