diff --git a/app/api/earthquake.py b/app/api/earthquake.py index eb5f5d0..36e1725 100644 --- a/app/api/earthquake.py +++ b/app/api/earthquake.py @@ -47,21 +47,40 @@ def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str], magnitude: float, depth: float, - epicenter_lon: float, epicenter_lat: float) -> List[PredictionItem]: - """同步执行地震预测(在线程池中运行)""" + epicenter_lon: float, epicenter_lat: float) -> tuple: + """ + 同步执行地震预测(在线程池中运行) + + Returns: + (预测结果列表, 原始结果) + """ points = _fetch_points(point_ids, region_code) if not points: - return [] + return [], [] model = get_earthquake_model() - results = model.predict_multiple_points( + raw_results = model.predict_multiple_points( points, magnitude=magnitude, depth=depth, epicenter_lon=epicenter_lon, epicenter_lat=epicenter_lat, ) - return _build_prediction_items(results) + items = _build_prediction_items(raw_results) + + save_results = [ + { + "point_id": r.get("point_id"), + "source_type": r.get("source_type"), + "lon": r.get("lon"), + "lat": r.get("lat"), + "disaster_probabilities": r.get("disaster_probabilities", {}), + "disaster_levels": r.get("disaster_levels", {}) + } + for r in raw_results + ] + + return items, save_results @router.post("/predict", response_model=PredictResponse, summary="地震灾害链预测") @@ -75,13 +94,15 @@ async def predict_earthquake(req: EarthquakePredictRequest): - **depth**: 震源深度(km),默认10km - **epicenter_lon**: 震中经度 - **epicenter_lat**: 震中纬度 + - **occurred_time**: 地震发生时间 + - **operation_type**: 操作类型(如 '实时监测', '情景模拟', '应急评估') """ semaphore = get_prediction_semaphore() async with semaphore: loop = asyncio.get_event_loop() try: - items = await loop.run_in_executor( + items, save_results = await loop.run_in_executor( None, _predict_sync, req.point_ids, req.region_code, req.magnitude, req.depth, req.epicenter_lon, req.epicenter_lat ) @@ -89,4 +110,27 @@ async def predict_earthquake(req: EarthquakePredictRequest): logger.error(f"地震预测失败: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"预测失败: {e}") - return PredictResponse(code=200, message="success", data=items) + # 保存推理结果 + record_id = None + if save_results: + try: + condition = { + "point_ids": req.point_ids, + "region_code": req.region_code, + "magnitude": req.magnitude, + "depth": req.depth, + "epicenter_lon": req.epicenter_lon, + "epicenter_lat": req.epicenter_lat + } + record_id = dbn_repository.save_inference_result( + event_type="earthquake", + occurred_time=req.occurred_time, + operation_type=req.operation_type, + condition=condition, + result=save_results + ) + logger.info(f"推理结果已保存,record_id={record_id}") + except Exception as e: + logger.error(f"保存推理结果失败: {e}", exc_info=True) + + return PredictResponse(code=200, message="success", data=items, record_id=record_id) diff --git a/app/api/rainfall.py b/app/api/rainfall.py index 3ace185..e2bf4ea 100644 --- a/app/api/rainfall.py +++ b/app/api/rainfall.py @@ -2,6 +2,7 @@ 暴雨灾害链预测接口 """ import asyncio +from datetime import datetime from typing import List, Dict, Any, Optional from fastapi import APIRouter, HTTPException @@ -46,15 +47,43 @@ def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str], - rainfall: float, duration: float) -> List[PredictionItem]: - """同步执行暴雨预测(在线程池中运行)""" + rainfall: Optional[float], duration: Optional[float], + operation_type: str) -> tuple: + """ + 同步执行暴雨预测(在线程池中运行) + + Returns: + (预测结果列表, 原始结果, 输入条件, 当前时间) + """ points = _fetch_points(point_ids, region_code) if not points: - return [] + return [], [], {}, datetime.now() model = get_rainfall_model() - results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration) - return _build_prediction_items(results) + raw_results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration) + items = _build_prediction_items(raw_results) + + # 构建条件和结果用于保存 + now = datetime.now() + condition = { + "point_ids": point_ids, + "region_code": region_code, + "rainfall": rainfall, + "duration": duration + } + save_results = [ + { + "point_id": r.get("point_id"), + "source_type": r.get("source_type"), + "lon": r.get("lon"), + "lat": r.get("lat"), + "disaster_probabilities": r.get("disaster_probabilities", {}), + "disaster_levels": r.get("disaster_levels", {}) + } + for r in raw_results + ] + + return items, save_results, condition, now @router.post("/predict", response_model=PredictResponse, summary="暴雨灾害链预测") @@ -64,20 +93,36 @@ async def predict_rainfall(req: RainfallPredictRequest): - **point_ids**: 点位ID列表(可选,不传则查询所有点) - **region_code**: 行政区划代码(可选,不传则不限区域) - - **rainfall**: 累计降雨量(mm) - - **duration**: 降雨持续时间(h) + - **rainfall**: 累计降雨量(mm),不传则从气象表自动获取 + - **duration**: 降雨持续时间(h),不传则从气象表自动获取 + - **operation_type**: 操作类型(如 '实时监测', '情景模拟', '应急评估') """ semaphore = get_prediction_semaphore() async with semaphore: loop = asyncio.get_event_loop() try: - items = await loop.run_in_executor( + items, save_results, condition, now = await loop.run_in_executor( None, _predict_sync, req.point_ids, req.region_code, - req.rainfall, req.duration + req.rainfall, req.duration, req.operation_type ) except Exception as e: logger.error(f"暴雨预测失败: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"预测失败: {e}") - return PredictResponse(code=200, message="success", data=items) + # 保存推理结果 + record_id = None + if save_results: + try: + record_id = dbn_repository.save_inference_result( + event_type="rainfall", + occurred_time=now, + operation_type=req.operation_type, + condition=condition, + result=save_results + ) + logger.info(f"推理结果已保存,record_id={record_id}") + except Exception as e: + logger.error(f"保存推理结果失败: {e}", exc_info=True) + + return PredictResponse(code=200, message="success", data=items, record_id=record_id) diff --git a/app/repositories/dbn_repository.py b/app/repositories/dbn_repository.py index a1fa0ea..2f447bb 100644 --- a/app/repositories/dbn_repository.py +++ b/app/repositories/dbn_repository.py @@ -684,6 +684,37 @@ class DbnRepository: return float(result['aspect']) return None + @staticmethod + def save_inference_result(event_type: str, occurred_time, operation_type: str, + condition: dict, result: list) -> int: + """ + 保存推理结果到 inference_result 表 + + Args: + event_type: 事件类型('rainfall' 或 'earthquake') + occurred_time: 事件发生时间 + operation_type: 操作类型 + condition: 输入条件(JSON) + result: 预测结果列表(JSON) + + Returns: + 新插入记录的 ID + """ + import json + sql = """ + INSERT INTO inference_result (event_type, occurred_time, operation_type, condition, result) + VALUES (%s, %s, %s, %s::jsonb, %s::jsonb) + RETURNING id + """ + row = db_helper.execute_query_one(sql, ( + event_type, + occurred_time, + operation_type, + json.dumps(condition, ensure_ascii=False), + json.dumps(result, ensure_ascii=False) + )) + return row['id'] if row else 0 + # 创建全局实例 dbn_repository = DbnRepository() diff --git a/app/schemas/api_schemas.py b/app/schemas/api_schemas.py index 618b4e6..77f108c 100644 --- a/app/schemas/api_schemas.py +++ b/app/schemas/api_schemas.py @@ -1,6 +1,7 @@ """ API 请求/响应数据模型 """ +from datetime import datetime from typing import List, Optional from pydantic import BaseModel, Field @@ -18,6 +19,8 @@ class RainfallPredictRequest(BaseModel): description="累计降雨量(mm),不传则从气象表自动获取") duration: Optional[float] = Field(None, ge=0, description="降雨持续时间(h),不传则从气象表自动获取") + operation_type: str = Field("模拟", min_length=1, max_length=50, + description="操作类型(如 '模拟', '实时监测', '应急评估')") # ============================================================ @@ -33,6 +36,9 @@ class EarthquakePredictRequest(BaseModel): depth: float = Field(10.0, gt=0, le=700, description="震源深度(km),默认10km") epicenter_lon: float = Field(..., ge=-180, le=180, description="震中经度") epicenter_lat: float = Field(..., ge=-90, le=90, description="震中纬度") + occurred_time: datetime = Field(..., description="地震发生时间") + operation_type: str = Field("模拟", min_length=1, max_length=50, + description="操作类型(如 '模拟', '实时监测', '应急评估')") # ============================================================ @@ -52,6 +58,7 @@ class PredictResponse(BaseModel): code: int = Field(200, description="状态码") message: str = Field("success", description="提示信息") data: List[PredictionItem] = Field(default_factory=list, description="预测结果列表") + record_id: Optional[int] = Field(None, description="推理结果记录ID") class HealthResponse(BaseModel):