修改API接口
This commit is contained in:
@@ -100,19 +100,23 @@ async def predict_earthquake(req: EarthquakePredictRequest):
|
||||
record_id = None
|
||||
if result_map:
|
||||
try:
|
||||
# 存储经过默认值处理的条件(depth 默认值为 10.0)
|
||||
# 使用传入的 occurred_time,如果未传则使用当前时间
|
||||
from datetime import datetime
|
||||
occurred_time = req.occurred_time if req.occurred_time else datetime.now()
|
||||
# 存储经过默认值处理的条件
|
||||
condition = {
|
||||
"point_ids": req.point_ids,
|
||||
"region_code": req.region_code,
|
||||
"magnitude": req.magnitude,
|
||||
"depth": req.depth, # 已有默认值 10.0
|
||||
"epicenter_lon": req.epicenter_lon,
|
||||
"epicenter_lat": req.epicenter_lat
|
||||
"epicenter_lat": req.epicenter_lat,
|
||||
"occurred_time": occurred_time.isoformat() if hasattr(occurred_time, 'isoformat') else str(occurred_time)
|
||||
}
|
||||
record_id = dbn_repository.save_inference_result(
|
||||
disaster_name=req.disaster_name,
|
||||
event_type="earthquake",
|
||||
occurred_time=req.occurred_time,
|
||||
occurred_time=occurred_time,
|
||||
operation_type=req.operation_type,
|
||||
condition=condition,
|
||||
result=result_map
|
||||
|
||||
+17
-11
@@ -47,19 +47,25 @@ def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) ->
|
||||
|
||||
def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
||||
rainfall: Optional[float], duration: Optional[float],
|
||||
operation_type: str) -> tuple:
|
||||
operation_type: str, occurred_time: Optional[datetime] = None) -> tuple:
|
||||
"""
|
||||
同步执行暴雨预测(在线程池中运行)
|
||||
|
||||
Args:
|
||||
occurred_time: 事件发生时间,用于查询降雨数据和DBN推理
|
||||
|
||||
Returns:
|
||||
(结果map, 实际使用的降雨数据, 当前时间)
|
||||
(结果map, 经过默认值处理的条件, 实际使用的事件时间)
|
||||
"""
|
||||
points = _fetch_points(point_ids, region_code)
|
||||
if not points:
|
||||
return {}, {}, datetime.now()
|
||||
return {}, {}, occurred_time or datetime.now()
|
||||
|
||||
# 使用传入的时间或当前时间作为查询时间
|
||||
query_time = occurred_time or datetime.now()
|
||||
|
||||
model = get_rainfall_model()
|
||||
raw_results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration)
|
||||
raw_results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration, query_time=query_time)
|
||||
result_map = _build_prediction_map(raw_results)
|
||||
|
||||
# 获取实际使用的降雨数据(如果未传递,模型会从数据库查询)
|
||||
@@ -69,22 +75,22 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
||||
# 获取第一个点的降雨数据作为参考
|
||||
from app.repositories.dbn_repository import DbnRepository
|
||||
first_point = points[0]
|
||||
rain_data = DbnRepository.get_rainfall_data_with_duration(first_point['lon'], first_point['lat'])
|
||||
rain_data = DbnRepository.get_rainfall_data_with_duration(first_point['lon'], first_point['lat'], query_time)
|
||||
if actual_rainfall is None:
|
||||
actual_rainfall = rain_data.get('accum_rain', 0.0)
|
||||
if actual_duration is None:
|
||||
actual_duration = rain_data.get('duration_hours', 0)
|
||||
|
||||
# 构建经过默认值处理的条件用于保存
|
||||
now = datetime.now()
|
||||
condition = {
|
||||
"point_ids": point_ids,
|
||||
"region_code": region_code,
|
||||
"rainfall": actual_rainfall,
|
||||
"duration": actual_duration
|
||||
"duration": actual_duration,
|
||||
"occurred_time": query_time.isoformat() if hasattr(query_time, 'isoformat') else str(query_time)
|
||||
}
|
||||
|
||||
return result_map, condition, now
|
||||
return result_map, condition, query_time
|
||||
|
||||
|
||||
@router.post("/update-monitoring-time", summary="更新降雨监测查询时间")
|
||||
@@ -132,9 +138,9 @@ async def predict_rainfall(req: RainfallPredictRequest):
|
||||
async with semaphore:
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
result_map, condition, now = await loop.run_in_executor(
|
||||
result_map, condition, occurred_time = await loop.run_in_executor(
|
||||
None, _predict_sync, req.point_ids, req.region_code,
|
||||
req.rainfall, req.duration, req.operation_type
|
||||
req.rainfall, req.duration, req.operation_type, req.occurred_time
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"暴雨预测失败: {e}", exc_info=True)
|
||||
@@ -147,7 +153,7 @@ async def predict_rainfall(req: RainfallPredictRequest):
|
||||
record_id = dbn_repository.save_inference_result(
|
||||
disaster_name=req.disaster_name,
|
||||
event_type="rainfall",
|
||||
occurred_time=now,
|
||||
occurred_time=occurred_time,
|
||||
operation_type=req.operation_type,
|
||||
condition=condition,
|
||||
result=result_map
|
||||
|
||||
Reference in New Issue
Block a user