修改API接口返回值——添加经纬度
This commit is contained in:
+27
-7
@@ -29,12 +29,31 @@ def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]:
|
||||
source_id = r["source_id"]
|
||||
source_type = r.get("source_type")
|
||||
max_hazard = max(probs, key=probs.get)
|
||||
# key 格式: {source_id}_{source_type},value 为百分比概率
|
||||
key = f"{source_id}_{source_type}"
|
||||
result_map[key] = round(probs[max_hazard] * 100, 2)
|
||||
return result_map
|
||||
|
||||
|
||||
def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
||||
"""将模型原始结果转换为返回格式: {id_type: {probability, lon, lat}}"""
|
||||
result_map = {}
|
||||
for r in results:
|
||||
probs = r.get("disaster_probabilities", {})
|
||||
if not probs:
|
||||
continue
|
||||
|
||||
source_id = r["source_id"]
|
||||
source_type = r.get("source_type")
|
||||
max_hazard = max(probs, key=probs.get)
|
||||
key = f"{source_id}_{source_type}"
|
||||
result_map[key] = {
|
||||
"probability": round(probs[max_hazard] * 100, 2),
|
||||
"lon": r.get("lon"),
|
||||
"lat": r.get("lat")
|
||||
}
|
||||
return result_map
|
||||
|
||||
|
||||
def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]:
|
||||
"""获取点位列表"""
|
||||
if point_ids:
|
||||
@@ -49,11 +68,11 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
||||
同步执行地震预测(在线程池中运行)
|
||||
|
||||
Returns:
|
||||
(结果map,)
|
||||
(存储用result_map, 返回用result_map_with_location)
|
||||
"""
|
||||
points = _fetch_points(point_ids, region_code)
|
||||
if not points:
|
||||
return {}
|
||||
return {}, {}
|
||||
|
||||
model = get_earthquake_model()
|
||||
raw_results = model.predict_multiple_points(
|
||||
@@ -63,9 +82,10 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
||||
epicenter_lon=epicenter_lon,
|
||||
epicenter_lat=epicenter_lat,
|
||||
)
|
||||
result_map = _build_prediction_map(raw_results)
|
||||
result_map = _build_prediction_map(raw_results) # 用于存储
|
||||
result_map_with_location = _build_prediction_map_with_location(raw_results) # 用于返回
|
||||
|
||||
return result_map
|
||||
return result_map, result_map_with_location
|
||||
|
||||
|
||||
@router.post("/predict", response_model=PredictResponse, summary="地震灾害链预测")
|
||||
@@ -88,7 +108,7 @@ async def predict_earthquake(req: EarthquakePredictRequest):
|
||||
async with semaphore:
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
result_map = await loop.run_in_executor(
|
||||
result_map, result_map_with_location = await loop.run_in_executor(
|
||||
None, _predict_sync, req.point_ids, req.region_code,
|
||||
req.magnitude, req.depth, req.epicenter_lon, req.epicenter_lat
|
||||
)
|
||||
@@ -125,4 +145,4 @@ async def predict_earthquake(req: EarthquakePredictRequest):
|
||||
except Exception as e:
|
||||
logger.error(f"保存推理结果失败: {e}", exc_info=True)
|
||||
|
||||
return PredictResponse(code=200, message="success", data=PredictData(record_id=record_id, list=result_map))
|
||||
return PredictResponse(code=200, message="success", data=PredictData(record_id=record_id, list=result_map_with_location))
|
||||
|
||||
+27
-7
@@ -32,12 +32,31 @@ def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]:
|
||||
source_id = r["source_id"]
|
||||
source_type = r.get("source_type")
|
||||
max_hazard = max(probs, key=probs.get)
|
||||
# key 格式: {source_id}_{source_type},value 为百分比概率
|
||||
key = f"{source_id}_{source_type}"
|
||||
result_map[key] = round(probs[max_hazard] * 100, 2)
|
||||
return result_map
|
||||
|
||||
|
||||
def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
||||
"""将模型原始结果转换为返回格式: {id_type: {probability, lon, lat}}"""
|
||||
result_map = {}
|
||||
for r in results:
|
||||
probs = r.get("disaster_probabilities", {})
|
||||
if not probs:
|
||||
continue
|
||||
|
||||
source_id = r["source_id"]
|
||||
source_type = r.get("source_type")
|
||||
max_hazard = max(probs, key=probs.get)
|
||||
key = f"{source_id}_{source_type}"
|
||||
result_map[key] = {
|
||||
"probability": round(probs[max_hazard] * 100, 2),
|
||||
"lon": r.get("lon"),
|
||||
"lat": r.get("lat")
|
||||
}
|
||||
return result_map
|
||||
|
||||
|
||||
def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]:
|
||||
"""获取点位列表"""
|
||||
if point_ids:
|
||||
@@ -55,18 +74,19 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
||||
occurred_time: 事件发生时间,用于查询降雨数据和DBN推理
|
||||
|
||||
Returns:
|
||||
(结果map, 传入的条件, 实际使用的事件时间)
|
||||
(存储用result_map, 返回用result_map_with_location, 传入的条件, 实际使用的事件时间)
|
||||
"""
|
||||
points = _fetch_points(point_ids, region_code)
|
||||
if not points:
|
||||
return {}, {}, occurred_time or datetime.now()
|
||||
return {}, {}, {}, occurred_time or datetime.now()
|
||||
|
||||
# 使用传入的时间或当前时间作为查询时间
|
||||
query_time = occurred_time or datetime.now()
|
||||
|
||||
model = get_rainfall_model()
|
||||
raw_results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration, query_time=query_time)
|
||||
result_map = _build_prediction_map(raw_results)
|
||||
result_map = _build_prediction_map(raw_results) # 用于存储
|
||||
result_map_with_location = _build_prediction_map_with_location(raw_results) # 用于返回
|
||||
|
||||
# 存储传入的原始条件(降雨量和持续时间可能每个点不同,所以存储传入值)
|
||||
condition = {
|
||||
@@ -77,7 +97,7 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
||||
"occurred_time": query_time.isoformat() if hasattr(query_time, 'isoformat') else str(query_time)
|
||||
}
|
||||
|
||||
return result_map, condition, query_time
|
||||
return result_map, result_map_with_location, condition, query_time
|
||||
|
||||
|
||||
@router.post("/update-monitoring-time", summary="更新降雨监测查询时间")
|
||||
@@ -125,7 +145,7 @@ async def predict_rainfall(req: RainfallPredictRequest):
|
||||
async with semaphore:
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
result_map, condition, occurred_time = await loop.run_in_executor(
|
||||
result_map, result_map_with_location, condition, occurred_time = await loop.run_in_executor(
|
||||
None, _predict_sync, req.point_ids, req.region_code,
|
||||
req.rainfall, req.duration, req.operation_type, req.occurred_time
|
||||
)
|
||||
@@ -149,4 +169,4 @@ async def predict_rainfall(req: RainfallPredictRequest):
|
||||
except Exception as e:
|
||||
logger.error(f"保存推理结果失败: {e}", exc_info=True)
|
||||
|
||||
return PredictResponse(code=200, message="success", data=PredictData(record_id=record_id, list=result_map))
|
||||
return PredictResponse(code=200, message="success", data=PredictData(record_id=record_id, list=result_map_with_location))
|
||||
|
||||
Reference in New Issue
Block a user