修改API接口返回值——添加经纬度

This commit is contained in:
wzy-warehouse
2026-06-14 16:41:31 +08:00
parent b2b5c00f33
commit 75046c99c8
3 changed files with 56 additions and 16 deletions
+27 -7
View File
@@ -29,12 +29,31 @@ def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]:
source_id = r["source_id"] source_id = r["source_id"]
source_type = r.get("source_type") source_type = r.get("source_type")
max_hazard = max(probs, key=probs.get) max_hazard = max(probs, key=probs.get)
# key 格式: {source_id}_{source_type}value 为百分比概率
key = f"{source_id}_{source_type}" key = f"{source_id}_{source_type}"
result_map[key] = round(probs[max_hazard] * 100, 2) result_map[key] = round(probs[max_hazard] * 100, 2)
return result_map return result_map
def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
"""将模型原始结果转换为返回格式: {id_type: {probability, lon, lat}}"""
result_map = {}
for r in results:
probs = r.get("disaster_probabilities", {})
if not probs:
continue
source_id = r["source_id"]
source_type = r.get("source_type")
max_hazard = max(probs, key=probs.get)
key = f"{source_id}_{source_type}"
result_map[key] = {
"probability": round(probs[max_hazard] * 100, 2),
"lon": r.get("lon"),
"lat": r.get("lat")
}
return result_map
def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]: def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]:
"""获取点位列表""" """获取点位列表"""
if point_ids: if point_ids:
@@ -49,11 +68,11 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
同步执行地震预测(在线程池中运行) 同步执行地震预测(在线程池中运行)
Returns: Returns:
(结果map,) (存储用result_map, 返回用result_map_with_location)
""" """
points = _fetch_points(point_ids, region_code) points = _fetch_points(point_ids, region_code)
if not points: if not points:
return {} return {}, {}
model = get_earthquake_model() model = get_earthquake_model()
raw_results = model.predict_multiple_points( raw_results = model.predict_multiple_points(
@@ -63,9 +82,10 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
epicenter_lon=epicenter_lon, epicenter_lon=epicenter_lon,
epicenter_lat=epicenter_lat, epicenter_lat=epicenter_lat,
) )
result_map = _build_prediction_map(raw_results) result_map = _build_prediction_map(raw_results) # 用于存储
result_map_with_location = _build_prediction_map_with_location(raw_results) # 用于返回
return result_map return result_map, result_map_with_location
@router.post("/predict", response_model=PredictResponse, summary="地震灾害链预测") @router.post("/predict", response_model=PredictResponse, summary="地震灾害链预测")
@@ -88,7 +108,7 @@ async def predict_earthquake(req: EarthquakePredictRequest):
async with semaphore: async with semaphore:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
try: try:
result_map = await loop.run_in_executor( result_map, result_map_with_location = await loop.run_in_executor(
None, _predict_sync, req.point_ids, req.region_code, None, _predict_sync, req.point_ids, req.region_code,
req.magnitude, req.depth, req.epicenter_lon, req.epicenter_lat req.magnitude, req.depth, req.epicenter_lon, req.epicenter_lat
) )
@@ -125,4 +145,4 @@ async def predict_earthquake(req: EarthquakePredictRequest):
except Exception as e: except Exception as e:
logger.error(f"保存推理结果失败: {e}", exc_info=True) logger.error(f"保存推理结果失败: {e}", exc_info=True)
return PredictResponse(code=200, message="success", data=PredictData(record_id=record_id, list=result_map)) return PredictResponse(code=200, message="success", data=PredictData(record_id=record_id, list=result_map_with_location))
+27 -7
View File
@@ -32,12 +32,31 @@ def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]:
source_id = r["source_id"] source_id = r["source_id"]
source_type = r.get("source_type") source_type = r.get("source_type")
max_hazard = max(probs, key=probs.get) max_hazard = max(probs, key=probs.get)
# key 格式: {source_id}_{source_type}value 为百分比概率
key = f"{source_id}_{source_type}" key = f"{source_id}_{source_type}"
result_map[key] = round(probs[max_hazard] * 100, 2) result_map[key] = round(probs[max_hazard] * 100, 2)
return result_map return result_map
def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
"""将模型原始结果转换为返回格式: {id_type: {probability, lon, lat}}"""
result_map = {}
for r in results:
probs = r.get("disaster_probabilities", {})
if not probs:
continue
source_id = r["source_id"]
source_type = r.get("source_type")
max_hazard = max(probs, key=probs.get)
key = f"{source_id}_{source_type}"
result_map[key] = {
"probability": round(probs[max_hazard] * 100, 2),
"lon": r.get("lon"),
"lat": r.get("lat")
}
return result_map
def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]: def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]:
"""获取点位列表""" """获取点位列表"""
if point_ids: if point_ids:
@@ -55,18 +74,19 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
occurred_time: 事件发生时间,用于查询降雨数据和DBN推理 occurred_time: 事件发生时间,用于查询降雨数据和DBN推理
Returns: Returns:
(结果map, 传入的条件, 实际使用的事件时间) (存储用result_map, 返回用result_map_with_location, 传入的条件, 实际使用的事件时间)
""" """
points = _fetch_points(point_ids, region_code) points = _fetch_points(point_ids, region_code)
if not points: if not points:
return {}, {}, occurred_time or datetime.now() return {}, {}, {}, occurred_time or datetime.now()
# 使用传入的时间或当前时间作为查询时间 # 使用传入的时间或当前时间作为查询时间
query_time = occurred_time or datetime.now() query_time = occurred_time or datetime.now()
model = get_rainfall_model() model = get_rainfall_model()
raw_results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration, query_time=query_time) raw_results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration, query_time=query_time)
result_map = _build_prediction_map(raw_results) result_map = _build_prediction_map(raw_results) # 用于存储
result_map_with_location = _build_prediction_map_with_location(raw_results) # 用于返回
# 存储传入的原始条件(降雨量和持续时间可能每个点不同,所以存储传入值) # 存储传入的原始条件(降雨量和持续时间可能每个点不同,所以存储传入值)
condition = { condition = {
@@ -77,7 +97,7 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
"occurred_time": query_time.isoformat() if hasattr(query_time, 'isoformat') else str(query_time) "occurred_time": query_time.isoformat() if hasattr(query_time, 'isoformat') else str(query_time)
} }
return result_map, condition, query_time return result_map, result_map_with_location, condition, query_time
@router.post("/update-monitoring-time", summary="更新降雨监测查询时间") @router.post("/update-monitoring-time", summary="更新降雨监测查询时间")
@@ -125,7 +145,7 @@ async def predict_rainfall(req: RainfallPredictRequest):
async with semaphore: async with semaphore:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
try: try:
result_map, condition, occurred_time = await loop.run_in_executor( result_map, result_map_with_location, condition, occurred_time = await loop.run_in_executor(
None, _predict_sync, req.point_ids, req.region_code, None, _predict_sync, req.point_ids, req.region_code,
req.rainfall, req.duration, req.operation_type, req.occurred_time req.rainfall, req.duration, req.operation_type, req.occurred_time
) )
@@ -149,4 +169,4 @@ async def predict_rainfall(req: RainfallPredictRequest):
except Exception as e: except Exception as e:
logger.error(f"保存推理结果失败: {e}", exc_info=True) logger.error(f"保存推理结果失败: {e}", exc_info=True)
return PredictResponse(code=200, message="success", data=PredictData(record_id=record_id, list=result_map)) return PredictResponse(code=200, message="success", data=PredictData(record_id=record_id, list=result_map_with_location))
+2 -2
View File
@@ -2,7 +2,7 @@
API 请求/响应数据模型 API 请求/响应数据模型
""" """
from datetime import datetime from datetime import datetime
from typing import List, Optional, Dict from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -51,7 +51,7 @@ class EarthquakePredictRequest(BaseModel):
class PredictData(BaseModel): class PredictData(BaseModel):
"""预测数据""" """预测数据"""
record_id: Optional[int] = Field(None, description="推理结果记录ID") record_id: Optional[int] = Field(None, description="推理结果记录ID")
list: Dict[str, float] = Field(default_factory=dict, description="预测结果列表") list: Dict[str, Dict[str, Any]] = Field(default_factory=dict, description="预测结果列表,包含概率和经纬度")
class PredictResponse(BaseModel): class PredictResponse(BaseModel):