Files
database_render/internal/database/connection.go
2025-08-05 17:26:59 +08:00

483 lines
11 KiB
Go

package database
import (
"database/sql"
"fmt"
"log/slog"
"strings"
"time"
"github.com/rogeecn/database_render/internal/config"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// ConnectionManager manages database connections
type ConnectionManager struct {
db *gorm.DB
sqlDB *sql.DB
config *config.DatabaseConfig
logger *slog.Logger
}
// NewConnectionManager creates a new database connection manager
func NewConnectionManager(config *config.Config) (*ConnectionManager, error) {
logger := slog.With("component", "database")
cm := &ConnectionManager{
config: &config.Database,
logger: logger,
}
if err := cm.connect(); err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
if err := cm.configure(); err != nil {
return nil, fmt.Errorf("failed to configure database: %w", err)
}
logger.Info("database connection established",
"type", config.Database.Type,
"host", config.Database.Host,
"database", config.Database.DBName,
)
return cm, nil
}
// connect establishes the database connection
func (cm *ConnectionManager) connect() error {
var dialector gorm.Dialector
switch cm.config.Type {
case "sqlite":
dialector = sqlite.Open(cm.config.Path)
case "mysql":
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local",
cm.config.User, cm.config.Password, cm.config.Host, cm.config.Port, cm.config.DBName)
dialector = mysql.Open(dsn)
case "postgres":
if cm.config.DSN != "" {
dialector = postgres.Open(cm.config.DSN)
} else {
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
cm.config.Host, cm.config.Port, cm.config.User, cm.config.Password, cm.config.DBName)
dialector = postgres.Open(dsn)
}
default:
return fmt.Errorf("unsupported database type: %s", cm.config.Type)
}
gormConfig := &gorm.Config{
Logger: logger.Default.LogMode(logger.Info),
NowFunc: func() time.Time {
return time.Now().Local()
},
}
db, err := gorm.Open(dialector, gormConfig)
if err != nil {
return fmt.Errorf("failed to open database: %w", err)
}
cm.db = db
// Get underlying sql.DB for connection pooling
sqlDB, err := db.DB()
if err != nil {
return fmt.Errorf("failed to get sql.DB: %w", err)
}
cm.sqlDB = sqlDB
return nil
}
// configure sets up connection pool settings
func (cm *ConnectionManager) configure() error {
if cm.sqlDB == nil {
return fmt.Errorf("sql.DB is nil")
}
// Connection pool settings
cm.sqlDB.SetMaxIdleConns(10)
cm.sqlDB.SetMaxOpenConns(100)
cm.sqlDB.SetConnMaxLifetime(time.Hour)
// Ping to verify connection
if err := cm.sqlDB.Ping(); err != nil {
return fmt.Errorf("failed to ping database: %w", err)
}
return nil
}
// GetDB returns the GORM database instance
func (cm *ConnectionManager) GetDB() *gorm.DB {
return cm.db
}
// GetSQLDB returns the underlying SQL database instance
func (cm *ConnectionManager) GetSQLDB() *sql.DB {
return cm.sqlDB
}
// Close closes the database connection
func (cm *ConnectionManager) Close() error {
if cm.sqlDB != nil {
return cm.sqlDB.Close()
}
return nil
}
// Health checks the database health
func (cm *ConnectionManager) Health() error {
if cm.sqlDB == nil {
return fmt.Errorf("database not initialized")
}
return cm.sqlDB.Ping()
}
// GetTableNames returns all table names in the database
func (cm *ConnectionManager) GetTableNames() ([]string, error) {
var tableNames []string
switch cm.config.Type {
case "sqlite":
rows, err := cm.db.Raw("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'").Rows()
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return nil, err
}
tableNames = append(tableNames, name)
}
case "mysql":
rows, err := cm.db.Raw("SHOW TABLES").Rows()
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return nil, err
}
tableNames = append(tableNames, name)
}
case "postgres":
rows, err := cm.db.Raw("SELECT tablename FROM pg_tables WHERE schemaname = 'public'").Rows()
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return nil, err
}
tableNames = append(tableNames, name)
}
}
return tableNames, nil
}
// GetTableColumns returns column information for a table
func (cm *ConnectionManager) GetTableColumns(tableName string) ([]ColumnInfo, error) {
var columns []ColumnInfo
switch cm.config.Type {
case "sqlite":
rows, err := cm.db.Raw(fmt.Sprintf("PRAGMA table_info(%s)", tableName)).Rows()
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var col ColumnInfo
var notused interface{}
if err := rows.Scan(&col.Position, &col.Name, &col.Type, &col.NotNull, &notused, &col.DefaultValue); err != nil {
return nil, err
}
col.DatabaseType = cm.config.Type
columns = append(columns, col)
}
case "mysql":
rows, err := cm.db.Raw(fmt.Sprintf("DESCRIBE %s", tableName)).Rows()
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var col ColumnInfo
var key, extra, nullStr string
if err := rows.Scan(&col.Name, &col.Type, &nullStr, &key, &col.DefaultValue, &extra); err != nil {
return nil, err
}
col.NotNull = nullStr == "NO"
col.DatabaseType = cm.config.Type
columns = append(columns, col)
}
case "postgres":
query := `
SELECT
column_name,
data_type,
is_nullable,
column_default
FROM information_schema.columns
WHERE table_name = $1
ORDER BY ordinal_position
`
rows, err := cm.db.Raw(query, tableName).Rows()
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var col ColumnInfo
var nullable string
if err := rows.Scan(&col.Name, &col.Type, &nullable, &col.DefaultValue); err != nil {
return nil, err
}
col.NotNull = nullable == "NO"
col.DatabaseType = cm.config.Type
columns = append(columns, col)
}
}
return columns, nil
}
// ColumnInfo represents database column information
type ColumnInfo struct {
Name string
Type string
NotNull bool
DefaultValue interface{}
Position int
DatabaseType string
}
// GetTableData retrieves paginated data from a table
func (cm *ConnectionManager) GetTableData(
tableName string,
page, pageSize int,
search string,
sortField string,
sortOrder string,
) ([]map[string]interface{}, int64, error) {
var total int64
var data []map[string]interface{}
// Build count query for pagination
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName)
if search != "" {
// For search, we need to get column names first
columns, err := cm.GetTableColumns(tableName)
if err != nil {
return nil, 0, err
}
// Build WHERE clause for text columns
var whereConditions []string
for _, col := range columns {
if cm.isSearchableColumn(col.Type) {
whereConditions = append(whereConditions, fmt.Sprintf("%s LIKE '%%%s%%'", col.Name, search))
}
}
if len(whereConditions) > 0 {
countQuery += " WHERE " + strings.Join(whereConditions, " OR ")
}
}
if err := cm.db.Raw(countQuery).Scan(&total).Error; err != nil {
return nil, 0, err
}
// Build data query
dataQuery := fmt.Sprintf("SELECT * FROM %s", tableName)
if search != "" {
columns, err := cm.GetTableColumns(tableName)
if err != nil {
return nil, 0, err
}
var whereConditions []string
for _, col := range columns {
if cm.isSearchableColumn(col.Type) {
whereConditions = append(whereConditions, fmt.Sprintf("%s LIKE '%%%s%%'", col.Name, search))
}
}
if len(whereConditions) > 0 {
dataQuery += " WHERE " + strings.Join(whereConditions, " OR ")
}
}
// Add sorting
if sortField != "" {
order := "ASC"
if sortOrder == "desc" {
order = "DESC"
}
dataQuery += fmt.Sprintf(" ORDER BY %s %s", sortField, order)
}
// Add pagination
offset := (page - 1) * pageSize
dataQuery += fmt.Sprintf(" LIMIT %d OFFSET %d", pageSize, offset)
// Execute query
rows, err := cm.db.Raw(dataQuery).Rows()
if err != nil {
return nil, 0, err
}
defer rows.Close()
// Get column names
columnNames, err := rows.Columns()
if err != nil {
return nil, 0, err
}
// Scan data
for rows.Next() {
values := make([]interface{}, len(columnNames))
valuePtrs := make([]interface{}, len(columnNames))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
return nil, 0, err
}
row := make(map[string]interface{})
for i, col := range columnNames {
val := values[i]
// Handle NULL values
if val == nil {
row[col] = nil
continue
}
// Convert []byte to string for JSON compatibility
if b, ok := val.([]byte); ok {
row[col] = string(b)
} else {
row[col] = val
}
}
data = append(data, row)
}
return data, total, nil
}
// isSearchableColumn determines if a column type is searchable
func (cm *ConnectionManager) isSearchableColumn(columnType string) bool {
searchableTypes := []string{
"VARCHAR", "TEXT", "CHAR", "STRING",
"varchar", "text", "char", "string",
}
for _, t := range searchableTypes {
if strings.Contains(strings.ToUpper(columnType), strings.ToUpper(t)) {
return true
}
}
return false
}
// GetTableDataByID retrieves a single record by ID
func (cm *ConnectionManager) GetTableDataByID(tableName string, id interface{}) (map[string]interface{}, error) {
// Find primary key column
columns, err := cm.GetTableColumns(tableName)
if err != nil {
return nil, err
}
var primaryKey string
for _, col := range columns {
// Assume 'id' is the primary key if it exists
if col.Name == "id" || col.Name == "ID" {
primaryKey = col.Name
break
}
}
if primaryKey == "" {
// Fallback to first column
primaryKey = columns[0].Name
}
query := fmt.Sprintf("SELECT * FROM %s WHERE %s = ?", tableName, primaryKey)
rows, err := cm.db.Raw(query, id).Rows()
if err != nil {
return nil, err
}
defer rows.Close()
columnNames, err := rows.Columns()
if err != nil {
return nil, err
}
if rows.Next() {
values := make([]interface{}, len(columnNames))
valuePtrs := make([]interface{}, len(columnNames))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
return nil, err
}
row := make(map[string]interface{})
for i, col := range columnNames {
val := values[i]
if val == nil {
row[col] = nil
continue
}
if b, ok := val.([]byte); ok {
row[col] = string(b)
} else {
row[col] = val
}
}
return row, nil
}
return nil, fmt.Errorf("record not found")
}