Files
atom/fen/bind.go

491 lines
15 KiB
Go

package fen
import (
"fmt"
"mime/multipart"
"reflect"
"strings"
"sync"
"github.com/gofiber/fiber/v3"
"github.com/pkg/errors"
)
func File[T any](key string) func(fiber.Ctx) (*multipart.FileHeader, error) {
return func(ctx fiber.Ctx) (*multipart.FileHeader, error) {
_ = new(T)
return ctx.FormFile(key)
}
}
func Local[T any](key any) func(fiber.Ctx) (T, error) {
return func(ctx fiber.Ctx) (T, error) {
v := fiber.Locals[T](ctx, key)
return v, nil
}
}
func Path[T fiber.GenericType](key string) func(fiber.Ctx) (T, error) {
return func(ctx fiber.Ctx) (T, error) {
v := fiber.Params[T](ctx, key)
return v, nil
}
}
func PathParam[T fiber.GenericType](name string) func(fiber.Ctx) (T, error) {
return func(ctx fiber.Ctx) (T, error) {
v := fiber.Params[T](ctx, name)
return v, nil
}
}
func Body[T any](name string) func(fiber.Ctx) (*T, error) {
return func(ctx fiber.Ctx) (*T, error) {
p := new(T)
if err := ctx.Bind().Body(p); err != nil {
return nil, errors.Wrapf(err, "body: %s", name)
}
return p, nil
}
}
func QueryParam[T fiber.GenericType](key string) func(fiber.Ctx) (T, error) {
return func(ctx fiber.Ctx) (T, error) {
v := fiber.Query[T](ctx, key)
return v, nil
}
}
func Query[T any](name string) func(fiber.Ctx) (*T, error) {
return func(ctx fiber.Ctx) (*T, error) {
p := new(T)
if err := ctx.Bind().Query(p); err != nil {
return nil, errors.Wrapf(err, "query: %s", name)
}
return p, nil
}
}
func Header[T any](name string) func(fiber.Ctx) (*T, error) {
return func(ctx fiber.Ctx) (*T, error) {
p := new(T)
err := ctx.Bind().Header(p)
if err != nil {
return nil, errors.Wrapf(err, "header: %s", name)
}
return p, nil
}
}
func Cookie[T any](name string) func(fiber.Ctx) (*T, error) {
return func(ctx fiber.Ctx) (*T, error) {
p := new(T)
if err := ctx.Bind().Cookie(p); err != nil {
return nil, errors.Wrapf(err, "cookie: %s", name)
}
return p, nil
}
}
func CookieParam(name string) func(fiber.Ctx) (string, error) {
return func(ctx fiber.Ctx) (string, error) {
return ctx.Cookies(name), nil
}
}
// ModelByPath creates a model parameter binder that queries a model instance using a path parameter
// This provides a unified way to bind model parameters similar to other parameter types
func ModelByPath[T any, K fiber.GenericType](queryFunc func(context interface{}) interface{}, field, pathKey string) func(fiber.Ctx) (*T, error) {
return func(ctx fiber.Ctx) (*T, error) {
v := fiber.Params[K](ctx, pathKey)
// Get query context with the provided context
queryWithContext := queryFunc(ctx)
// Use reflection to call Where and First methods
queryValue := reflect.ValueOf(queryWithContext)
// Get Where method
whereMethod := queryValue.MethodByName("Where")
if !whereMethod.IsValid() {
return nil, fmt.Errorf("query object does not have Where method")
}
// Get field object by field name
// If queryValue is a pointer, we need to get the element it points to
var queryStruct reflect.Value
if queryValue.Kind() == reflect.Ptr {
queryStruct = queryValue.Elem()
} else {
queryStruct = queryValue
}
fieldValue := queryStruct.FieldByName(field)
if !fieldValue.IsValid() {
// Try with capitalized first letter for Go field naming convention
capitalizedField := ""
if len(field) > 0 {
capitalizedField = strings.ToUpper(field[:1]) + field[1:]
}
fieldValue = queryStruct.FieldByName(capitalizedField)
if !fieldValue.IsValid() {
return nil, fmt.Errorf("query object does not have field '%s' or '%s'", field, capitalizedField)
}
}
// Get Eq method from the field
eqMethod := fieldValue.MethodByName("Eq")
if !eqMethod.IsValid() {
return nil, fmt.Errorf("field '%s' does not have Eq method", field)
}
// Convert paramValue to the appropriate type based on field type
var paramValueReflect reflect.Value
fieldType := fieldValue.Type()
// Check if it's a field.String type
if fieldType.String() == "field.String" {
paramValueReflect = reflect.ValueOf(fmt.Sprintf("%v", v))
} else if fieldType.String() == "field.Int32" || fieldType.String() == "field.Int" {
// Use reflection to handle the generic type parameter
vValue := reflect.ValueOf(v)
switch vValue.Kind() {
case reflect.Int, reflect.Int32:
paramValueReflect = vValue
case reflect.String:
strVal := vValue.String()
var intVal int
_, err := fmt.Sscanf(strVal, "%d", &intVal)
if err != nil {
return nil, fmt.Errorf("failed to convert '%s' to int: %v", strVal, err)
}
paramValueReflect = reflect.ValueOf(intVal)
default:
return nil, fmt.Errorf("unsupported param type for int field: %v", vValue.Kind())
}
} else {
// Fallback to basic kind-based detection
switch fieldValue.Kind() {
case reflect.String:
paramValueReflect = reflect.ValueOf(fmt.Sprintf("%v", v))
case reflect.Int, reflect.Int32:
// Use reflection to handle the generic type parameter
vValue := reflect.ValueOf(v)
switch vValue.Kind() {
case reflect.Int, reflect.Int32:
paramValueReflect = vValue
case reflect.String:
strVal := vValue.String()
var intVal int
_, err := fmt.Sscanf(strVal, "%d", &intVal)
if err != nil {
return nil, fmt.Errorf("failed to convert '%s' to int: %v", strVal, err)
}
paramValueReflect = reflect.ValueOf(intVal)
default:
return nil, fmt.Errorf("unsupported param type for int field: %v", vValue.Kind())
}
default:
return nil, fmt.Errorf("unsupported field type: %v (kind: %v)", fieldType, fieldValue.Kind())
}
}
// Call Eq method to create condition
conditionResult := eqMethod.Call([]reflect.Value{paramValueReflect})
if len(conditionResult) == 0 {
return nil, fmt.Errorf("Eq method returned no result")
}
// Call Where with the condition
whereResult := whereMethod.Call([]reflect.Value{conditionResult[0]})
if len(whereResult) == 0 {
return nil, fmt.Errorf("Where method returned no result")
}
whereQuery := whereResult[0]
// Get First method
firstMethod := whereQuery.MethodByName("First")
if !firstMethod.IsValid() {
return nil, fmt.Errorf("query object does not have First method")
}
// Call First()
firstResult := firstMethod.Call(nil)
if len(firstResult) < 2 {
return nil, fmt.Errorf("First method should return (model, error)")
}
// Check for error
if err, ok := firstResult[1].Interface().(error); ok && err != nil {
return nil, err
}
// Return model instance
if model, ok := firstResult[0].Interface().(*T); ok {
return model, nil
}
return nil, fmt.Errorf("failed to query model")
}
}
// ModelLookup provides a unified interface for model parameter binding
// It follows the same pattern as other parameter binders like PathParam, QueryParam, etc.
// Usage: ModelLookup[models.User, int]("id", "id")
func ModelLookup[T any, K fiber.GenericType](field, pathKey string) func(fiber.Ctx) (*T, error) {
return func(ctx fiber.Ctx) (*T, error) {
// This is a placeholder implementation.
// In the actual generated code, this function would be replaced with:
// func(ctx fiber.Ctx) (*models.User, error) {
// v := fiber.Params[int](ctx, "id")
// return models.UserQuery.WithContext(ctx).Where(field.NewUnsafeFieldRaw("id = ?", v)).First()
// }
// The route generator should detect ModelLookup calls and generate the appropriate inline function
return nil, fmt.Errorf("ModelLookup[%s] should be replaced by generated code for field '%s'", reflect.TypeOf((*T)(nil)).Elem().Name(), field)
}
}
// ModelLegacy provides the original model binding interface for backward compatibility
func ModelLegacy[T any, K fiber.GenericType](modelName, field, pathKey string) func(fiber.Ctx) (*T, error) {
return ModelLookup[T, K](field, pathKey)
}
// ModelQuery provides a more direct approach by accepting a query function
func ModelQuery[T any, K fiber.GenericType](queryFunc func(ctx fiber.Ctx, v K) (*T, error)) func(fiber.Ctx) (*T, error) {
return func(ctx fiber.Ctx) (*T, error) {
// Extract the parameter name from the query function or use a default
// This is a limitation - we can't easily determine the path parameter name at runtime
return nil, fmt.Errorf("ModelQuery requires path parameter name to be specified")
}
}
// ModelQueryWithKey provides a complete implementation with specified path key
func ModelQueryWithKey[T any, K fiber.GenericType](queryFunc func(ctx fiber.Ctx, v K) (*T, error), pathKey string) func(fiber.Ctx) (*T, error) {
return func(ctx fiber.Ctx) (*T, error) {
v := fiber.Params[K](ctx, pathKey)
return queryFunc(ctx, v)
}
}
// ModelRegistry maintains model metadata for runtime reflection
type ModelRegistry struct {
mu sync.RWMutex
models map[string]ModelInfo
}
// ModelInfo holds metadata about a model type
type ModelInfo struct {
Type reflect.Type
QueryObject interface{}
DefaultField string
}
var registry = &ModelRegistry{
models: make(map[string]ModelInfo),
}
// RegisterModel registers a model type with its query object
// This should be called during application initialization
func RegisterModel[T any](queryObject interface{}, defaultField string) {
registry.mu.Lock()
defer registry.mu.Unlock()
var zero T
modelType := reflect.TypeOf(zero)
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
typeName := modelType.String()
registry.models[typeName] = ModelInfo{
Type: modelType,
QueryObject: queryObject,
DefaultField: defaultField,
}
}
// Model provides a simplified model parameter binding interface without closures
// Usage: Model[models.User]("id") or Model[models.User]("role", "role")
func Model[T any](fieldAndPath ...string) func(fiber.Ctx) (*T, error) {
var zero T
modelType := reflect.TypeOf(zero)
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
typeName := modelType.String()
// Determine field and path key
field := "id"
pathKey := "id"
switch len(fieldAndPath) {
case 1:
// Model[models.User]("role") - field=pathKey="role"
field = fieldAndPath[0]
pathKey = fieldAndPath[0]
case 2:
// Model[models.User]("user_id", "id") - field="user_id", pathKey="id"
field = fieldAndPath[0]
pathKey = fieldAndPath[1]
}
return func(ctx fiber.Ctx) (*T, error) {
info, err := getModelInfo(typeName)
if err != nil {
return nil, err
}
// Extract path parameter
paramValue := fiber.Params[string](ctx, pathKey)
// Use reflection to call the query methods
return executeModelQuery[T](info, field, paramValue, ctx)
}
}
// ModelById provides an even simpler interface using the default id field
func ModelById[T any](pathKey string) func(fiber.Ctx) (*T, error) {
return Model[T]("id", pathKey)
}
// getModelInfo retrieves model information from registry
func getModelInfo(typeName string) (*ModelInfo, error) {
registry.mu.RLock()
defer registry.mu.RUnlock()
info, exists := registry.models[typeName]
if !exists {
return nil, fmt.Errorf("model %s not registered. Call RegisterModel[%s]() during initialization", typeName, typeName)
}
return &info, nil
}
// executeModelQuery performs the actual database query using reflection
func executeModelQuery[T any](info *ModelInfo, field, paramValue string, ctx fiber.Ctx) (*T, error) {
queryValue := reflect.ValueOf(info.QueryObject)
// Get WithContext method
withContextMethod := queryValue.MethodByName("WithContext")
if !withContextMethod.IsValid() {
return nil, fmt.Errorf("query object does not have WithContext method")
}
// Call WithContext with the context
contextResult := withContextMethod.Call([]reflect.Value{reflect.ValueOf(ctx)})
if len(contextResult) == 0 {
return nil, fmt.Errorf("WithContext method returned no result")
}
contextQuery := contextResult[0]
// Get Where method
whereMethod := contextQuery.MethodByName("Where")
if !whereMethod.IsValid() {
return nil, fmt.Errorf("query object does not have Where method")
}
// Get field object by field name from the original query object (not contextQuery)
// If queryValue is a pointer, we need to get the element it points to
var queryStruct reflect.Value
if queryValue.Kind() == reflect.Ptr {
queryStruct = queryValue.Elem()
} else {
queryStruct = queryValue
}
fieldValue := queryStruct.FieldByName(field)
if !fieldValue.IsValid() {
// Try with capitalized first letter for Go field naming convention
capitalizedField := ""
if len(field) > 0 {
capitalizedField = strings.ToUpper(field[:1]) + field[1:]
}
fieldValue = queryStruct.FieldByName(capitalizedField)
if !fieldValue.IsValid() {
return nil, fmt.Errorf("query object does not have field '%s' or '%s'", field, capitalizedField)
}
}
// Get Eq method from the field
eqMethod := fieldValue.MethodByName("Eq")
if !eqMethod.IsValid() {
return nil, fmt.Errorf("field '%s' does not have Eq method", field)
}
// Convert paramValue to the appropriate type based on field type
var paramValueReflect reflect.Value
fieldType := fieldValue.Type()
// Check if it's a field.String type
if fieldType.String() == "field.String" {
paramValueReflect = reflect.ValueOf(paramValue)
} else if fieldType.String() == "field.Int32" || fieldType.String() == "field.Int" {
var intVal int
_, err := fmt.Sscanf(paramValue, "%d", &intVal)
if err != nil {
return nil, fmt.Errorf("failed to convert '%s' to int: %v", paramValue, err)
}
paramValueReflect = reflect.ValueOf(intVal)
} else {
// Fallback to basic kind-based detection
switch fieldValue.Kind() {
case reflect.String:
paramValueReflect = reflect.ValueOf(paramValue)
case reflect.Int, reflect.Int32:
var intVal int
_, err := fmt.Sscanf(paramValue, "%d", &intVal)
if err != nil {
return nil, fmt.Errorf("failed to convert '%s' to int: %v", paramValue, err)
}
paramValueReflect = reflect.ValueOf(intVal)
default:
return nil, fmt.Errorf("unsupported field type: %v (kind: %v)", fieldType, fieldValue.Kind())
}
}
// Call Eq method to create condition
conditionResult := eqMethod.Call([]reflect.Value{paramValueReflect})
if len(conditionResult) == 0 {
return nil, fmt.Errorf("Eq method returned no result")
}
// Call Where with the condition
whereResult := whereMethod.Call([]reflect.Value{conditionResult[0]})
if len(whereResult) == 0 {
return nil, fmt.Errorf("Where method returned no result")
}
whereQuery := whereResult[0]
// Get First method
firstMethod := whereQuery.MethodByName("First")
if !firstMethod.IsValid() {
return nil, fmt.Errorf("query object does not have First method")
}
// Call First()
firstResult := firstMethod.Call(nil)
if len(firstResult) < 2 {
return nil, fmt.Errorf("First method should return (model, error)")
}
// Check for error
if err, ok := firstResult[1].Interface().(error); ok && err != nil {
return nil, err
}
// Return model instance
if model, ok := firstResult[0].Interface().(*T); ok {
return model, nil
}
return nil, fmt.Errorf("failed to query model")
}