修改接口返回值
This commit is contained in:
+18
-33
@@ -18,24 +18,21 @@ SOURCE_TYPE_MAP = {1: "隐患点", 2: "风险点"}
|
|||||||
LEVEL_MAP = {"低": "低", "中": "中", "较高": "较高", "高": "高"}
|
LEVEL_MAP = {"低": "低", "中": "中", "较高": "较高", "高": "高"}
|
||||||
|
|
||||||
|
|
||||||
def _build_prediction_items(results: List[Dict[str, Any]]) -> List[PredictionItem]:
|
def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]:
|
||||||
"""将模型原始结果转换为接口返回格式"""
|
"""将模型原始结果转换为存储格式: {id_type: 概率百分比}"""
|
||||||
items = []
|
result_map = {}
|
||||||
for r in results:
|
for r in results:
|
||||||
probs = r.get("disaster_probabilities", {})
|
probs = r.get("disaster_probabilities", {})
|
||||||
levels = r.get("disaster_levels", {})
|
|
||||||
|
|
||||||
if not probs:
|
if not probs:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
source_id = r["source_id"]
|
||||||
|
source_type = r.get("source_type")
|
||||||
max_hazard = max(probs, key=probs.get)
|
max_hazard = max(probs, key=probs.get)
|
||||||
items.append(PredictionItem(
|
# key 格式: {source_id}_{source_type},value 为百分比概率
|
||||||
id=r["source_id"], # 使用 source_id(隐患点/风险点ID)而非 xian_risk_factors.id
|
key = f"{source_id}_{source_type}"
|
||||||
type=SOURCE_TYPE_MAP.get(r.get("source_type"), "未知"),
|
result_map[key] = round(probs[max_hazard] * 100, 2)
|
||||||
probability=round(probs[max_hazard], 4),
|
return result_map
|
||||||
level=LEVEL_MAP.get(levels.get(max_hazard, "none"), "无"),
|
|
||||||
))
|
|
||||||
return items
|
|
||||||
|
|
||||||
|
|
||||||
def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]:
|
def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]:
|
||||||
@@ -52,11 +49,11 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
|||||||
同步执行地震预测(在线程池中运行)
|
同步执行地震预测(在线程池中运行)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(预测结果列表, 原始结果)
|
(结果map,)
|
||||||
"""
|
"""
|
||||||
points = _fetch_points(point_ids, region_code)
|
points = _fetch_points(point_ids, region_code)
|
||||||
if not points:
|
if not points:
|
||||||
return [], []
|
return {}
|
||||||
|
|
||||||
model = get_earthquake_model()
|
model = get_earthquake_model()
|
||||||
raw_results = model.predict_multiple_points(
|
raw_results = model.predict_multiple_points(
|
||||||
@@ -66,27 +63,15 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
|||||||
epicenter_lon=epicenter_lon,
|
epicenter_lon=epicenter_lon,
|
||||||
epicenter_lat=epicenter_lat,
|
epicenter_lat=epicenter_lat,
|
||||||
)
|
)
|
||||||
items = _build_prediction_items(raw_results)
|
result_map = _build_prediction_map(raw_results)
|
||||||
|
|
||||||
save_results = [
|
return result_map
|
||||||
{
|
|
||||||
"point_id": r.get("source_id"), # 使用 source_id(隐患点/风险点ID)而非 xian_risk_factors.id
|
|
||||||
"source_type": r.get("source_type"),
|
|
||||||
"lon": r.get("lon"),
|
|
||||||
"lat": r.get("lat"),
|
|
||||||
"disaster_probabilities": r.get("disaster_probabilities", {}),
|
|
||||||
"disaster_levels": r.get("disaster_levels", {})
|
|
||||||
}
|
|
||||||
for r in raw_results
|
|
||||||
]
|
|
||||||
|
|
||||||
return items, save_results
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/predict", response_model=PredictResponse, summary="地震灾害链预测")
|
@router.post("/predict", response_model=PredictResponse, summary="地震灾害链预测")
|
||||||
async def predict_earthquake(req: EarthquakePredictRequest):
|
async def predict_earthquake(req: EarthquakePredictRequest):
|
||||||
"""
|
"""
|
||||||
根据震级、震源深度和震中位置,批量预测隐患点/风险点的次生灾害概率和等级。
|
根据震级、震源深度和震中位置,批量预测隐患点/风险点的次生灾害概率。
|
||||||
|
|
||||||
- **disaster_name**: 灾害名称
|
- **disaster_name**: 灾害名称
|
||||||
- **point_ids**: 点位ID列表(可选,不传则查询所有点)
|
- **point_ids**: 点位ID列表(可选,不传则查询所有点)
|
||||||
@@ -103,7 +88,7 @@ async def predict_earthquake(req: EarthquakePredictRequest):
|
|||||||
async with semaphore:
|
async with semaphore:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
try:
|
try:
|
||||||
items, save_results = await loop.run_in_executor(
|
result_map = await loop.run_in_executor(
|
||||||
None, _predict_sync, req.point_ids, req.region_code,
|
None, _predict_sync, req.point_ids, req.region_code,
|
||||||
req.magnitude, req.depth, req.epicenter_lon, req.epicenter_lat
|
req.magnitude, req.depth, req.epicenter_lon, req.epicenter_lat
|
||||||
)
|
)
|
||||||
@@ -113,7 +98,7 @@ async def predict_earthquake(req: EarthquakePredictRequest):
|
|||||||
|
|
||||||
# 保存推理结果
|
# 保存推理结果
|
||||||
record_id = None
|
record_id = None
|
||||||
if save_results:
|
if result_map:
|
||||||
try:
|
try:
|
||||||
condition = {
|
condition = {
|
||||||
"point_ids": req.point_ids,
|
"point_ids": req.point_ids,
|
||||||
@@ -129,10 +114,10 @@ async def predict_earthquake(req: EarthquakePredictRequest):
|
|||||||
occurred_time=req.occurred_time,
|
occurred_time=req.occurred_time,
|
||||||
operation_type=req.operation_type,
|
operation_type=req.operation_type,
|
||||||
condition=condition,
|
condition=condition,
|
||||||
result=save_results
|
result=result_map
|
||||||
)
|
)
|
||||||
logger.info(f"推理结果已保存,record_id={record_id}")
|
logger.info(f"推理结果已保存,record_id={record_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存推理结果失败: {e}", exc_info=True)
|
logger.error(f"保存推理结果失败: {e}", exc_info=True)
|
||||||
|
|
||||||
return PredictResponse(code=200, message="success", data=items, record_id=record_id)
|
return PredictResponse(code=200, message="success", data=record_id)
|
||||||
|
|||||||
+18
-32
@@ -21,24 +21,21 @@ SOURCE_TYPE_MAP = {1: "隐患点", 2: "风险点"}
|
|||||||
LEVEL_MAP = {"低": "低", "中": "中", "较高": "较高", "高": "高"}
|
LEVEL_MAP = {"低": "低", "中": "中", "较高": "较高", "高": "高"}
|
||||||
|
|
||||||
|
|
||||||
def _build_prediction_items(results: List[Dict[str, Any]]) -> List[PredictionItem]:
|
def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]:
|
||||||
"""将模型原始结果转换为接口返回格式"""
|
"""将模型原始结果转换为存储格式: {id_type: 概率百分比}"""
|
||||||
items = []
|
result_map = {}
|
||||||
for r in results:
|
for r in results:
|
||||||
probs = r.get("disaster_probabilities", {})
|
probs = r.get("disaster_probabilities", {})
|
||||||
levels = r.get("disaster_levels", {})
|
|
||||||
|
|
||||||
if not probs:
|
if not probs:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
source_id = r["source_id"]
|
||||||
|
source_type = r.get("source_type")
|
||||||
max_hazard = max(probs, key=probs.get)
|
max_hazard = max(probs, key=probs.get)
|
||||||
items.append(PredictionItem(
|
# key 格式: {source_id}_{source_type},value 为百分比概率
|
||||||
id=r["source_id"], # 使用 source_id(隐患点/风险点ID)而非 xian_risk_factors.id
|
key = f"{source_id}_{source_type}"
|
||||||
type=SOURCE_TYPE_MAP.get(r.get("source_type"), "未知"),
|
result_map[key] = round(probs[max_hazard] * 100, 2)
|
||||||
probability=round(probs[max_hazard], 4),
|
return result_map
|
||||||
level=LEVEL_MAP.get(levels.get(max_hazard, "none"), "无"),
|
|
||||||
))
|
|
||||||
return items
|
|
||||||
|
|
||||||
|
|
||||||
def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]:
|
def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]:
|
||||||
@@ -55,15 +52,15 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
|||||||
同步执行暴雨预测(在线程池中运行)
|
同步执行暴雨预测(在线程池中运行)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(预测结果列表, 原始结果, 输入条件, 当前时间)
|
(结果map, 输入条件, 当前时间)
|
||||||
"""
|
"""
|
||||||
points = _fetch_points(point_ids, region_code)
|
points = _fetch_points(point_ids, region_code)
|
||||||
if not points:
|
if not points:
|
||||||
return [], [], {}, datetime.now()
|
return {}, {}, datetime.now()
|
||||||
|
|
||||||
model = get_rainfall_model()
|
model = get_rainfall_model()
|
||||||
raw_results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration)
|
raw_results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration)
|
||||||
items = _build_prediction_items(raw_results)
|
result_map = _build_prediction_map(raw_results)
|
||||||
|
|
||||||
# 构建条件和结果用于保存
|
# 构建条件和结果用于保存
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
@@ -73,19 +70,8 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
|||||||
"rainfall": rainfall,
|
"rainfall": rainfall,
|
||||||
"duration": duration
|
"duration": duration
|
||||||
}
|
}
|
||||||
save_results = [
|
|
||||||
{
|
|
||||||
"point_id": r.get("source_id"), # 使用 source_id(隐患点/风险点ID)而非 xian_risk_factors.id
|
|
||||||
"source_type": r.get("source_type"),
|
|
||||||
"lon": r.get("lon"),
|
|
||||||
"lat": r.get("lat"),
|
|
||||||
"disaster_probabilities": r.get("disaster_probabilities", {}),
|
|
||||||
"disaster_levels": r.get("disaster_levels", {})
|
|
||||||
}
|
|
||||||
for r in raw_results
|
|
||||||
]
|
|
||||||
|
|
||||||
return items, save_results, condition, now
|
return result_map, condition, now
|
||||||
|
|
||||||
|
|
||||||
@router.post("/update-monitoring-time", summary="更新降雨监测查询时间")
|
@router.post("/update-monitoring-time", summary="更新降雨监测查询时间")
|
||||||
@@ -119,7 +105,7 @@ async def update_monitoring_time(req: UpdateMonitoringTimeRequest):
|
|||||||
@router.post("/predict", response_model=PredictResponse, summary="暴雨灾害链预测")
|
@router.post("/predict", response_model=PredictResponse, summary="暴雨灾害链预测")
|
||||||
async def predict_rainfall(req: RainfallPredictRequest):
|
async def predict_rainfall(req: RainfallPredictRequest):
|
||||||
"""
|
"""
|
||||||
根据降雨量和持续时间,批量预测隐患点/风险点的灾害概率和等级。
|
根据降雨量和持续时间,批量预测隐患点/风险点的灾害概率。
|
||||||
|
|
||||||
- **disaster_name**: 灾害名称
|
- **disaster_name**: 灾害名称
|
||||||
- **point_ids**: 点位ID列表(可选,不传则查询所有点)
|
- **point_ids**: 点位ID列表(可选,不传则查询所有点)
|
||||||
@@ -133,7 +119,7 @@ async def predict_rainfall(req: RainfallPredictRequest):
|
|||||||
async with semaphore:
|
async with semaphore:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
try:
|
try:
|
||||||
items, save_results, condition, now = await loop.run_in_executor(
|
result_map, condition, now = await loop.run_in_executor(
|
||||||
None, _predict_sync, req.point_ids, req.region_code,
|
None, _predict_sync, req.point_ids, req.region_code,
|
||||||
req.rainfall, req.duration, req.operation_type
|
req.rainfall, req.duration, req.operation_type
|
||||||
)
|
)
|
||||||
@@ -143,7 +129,7 @@ async def predict_rainfall(req: RainfallPredictRequest):
|
|||||||
|
|
||||||
# 保存推理结果
|
# 保存推理结果
|
||||||
record_id = None
|
record_id = None
|
||||||
if save_results:
|
if result_map:
|
||||||
try:
|
try:
|
||||||
record_id = dbn_repository.save_inference_result(
|
record_id = dbn_repository.save_inference_result(
|
||||||
disaster_name=req.disaster_name,
|
disaster_name=req.disaster_name,
|
||||||
@@ -151,10 +137,10 @@ async def predict_rainfall(req: RainfallPredictRequest):
|
|||||||
occurred_time=now,
|
occurred_time=now,
|
||||||
operation_type=req.operation_type,
|
operation_type=req.operation_type,
|
||||||
condition=condition,
|
condition=condition,
|
||||||
result=save_results
|
result=result_map
|
||||||
)
|
)
|
||||||
logger.info(f"推理结果已保存,record_id={record_id}")
|
logger.info(f"推理结果已保存,record_id={record_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存推理结果失败: {e}", exc_info=True)
|
logger.error(f"保存推理结果失败: {e}", exc_info=True)
|
||||||
|
|
||||||
return PredictResponse(code=200, message="success", data=items, record_id=record_id)
|
return PredictResponse(code=200, message="success", data=record_id)
|
||||||
|
|||||||
@@ -59,8 +59,7 @@ class PredictResponse(BaseModel):
|
|||||||
"""预测响应"""
|
"""预测响应"""
|
||||||
code: int = Field(200, description="状态码")
|
code: int = Field(200, description="状态码")
|
||||||
message: str = Field("success", description="提示信息")
|
message: str = Field("success", description="提示信息")
|
||||||
data: List[PredictionItem] = Field(default_factory=list, description="预测结果列表")
|
data: Optional[int] = Field(None, description="推理结果记录ID")
|
||||||
record_id: Optional[int] = Field(None, description="推理结果记录ID")
|
|
||||||
|
|
||||||
|
|
||||||
class UpdateMonitoringTimeRequest(BaseModel):
|
class UpdateMonitoringTimeRequest(BaseModel):
|
||||||
|
|||||||
Reference in New Issue
Block a user