package database import ( "database/sql" "fmt" "log/slog" "strings" "time" "github.com/rogeecn/database_render/internal/config" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/gorm" "gorm.io/gorm/logger" ) // ConnectionManager manages database connections type ConnectionManager struct { db *gorm.DB sqlDB *sql.DB config *config.DatabaseConfig logger *slog.Logger } // NewConnectionManager creates a new database connection manager func NewConnectionManager(config *config.Config) (*ConnectionManager, error) { logger := slog.With("component", "database") cm := &ConnectionManager{ config: &config.Database, logger: logger, } if err := cm.connect(); err != nil { return nil, fmt.Errorf("failed to connect to database: %w", err) } if err := cm.configure(); err != nil { return nil, fmt.Errorf("failed to configure database: %w", err) } logger.Info("database connection established", "type", config.Database.Type, "host", config.Database.Host, "database", config.Database.DBName, ) return cm, nil } // connect establishes the database connection func (cm *ConnectionManager) connect() error { var dialector gorm.Dialector switch cm.config.Type { case "sqlite": dialector = sqlite.Open(cm.config.Path) case "mysql": dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", cm.config.User, cm.config.Password, cm.config.Host, cm.config.Port, cm.config.DBName) dialector = mysql.Open(dsn) case "postgres": if cm.config.DSN != "" { dialector = postgres.Open(cm.config.DSN) } else { dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", cm.config.Host, cm.config.Port, cm.config.User, cm.config.Password, cm.config.DBName) dialector = postgres.Open(dsn) } default: return fmt.Errorf("unsupported database type: %s", cm.config.Type) } gormConfig := &gorm.Config{ Logger: logger.Default.LogMode(logger.Info), NowFunc: func() time.Time { return time.Now().Local() }, } db, err := gorm.Open(dialector, gormConfig) if err != nil { return fmt.Errorf("failed to open database: %w", err) } cm.db = db // Get underlying sql.DB for connection pooling sqlDB, err := db.DB() if err != nil { return fmt.Errorf("failed to get sql.DB: %w", err) } cm.sqlDB = sqlDB return nil } // configure sets up connection pool settings func (cm *ConnectionManager) configure() error { if cm.sqlDB == nil { return fmt.Errorf("sql.DB is nil") } // Connection pool settings cm.sqlDB.SetMaxIdleConns(10) cm.sqlDB.SetMaxOpenConns(100) cm.sqlDB.SetConnMaxLifetime(time.Hour) // Ping to verify connection if err := cm.sqlDB.Ping(); err != nil { return fmt.Errorf("failed to ping database: %w", err) } return nil } // GetDB returns the GORM database instance func (cm *ConnectionManager) GetDB() *gorm.DB { return cm.db } // GetSQLDB returns the underlying SQL database instance func (cm *ConnectionManager) GetSQLDB() *sql.DB { return cm.sqlDB } // Close closes the database connection func (cm *ConnectionManager) Close() error { if cm.sqlDB != nil { return cm.sqlDB.Close() } return nil } // Health checks the database health func (cm *ConnectionManager) Health() error { if cm.sqlDB == nil { return fmt.Errorf("database not initialized") } return cm.sqlDB.Ping() } // GetTableNames returns all table names in the database func (cm *ConnectionManager) GetTableNames() ([]string, error) { var tableNames []string switch cm.config.Type { case "sqlite": rows, err := cm.db.Raw("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'").Rows() if err != nil { return nil, err } defer rows.Close() for rows.Next() { var name string if err := rows.Scan(&name); err != nil { return nil, err } tableNames = append(tableNames, name) } case "mysql": rows, err := cm.db.Raw("SHOW TABLES").Rows() if err != nil { return nil, err } defer rows.Close() for rows.Next() { var name string if err := rows.Scan(&name); err != nil { return nil, err } tableNames = append(tableNames, name) } case "postgres": rows, err := cm.db.Raw("SELECT tablename FROM pg_tables WHERE schemaname = 'public'").Rows() if err != nil { return nil, err } defer rows.Close() for rows.Next() { var name string if err := rows.Scan(&name); err != nil { return nil, err } tableNames = append(tableNames, name) } } return tableNames, nil } // GetTableColumns returns column information for a table func (cm *ConnectionManager) GetTableColumns(tableName string) ([]ColumnInfo, error) { var columns []ColumnInfo switch cm.config.Type { case "sqlite": rows, err := cm.db.Raw(fmt.Sprintf("PRAGMA table_info(%s)", tableName)).Rows() if err != nil { return nil, err } defer rows.Close() for rows.Next() { var col ColumnInfo var notused interface{} if err := rows.Scan(&col.Position, &col.Name, &col.Type, &col.NotNull, ¬used, &col.DefaultValue); err != nil { return nil, err } col.DatabaseType = cm.config.Type columns = append(columns, col) } case "mysql": rows, err := cm.db.Raw(fmt.Sprintf("DESCRIBE %s", tableName)).Rows() if err != nil { return nil, err } defer rows.Close() for rows.Next() { var col ColumnInfo var key, extra, nullStr string if err := rows.Scan(&col.Name, &col.Type, &nullStr, &key, &col.DefaultValue, &extra); err != nil { return nil, err } col.NotNull = nullStr == "NO" col.DatabaseType = cm.config.Type columns = append(columns, col) } case "postgres": query := ` SELECT column_name, data_type, is_nullable, column_default FROM information_schema.columns WHERE table_name = $1 ORDER BY ordinal_position ` rows, err := cm.db.Raw(query, tableName).Rows() if err != nil { return nil, err } defer rows.Close() for rows.Next() { var col ColumnInfo var nullable string if err := rows.Scan(&col.Name, &col.Type, &nullable, &col.DefaultValue); err != nil { return nil, err } col.NotNull = nullable == "NO" col.DatabaseType = cm.config.Type columns = append(columns, col) } } return columns, nil } // ColumnInfo represents database column information type ColumnInfo struct { Name string Type string NotNull bool DefaultValue interface{} Position int DatabaseType string } // GetTableData retrieves paginated data from a table func (cm *ConnectionManager) GetTableData( tableName string, page, pageSize int, search string, sortField string, sortOrder string, ) ([]map[string]interface{}, int64, error) { var total int64 var data []map[string]interface{} // Build count query for pagination countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName) if search != "" { // For search, we need to get column names first columns, err := cm.GetTableColumns(tableName) if err != nil { return nil, 0, err } // Build WHERE clause for text columns var whereConditions []string for _, col := range columns { if cm.isSearchableColumn(col.Type) { whereConditions = append(whereConditions, fmt.Sprintf("%s LIKE '%%%s%%'", col.Name, search)) } } if len(whereConditions) > 0 { countQuery += " WHERE " + strings.Join(whereConditions, " OR ") } } if err := cm.db.Raw(countQuery).Scan(&total).Error; err != nil { return nil, 0, err } // Build data query dataQuery := fmt.Sprintf("SELECT * FROM %s", tableName) if search != "" { columns, err := cm.GetTableColumns(tableName) if err != nil { return nil, 0, err } var whereConditions []string for _, col := range columns { if cm.isSearchableColumn(col.Type) { whereConditions = append(whereConditions, fmt.Sprintf("%s LIKE '%%%s%%'", col.Name, search)) } } if len(whereConditions) > 0 { dataQuery += " WHERE " + strings.Join(whereConditions, " OR ") } } // Add sorting if sortField != "" { order := "ASC" if sortOrder == "desc" { order = "DESC" } dataQuery += fmt.Sprintf(" ORDER BY %s %s", sortField, order) } // Add pagination offset := (page - 1) * pageSize dataQuery += fmt.Sprintf(" LIMIT %d OFFSET %d", pageSize, offset) // Execute query rows, err := cm.db.Raw(dataQuery).Rows() if err != nil { return nil, 0, err } defer rows.Close() // Get column names columnNames, err := rows.Columns() if err != nil { return nil, 0, err } // Scan data for rows.Next() { values := make([]interface{}, len(columnNames)) valuePtrs := make([]interface{}, len(columnNames)) for i := range values { valuePtrs[i] = &values[i] } if err := rows.Scan(valuePtrs...); err != nil { return nil, 0, err } row := make(map[string]interface{}) for i, col := range columnNames { val := values[i] // Handle NULL values if val == nil { row[col] = nil continue } // Convert []byte to string for JSON compatibility if b, ok := val.([]byte); ok { row[col] = string(b) } else { row[col] = val } } data = append(data, row) } return data, total, nil } // isSearchableColumn determines if a column type is searchable func (cm *ConnectionManager) isSearchableColumn(columnType string) bool { searchableTypes := []string{ "VARCHAR", "TEXT", "CHAR", "STRING", "varchar", "text", "char", "string", } for _, t := range searchableTypes { if strings.Contains(strings.ToUpper(columnType), strings.ToUpper(t)) { return true } } return false } // GetTableRowCount returns the total number of rows in a table func (cm *ConnectionManager) GetTableRowCount(tableName string) (int64, error) { var count int64 query := fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName) err := cm.db.Raw(query).Scan(&count).Error return count, err } // GetTableDataByID retrieves a single record by ID func (cm *ConnectionManager) GetTableDataByID(tableName string, id interface{}) (map[string]interface{}, error) { // Find primary key column columns, err := cm.GetTableColumns(tableName) if err != nil { return nil, err } var primaryKey string for _, col := range columns { // Assume 'id' is the primary key if it exists if col.Name == "id" || col.Name == "ID" { primaryKey = col.Name break } } if primaryKey == "" { // Fallback to first column primaryKey = columns[0].Name } query := fmt.Sprintf("SELECT * FROM %s WHERE %s = ?", tableName, primaryKey) rows, err := cm.db.Raw(query, id).Rows() if err != nil { return nil, err } defer rows.Close() columnNames, err := rows.Columns() if err != nil { return nil, err } if rows.Next() { values := make([]interface{}, len(columnNames)) valuePtrs := make([]interface{}, len(columnNames)) for i := range values { valuePtrs[i] = &values[i] } if err := rows.Scan(valuePtrs...); err != nil { return nil, err } row := make(map[string]interface{}) for i, col := range columnNames { val := values[i] if val == nil { row[col] = nil continue } if b, ok := val.([]byte); ok { row[col] = string(b) } else { row[col] = val } } return row, nil } return nil, fmt.Errorf("record not found") }