修改接口返回值

This commit is contained in:
wzy-warehouse
2026-06-14 15:31:42 +08:00
parent 9a49764b35
commit 5e1dc585d4
3 changed files with 37 additions and 67 deletions
+18 -33
View File
@@ -18,24 +18,21 @@ SOURCE_TYPE_MAP = {1: "隐患点", 2: "风险点"}
LEVEL_MAP = {"": "", "": "", "较高": "较高", "": ""}
def _build_prediction_items(results: List[Dict[str, Any]]) -> List[PredictionItem]:
"""将模型原始结果转换为接口返回格式"""
items = []
def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]:
"""将模型原始结果转换为存储格式: {id_type: 概率百分比}"""
result_map = {}
for r in results:
probs = r.get("disaster_probabilities", {})
levels = r.get("disaster_levels", {})
if not probs:
continue
source_id = r["source_id"]
source_type = r.get("source_type")
max_hazard = max(probs, key=probs.get)
items.append(PredictionItem(
id=r["source_id"], # 使用 source_id(隐患点/风险点ID)而非 xian_risk_factors.id
type=SOURCE_TYPE_MAP.get(r.get("source_type"), "未知"),
probability=round(probs[max_hazard], 4),
level=LEVEL_MAP.get(levels.get(max_hazard, "none"), ""),
))
return items
# key 格式: {source_id}_{source_type}value 为百分比概率
key = f"{source_id}_{source_type}"
result_map[key] = round(probs[max_hazard] * 100, 2)
return result_map
def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]:
@@ -52,11 +49,11 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
同步执行地震预测(在线程池中运行)
Returns:
(预测结果列表, 原始结果)
(结果map,)
"""
points = _fetch_points(point_ids, region_code)
if not points:
return [], []
return {}
model = get_earthquake_model()
raw_results = model.predict_multiple_points(
@@ -66,27 +63,15 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
epicenter_lon=epicenter_lon,
epicenter_lat=epicenter_lat,
)
items = _build_prediction_items(raw_results)
result_map = _build_prediction_map(raw_results)
save_results = [
{
"point_id": r.get("source_id"), # 使用 source_id(隐患点/风险点ID)而非 xian_risk_factors.id
"source_type": r.get("source_type"),
"lon": r.get("lon"),
"lat": r.get("lat"),
"disaster_probabilities": r.get("disaster_probabilities", {}),
"disaster_levels": r.get("disaster_levels", {})
}
for r in raw_results
]
return items, save_results
return result_map
@router.post("/predict", response_model=PredictResponse, summary="地震灾害链预测")
async def predict_earthquake(req: EarthquakePredictRequest):
"""
根据震级、震源深度和震中位置,批量预测隐患点/风险点的次生灾害概率和等级
根据震级、震源深度和震中位置,批量预测隐患点/风险点的次生灾害概率。
- **disaster_name**: 灾害名称
- **point_ids**: 点位ID列表(可选,不传则查询所有点)
@@ -103,7 +88,7 @@ async def predict_earthquake(req: EarthquakePredictRequest):
async with semaphore:
loop = asyncio.get_event_loop()
try:
items, save_results = await loop.run_in_executor(
result_map = await loop.run_in_executor(
None, _predict_sync, req.point_ids, req.region_code,
req.magnitude, req.depth, req.epicenter_lon, req.epicenter_lat
)
@@ -113,7 +98,7 @@ async def predict_earthquake(req: EarthquakePredictRequest):
# 保存推理结果
record_id = None
if save_results:
if result_map:
try:
condition = {
"point_ids": req.point_ids,
@@ -129,10 +114,10 @@ async def predict_earthquake(req: EarthquakePredictRequest):
occurred_time=req.occurred_time,
operation_type=req.operation_type,
condition=condition,
result=save_results
result=result_map
)
logger.info(f"推理结果已保存,record_id={record_id}")
except Exception as e:
logger.error(f"保存推理结果失败: {e}", exc_info=True)
return PredictResponse(code=200, message="success", data=items, record_id=record_id)
return PredictResponse(code=200, message="success", data=record_id)