修改API接口返回值——添加经纬度

This commit is contained in:
wzy-warehouse
2026-06-14 16:41:31 +08:00
parent b2b5c00f33
commit 75046c99c8
3 changed files with 56 additions and 16 deletions
+27 -7
View File
@@ -29,12 +29,31 @@ def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]:
source_id = r["source_id"]
source_type = r.get("source_type")
max_hazard = max(probs, key=probs.get)
# key 格式: {source_id}_{source_type}value 为百分比概率
key = f"{source_id}_{source_type}"
result_map[key] = round(probs[max_hazard] * 100, 2)
return result_map
def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
"""将模型原始结果转换为返回格式: {id_type: {probability, lon, lat}}"""
result_map = {}
for r in results:
probs = r.get("disaster_probabilities", {})
if not probs:
continue
source_id = r["source_id"]
source_type = r.get("source_type")
max_hazard = max(probs, key=probs.get)
key = f"{source_id}_{source_type}"
result_map[key] = {
"probability": round(probs[max_hazard] * 100, 2),
"lon": r.get("lon"),
"lat": r.get("lat")
}
return result_map
def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]:
"""获取点位列表"""
if point_ids:
@@ -49,11 +68,11 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
同步执行地震预测(在线程池中运行)
Returns:
(结果map,)
(存储用result_map, 返回用result_map_with_location)
"""
points = _fetch_points(point_ids, region_code)
if not points:
return {}
return {}, {}
model = get_earthquake_model()
raw_results = model.predict_multiple_points(
@@ -63,9 +82,10 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
epicenter_lon=epicenter_lon,
epicenter_lat=epicenter_lat,
)
result_map = _build_prediction_map(raw_results)
result_map = _build_prediction_map(raw_results) # 用于存储
result_map_with_location = _build_prediction_map_with_location(raw_results) # 用于返回
return result_map
return result_map, result_map_with_location
@router.post("/predict", response_model=PredictResponse, summary="地震灾害链预测")
@@ -88,7 +108,7 @@ async def predict_earthquake(req: EarthquakePredictRequest):
async with semaphore:
loop = asyncio.get_event_loop()
try:
result_map = await loop.run_in_executor(
result_map, result_map_with_location = await loop.run_in_executor(
None, _predict_sync, req.point_ids, req.region_code,
req.magnitude, req.depth, req.epicenter_lon, req.epicenter_lat
)
@@ -125,4 +145,4 @@ async def predict_earthquake(req: EarthquakePredictRequest):
except Exception as e:
logger.error(f"保存推理结果失败: {e}", exc_info=True)
return PredictResponse(code=200, message="success", data=PredictData(record_id=record_id, list=result_map))
return PredictResponse(code=200, message="success", data=PredictData(record_id=record_id, list=result_map_with_location))