483 lines
11 KiB
Go
483 lines
11 KiB
Go
package database
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"log/slog"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/rogeecn/database_render/internal/config"
|
|
"gorm.io/driver/mysql"
|
|
"gorm.io/driver/postgres"
|
|
"gorm.io/driver/sqlite"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/logger"
|
|
)
|
|
|
|
// ConnectionManager manages database connections
|
|
type ConnectionManager struct {
|
|
db *gorm.DB
|
|
sqlDB *sql.DB
|
|
config *config.DatabaseConfig
|
|
logger *slog.Logger
|
|
}
|
|
|
|
// NewConnectionManager creates a new database connection manager
|
|
func NewConnectionManager(config *config.Config) (*ConnectionManager, error) {
|
|
logger := slog.With("component", "database")
|
|
|
|
cm := &ConnectionManager{
|
|
config: &config.Database,
|
|
logger: logger,
|
|
}
|
|
|
|
if err := cm.connect(); err != nil {
|
|
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
|
}
|
|
|
|
if err := cm.configure(); err != nil {
|
|
return nil, fmt.Errorf("failed to configure database: %w", err)
|
|
}
|
|
|
|
logger.Info("database connection established",
|
|
"type", config.Database.Type,
|
|
"host", config.Database.Host,
|
|
"database", config.Database.DBName,
|
|
)
|
|
|
|
return cm, nil
|
|
}
|
|
|
|
// connect establishes the database connection
|
|
func (cm *ConnectionManager) connect() error {
|
|
var dialector gorm.Dialector
|
|
|
|
switch cm.config.Type {
|
|
case "sqlite":
|
|
dialector = sqlite.Open(cm.config.Path)
|
|
case "mysql":
|
|
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
|
cm.config.User, cm.config.Password, cm.config.Host, cm.config.Port, cm.config.DBName)
|
|
dialector = mysql.Open(dsn)
|
|
case "postgres":
|
|
if cm.config.DSN != "" {
|
|
dialector = postgres.Open(cm.config.DSN)
|
|
} else {
|
|
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
|
cm.config.Host, cm.config.Port, cm.config.User, cm.config.Password, cm.config.DBName)
|
|
dialector = postgres.Open(dsn)
|
|
}
|
|
default:
|
|
return fmt.Errorf("unsupported database type: %s", cm.config.Type)
|
|
}
|
|
|
|
gormConfig := &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Info),
|
|
NowFunc: func() time.Time {
|
|
return time.Now().Local()
|
|
},
|
|
}
|
|
|
|
db, err := gorm.Open(dialector, gormConfig)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to open database: %w", err)
|
|
}
|
|
|
|
cm.db = db
|
|
|
|
// Get underlying sql.DB for connection pooling
|
|
sqlDB, err := db.DB()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get sql.DB: %w", err)
|
|
}
|
|
cm.sqlDB = sqlDB
|
|
|
|
return nil
|
|
}
|
|
|
|
// configure sets up connection pool settings
|
|
func (cm *ConnectionManager) configure() error {
|
|
if cm.sqlDB == nil {
|
|
return fmt.Errorf("sql.DB is nil")
|
|
}
|
|
|
|
// Connection pool settings
|
|
cm.sqlDB.SetMaxIdleConns(10)
|
|
cm.sqlDB.SetMaxOpenConns(100)
|
|
cm.sqlDB.SetConnMaxLifetime(time.Hour)
|
|
|
|
// Ping to verify connection
|
|
if err := cm.sqlDB.Ping(); err != nil {
|
|
return fmt.Errorf("failed to ping database: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetDB returns the GORM database instance
|
|
func (cm *ConnectionManager) GetDB() *gorm.DB {
|
|
return cm.db
|
|
}
|
|
|
|
// GetSQLDB returns the underlying SQL database instance
|
|
func (cm *ConnectionManager) GetSQLDB() *sql.DB {
|
|
return cm.sqlDB
|
|
}
|
|
|
|
// Close closes the database connection
|
|
func (cm *ConnectionManager) Close() error {
|
|
if cm.sqlDB != nil {
|
|
return cm.sqlDB.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Health checks the database health
|
|
func (cm *ConnectionManager) Health() error {
|
|
if cm.sqlDB == nil {
|
|
return fmt.Errorf("database not initialized")
|
|
}
|
|
return cm.sqlDB.Ping()
|
|
}
|
|
|
|
// GetTableNames returns all table names in the database
|
|
func (cm *ConnectionManager) GetTableNames() ([]string, error) {
|
|
var tableNames []string
|
|
|
|
switch cm.config.Type {
|
|
case "sqlite":
|
|
rows, err := cm.db.Raw("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'").Rows()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var name string
|
|
if err := rows.Scan(&name); err != nil {
|
|
return nil, err
|
|
}
|
|
tableNames = append(tableNames, name)
|
|
}
|
|
|
|
case "mysql":
|
|
rows, err := cm.db.Raw("SHOW TABLES").Rows()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var name string
|
|
if err := rows.Scan(&name); err != nil {
|
|
return nil, err
|
|
}
|
|
tableNames = append(tableNames, name)
|
|
}
|
|
|
|
case "postgres":
|
|
rows, err := cm.db.Raw("SELECT tablename FROM pg_tables WHERE schemaname = 'public'").Rows()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var name string
|
|
if err := rows.Scan(&name); err != nil {
|
|
return nil, err
|
|
}
|
|
tableNames = append(tableNames, name)
|
|
}
|
|
}
|
|
|
|
return tableNames, nil
|
|
}
|
|
|
|
// GetTableColumns returns column information for a table
|
|
func (cm *ConnectionManager) GetTableColumns(tableName string) ([]ColumnInfo, error) {
|
|
var columns []ColumnInfo
|
|
|
|
switch cm.config.Type {
|
|
case "sqlite":
|
|
rows, err := cm.db.Raw(fmt.Sprintf("PRAGMA table_info(%s)", tableName)).Rows()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var col ColumnInfo
|
|
var notused interface{}
|
|
if err := rows.Scan(&col.Position, &col.Name, &col.Type, &col.NotNull, ¬used, &col.DefaultValue); err != nil {
|
|
return nil, err
|
|
}
|
|
col.DatabaseType = cm.config.Type
|
|
columns = append(columns, col)
|
|
}
|
|
|
|
case "mysql":
|
|
rows, err := cm.db.Raw(fmt.Sprintf("DESCRIBE %s", tableName)).Rows()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var col ColumnInfo
|
|
var key, extra, nullStr string
|
|
if err := rows.Scan(&col.Name, &col.Type, &nullStr, &key, &col.DefaultValue, &extra); err != nil {
|
|
return nil, err
|
|
}
|
|
col.NotNull = nullStr == "NO"
|
|
col.DatabaseType = cm.config.Type
|
|
columns = append(columns, col)
|
|
}
|
|
|
|
case "postgres":
|
|
query := `
|
|
SELECT
|
|
column_name,
|
|
data_type,
|
|
is_nullable,
|
|
column_default
|
|
FROM information_schema.columns
|
|
WHERE table_name = $1
|
|
ORDER BY ordinal_position
|
|
`
|
|
rows, err := cm.db.Raw(query, tableName).Rows()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var col ColumnInfo
|
|
var nullable string
|
|
if err := rows.Scan(&col.Name, &col.Type, &nullable, &col.DefaultValue); err != nil {
|
|
return nil, err
|
|
}
|
|
col.NotNull = nullable == "NO"
|
|
col.DatabaseType = cm.config.Type
|
|
columns = append(columns, col)
|
|
}
|
|
}
|
|
|
|
return columns, nil
|
|
}
|
|
|
|
// ColumnInfo represents database column information
|
|
type ColumnInfo struct {
|
|
Name string
|
|
Type string
|
|
NotNull bool
|
|
DefaultValue interface{}
|
|
Position int
|
|
DatabaseType string
|
|
}
|
|
|
|
// GetTableData retrieves paginated data from a table
|
|
func (cm *ConnectionManager) GetTableData(
|
|
tableName string,
|
|
page, pageSize int,
|
|
search string,
|
|
sortField string,
|
|
sortOrder string,
|
|
) ([]map[string]interface{}, int64, error) {
|
|
var total int64
|
|
var data []map[string]interface{}
|
|
|
|
// Build count query for pagination
|
|
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName)
|
|
if search != "" {
|
|
// For search, we need to get column names first
|
|
columns, err := cm.GetTableColumns(tableName)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
// Build WHERE clause for text columns
|
|
var whereConditions []string
|
|
for _, col := range columns {
|
|
if cm.isSearchableColumn(col.Type) {
|
|
whereConditions = append(whereConditions, fmt.Sprintf("%s LIKE '%%%s%%'", col.Name, search))
|
|
}
|
|
}
|
|
|
|
if len(whereConditions) > 0 {
|
|
countQuery += " WHERE " + strings.Join(whereConditions, " OR ")
|
|
}
|
|
}
|
|
|
|
if err := cm.db.Raw(countQuery).Scan(&total).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
// Build data query
|
|
dataQuery := fmt.Sprintf("SELECT * FROM %s", tableName)
|
|
if search != "" {
|
|
columns, err := cm.GetTableColumns(tableName)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
var whereConditions []string
|
|
for _, col := range columns {
|
|
if cm.isSearchableColumn(col.Type) {
|
|
whereConditions = append(whereConditions, fmt.Sprintf("%s LIKE '%%%s%%'", col.Name, search))
|
|
}
|
|
}
|
|
|
|
if len(whereConditions) > 0 {
|
|
dataQuery += " WHERE " + strings.Join(whereConditions, " OR ")
|
|
}
|
|
}
|
|
|
|
// Add sorting
|
|
if sortField != "" {
|
|
order := "ASC"
|
|
if sortOrder == "desc" {
|
|
order = "DESC"
|
|
}
|
|
dataQuery += fmt.Sprintf(" ORDER BY %s %s", sortField, order)
|
|
}
|
|
|
|
// Add pagination
|
|
offset := (page - 1) * pageSize
|
|
dataQuery += fmt.Sprintf(" LIMIT %d OFFSET %d", pageSize, offset)
|
|
|
|
// Execute query
|
|
rows, err := cm.db.Raw(dataQuery).Rows()
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
// Get column names
|
|
columnNames, err := rows.Columns()
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
// Scan data
|
|
for rows.Next() {
|
|
values := make([]interface{}, len(columnNames))
|
|
valuePtrs := make([]interface{}, len(columnNames))
|
|
|
|
for i := range values {
|
|
valuePtrs[i] = &values[i]
|
|
}
|
|
|
|
if err := rows.Scan(valuePtrs...); err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
row := make(map[string]interface{})
|
|
for i, col := range columnNames {
|
|
val := values[i]
|
|
|
|
// Handle NULL values
|
|
if val == nil {
|
|
row[col] = nil
|
|
continue
|
|
}
|
|
|
|
// Convert []byte to string for JSON compatibility
|
|
if b, ok := val.([]byte); ok {
|
|
row[col] = string(b)
|
|
} else {
|
|
row[col] = val
|
|
}
|
|
}
|
|
|
|
data = append(data, row)
|
|
}
|
|
|
|
return data, total, nil
|
|
}
|
|
|
|
// isSearchableColumn determines if a column type is searchable
|
|
func (cm *ConnectionManager) isSearchableColumn(columnType string) bool {
|
|
searchableTypes := []string{
|
|
"VARCHAR", "TEXT", "CHAR", "STRING",
|
|
"varchar", "text", "char", "string",
|
|
}
|
|
|
|
for _, t := range searchableTypes {
|
|
if strings.Contains(strings.ToUpper(columnType), strings.ToUpper(t)) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// GetTableDataByID retrieves a single record by ID
|
|
func (cm *ConnectionManager) GetTableDataByID(tableName string, id interface{}) (map[string]interface{}, error) {
|
|
// Find primary key column
|
|
columns, err := cm.GetTableColumns(tableName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var primaryKey string
|
|
for _, col := range columns {
|
|
// Assume 'id' is the primary key if it exists
|
|
if col.Name == "id" || col.Name == "ID" {
|
|
primaryKey = col.Name
|
|
break
|
|
}
|
|
}
|
|
|
|
if primaryKey == "" {
|
|
// Fallback to first column
|
|
primaryKey = columns[0].Name
|
|
}
|
|
|
|
query := fmt.Sprintf("SELECT * FROM %s WHERE %s = ?", tableName, primaryKey)
|
|
|
|
rows, err := cm.db.Raw(query, id).Rows()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
columnNames, err := rows.Columns()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if rows.Next() {
|
|
values := make([]interface{}, len(columnNames))
|
|
valuePtrs := make([]interface{}, len(columnNames))
|
|
|
|
for i := range values {
|
|
valuePtrs[i] = &values[i]
|
|
}
|
|
|
|
if err := rows.Scan(valuePtrs...); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
row := make(map[string]interface{})
|
|
for i, col := range columnNames {
|
|
val := values[i]
|
|
|
|
if val == nil {
|
|
row[col] = nil
|
|
continue
|
|
}
|
|
|
|
if b, ok := val.([]byte); ok {
|
|
row[col] = string(b)
|
|
} else {
|
|
row[col] = val
|
|
}
|
|
}
|
|
|
|
return row, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("record not found")
|
|
}
|