修改API接口
This commit is contained in:
@@ -100,19 +100,23 @@ async def predict_earthquake(req: EarthquakePredictRequest):
|
|||||||
record_id = None
|
record_id = None
|
||||||
if result_map:
|
if result_map:
|
||||||
try:
|
try:
|
||||||
# 存储经过默认值处理的条件(depth 默认值为 10.0)
|
# 使用传入的 occurred_time,如果未传则使用当前时间
|
||||||
|
from datetime import datetime
|
||||||
|
occurred_time = req.occurred_time if req.occurred_time else datetime.now()
|
||||||
|
# 存储经过默认值处理的条件
|
||||||
condition = {
|
condition = {
|
||||||
"point_ids": req.point_ids,
|
"point_ids": req.point_ids,
|
||||||
"region_code": req.region_code,
|
"region_code": req.region_code,
|
||||||
"magnitude": req.magnitude,
|
"magnitude": req.magnitude,
|
||||||
"depth": req.depth, # 已有默认值 10.0
|
"depth": req.depth, # 已有默认值 10.0
|
||||||
"epicenter_lon": req.epicenter_lon,
|
"epicenter_lon": req.epicenter_lon,
|
||||||
"epicenter_lat": req.epicenter_lat
|
"epicenter_lat": req.epicenter_lat,
|
||||||
|
"occurred_time": occurred_time.isoformat() if hasattr(occurred_time, 'isoformat') else str(occurred_time)
|
||||||
}
|
}
|
||||||
record_id = dbn_repository.save_inference_result(
|
record_id = dbn_repository.save_inference_result(
|
||||||
disaster_name=req.disaster_name,
|
disaster_name=req.disaster_name,
|
||||||
event_type="earthquake",
|
event_type="earthquake",
|
||||||
occurred_time=req.occurred_time,
|
occurred_time=occurred_time,
|
||||||
operation_type=req.operation_type,
|
operation_type=req.operation_type,
|
||||||
condition=condition,
|
condition=condition,
|
||||||
result=result_map
|
result=result_map
|
||||||
|
|||||||
+17
-11
@@ -47,19 +47,25 @@ def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) ->
|
|||||||
|
|
||||||
def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
||||||
rainfall: Optional[float], duration: Optional[float],
|
rainfall: Optional[float], duration: Optional[float],
|
||||||
operation_type: str) -> tuple:
|
operation_type: str, occurred_time: Optional[datetime] = None) -> tuple:
|
||||||
"""
|
"""
|
||||||
同步执行暴雨预测(在线程池中运行)
|
同步执行暴雨预测(在线程池中运行)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
occurred_time: 事件发生时间,用于查询降雨数据和DBN推理
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(结果map, 实际使用的降雨数据, 当前时间)
|
(结果map, 经过默认值处理的条件, 实际使用的事件时间)
|
||||||
"""
|
"""
|
||||||
points = _fetch_points(point_ids, region_code)
|
points = _fetch_points(point_ids, region_code)
|
||||||
if not points:
|
if not points:
|
||||||
return {}, {}, datetime.now()
|
return {}, {}, occurred_time or datetime.now()
|
||||||
|
|
||||||
|
# 使用传入的时间或当前时间作为查询时间
|
||||||
|
query_time = occurred_time or datetime.now()
|
||||||
|
|
||||||
model = get_rainfall_model()
|
model = get_rainfall_model()
|
||||||
raw_results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration)
|
raw_results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration, query_time=query_time)
|
||||||
result_map = _build_prediction_map(raw_results)
|
result_map = _build_prediction_map(raw_results)
|
||||||
|
|
||||||
# 获取实际使用的降雨数据(如果未传递,模型会从数据库查询)
|
# 获取实际使用的降雨数据(如果未传递,模型会从数据库查询)
|
||||||
@@ -69,22 +75,22 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
|||||||
# 获取第一个点的降雨数据作为参考
|
# 获取第一个点的降雨数据作为参考
|
||||||
from app.repositories.dbn_repository import DbnRepository
|
from app.repositories.dbn_repository import DbnRepository
|
||||||
first_point = points[0]
|
first_point = points[0]
|
||||||
rain_data = DbnRepository.get_rainfall_data_with_duration(first_point['lon'], first_point['lat'])
|
rain_data = DbnRepository.get_rainfall_data_with_duration(first_point['lon'], first_point['lat'], query_time)
|
||||||
if actual_rainfall is None:
|
if actual_rainfall is None:
|
||||||
actual_rainfall = rain_data.get('accum_rain', 0.0)
|
actual_rainfall = rain_data.get('accum_rain', 0.0)
|
||||||
if actual_duration is None:
|
if actual_duration is None:
|
||||||
actual_duration = rain_data.get('duration_hours', 0)
|
actual_duration = rain_data.get('duration_hours', 0)
|
||||||
|
|
||||||
# 构建经过默认值处理的条件用于保存
|
# 构建经过默认值处理的条件用于保存
|
||||||
now = datetime.now()
|
|
||||||
condition = {
|
condition = {
|
||||||
"point_ids": point_ids,
|
"point_ids": point_ids,
|
||||||
"region_code": region_code,
|
"region_code": region_code,
|
||||||
"rainfall": actual_rainfall,
|
"rainfall": actual_rainfall,
|
||||||
"duration": actual_duration
|
"duration": actual_duration,
|
||||||
|
"occurred_time": query_time.isoformat() if hasattr(query_time, 'isoformat') else str(query_time)
|
||||||
}
|
}
|
||||||
|
|
||||||
return result_map, condition, now
|
return result_map, condition, query_time
|
||||||
|
|
||||||
|
|
||||||
@router.post("/update-monitoring-time", summary="更新降雨监测查询时间")
|
@router.post("/update-monitoring-time", summary="更新降雨监测查询时间")
|
||||||
@@ -132,9 +138,9 @@ async def predict_rainfall(req: RainfallPredictRequest):
|
|||||||
async with semaphore:
|
async with semaphore:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
try:
|
try:
|
||||||
result_map, condition, now = await loop.run_in_executor(
|
result_map, condition, occurred_time = await loop.run_in_executor(
|
||||||
None, _predict_sync, req.point_ids, req.region_code,
|
None, _predict_sync, req.point_ids, req.region_code,
|
||||||
req.rainfall, req.duration, req.operation_type
|
req.rainfall, req.duration, req.operation_type, req.occurred_time
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"暴雨预测失败: {e}", exc_info=True)
|
logger.error(f"暴雨预测失败: {e}", exc_info=True)
|
||||||
@@ -147,7 +153,7 @@ async def predict_rainfall(req: RainfallPredictRequest):
|
|||||||
record_id = dbn_repository.save_inference_result(
|
record_id = dbn_repository.save_inference_result(
|
||||||
disaster_name=req.disaster_name,
|
disaster_name=req.disaster_name,
|
||||||
event_type="rainfall",
|
event_type="rainfall",
|
||||||
occurred_time=now,
|
occurred_time=occurred_time,
|
||||||
operation_type=req.operation_type,
|
operation_type=req.operation_type,
|
||||||
condition=condition,
|
condition=condition,
|
||||||
result=result_map
|
result=result_map
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ class RainfallPredictRequest(BaseModel):
|
|||||||
description="累计降雨量(mm),不传则从气象表自动获取")
|
description="累计降雨量(mm),不传则从气象表自动获取")
|
||||||
duration: Optional[float] = Field(None, ge=0,
|
duration: Optional[float] = Field(None, ge=0,
|
||||||
description="降雨持续时间(h),不传则从气象表自动获取")
|
description="降雨持续时间(h),不传则从气象表自动获取")
|
||||||
|
occurred_time: Optional[datetime] = Field(None, description="事件发生时间,不传则为当前时间")
|
||||||
operation_type: str = Field("模拟", min_length=1, max_length=50,
|
operation_type: str = Field("模拟", min_length=1, max_length=50,
|
||||||
description="操作类型(如 '模拟', '实时监测', '应急评估')")
|
description="操作类型(如 '模拟', '实时监测', '应急评估')")
|
||||||
|
|
||||||
@@ -38,7 +39,7 @@ class EarthquakePredictRequest(BaseModel):
|
|||||||
depth: float = Field(10.0, gt=0, le=700, description="震源深度(km),默认10km")
|
depth: float = Field(10.0, gt=0, le=700, description="震源深度(km),默认10km")
|
||||||
epicenter_lon: float = Field(..., ge=-180, le=180, description="震中经度")
|
epicenter_lon: float = Field(..., ge=-180, le=180, description="震中经度")
|
||||||
epicenter_lat: float = Field(..., ge=-90, le=90, description="震中纬度")
|
epicenter_lat: float = Field(..., ge=-90, le=90, description="震中纬度")
|
||||||
occurred_time: datetime = Field(..., description="地震发生时间")
|
occurred_time: Optional[datetime] = Field(None, description="地震发生时间,不传则为当前时间")
|
||||||
operation_type: str = Field("模拟", min_length=1, max_length=50,
|
operation_type: str = Field("模拟", min_length=1, max_length=50,
|
||||||
description="操作类型(如 '模拟', '实时监测', '应急评估')")
|
description="操作类型(如 '模拟', '实时监测', '应急评估')")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user