修改API接口返回值——添加经纬度

This commit is contained in:
wzy-warehouse
2026-06-14 16:41:31 +08:00
parent b2b5c00f33
commit 75046c99c8
3 changed files with 56 additions and 16 deletions
+27 -7
View File
@@ -29,12 +29,31 @@ def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]:
source_id = r["source_id"]
source_type = r.get("source_type")
max_hazard = max(probs, key=probs.get)
# key 格式: {source_id}_{source_type}value 为百分比概率
key = f"{source_id}_{source_type}"
result_map[key] = round(probs[max_hazard] * 100, 2)
return result_map
def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
"""将模型原始结果转换为返回格式: {id_type: {probability, lon, lat}}"""
result_map = {}
for r in results:
probs = r.get("disaster_probabilities", {})
if not probs:
continue
source_id = r["source_id"]
source_type = r.get("source_type")
max_hazard = max(probs, key=probs.get)
key = f"{source_id}_{source_type}"
result_map[key] = {
"probability": round(probs[max_hazard] * 100, 2),
"lon": r.get("lon"),
"lat": r.get("lat")
}
return result_map
def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]:
"""获取点位列表"""
if point_ids:
@@ -49,11 +68,11 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
同步执行地震预测(在线程池中运行)
Returns:
(结果map,)
(存储用result_map, 返回用result_map_with_location)
"""
points = _fetch_points(point_ids, region_code)
if not points:
return {}
return {}, {}
model = get_earthquake_model()
raw_results = model.predict_multiple_points(
@@ -63,9 +82,10 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
epicenter_lon=epicenter_lon,
epicenter_lat=epicenter_lat,
)
result_map = _build_prediction_map(raw_results)
result_map = _build_prediction_map(raw_results) # 用于存储
result_map_with_location = _build_prediction_map_with_location(raw_results) # 用于返回
return result_map
return result_map, result_map_with_location
@router.post("/predict", response_model=PredictResponse, summary="地震灾害链预测")
@@ -88,7 +108,7 @@ async def predict_earthquake(req: EarthquakePredictRequest):
async with semaphore:
loop = asyncio.get_event_loop()
try:
result_map = await loop.run_in_executor(
result_map, result_map_with_location = await loop.run_in_executor(
None, _predict_sync, req.point_ids, req.region_code,
req.magnitude, req.depth, req.epicenter_lon, req.epicenter_lat
)
@@ -125,4 +145,4 @@ async def predict_earthquake(req: EarthquakePredictRequest):
except Exception as e:
logger.error(f"保存推理结果失败: {e}", exc_info=True)
return PredictResponse(code=200, message="success", data=PredictData(record_id=record_id, list=result_map))
return PredictResponse(code=200, message="success", data=PredictData(record_id=record_id, list=result_map_with_location))
+27 -7
View File
@@ -32,12 +32,31 @@ def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]:
source_id = r["source_id"]
source_type = r.get("source_type")
max_hazard = max(probs, key=probs.get)
# key 格式: {source_id}_{source_type}value 为百分比概率
key = f"{source_id}_{source_type}"
result_map[key] = round(probs[max_hazard] * 100, 2)
return result_map
def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
"""将模型原始结果转换为返回格式: {id_type: {probability, lon, lat}}"""
result_map = {}
for r in results:
probs = r.get("disaster_probabilities", {})
if not probs:
continue
source_id = r["source_id"]
source_type = r.get("source_type")
max_hazard = max(probs, key=probs.get)
key = f"{source_id}_{source_type}"
result_map[key] = {
"probability": round(probs[max_hazard] * 100, 2),
"lon": r.get("lon"),
"lat": r.get("lat")
}
return result_map
def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]:
"""获取点位列表"""
if point_ids:
@@ -55,18 +74,19 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
occurred_time: 事件发生时间,用于查询降雨数据和DBN推理
Returns:
(结果map, 传入的条件, 实际使用的事件时间)
(存储用result_map, 返回用result_map_with_location, 传入的条件, 实际使用的事件时间)
"""
points = _fetch_points(point_ids, region_code)
if not points:
return {}, {}, occurred_time or datetime.now()
return {}, {}, {}, occurred_time or datetime.now()
# 使用传入的时间或当前时间作为查询时间
query_time = occurred_time or datetime.now()
model = get_rainfall_model()
raw_results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration, query_time=query_time)
result_map = _build_prediction_map(raw_results)
result_map = _build_prediction_map(raw_results) # 用于存储
result_map_with_location = _build_prediction_map_with_location(raw_results) # 用于返回
# 存储传入的原始条件(降雨量和持续时间可能每个点不同,所以存储传入值)
condition = {
@@ -77,7 +97,7 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
"occurred_time": query_time.isoformat() if hasattr(query_time, 'isoformat') else str(query_time)
}
return result_map, condition, query_time
return result_map, result_map_with_location, condition, query_time
@router.post("/update-monitoring-time", summary="更新降雨监测查询时间")
@@ -125,7 +145,7 @@ async def predict_rainfall(req: RainfallPredictRequest):
async with semaphore:
loop = asyncio.get_event_loop()
try:
result_map, condition, occurred_time = await loop.run_in_executor(
result_map, result_map_with_location, condition, occurred_time = await loop.run_in_executor(
None, _predict_sync, req.point_ids, req.region_code,
req.rainfall, req.duration, req.operation_type, req.occurred_time
)
@@ -149,4 +169,4 @@ async def predict_rainfall(req: RainfallPredictRequest):
except Exception as e:
logger.error(f"保存推理结果失败: {e}", exc_info=True)
return PredictResponse(code=200, message="success", data=PredictData(record_id=record_id, list=result_map))
return PredictResponse(code=200, message="success", data=PredictData(record_id=record_id, list=result_map_with_location))