From 75046c99c89c24750bf0266bfba0653ae64dac47 Mon Sep 17 00:00:00 2001 From: wzy-warehouse <18135009705@163.com> Date: Sun, 14 Jun 2026 16:41:31 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9API=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=E8=BF=94=E5=9B=9E=E5=80=BC=E2=80=94=E2=80=94=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E7=BB=8F=E7=BA=AC=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/earthquake.py | 34 +++++++++++++++++++++++++++------- app/api/rainfall.py | 34 +++++++++++++++++++++++++++------- app/schemas/api_schemas.py | 4 ++-- 3 files changed, 56 insertions(+), 16 deletions(-) diff --git a/app/api/earthquake.py b/app/api/earthquake.py index 03801be..4f49070 100644 --- a/app/api/earthquake.py +++ b/app/api/earthquake.py @@ -29,12 +29,31 @@ def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]: source_id = r["source_id"] source_type = r.get("source_type") max_hazard = max(probs, key=probs.get) - # key 格式: {source_id}_{source_type},value 为百分比概率 key = f"{source_id}_{source_type}" result_map[key] = round(probs[max_hazard] * 100, 2) return result_map +def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: + """将模型原始结果转换为返回格式: {id_type: {probability, lon, lat}}""" + result_map = {} + for r in results: + probs = r.get("disaster_probabilities", {}) + if not probs: + continue + + source_id = r["source_id"] + source_type = r.get("source_type") + max_hazard = max(probs, key=probs.get) + key = f"{source_id}_{source_type}" + result_map[key] = { + "probability": round(probs[max_hazard] * 100, 2), + "lon": r.get("lon"), + "lat": r.get("lat") + } + return result_map + + def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]: """获取点位列表""" if point_ids: @@ -49,11 +68,11 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str], 同步执行地震预测(在线程池中运行) Returns: - (结果map,) + (存储用result_map, 返回用result_map_with_location) """ points = _fetch_points(point_ids, region_code) if not points: - return {} + return {}, {} model = get_earthquake_model() raw_results = model.predict_multiple_points( @@ -63,9 +82,10 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str], epicenter_lon=epicenter_lon, epicenter_lat=epicenter_lat, ) - result_map = _build_prediction_map(raw_results) + result_map = _build_prediction_map(raw_results) # 用于存储 + result_map_with_location = _build_prediction_map_with_location(raw_results) # 用于返回 - return result_map + return result_map, result_map_with_location @router.post("/predict", response_model=PredictResponse, summary="地震灾害链预测") @@ -88,7 +108,7 @@ async def predict_earthquake(req: EarthquakePredictRequest): async with semaphore: loop = asyncio.get_event_loop() try: - result_map = await loop.run_in_executor( + result_map, result_map_with_location = await loop.run_in_executor( None, _predict_sync, req.point_ids, req.region_code, req.magnitude, req.depth, req.epicenter_lon, req.epicenter_lat ) @@ -125,4 +145,4 @@ async def predict_earthquake(req: EarthquakePredictRequest): except Exception as e: logger.error(f"保存推理结果失败: {e}", exc_info=True) - return PredictResponse(code=200, message="success", data=PredictData(record_id=record_id, list=result_map)) + return PredictResponse(code=200, message="success", data=PredictData(record_id=record_id, list=result_map_with_location)) diff --git a/app/api/rainfall.py b/app/api/rainfall.py index 6d8a1ed..0f00df2 100644 --- a/app/api/rainfall.py +++ b/app/api/rainfall.py @@ -32,12 +32,31 @@ def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]: source_id = r["source_id"] source_type = r.get("source_type") max_hazard = max(probs, key=probs.get) - # key 格式: {source_id}_{source_type},value 为百分比概率 key = f"{source_id}_{source_type}" result_map[key] = round(probs[max_hazard] * 100, 2) return result_map +def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: + """将模型原始结果转换为返回格式: {id_type: {probability, lon, lat}}""" + result_map = {} + for r in results: + probs = r.get("disaster_probabilities", {}) + if not probs: + continue + + source_id = r["source_id"] + source_type = r.get("source_type") + max_hazard = max(probs, key=probs.get) + key = f"{source_id}_{source_type}" + result_map[key] = { + "probability": round(probs[max_hazard] * 100, 2), + "lon": r.get("lon"), + "lat": r.get("lat") + } + return result_map + + def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]: """获取点位列表""" if point_ids: @@ -55,18 +74,19 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str], occurred_time: 事件发生时间,用于查询降雨数据和DBN推理 Returns: - (结果map, 传入的条件, 实际使用的事件时间) + (存储用result_map, 返回用result_map_with_location, 传入的条件, 实际使用的事件时间) """ points = _fetch_points(point_ids, region_code) if not points: - return {}, {}, occurred_time or datetime.now() + return {}, {}, {}, occurred_time or datetime.now() # 使用传入的时间或当前时间作为查询时间 query_time = occurred_time or datetime.now() model = get_rainfall_model() raw_results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration, query_time=query_time) - result_map = _build_prediction_map(raw_results) + result_map = _build_prediction_map(raw_results) # 用于存储 + result_map_with_location = _build_prediction_map_with_location(raw_results) # 用于返回 # 存储传入的原始条件(降雨量和持续时间可能每个点不同,所以存储传入值) condition = { @@ -77,7 +97,7 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str], "occurred_time": query_time.isoformat() if hasattr(query_time, 'isoformat') else str(query_time) } - return result_map, condition, query_time + return result_map, result_map_with_location, condition, query_time @router.post("/update-monitoring-time", summary="更新降雨监测查询时间") @@ -125,7 +145,7 @@ async def predict_rainfall(req: RainfallPredictRequest): async with semaphore: loop = asyncio.get_event_loop() try: - result_map, condition, occurred_time = await loop.run_in_executor( + result_map, result_map_with_location, condition, occurred_time = await loop.run_in_executor( None, _predict_sync, req.point_ids, req.region_code, req.rainfall, req.duration, req.operation_type, req.occurred_time ) @@ -149,4 +169,4 @@ async def predict_rainfall(req: RainfallPredictRequest): except Exception as e: logger.error(f"保存推理结果失败: {e}", exc_info=True) - return PredictResponse(code=200, message="success", data=PredictData(record_id=record_id, list=result_map)) + return PredictResponse(code=200, message="success", data=PredictData(record_id=record_id, list=result_map_with_location)) diff --git a/app/schemas/api_schemas.py b/app/schemas/api_schemas.py index 08e8566..c4a8042 100644 --- a/app/schemas/api_schemas.py +++ b/app/schemas/api_schemas.py @@ -2,7 +2,7 @@ API 请求/响应数据模型 """ from datetime import datetime -from typing import List, Optional, Dict +from typing import List, Optional, Dict, Any from pydantic import BaseModel, Field @@ -51,7 +51,7 @@ class EarthquakePredictRequest(BaseModel): class PredictData(BaseModel): """预测数据""" record_id: Optional[int] = Field(None, description="推理结果记录ID") - list: Dict[str, float] = Field(default_factory=dict, description="预测结果列表") + list: Dict[str, Dict[str, Any]] = Field(default_factory=dict, description="预测结果列表,包含概率和经纬度") class PredictResponse(BaseModel):