修改API接口返回值——添加经纬度
This commit is contained in:
+27
-7
@@ -29,12 +29,31 @@ def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]:
|
|||||||
source_id = r["source_id"]
|
source_id = r["source_id"]
|
||||||
source_type = r.get("source_type")
|
source_type = r.get("source_type")
|
||||||
max_hazard = max(probs, key=probs.get)
|
max_hazard = max(probs, key=probs.get)
|
||||||
# key 格式: {source_id}_{source_type},value 为百分比概率
|
|
||||||
key = f"{source_id}_{source_type}"
|
key = f"{source_id}_{source_type}"
|
||||||
result_map[key] = round(probs[max_hazard] * 100, 2)
|
result_map[key] = round(probs[max_hazard] * 100, 2)
|
||||||
return result_map
|
return result_map
|
||||||
|
|
||||||
|
|
||||||
|
def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""将模型原始结果转换为返回格式: {id_type: {probability, lon, lat}}"""
|
||||||
|
result_map = {}
|
||||||
|
for r in results:
|
||||||
|
probs = r.get("disaster_probabilities", {})
|
||||||
|
if not probs:
|
||||||
|
continue
|
||||||
|
|
||||||
|
source_id = r["source_id"]
|
||||||
|
source_type = r.get("source_type")
|
||||||
|
max_hazard = max(probs, key=probs.get)
|
||||||
|
key = f"{source_id}_{source_type}"
|
||||||
|
result_map[key] = {
|
||||||
|
"probability": round(probs[max_hazard] * 100, 2),
|
||||||
|
"lon": r.get("lon"),
|
||||||
|
"lat": r.get("lat")
|
||||||
|
}
|
||||||
|
return result_map
|
||||||
|
|
||||||
|
|
||||||
def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]:
|
def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]:
|
||||||
"""获取点位列表"""
|
"""获取点位列表"""
|
||||||
if point_ids:
|
if point_ids:
|
||||||
@@ -49,11 +68,11 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
|||||||
同步执行地震预测(在线程池中运行)
|
同步执行地震预测(在线程池中运行)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(结果map,)
|
(存储用result_map, 返回用result_map_with_location)
|
||||||
"""
|
"""
|
||||||
points = _fetch_points(point_ids, region_code)
|
points = _fetch_points(point_ids, region_code)
|
||||||
if not points:
|
if not points:
|
||||||
return {}
|
return {}, {}
|
||||||
|
|
||||||
model = get_earthquake_model()
|
model = get_earthquake_model()
|
||||||
raw_results = model.predict_multiple_points(
|
raw_results = model.predict_multiple_points(
|
||||||
@@ -63,9 +82,10 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
|||||||
epicenter_lon=epicenter_lon,
|
epicenter_lon=epicenter_lon,
|
||||||
epicenter_lat=epicenter_lat,
|
epicenter_lat=epicenter_lat,
|
||||||
)
|
)
|
||||||
result_map = _build_prediction_map(raw_results)
|
result_map = _build_prediction_map(raw_results) # 用于存储
|
||||||
|
result_map_with_location = _build_prediction_map_with_location(raw_results) # 用于返回
|
||||||
|
|
||||||
return result_map
|
return result_map, result_map_with_location
|
||||||
|
|
||||||
|
|
||||||
@router.post("/predict", response_model=PredictResponse, summary="地震灾害链预测")
|
@router.post("/predict", response_model=PredictResponse, summary="地震灾害链预测")
|
||||||
@@ -88,7 +108,7 @@ async def predict_earthquake(req: EarthquakePredictRequest):
|
|||||||
async with semaphore:
|
async with semaphore:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
try:
|
try:
|
||||||
result_map = await loop.run_in_executor(
|
result_map, result_map_with_location = await loop.run_in_executor(
|
||||||
None, _predict_sync, req.point_ids, req.region_code,
|
None, _predict_sync, req.point_ids, req.region_code,
|
||||||
req.magnitude, req.depth, req.epicenter_lon, req.epicenter_lat
|
req.magnitude, req.depth, req.epicenter_lon, req.epicenter_lat
|
||||||
)
|
)
|
||||||
@@ -125,4 +145,4 @@ async def predict_earthquake(req: EarthquakePredictRequest):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存推理结果失败: {e}", exc_info=True)
|
logger.error(f"保存推理结果失败: {e}", exc_info=True)
|
||||||
|
|
||||||
return PredictResponse(code=200, message="success", data=PredictData(record_id=record_id, list=result_map))
|
return PredictResponse(code=200, message="success", data=PredictData(record_id=record_id, list=result_map_with_location))
|
||||||
|
|||||||
+27
-7
@@ -32,12 +32,31 @@ def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]:
|
|||||||
source_id = r["source_id"]
|
source_id = r["source_id"]
|
||||||
source_type = r.get("source_type")
|
source_type = r.get("source_type")
|
||||||
max_hazard = max(probs, key=probs.get)
|
max_hazard = max(probs, key=probs.get)
|
||||||
# key 格式: {source_id}_{source_type},value 为百分比概率
|
|
||||||
key = f"{source_id}_{source_type}"
|
key = f"{source_id}_{source_type}"
|
||||||
result_map[key] = round(probs[max_hazard] * 100, 2)
|
result_map[key] = round(probs[max_hazard] * 100, 2)
|
||||||
return result_map
|
return result_map
|
||||||
|
|
||||||
|
|
||||||
|
def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""将模型原始结果转换为返回格式: {id_type: {probability, lon, lat}}"""
|
||||||
|
result_map = {}
|
||||||
|
for r in results:
|
||||||
|
probs = r.get("disaster_probabilities", {})
|
||||||
|
if not probs:
|
||||||
|
continue
|
||||||
|
|
||||||
|
source_id = r["source_id"]
|
||||||
|
source_type = r.get("source_type")
|
||||||
|
max_hazard = max(probs, key=probs.get)
|
||||||
|
key = f"{source_id}_{source_type}"
|
||||||
|
result_map[key] = {
|
||||||
|
"probability": round(probs[max_hazard] * 100, 2),
|
||||||
|
"lon": r.get("lon"),
|
||||||
|
"lat": r.get("lat")
|
||||||
|
}
|
||||||
|
return result_map
|
||||||
|
|
||||||
|
|
||||||
def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]:
|
def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]:
|
||||||
"""获取点位列表"""
|
"""获取点位列表"""
|
||||||
if point_ids:
|
if point_ids:
|
||||||
@@ -55,18 +74,19 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
|||||||
occurred_time: 事件发生时间,用于查询降雨数据和DBN推理
|
occurred_time: 事件发生时间,用于查询降雨数据和DBN推理
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(结果map, 传入的条件, 实际使用的事件时间)
|
(存储用result_map, 返回用result_map_with_location, 传入的条件, 实际使用的事件时间)
|
||||||
"""
|
"""
|
||||||
points = _fetch_points(point_ids, region_code)
|
points = _fetch_points(point_ids, region_code)
|
||||||
if not points:
|
if not points:
|
||||||
return {}, {}, occurred_time or datetime.now()
|
return {}, {}, {}, occurred_time or datetime.now()
|
||||||
|
|
||||||
# 使用传入的时间或当前时间作为查询时间
|
# 使用传入的时间或当前时间作为查询时间
|
||||||
query_time = occurred_time or datetime.now()
|
query_time = occurred_time or datetime.now()
|
||||||
|
|
||||||
model = get_rainfall_model()
|
model = get_rainfall_model()
|
||||||
raw_results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration, query_time=query_time)
|
raw_results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration, query_time=query_time)
|
||||||
result_map = _build_prediction_map(raw_results)
|
result_map = _build_prediction_map(raw_results) # 用于存储
|
||||||
|
result_map_with_location = _build_prediction_map_with_location(raw_results) # 用于返回
|
||||||
|
|
||||||
# 存储传入的原始条件(降雨量和持续时间可能每个点不同,所以存储传入值)
|
# 存储传入的原始条件(降雨量和持续时间可能每个点不同,所以存储传入值)
|
||||||
condition = {
|
condition = {
|
||||||
@@ -77,7 +97,7 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
|||||||
"occurred_time": query_time.isoformat() if hasattr(query_time, 'isoformat') else str(query_time)
|
"occurred_time": query_time.isoformat() if hasattr(query_time, 'isoformat') else str(query_time)
|
||||||
}
|
}
|
||||||
|
|
||||||
return result_map, condition, query_time
|
return result_map, result_map_with_location, condition, query_time
|
||||||
|
|
||||||
|
|
||||||
@router.post("/update-monitoring-time", summary="更新降雨监测查询时间")
|
@router.post("/update-monitoring-time", summary="更新降雨监测查询时间")
|
||||||
@@ -125,7 +145,7 @@ async def predict_rainfall(req: RainfallPredictRequest):
|
|||||||
async with semaphore:
|
async with semaphore:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
try:
|
try:
|
||||||
result_map, condition, occurred_time = await loop.run_in_executor(
|
result_map, result_map_with_location, condition, occurred_time = await loop.run_in_executor(
|
||||||
None, _predict_sync, req.point_ids, req.region_code,
|
None, _predict_sync, req.point_ids, req.region_code,
|
||||||
req.rainfall, req.duration, req.operation_type, req.occurred_time
|
req.rainfall, req.duration, req.operation_type, req.occurred_time
|
||||||
)
|
)
|
||||||
@@ -149,4 +169,4 @@ async def predict_rainfall(req: RainfallPredictRequest):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存推理结果失败: {e}", exc_info=True)
|
logger.error(f"保存推理结果失败: {e}", exc_info=True)
|
||||||
|
|
||||||
return PredictResponse(code=200, message="success", data=PredictData(record_id=record_id, list=result_map))
|
return PredictResponse(code=200, message="success", data=PredictData(record_id=record_id, list=result_map_with_location))
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
API 请求/响应数据模型
|
API 请求/响应数据模型
|
||||||
"""
|
"""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Optional, Dict
|
from typing import List, Optional, Dict, Any
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
@@ -51,7 +51,7 @@ class EarthquakePredictRequest(BaseModel):
|
|||||||
class PredictData(BaseModel):
|
class PredictData(BaseModel):
|
||||||
"""预测数据"""
|
"""预测数据"""
|
||||||
record_id: Optional[int] = Field(None, description="推理结果记录ID")
|
record_id: Optional[int] = Field(None, description="推理结果记录ID")
|
||||||
list: Dict[str, float] = Field(default_factory=dict, description="预测结果列表")
|
list: Dict[str, Dict[str, Any]] = Field(default_factory=dict, description="预测结果列表,包含概率和经纬度")
|
||||||
|
|
||||||
|
|
||||||
class PredictResponse(BaseModel):
|
class PredictResponse(BaseModel):
|
||||||
|
|||||||
Reference in New Issue
Block a user