Files
xian_algorithm_new/app/core/database.py
T
2026-05-05 19:49:12 +08:00

256 lines
7.8 KiB
Python

"""
数据库连接管理 - 使用SQLAlchemy 2.0
"""
import logging
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker, Session, DeclarativeBase
from sqlalchemy.pool import QueuePool
from typing import List, Dict, Any
from contextlib import contextmanager
from app.config.settings import get_settings
from app.utils.logger import setup_logging
# 初始化日志
logger = setup_logging()
# 关闭SQLAlchemy引擎日志,只保留关键信息
logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING)
logging.getLogger('sqlalchemy.pool').setLevel(logging.WARNING)
# 获取配置
settings = get_settings()
# 创建数据库引擎
engine = create_engine(
settings.DATABASE_URL,
poolclass=QueuePool,
pool_size=settings.DB_POOL_SIZE,
max_overflow=settings.DB_MAX_OVERFLOW,
pool_timeout=settings.DB_POOL_TIMEOUT,
pool_recycle=settings.DB_POOL_RECYCLE,
pool_pre_ping=True,
echo=False, # 关闭SQL语句打印
connect_args={
"options": "-c client_encoding=UTF8"
}
)
# 创建会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 声明基类
Base = DeclarativeBase()
class DatabaseManager:
"""数据库管理器 - 提供通用的CRUD操作"""
def __init__(self):
self.engine = engine
self.SessionLocal = SessionLocal
@contextmanager
def get_session(self) -> Session:
"""获取数据库会话的上下文管理器"""
session = self.SessionLocal()
try:
yield session
session.commit()
except Exception as e:
session.rollback()
logger.error(f"数据库操作失败: {e}")
raise
finally:
session.close()
def init_db(self):
"""初始化数据库,创建所有表"""
try:
Base.metadata.create_all(bind=self.engine)
logger.info("数据库表创建成功")
except Exception as e:
logger.error(f"数据库表创建失败: {e}")
raise
def test_connection(self) -> bool:
"""测试数据库连接"""
try:
with self.get_session() as session:
session.execute(text("SELECT 1"))
logger.info("数据库连接测试成功")
return True
except Exception as e:
logger.error(f"数据库连接测试失败: {e}")
return False
def execute_raw_sql(self, sql: str, params: dict = None) -> List[Dict[str, Any]]:
"""执行原生SQL查询
Args:
sql: SQL语句
params: 参数字典
Returns:
查询结果列表
"""
with self.get_session() as session:
result = session.execute(text(sql), params or {})
if result.returns_rows:
columns = result.keys()
return [dict(zip(columns, row)) for row in result.fetchall()]
return []
def insert(self, table_name: str, data: Dict[str, Any]) -> int:
"""插入单条记录
Args:
table_name: 表名
data: 要插入的数据字典
Returns:
插入的行数
"""
columns = ", ".join(data.keys())
placeholders = ", ".join([f":{key}" for key in data.keys()])
sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
with self.get_session() as session:
result = session.execute(text(sql), data)
return result.rowcount
def insert_many(self, table_name: str, data_list: List[Dict[str, Any]]) -> int:
"""批量插入记录
Args:
table_name: 表名
data_list: 数据字典列表
Returns:
插入的行数
"""
if not data_list:
return 0
columns = ", ".join(data_list[0].keys())
placeholders = ", ".join([f":{key}" for key in data_list[0].keys()])
sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
with self.get_session() as session:
result = session.execute(text(sql), data_list)
return result.rowcount
def select(
self,
table_name: str,
conditions: Dict[str, Any] = None,
columns: List[str] = None,
limit: int = None,
offset: int = None,
order_by: str = None
) -> List[Dict[str, Any]]:
"""查询记录
Args:
table_name: 表名
conditions: 查询条件字典
columns: 要查询的列列表,None表示查询所有列
limit: 限制返回行数
offset: 偏移量
order_by: 排序字段
Returns:
查询结果列表
"""
col_str = ", ".join(columns) if columns else "*"
sql = f"SELECT {col_str} FROM {table_name}"
params = {}
if conditions:
where_clauses = []
for key, value in conditions.items():
where_clauses.append(f"{key} = :{key}")
params[key] = value
sql += " WHERE " + " AND ".join(where_clauses)
if order_by:
sql += f" ORDER BY {order_by}"
if limit:
sql += f" LIMIT :limit"
params["limit"] = limit
if offset:
sql += f" OFFSET :offset"
params["offset"] = offset
return self.execute_raw_sql(sql, params)
def update(self, table_name: str, data: Dict[str, Any],
conditions: Dict[str, Any]) -> int:
"""更新记录
Args:
table_name: 表名
data: 要更新的数据字典
conditions: 更新条件字典
Returns:
更新的行数
"""
set_clauses = [f"{key} = :{key}" for key in data.keys()]
where_clauses = [f"{key} = :cond_{key}" for key in conditions.keys()]
sql = f"UPDATE {table_name} SET {', '.join(set_clauses)} WHERE {' AND '.join(where_clauses)}"
# 合并参数,避免键名冲突
params = {**data, **{f"cond_{key}": value for key, value in conditions.items()}}
with self.get_session() as session:
result = session.execute(text(sql), params)
return result.rowcount
def delete(self, table_name: str, conditions: Dict[str, Any]) -> int:
"""删除记录
Args:
table_name: 表名
conditions: 删除条件字典
Returns:
删除的行数
"""
where_clauses = [f"{key} = :{key}" for key in conditions.keys()]
sql = f"DELETE FROM {table_name} WHERE {' AND '.join(where_clauses)}"
with self.get_session() as session:
result = session.execute(text(sql), conditions)
return result.rowcount
def count(self, table_name: str, conditions: Dict[str, Any] = None) -> int:
"""统计记录数
Args:
table_name: 表名
conditions: 统计条件字典
Returns:
记录数
"""
sql = f"SELECT COUNT(*) as count FROM {table_name}"
params = {}
if conditions:
where_clauses = []
for key, value in conditions.items():
where_clauses.append(f"{key} = :{key}")
params[key] = value
sql += " WHERE " + " AND ".join(where_clauses)
result = self.execute_raw_sql(sql, params)
return result[0]["count"] if result else 0
# 创建全局数据库管理器实例
db_manager = DatabaseManager()