diff --git a/app/api/earthquake.py b/app/api/earthquake.py index 19b6408..03ee1f6 100644 --- a/app/api/earthquake.py +++ b/app/api/earthquake.py @@ -18,24 +18,21 @@ SOURCE_TYPE_MAP = {1: "隐患点", 2: "风险点"} LEVEL_MAP = {"低": "低", "中": "中", "较高": "较高", "高": "高"} -def _build_prediction_items(results: List[Dict[str, Any]]) -> List[PredictionItem]: - """将模型原始结果转换为接口返回格式""" - items = [] +def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]: + """将模型原始结果转换为存储格式: {id_type: 概率百分比}""" + result_map = {} for r in results: probs = r.get("disaster_probabilities", {}) - levels = r.get("disaster_levels", {}) - if not probs: continue + source_id = r["source_id"] + source_type = r.get("source_type") max_hazard = max(probs, key=probs.get) - items.append(PredictionItem( - id=r["source_id"], # 使用 source_id(隐患点/风险点ID)而非 xian_risk_factors.id - type=SOURCE_TYPE_MAP.get(r.get("source_type"), "未知"), - probability=round(probs[max_hazard], 4), - level=LEVEL_MAP.get(levels.get(max_hazard, "none"), "无"), - )) - return items + # key 格式: {source_id}_{source_type},value 为百分比概率 + key = f"{source_id}_{source_type}" + result_map[key] = round(probs[max_hazard] * 100, 2) + return result_map def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]: @@ -52,11 +49,11 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str], 同步执行地震预测(在线程池中运行) Returns: - (预测结果列表, 原始结果) + (结果map,) """ points = _fetch_points(point_ids, region_code) if not points: - return [], [] + return {} model = get_earthquake_model() raw_results = model.predict_multiple_points( @@ -66,27 +63,15 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str], epicenter_lon=epicenter_lon, epicenter_lat=epicenter_lat, ) - items = _build_prediction_items(raw_results) + result_map = _build_prediction_map(raw_results) - save_results = [ - { - "point_id": r.get("source_id"), # 使用 source_id(隐患点/风险点ID)而非 xian_risk_factors.id - "source_type": r.get("source_type"), - "lon": r.get("lon"), - "lat": r.get("lat"), - "disaster_probabilities": r.get("disaster_probabilities", {}), - "disaster_levels": r.get("disaster_levels", {}) - } - for r in raw_results - ] - - return items, save_results + return result_map @router.post("/predict", response_model=PredictResponse, summary="地震灾害链预测") async def predict_earthquake(req: EarthquakePredictRequest): """ - 根据震级、震源深度和震中位置,批量预测隐患点/风险点的次生灾害概率和等级。 + 根据震级、震源深度和震中位置,批量预测隐患点/风险点的次生灾害概率。 - **disaster_name**: 灾害名称 - **point_ids**: 点位ID列表(可选,不传则查询所有点) @@ -103,7 +88,7 @@ async def predict_earthquake(req: EarthquakePredictRequest): async with semaphore: loop = asyncio.get_event_loop() try: - items, save_results = await loop.run_in_executor( + result_map = await loop.run_in_executor( None, _predict_sync, req.point_ids, req.region_code, req.magnitude, req.depth, req.epicenter_lon, req.epicenter_lat ) @@ -113,7 +98,7 @@ async def predict_earthquake(req: EarthquakePredictRequest): # 保存推理结果 record_id = None - if save_results: + if result_map: try: condition = { "point_ids": req.point_ids, @@ -129,10 +114,10 @@ async def predict_earthquake(req: EarthquakePredictRequest): occurred_time=req.occurred_time, operation_type=req.operation_type, condition=condition, - result=save_results + result=result_map ) logger.info(f"推理结果已保存,record_id={record_id}") except Exception as e: logger.error(f"保存推理结果失败: {e}", exc_info=True) - return PredictResponse(code=200, message="success", data=items, record_id=record_id) + return PredictResponse(code=200, message="success", data=record_id) diff --git a/app/api/rainfall.py b/app/api/rainfall.py index 11b93f4..6cb691e 100644 --- a/app/api/rainfall.py +++ b/app/api/rainfall.py @@ -21,24 +21,21 @@ SOURCE_TYPE_MAP = {1: "隐患点", 2: "风险点"} LEVEL_MAP = {"低": "低", "中": "中", "较高": "较高", "高": "高"} -def _build_prediction_items(results: List[Dict[str, Any]]) -> List[PredictionItem]: - """将模型原始结果转换为接口返回格式""" - items = [] +def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]: + """将模型原始结果转换为存储格式: {id_type: 概率百分比}""" + result_map = {} for r in results: probs = r.get("disaster_probabilities", {}) - levels = r.get("disaster_levels", {}) - if not probs: continue + source_id = r["source_id"] + source_type = r.get("source_type") max_hazard = max(probs, key=probs.get) - items.append(PredictionItem( - id=r["source_id"], # 使用 source_id(隐患点/风险点ID)而非 xian_risk_factors.id - type=SOURCE_TYPE_MAP.get(r.get("source_type"), "未知"), - probability=round(probs[max_hazard], 4), - level=LEVEL_MAP.get(levels.get(max_hazard, "none"), "无"), - )) - return items + # key 格式: {source_id}_{source_type},value 为百分比概率 + key = f"{source_id}_{source_type}" + result_map[key] = round(probs[max_hazard] * 100, 2) + return result_map def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]: @@ -55,15 +52,15 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str], 同步执行暴雨预测(在线程池中运行) Returns: - (预测结果列表, 原始结果, 输入条件, 当前时间) + (结果map, 输入条件, 当前时间) """ points = _fetch_points(point_ids, region_code) if not points: - return [], [], {}, datetime.now() + return {}, {}, datetime.now() model = get_rainfall_model() raw_results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration) - items = _build_prediction_items(raw_results) + result_map = _build_prediction_map(raw_results) # 构建条件和结果用于保存 now = datetime.now() @@ -73,19 +70,8 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str], "rainfall": rainfall, "duration": duration } - save_results = [ - { - "point_id": r.get("source_id"), # 使用 source_id(隐患点/风险点ID)而非 xian_risk_factors.id - "source_type": r.get("source_type"), - "lon": r.get("lon"), - "lat": r.get("lat"), - "disaster_probabilities": r.get("disaster_probabilities", {}), - "disaster_levels": r.get("disaster_levels", {}) - } - for r in raw_results - ] - return items, save_results, condition, now + return result_map, condition, now @router.post("/update-monitoring-time", summary="更新降雨监测查询时间") @@ -119,7 +105,7 @@ async def update_monitoring_time(req: UpdateMonitoringTimeRequest): @router.post("/predict", response_model=PredictResponse, summary="暴雨灾害链预测") async def predict_rainfall(req: RainfallPredictRequest): """ - 根据降雨量和持续时间,批量预测隐患点/风险点的灾害概率和等级。 + 根据降雨量和持续时间,批量预测隐患点/风险点的灾害概率。 - **disaster_name**: 灾害名称 - **point_ids**: 点位ID列表(可选,不传则查询所有点) @@ -133,7 +119,7 @@ async def predict_rainfall(req: RainfallPredictRequest): async with semaphore: loop = asyncio.get_event_loop() try: - items, save_results, condition, now = await loop.run_in_executor( + result_map, condition, now = await loop.run_in_executor( None, _predict_sync, req.point_ids, req.region_code, req.rainfall, req.duration, req.operation_type ) @@ -143,7 +129,7 @@ async def predict_rainfall(req: RainfallPredictRequest): # 保存推理结果 record_id = None - if save_results: + if result_map: try: record_id = dbn_repository.save_inference_result( disaster_name=req.disaster_name, @@ -151,10 +137,10 @@ async def predict_rainfall(req: RainfallPredictRequest): occurred_time=now, operation_type=req.operation_type, condition=condition, - result=save_results + result=result_map ) logger.info(f"推理结果已保存,record_id={record_id}") except Exception as e: logger.error(f"保存推理结果失败: {e}", exc_info=True) - return PredictResponse(code=200, message="success", data=items, record_id=record_id) + return PredictResponse(code=200, message="success", data=record_id) diff --git a/app/schemas/api_schemas.py b/app/schemas/api_schemas.py index c0b4e53..ad9feac 100644 --- a/app/schemas/api_schemas.py +++ b/app/schemas/api_schemas.py @@ -59,8 +59,7 @@ class PredictResponse(BaseModel): """预测响应""" code: int = Field(200, description="状态码") message: str = Field("success", description="提示信息") - data: List[PredictionItem] = Field(default_factory=list, description="预测结果列表") - record_id: Optional[int] = Field(None, description="推理结果记录ID") + data: Optional[int] = Field(None, description="推理结果记录ID") class UpdateMonitoringTimeRequest(BaseModel):