修改接口返回值

This commit is contained in:
wzy-warehouse
2026-06-14 15:31:42 +08:00
parent 9a49764b35
commit 5e1dc585d4
3 changed files with 37 additions and 67 deletions
+18 -33
View File
@@ -18,24 +18,21 @@ SOURCE_TYPE_MAP = {1: "隐患点", 2: "风险点"}
LEVEL_MAP = {"": "", "": "", "较高": "较高", "": ""} LEVEL_MAP = {"": "", "": "", "较高": "较高", "": ""}
def _build_prediction_items(results: List[Dict[str, Any]]) -> List[PredictionItem]: def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]:
"""将模型原始结果转换为接口返回格式""" """将模型原始结果转换为存储格式: {id_type: 概率百分比}"""
items = [] result_map = {}
for r in results: for r in results:
probs = r.get("disaster_probabilities", {}) probs = r.get("disaster_probabilities", {})
levels = r.get("disaster_levels", {})
if not probs: if not probs:
continue continue
source_id = r["source_id"]
source_type = r.get("source_type")
max_hazard = max(probs, key=probs.get) max_hazard = max(probs, key=probs.get)
items.append(PredictionItem( # key 格式: {source_id}_{source_type}value 为百分比概率
id=r["source_id"], # 使用 source_id(隐患点/风险点ID)而非 xian_risk_factors.id key = f"{source_id}_{source_type}"
type=SOURCE_TYPE_MAP.get(r.get("source_type"), "未知"), result_map[key] = round(probs[max_hazard] * 100, 2)
probability=round(probs[max_hazard], 4), return result_map
level=LEVEL_MAP.get(levels.get(max_hazard, "none"), ""),
))
return items
def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]: def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]:
@@ -52,11 +49,11 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
同步执行地震预测(在线程池中运行) 同步执行地震预测(在线程池中运行)
Returns: Returns:
(预测结果列表, 原始结果) (结果map,)
""" """
points = _fetch_points(point_ids, region_code) points = _fetch_points(point_ids, region_code)
if not points: if not points:
return [], [] return {}
model = get_earthquake_model() model = get_earthquake_model()
raw_results = model.predict_multiple_points( raw_results = model.predict_multiple_points(
@@ -66,27 +63,15 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
epicenter_lon=epicenter_lon, epicenter_lon=epicenter_lon,
epicenter_lat=epicenter_lat, epicenter_lat=epicenter_lat,
) )
items = _build_prediction_items(raw_results) result_map = _build_prediction_map(raw_results)
save_results = [ return result_map
{
"point_id": r.get("source_id"), # 使用 source_id(隐患点/风险点ID)而非 xian_risk_factors.id
"source_type": r.get("source_type"),
"lon": r.get("lon"),
"lat": r.get("lat"),
"disaster_probabilities": r.get("disaster_probabilities", {}),
"disaster_levels": r.get("disaster_levels", {})
}
for r in raw_results
]
return items, save_results
@router.post("/predict", response_model=PredictResponse, summary="地震灾害链预测") @router.post("/predict", response_model=PredictResponse, summary="地震灾害链预测")
async def predict_earthquake(req: EarthquakePredictRequest): async def predict_earthquake(req: EarthquakePredictRequest):
""" """
根据震级、震源深度和震中位置,批量预测隐患点/风险点的次生灾害概率和等级 根据震级、震源深度和震中位置,批量预测隐患点/风险点的次生灾害概率。
- **disaster_name**: 灾害名称 - **disaster_name**: 灾害名称
- **point_ids**: 点位ID列表(可选,不传则查询所有点) - **point_ids**: 点位ID列表(可选,不传则查询所有点)
@@ -103,7 +88,7 @@ async def predict_earthquake(req: EarthquakePredictRequest):
async with semaphore: async with semaphore:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
try: try:
items, save_results = await loop.run_in_executor( result_map = await loop.run_in_executor(
None, _predict_sync, req.point_ids, req.region_code, None, _predict_sync, req.point_ids, req.region_code,
req.magnitude, req.depth, req.epicenter_lon, req.epicenter_lat req.magnitude, req.depth, req.epicenter_lon, req.epicenter_lat
) )
@@ -113,7 +98,7 @@ async def predict_earthquake(req: EarthquakePredictRequest):
# 保存推理结果 # 保存推理结果
record_id = None record_id = None
if save_results: if result_map:
try: try:
condition = { condition = {
"point_ids": req.point_ids, "point_ids": req.point_ids,
@@ -129,10 +114,10 @@ async def predict_earthquake(req: EarthquakePredictRequest):
occurred_time=req.occurred_time, occurred_time=req.occurred_time,
operation_type=req.operation_type, operation_type=req.operation_type,
condition=condition, condition=condition,
result=save_results result=result_map
) )
logger.info(f"推理结果已保存,record_id={record_id}") logger.info(f"推理结果已保存,record_id={record_id}")
except Exception as e: except Exception as e:
logger.error(f"保存推理结果失败: {e}", exc_info=True) logger.error(f"保存推理结果失败: {e}", exc_info=True)
return PredictResponse(code=200, message="success", data=items, record_id=record_id) return PredictResponse(code=200, message="success", data=record_id)
+18 -32
View File
@@ -21,24 +21,21 @@ SOURCE_TYPE_MAP = {1: "隐患点", 2: "风险点"}
LEVEL_MAP = {"": "", "": "", "较高": "较高", "": ""} LEVEL_MAP = {"": "", "": "", "较高": "较高", "": ""}
def _build_prediction_items(results: List[Dict[str, Any]]) -> List[PredictionItem]: def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]:
"""将模型原始结果转换为接口返回格式""" """将模型原始结果转换为存储格式: {id_type: 概率百分比}"""
items = [] result_map = {}
for r in results: for r in results:
probs = r.get("disaster_probabilities", {}) probs = r.get("disaster_probabilities", {})
levels = r.get("disaster_levels", {})
if not probs: if not probs:
continue continue
source_id = r["source_id"]
source_type = r.get("source_type")
max_hazard = max(probs, key=probs.get) max_hazard = max(probs, key=probs.get)
items.append(PredictionItem( # key 格式: {source_id}_{source_type}value 为百分比概率
id=r["source_id"], # 使用 source_id(隐患点/风险点ID)而非 xian_risk_factors.id key = f"{source_id}_{source_type}"
type=SOURCE_TYPE_MAP.get(r.get("source_type"), "未知"), result_map[key] = round(probs[max_hazard] * 100, 2)
probability=round(probs[max_hazard], 4), return result_map
level=LEVEL_MAP.get(levels.get(max_hazard, "none"), ""),
))
return items
def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]: def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]:
@@ -55,15 +52,15 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
同步执行暴雨预测(在线程池中运行) 同步执行暴雨预测(在线程池中运行)
Returns: Returns:
(预测结果列表, 原始结果, 输入条件, 当前时间) (结果map, 输入条件, 当前时间)
""" """
points = _fetch_points(point_ids, region_code) points = _fetch_points(point_ids, region_code)
if not points: if not points:
return [], [], {}, datetime.now() return {}, {}, datetime.now()
model = get_rainfall_model() model = get_rainfall_model()
raw_results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration) raw_results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration)
items = _build_prediction_items(raw_results) result_map = _build_prediction_map(raw_results)
# 构建条件和结果用于保存 # 构建条件和结果用于保存
now = datetime.now() now = datetime.now()
@@ -73,19 +70,8 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
"rainfall": rainfall, "rainfall": rainfall,
"duration": duration "duration": duration
} }
save_results = [
{
"point_id": r.get("source_id"), # 使用 source_id(隐患点/风险点ID)而非 xian_risk_factors.id
"source_type": r.get("source_type"),
"lon": r.get("lon"),
"lat": r.get("lat"),
"disaster_probabilities": r.get("disaster_probabilities", {}),
"disaster_levels": r.get("disaster_levels", {})
}
for r in raw_results
]
return items, save_results, condition, now return result_map, condition, now
@router.post("/update-monitoring-time", summary="更新降雨监测查询时间") @router.post("/update-monitoring-time", summary="更新降雨监测查询时间")
@@ -119,7 +105,7 @@ async def update_monitoring_time(req: UpdateMonitoringTimeRequest):
@router.post("/predict", response_model=PredictResponse, summary="暴雨灾害链预测") @router.post("/predict", response_model=PredictResponse, summary="暴雨灾害链预测")
async def predict_rainfall(req: RainfallPredictRequest): async def predict_rainfall(req: RainfallPredictRequest):
""" """
根据降雨量和持续时间,批量预测隐患点/风险点的灾害概率和等级 根据降雨量和持续时间,批量预测隐患点/风险点的灾害概率。
- **disaster_name**: 灾害名称 - **disaster_name**: 灾害名称
- **point_ids**: 点位ID列表(可选,不传则查询所有点) - **point_ids**: 点位ID列表(可选,不传则查询所有点)
@@ -133,7 +119,7 @@ async def predict_rainfall(req: RainfallPredictRequest):
async with semaphore: async with semaphore:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
try: try:
items, save_results, condition, now = await loop.run_in_executor( result_map, condition, now = await loop.run_in_executor(
None, _predict_sync, req.point_ids, req.region_code, None, _predict_sync, req.point_ids, req.region_code,
req.rainfall, req.duration, req.operation_type req.rainfall, req.duration, req.operation_type
) )
@@ -143,7 +129,7 @@ async def predict_rainfall(req: RainfallPredictRequest):
# 保存推理结果 # 保存推理结果
record_id = None record_id = None
if save_results: if result_map:
try: try:
record_id = dbn_repository.save_inference_result( record_id = dbn_repository.save_inference_result(
disaster_name=req.disaster_name, disaster_name=req.disaster_name,
@@ -151,10 +137,10 @@ async def predict_rainfall(req: RainfallPredictRequest):
occurred_time=now, occurred_time=now,
operation_type=req.operation_type, operation_type=req.operation_type,
condition=condition, condition=condition,
result=save_results result=result_map
) )
logger.info(f"推理结果已保存,record_id={record_id}") logger.info(f"推理结果已保存,record_id={record_id}")
except Exception as e: except Exception as e:
logger.error(f"保存推理结果失败: {e}", exc_info=True) logger.error(f"保存推理结果失败: {e}", exc_info=True)
return PredictResponse(code=200, message="success", data=items, record_id=record_id) return PredictResponse(code=200, message="success", data=record_id)
+1 -2
View File
@@ -59,8 +59,7 @@ class PredictResponse(BaseModel):
"""预测响应""" """预测响应"""
code: int = Field(200, description="状态码") code: int = Field(200, description="状态码")
message: str = Field("success", description="提示信息") message: str = Field("success", description="提示信息")
data: List[PredictionItem] = Field(default_factory=list, description="预测结果列表") data: Optional[int] = Field(None, description="推理结果记录ID")
record_id: Optional[int] = Field(None, description="推理结果记录ID")
class UpdateMonitoringTimeRequest(BaseModel): class UpdateMonitoringTimeRequest(BaseModel):