修改API接口返回值——添加经纬度
This commit is contained in:
+27
-7
@@ -29,12 +29,31 @@ def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]:
|
||||
source_id = r["source_id"]
|
||||
source_type = r.get("source_type")
|
||||
max_hazard = max(probs, key=probs.get)
|
||||
# key 格式: {source_id}_{source_type},value 为百分比概率
|
||||
key = f"{source_id}_{source_type}"
|
||||
result_map[key] = round(probs[max_hazard] * 100, 2)
|
||||
return result_map
|
||||
|
||||
|
||||
def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
||||
"""将模型原始结果转换为返回格式: {id_type: {probability, lon, lat}}"""
|
||||
result_map = {}
|
||||
for r in results:
|
||||
probs = r.get("disaster_probabilities", {})
|
||||
if not probs:
|
||||
continue
|
||||
|
||||
source_id = r["source_id"]
|
||||
source_type = r.get("source_type")
|
||||
max_hazard = max(probs, key=probs.get)
|
||||
key = f"{source_id}_{source_type}"
|
||||
result_map[key] = {
|
||||
"probability": round(probs[max_hazard] * 100, 2),
|
||||
"lon": r.get("lon"),
|
||||
"lat": r.get("lat")
|
||||
}
|
||||
return result_map
|
||||
|
||||
|
||||
def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]:
|
||||
"""获取点位列表"""
|
||||
if point_ids:
|
||||
@@ -49,11 +68,11 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
||||
同步执行地震预测(在线程池中运行)
|
||||
|
||||
Returns:
|
||||
(结果map,)
|
||||
(存储用result_map, 返回用result_map_with_location)
|
||||
"""
|
||||
points = _fetch_points(point_ids, region_code)
|
||||
if not points:
|
||||
return {}
|
||||
return {}, {}
|
||||
|
||||
model = get_earthquake_model()
|
||||
raw_results = model.predict_multiple_points(
|
||||
@@ -63,9 +82,10 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
||||
epicenter_lon=epicenter_lon,
|
||||
epicenter_lat=epicenter_lat,
|
||||
)
|
||||
result_map = _build_prediction_map(raw_results)
|
||||
result_map = _build_prediction_map(raw_results) # 用于存储
|
||||
result_map_with_location = _build_prediction_map_with_location(raw_results) # 用于返回
|
||||
|
||||
return result_map
|
||||
return result_map, result_map_with_location
|
||||
|
||||
|
||||
@router.post("/predict", response_model=PredictResponse, summary="地震灾害链预测")
|
||||
@@ -88,7 +108,7 @@ async def predict_earthquake(req: EarthquakePredictRequest):
|
||||
async with semaphore:
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
result_map = await loop.run_in_executor(
|
||||
result_map, result_map_with_location = await loop.run_in_executor(
|
||||
None, _predict_sync, req.point_ids, req.region_code,
|
||||
req.magnitude, req.depth, req.epicenter_lon, req.epicenter_lat
|
||||
)
|
||||
@@ -125,4 +145,4 @@ async def predict_earthquake(req: EarthquakePredictRequest):
|
||||
except Exception as e:
|
||||
logger.error(f"保存推理结果失败: {e}", exc_info=True)
|
||||
|
||||
return PredictResponse(code=200, message="success", data=PredictData(record_id=record_id, list=result_map))
|
||||
return PredictResponse(code=200, message="success", data=PredictData(record_id=record_id, list=result_map_with_location))
|
||||
|
||||
Reference in New Issue
Block a user