Files
xian_algorithm_new/app/utils/db_helper.py
T
2026-05-08 16:33:57 +08:00

151 lines
4.4 KiB
Python

"""
PostgreSQL 数据库工具类
提供增删改查方法
"""
import psycopg2
from psycopg2.extras import RealDictCursor
from typing import List, Dict, Any, Optional, Tuple
from contextlib import contextmanager
from config import settings
class PostgresSQLHelper:
"""PostgreSQL 数据库帮助类"""
def __init__(self):
"""初始化数据库连接配置"""
self.db_config = {
'host': settings.DB_HOST,
'port': settings.DB_PORT,
'user': settings.DB_USER,
'password': settings.DB_PASSWORD,
'database': settings.DB_NAME,
}
@contextmanager
def get_connection(self):
"""
获取数据库连接的上下文管理器
自动管理连接的开启和关闭
"""
conn = None
try:
conn = psycopg2.connect(**self.db_config)
yield conn
conn.commit()
except Exception as e:
if conn:
conn.rollback()
raise e
finally:
if conn:
conn.close()
@contextmanager
def get_cursor(self, dict_cursor=False):
"""
获取数据库游标的上下文管理器
Args:
dict_cursor: 是否使用字典游标(返回字典格式结果)
"""
with self.get_connection() as conn:
cursor = None
try:
if dict_cursor:
cursor = conn.cursor(cursor_factory=RealDictCursor)
else:
cursor = conn.cursor()
yield cursor
finally:
if cursor:
cursor.close()
def execute_query(self, sql: str, params: tuple = None) -> List[Dict[str, Any]]:
"""
执行查询语句,返回字典列表
Args:
sql: SQL 查询语句
params: 参数元组
Returns:
查询结果列表,每个元素为字典
"""
with self.get_cursor(dict_cursor=True) as cursor:
cursor.execute(sql, params)
results = cursor.fetchall()
# 将 RealDictRow 转换为普通字典
return [dict(row) for row in results]
def execute_query_one(self, sql: str, params: tuple = None) -> Optional[Dict[str, Any]]:
"""
执行查询语句,返回单条记录
Args:
sql: SQL 查询语句
params: 参数元组
Returns:
单条记录的字典,如果没有结果则返回 None
"""
with self.get_cursor(dict_cursor=True) as cursor:
cursor.execute(sql, params)
result = cursor.fetchone()
return dict(result) if result else None
def execute_update(self, sql: str, params: tuple = None) -> int:
"""
执行更新/插入/删除语句
Args:
sql: SQL 语句
params: 参数元组
Returns:
影响的行数
"""
with self.get_cursor() as cursor:
cursor.execute(sql, params)
return cursor.rowcount
def execute_insert(self, sql: str, params: tuple = None, returning: str = None) -> Any:
"""
执行插入语句并返回新生成的 ID 或指定字段
Args:
sql: INSERT SQL 语句
params: 参数元组
returning: RETURNING 子句指定的字段名,默认为 'id'
Returns:
新生成的 ID 或指定字段的值
"""
if returning and not sql.upper().endswith(f'RETURNING {returning}'.upper()):
sql = f"{sql} RETURNING {returning}"
with self.get_cursor() as cursor:
cursor.execute(sql, params)
result = cursor.fetchone()
return result[0] if result else None
def execute_many(self, sql: str, params_list: List[tuple]) -> int:
"""
批量执行 SQL 语句
Args:
sql: SQL 语句
params_list: 参数列表,每个元素为参数元组
Returns:
影响的总行数
"""
with self.get_cursor() as cursor:
cursor.executemany(sql, params_list)
return cursor.rowcount
# 创建全局实例
db_helper = PostgresSQLHelper()