From eddbdaca1fc35387f0aefa9bb6224bd14bab68e6 Mon Sep 17 00:00:00 2001 From: wzy-warehouse <18135009705@163.com> Date: Sat, 6 Jun 2026 08:38:19 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9A=B4=E9=9B=A8=E5=9C=B0=E9=9C=87=E7=81=BE?= =?UTF-8?q?=E5=AE=B3=E9=93=BEHTTP=E8=AF=B7=E6=B1=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/__init__.py | 13 +++ app/api/earthquake.py | 92 +++++++++++++++++++++ app/api/rainfall.py | 83 +++++++++++++++++++ app/core/launcher.py | 18 ++++ app/core/server.py | 55 ++++++++++++ app/models/dbn/earthquake/earthquake_dbn.py | 81 ++++++++++++++---- app/models/dbn/rainfall/rainfall_dbn.py | 49 +++++++++-- app/repositories/dbn_repository.py | 76 +++++++++++++---- app/schemas/api_schemas.py | 59 +++++++++++++ app/utils/api_deps.py | 59 +++++++++++++ requirements.txt | 4 +- 11 files changed, 551 insertions(+), 38 deletions(-) create mode 100644 app/api/__init__.py create mode 100644 app/api/earthquake.py create mode 100644 app/api/rainfall.py create mode 100644 app/core/server.py create mode 100644 app/schemas/api_schemas.py create mode 100644 app/utils/api_deps.py diff --git a/app/api/__init__.py b/app/api/__init__.py new file mode 100644 index 0000000..2ab7859 --- /dev/null +++ b/app/api/__init__.py @@ -0,0 +1,13 @@ +""" +API 路由模块 +""" +from fastapi import FastAPI + + +def register_routers(application: FastAPI): + """注册所有路由""" + from app.api.rainfall import router as rainfall_router + from app.api.earthquake import router as earthquake_router + + application.include_router(rainfall_router) + application.include_router(earthquake_router) diff --git a/app/api/earthquake.py b/app/api/earthquake.py new file mode 100644 index 0000000..eb5f5d0 --- /dev/null +++ b/app/api/earthquake.py @@ -0,0 +1,92 @@ +""" +地震灾害链预测接口 +""" +import asyncio +from typing import List, Dict, Any, Optional + +from fastapi import APIRouter, HTTPException + +from app.schemas.api_schemas import EarthquakePredictRequest, PredictResponse, PredictionItem +from app.utils.api_deps import get_earthquake_model, get_prediction_semaphore +from app.repositories.dbn_repository import dbn_repository +from app.config.paths import get_logger + +router = APIRouter(prefix="/earthquake", tags=["地震灾害链"]) +logger = get_logger("api.earthquake") + +SOURCE_TYPE_MAP = {1: "隐患点", 2: "风险点"} +LEVEL_MAP = {"低": "低", "中": "中", "较高": "较高", "高": "高"} + + +def _build_prediction_items(results: List[Dict[str, Any]]) -> List[PredictionItem]: + """将模型原始结果转换为接口返回格式""" + items = [] + for r in results: + probs = r.get("disaster_probabilities", {}) + levels = r.get("disaster_levels", {}) + + if not probs: + continue + + max_hazard = max(probs, key=probs.get) + items.append(PredictionItem( + id=r["point_id"], + type=SOURCE_TYPE_MAP.get(r.get("source_type"), "未知"), + probability=round(probs[max_hazard], 4), + level=LEVEL_MAP.get(levels.get(max_hazard, "none"), "无"), + )) + return items + + +def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]: + """获取点位列表""" + if point_ids: + return dbn_repository.get_points_by_ids(point_ids) + return dbn_repository.get_all_points(region_code) + + +def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str], + magnitude: float, depth: float, + epicenter_lon: float, epicenter_lat: float) -> List[PredictionItem]: + """同步执行地震预测(在线程池中运行)""" + points = _fetch_points(point_ids, region_code) + if not points: + return [] + + model = get_earthquake_model() + results = model.predict_multiple_points( + points, + magnitude=magnitude, + depth=depth, + epicenter_lon=epicenter_lon, + epicenter_lat=epicenter_lat, + ) + return _build_prediction_items(results) + + +@router.post("/predict", response_model=PredictResponse, summary="地震灾害链预测") +async def predict_earthquake(req: EarthquakePredictRequest): + """ + 根据震级、震源深度和震中位置,批量预测隐患点/风险点的次生灾害概率和等级。 + + - **point_ids**: 点位ID列表(可选,不传则查询所有点) + - **region_code**: 行政区划代码(可选,不传则不限区域) + - **magnitude**: 震级(Richter) + - **depth**: 震源深度(km),默认10km + - **epicenter_lon**: 震中经度 + - **epicenter_lat**: 震中纬度 + """ + semaphore = get_prediction_semaphore() + + async with semaphore: + loop = asyncio.get_event_loop() + try: + items = await loop.run_in_executor( + None, _predict_sync, req.point_ids, req.region_code, + req.magnitude, req.depth, req.epicenter_lon, req.epicenter_lat + ) + except Exception as e: + logger.error(f"地震预测失败: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"预测失败: {e}") + + return PredictResponse(code=200, message="success", data=items) diff --git a/app/api/rainfall.py b/app/api/rainfall.py new file mode 100644 index 0000000..3ace185 --- /dev/null +++ b/app/api/rainfall.py @@ -0,0 +1,83 @@ +""" +暴雨灾害链预测接口 +""" +import asyncio +from typing import List, Dict, Any, Optional + +from fastapi import APIRouter, HTTPException + +from app.schemas.api_schemas import RainfallPredictRequest, PredictResponse, PredictionItem +from app.utils.api_deps import get_rainfall_model, get_prediction_semaphore +from app.repositories.dbn_repository import dbn_repository +from app.config.paths import get_logger + +router = APIRouter(prefix="/rainfall", tags=["暴雨灾害链"]) +logger = get_logger("api.rainfall") + +SOURCE_TYPE_MAP = {1: "隐患点", 2: "风险点"} +LEVEL_MAP = {"低": "低", "中": "中", "较高": "较高", "高": "高"} + + +def _build_prediction_items(results: List[Dict[str, Any]]) -> List[PredictionItem]: + """将模型原始结果转换为接口返回格式""" + items = [] + for r in results: + probs = r.get("disaster_probabilities", {}) + levels = r.get("disaster_levels", {}) + + if not probs: + continue + + max_hazard = max(probs, key=probs.get) + items.append(PredictionItem( + id=r["point_id"], + type=SOURCE_TYPE_MAP.get(r.get("source_type"), "未知"), + probability=round(probs[max_hazard], 4), + level=LEVEL_MAP.get(levels.get(max_hazard, "none"), "无"), + )) + return items + + +def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]: + """获取点位列表""" + if point_ids: + return dbn_repository.get_points_by_ids(point_ids) + return dbn_repository.get_all_points(region_code) + + +def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str], + rainfall: float, duration: float) -> List[PredictionItem]: + """同步执行暴雨预测(在线程池中运行)""" + points = _fetch_points(point_ids, region_code) + if not points: + return [] + + model = get_rainfall_model() + results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration) + return _build_prediction_items(results) + + +@router.post("/predict", response_model=PredictResponse, summary="暴雨灾害链预测") +async def predict_rainfall(req: RainfallPredictRequest): + """ + 根据降雨量和持续时间,批量预测隐患点/风险点的灾害概率和等级。 + + - **point_ids**: 点位ID列表(可选,不传则查询所有点) + - **region_code**: 行政区划代码(可选,不传则不限区域) + - **rainfall**: 累计降雨量(mm) + - **duration**: 降雨持续时间(h) + """ + semaphore = get_prediction_semaphore() + + async with semaphore: + loop = asyncio.get_event_loop() + try: + items = await loop.run_in_executor( + None, _predict_sync, req.point_ids, req.region_code, + req.rainfall, req.duration + ) + except Exception as e: + logger.error(f"暴雨预测失败: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"预测失败: {e}") + + return PredictResponse(code=200, message="success", data=items) diff --git a/app/core/launcher.py b/app/core/launcher.py index 34cff02..8784446 100644 --- a/app/core/launcher.py +++ b/app/core/launcher.py @@ -51,12 +51,30 @@ class AppLauncher: def start(): """启动应用服务""" + import threading + from config import settings from app.core.rainfall_manager import rainfall_manager from app.utils.logger import get_logger from app.utils.thread_pool_manager import block_main_thread, thread_pool_manager logger = get_logger() + # 启动 FastAPI 服务(守护线程) + def run_api_server(): + import uvicorn + from app.core.server import create_app + api_app = create_app() + uvicorn.run( + api_app, + host=getattr(settings, "API_HOST", "127.0.0.1"), + port=int(getattr(settings, "API_PORT", 8082)), + log_level="info", + ) + + api_thread = threading.Thread(target=run_api_server, daemon=True, name="api-server") + api_thread.start() + logger.info(f"FastAPI 服务已启动: http://{getattr(settings, 'API_HOST', '127.0.0.1')}:{getattr(settings, 'API_PORT', 8082)}") + # 启动降雨站点监测 logger.info("启动降雨站点监测服务...") rainfall_manager.monitoring_rainfall_station_id('2025-09-16 20:00:00') diff --git a/app/core/server.py b/app/core/server.py new file mode 100644 index 0000000..76ccf8c --- /dev/null +++ b/app/core/server.py @@ -0,0 +1,55 @@ +""" +FastAPI 服务创建与配置 +""" +import time +from contextlib import asynccontextmanager + +from fastapi import FastAPI, Request + +from app.utils.api_deps import get_rainfall_model, get_earthquake_model, is_model_loaded +from app.schemas.api_schemas import HealthResponse +from app.config.paths import get_logger + +logger = get_logger("api") + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """应用生命周期:启动时预加载模型""" + logger.info("正在预加载DBN模型...") + get_rainfall_model() + get_earthquake_model() + logger.info("DBN模型预加载完成") + yield + logger.info("应用关闭") + + +def create_app() -> FastAPI: + """创建 FastAPI 应用实例""" + application = FastAPI( + title="西安灾害链预测服务", + description="基于动态贝叶斯网络的暴雨/地震灾害链预测API", + version="1.0.0", + lifespan=lifespan, + ) + + @application.middleware("http") + async def log_requests(request: Request, call_next): + """请求日志中间件""" + start = time.time() + response = await call_next(request) + elapsed = time.time() - start + logger.info(f"{request.method} {request.url.path} -> {response.status_code} ({elapsed:.3f}s)") + return response + + # 注册路由 + from app.api import register_routers + register_routers(application) + + @application.get("/health", response_model=HealthResponse, tags=["系统"]) + async def health_check(): + """健康检查""" + status = is_model_loaded() + return HealthResponse(status="ok", **status) + + return application diff --git a/app/models/dbn/earthquake/earthquake_dbn.py b/app/models/dbn/earthquake/earthquake_dbn.py index caedaaa..0e64d51 100644 --- a/app/models/dbn/earthquake/earthquake_dbn.py +++ b/app/models/dbn/earthquake/earthquake_dbn.py @@ -18,21 +18,20 @@ logger = get_logger("earthquake_dbn") class EarthquakeDBN: """地震灾害链DBN模型""" - # 灾害概率→离散等级的阈值映射 + # 灾害概率→等级的阈值映射 HAZARD_LEVEL_THRESHOLDS = [ - (0.6, 'very_high'), - (0.4, 'high'), - (0.2, 'medium'), - (0.05, 'low'), - (0.0, 'none'), + (0.7, '高'), + (0.5, '较高'), + (0.3, '中'), + (0.0, '低'), ] def _probability_to_level(self, prob: float) -> str: - """将连续概率映射到离散等级""" + """将连续概率映射到风险等级:低(<30%) / 中(30-50%) / 较高(50-70%) / 高(70%+)""" for threshold, level in self.HAZARD_LEVEL_THRESHOLDS: if prob >= threshold: return level - return 'none' + return '低' def __init__(self, config_dir: Optional[str] = None): """ @@ -114,18 +113,20 @@ class EarthquakeDBN: } @staticmethod - def estimate_seismic_intensity(magnitude: float, epicenter_distance_km: float) -> float: + def estimate_seismic_intensity(magnitude: float, epicenter_distance_km: float, + depth_km: float = 10.0) -> float: """ - 根据震级和震中距估算地震烈度 - 使用中国地震烈度衰减关系 + 根据震级、震中距和震源深度估算地震烈度 - I = 0.923 + 1.621*M - 3.494*ln(R+10) + I = 0.923 + 1.621*M - 3.494*ln(R+10) - ln(H/10) 参考:GB 18306-2015 中国地震动参数区划图 + 深度修正:震源越深,地表烈度越低 Args: magnitude: 震级(Richter) epicenter_distance_km: 震中距(km) + depth_km: 震源深度(km),默认10km Returns: 估算的地震烈度(中国烈度表数值) @@ -135,6 +136,10 @@ class EarthquakeDBN: intensity = 0.923 + 1.621 * magnitude - 3.494 * math.log(epicenter_distance_km + 10) + # 震源深度修正:以10km为基准,深度越大烈度衰减越多 + depth_km = max(depth_km, 1.0) + intensity -= math.log(depth_km / 10.0) + # 限制在合理范围内 return max(1.0, min(12.0, intensity)) @@ -191,7 +196,8 @@ class EarthquakeDBN: epicenter_distance: Optional[float] = None, seismic_intensity: Optional[float] = None, epicenter_lon: Optional[float] = None, - epicenter_lat: Optional[float] = None) -> Dict[str, Any]: + epicenter_lat: Optional[float] = None, + depth: float = 10.0) -> Dict[str, Any]: """ 对单个点进行地震灾害预测 @@ -202,6 +208,7 @@ class EarthquakeDBN: seismic_intensity: 地震烈度(中国烈度表),若未提供则自动估算 epicenter_lon: 震中经度(可选,用于计算震中距) epicenter_lat: 震中纬度(可选,用于计算震中距) + depth: 震源深度(km),默认10.0 Returns: 预测结果 @@ -226,7 +233,7 @@ class EarthquakeDBN: # 估算地震烈度(如果未直接提供) if seismic_intensity is None: - seismic_intensity = self.estimate_seismic_intensity(magnitude, epicenter_distance) + seismic_intensity = self.estimate_seismic_intensity(magnitude, epicenter_distance, depth) logger.info(f"估算地震烈度: {seismic_intensity:.1f}") # 获取静态因子数据 @@ -379,6 +386,52 @@ class EarthquakeDBN: return results + def predict_multiple_points(self, points: List[Dict[str, Any]], + magnitude: float = 6.0, + epicenter_distance: Optional[float] = None, + seismic_intensity: Optional[float] = None, + epicenter_lon: Optional[float] = None, + epicenter_lat: Optional[float] = None, + depth: float = 10.0) -> List[Dict[str, Any]]: + """ + 对已获取的点列表进行地震灾害预测 + + Args: + points: 点信息列表(已从数据库获取) + magnitude: 地震震级 + epicenter_distance: 震中距(km,可选) + seismic_intensity: 地震烈度(可选) + epicenter_lon: 震中经度(可选) + epicenter_lat: 震中纬度(可选) + depth: 震源深度(km),默认10.0 + + Returns: + 预测结果列表 + """ + results = [] + for point in points: + try: + result = self.predict_single_point( + point, + magnitude=magnitude, + epicenter_distance=epicenter_distance, + seismic_intensity=seismic_intensity, + epicenter_lon=epicenter_lon, + epicenter_lat=epicenter_lat, + depth=depth + ) + results.append(result) + except Exception as e: + logger.error(f"预测点 {point.get('id')} 失败: {e}") + results.append({ + 'point_id': point.get('id'), + 'source_type': point.get('source_type'), + 'lon': point.get('lon'), + 'lat': point.get('lat'), + 'error': str(e) + }) + return results + def get_model_info(self) -> Dict[str, Any]: """获取模型信息""" return { diff --git a/app/models/dbn/rainfall/rainfall_dbn.py b/app/models/dbn/rainfall/rainfall_dbn.py index aaa048a..f367d33 100644 --- a/app/models/dbn/rainfall/rainfall_dbn.py +++ b/app/models/dbn/rainfall/rainfall_dbn.py @@ -17,21 +17,20 @@ logger = get_logger("dbn") class RainfallDBN: """暴雨灾害链DBN模型""" - # 灾害概率→离散等级的阈值映射 + # 灾害概率→等级的阈值映射 HAZARD_LEVEL_THRESHOLDS = [ - (0.6, 'very_high'), - (0.4, 'high'), - (0.2, 'medium'), - (0.05, 'low'), - (0.0, 'none'), + (0.7, '高'), + (0.5, '较高'), + (0.3, '中'), + (0.0, '低'), ] def _probability_to_level(self, prob: float) -> str: - """将连续概率映射到离散等级""" + """将连续概率映射到风险等级:低(<30%) / 中(30-50%) / 较高(50-70%) / 高(70%+)""" for threshold, level in self.HAZARD_LEVEL_THRESHOLDS: if prob >= threshold: return level - return 'none' + return '低' def __init__(self, config_dir: Optional[str] = None): """ @@ -374,6 +373,40 @@ class RainfallDBN: return results + def predict_multiple_points(self, points: List[Dict[str, Any]], + rainfall: Optional[float] = None, + duration: Optional[float] = None, + query_time: Optional[datetime] = None) -> List[Dict[str, Any]]: + """ + 对已获取的点列表进行暴雨灾害预测 + + Args: + points: 点信息列表(已从数据库获取) + rainfall: 累计降雨量(可选) + duration: 持续时间(可选) + query_time: 查询时间(可选) + + Returns: + 预测结果列表 + """ + results = [] + for point in points: + try: + result = self.predict_single_point( + point, rainfall=rainfall, duration=duration, query_time=query_time + ) + results.append(result) + except Exception as e: + logger.error(f"预测点 {point.get('id')} 失败: {e}") + results.append({ + 'point_id': point.get('id'), + 'source_type': point.get('source_type'), + 'lon': point.get('lon'), + 'lat': point.get('lat'), + 'error': str(e) + }) + return results + def get_model_info(self) -> Dict[str, Any]: """ 获取模型信息 diff --git a/app/repositories/dbn_repository.py b/app/repositories/dbn_repository.py index cd7e02b..c25c64d 100644 --- a/app/repositories/dbn_repository.py +++ b/app/repositories/dbn_repository.py @@ -22,26 +22,35 @@ class DbnRepository: 获取所有隐患点和风险点(从 xian_risk_factors 表) Args: - region_code: 行政区划代码(区县名称),可选 + region_code: 行政区划代码(如 '610104'),可选,匹配隐患点.county_code 和风险点.unit_code Returns: 点列表,每个元素包含:id, source_id, source_type, lon, lat, static_factors """ - sql = """ - SELECT - id, - source_id, - source_type, - lon, - lat, - static_factors - FROM xian_risk_factors - WHERE is_delete = 0 - """ - params = (region_code,) if region_code else None - if region_code: - sql += " AND county = %s" + # 通过源表的行政区划代码筛选 + sql = """ + SELECT rf.id, rf.source_id, rf.source_type, rf.lon, rf.lat, rf.static_factors + FROM xian_risk_factors rf + WHERE rf.is_delete = 0 + AND ( + (rf.source_type = 1 AND rf.source_id IN ( + SELECT id FROM xian_hidden_danger_spots WHERE county_id = %s AND is_delete = 0 + )) + OR + (rf.source_type = 2 AND rf.source_id IN ( + SELECT id FROM xian_risk_spots WHERE unit_code = %s AND is_delete = 0 + )) + ) + """ + params = (region_code, region_code) + else: + sql = """ + SELECT id, source_id, source_type, lon, lat, static_factors + FROM xian_risk_factors + WHERE is_delete = 0 + """ + params = None results = db_helper.execute_query(sql, params) @@ -94,6 +103,43 @@ class DbnRepository: 'static_factors': result.get('static_factors') or {} } + @staticmethod + def get_points_by_ids(point_ids: List[int]) -> List[Dict[str, Any]]: + """ + 批量获取点信息 + + Args: + point_ids: 点ID列表 + + Returns: + 点信息列表 + """ + if not point_ids: + return [] + + placeholders = ','.join(['%s'] * len(point_ids)) + sql = f""" + SELECT + rf.id, + rf.source_id, + rf.source_type, + rf.lon, + rf.lat, + rf.static_factors + FROM xian_risk_factors rf + WHERE rf.id IN ({placeholders}) AND rf.is_delete = 0 + """ + results = db_helper.execute_query(sql, tuple(point_ids)) + + return [{ + 'id': row['id'], + 'source_id': row['source_id'], + 'source_type': row['source_type'], + 'lon': float(row['lon']) if row['lon'] else None, + 'lat': float(row['lat']) if row['lat'] else None, + 'static_factors': row.get('static_factors') or {} + } for row in results] + @staticmethod def get_static_factors(point_id: int) -> Dict[str, Any]: """ diff --git a/app/schemas/api_schemas.py b/app/schemas/api_schemas.py new file mode 100644 index 0000000..74d2ca6 --- /dev/null +++ b/app/schemas/api_schemas.py @@ -0,0 +1,59 @@ +""" +API 请求/响应数据模型 +""" +from typing import List, Optional +from pydantic import BaseModel, Field + + +# ============================================================ +# 暴雨预测 +# ============================================================ + +class RainfallPredictRequest(BaseModel): + """暴雨灾害链预测请求""" + point_ids: Optional[List[int]] = Field(None, max_length=500, + description="点位ID列表,不传则查询所有点") + region_code: Optional[str] = Field(None, description="行政区划代码(如 '610104'),不传则不限区域") + rainfall: float = Field(..., ge=0, description="累计降雨量(mm)") + duration: float = Field(..., ge=0, description="降雨持续时间(h)") + + +# ============================================================ +# 地震预测 +# ============================================================ + +class EarthquakePredictRequest(BaseModel): + """地震灾害链预测请求""" + point_ids: Optional[List[int]] = Field(None, max_length=500, + description="点位ID列表,不传则查询所有点") + region_code: Optional[str] = Field(None, description="行政区划代码(如 '610104'),不传则不限区域") + magnitude: float = Field(..., ge=0, le=10, description="震级(Richter)") + depth: float = Field(10.0, gt=0, le=700, description="震源深度(km),默认10km") + epicenter_lon: float = Field(..., ge=-180, le=180, description="震中经度") + epicenter_lat: float = Field(..., ge=-90, le=90, description="震中纬度") + + +# ============================================================ +# 通用响应 +# ============================================================ + +class PredictionItem(BaseModel): + """单个点位预测结果""" + id: int = Field(..., description="点位ID") + type: str = Field(..., description="类型: 隐患点 / 风险点") + probability: float = Field(..., description="最大灾害概率") + level: str = Field(..., description="灾害等级: 低/中/较高/高") + + +class PredictResponse(BaseModel): + """预测响应""" + code: int = Field(200, description="状态码") + message: str = Field("success", description="提示信息") + data: List[PredictionItem] = Field(default_factory=list, description="预测结果列表") + + +class HealthResponse(BaseModel): + """健康检查响应""" + status: str = "ok" + rainfall_model_loaded: bool = False + earthquake_model_loaded: bool = False diff --git a/app/utils/api_deps.py b/app/utils/api_deps.py new file mode 100644 index 0000000..d57c0a2 --- /dev/null +++ b/app/utils/api_deps.py @@ -0,0 +1,59 @@ +""" +API 依赖注入 +模型懒加载 + 并发控制 +""" +import asyncio +from typing import Optional + +from app.config.paths import get_logger + +logger = get_logger("api") + +# ============================================================ +# 并发控制:限制同时进行的预测任务数,防止资源耗尽 +# ============================================================ +MAX_CONCURRENT_PREDICTIONS = 8 +_prediction_semaphore: Optional[asyncio.Semaphore] = None + + +def get_prediction_semaphore() -> asyncio.Semaphore: + """获取预测信号量(惰性初始化,兼容事件循环)""" + global _prediction_semaphore + if _prediction_semaphore is None: + _prediction_semaphore = asyncio.Semaphore(MAX_CONCURRENT_PREDICTIONS) + return _prediction_semaphore + + +# ============================================================ +# 模型单例(启动时加载一次) +# ============================================================ +_rainfall_model = None +_earthquake_model = None + + +def get_rainfall_model(): + """获取暴雨DBN模型单例""" + global _rainfall_model + if _rainfall_model is None: + from app.models.dbn.rainfall.rainfall_dbn import RainfallDBN + _rainfall_model = RainfallDBN() + logger.info("暴雨DBN模型加载完成") + return _rainfall_model + + +def get_earthquake_model(): + """获取地震DBN模型单例""" + global _earthquake_model + if _earthquake_model is None: + from app.models.dbn.earthquake.earthquake_dbn import EarthquakeDBN + _earthquake_model = EarthquakeDBN() + logger.info("地震DBN模型加载完成") + return _earthquake_model + + +def is_model_loaded() -> dict: + """检查模型加载状态""" + return { + "rainfall_model_loaded": _rainfall_model is not None, + "earthquake_model_loaded": _earthquake_model is not None, + } diff --git a/requirements.txt b/requirements.txt index a0ff547..1e843cc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,6 @@ numpy == 2.4.4 scipy == 1.17.1 matplotlib == 3.10.0 Pillow == 12.2.0 -pyyaml == 6.0.2 \ No newline at end of file +pyyaml == 6.0.2 +fastapi == 0.136.3 +uvicorn[standard] == 0.49.0 \ No newline at end of file