修改API接口

This commit is contained in:
wzy-warehouse
2026-06-14 16:29:01 +08:00
parent 028b7989ef
commit de118dc57b
3 changed files with 26 additions and 15 deletions
+7 -3
View File
@@ -100,19 +100,23 @@ async def predict_earthquake(req: EarthquakePredictRequest):
record_id = None record_id = None
if result_map: if result_map:
try: try:
# 存储经过默认值处理的条件(depth 默认值为 10.0) # 使用传入的 occurred_time,如果未传则使用当前时间
from datetime import datetime
occurred_time = req.occurred_time if req.occurred_time else datetime.now()
# 存储经过默认值处理的条件
condition = { condition = {
"point_ids": req.point_ids, "point_ids": req.point_ids,
"region_code": req.region_code, "region_code": req.region_code,
"magnitude": req.magnitude, "magnitude": req.magnitude,
"depth": req.depth, # 已有默认值 10.0 "depth": req.depth, # 已有默认值 10.0
"epicenter_lon": req.epicenter_lon, "epicenter_lon": req.epicenter_lon,
"epicenter_lat": req.epicenter_lat "epicenter_lat": req.epicenter_lat,
"occurred_time": occurred_time.isoformat() if hasattr(occurred_time, 'isoformat') else str(occurred_time)
} }
record_id = dbn_repository.save_inference_result( record_id = dbn_repository.save_inference_result(
disaster_name=req.disaster_name, disaster_name=req.disaster_name,
event_type="earthquake", event_type="earthquake",
occurred_time=req.occurred_time, occurred_time=occurred_time,
operation_type=req.operation_type, operation_type=req.operation_type,
condition=condition, condition=condition,
result=result_map result=result_map
+17 -11
View File
@@ -47,19 +47,25 @@ def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) ->
def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str], def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
rainfall: Optional[float], duration: Optional[float], rainfall: Optional[float], duration: Optional[float],
operation_type: str) -> tuple: operation_type: str, occurred_time: Optional[datetime] = None) -> tuple:
""" """
同步执行暴雨预测(在线程池中运行) 同步执行暴雨预测(在线程池中运行)
Args:
occurred_time: 事件发生时间,用于查询降雨数据和DBN推理
Returns: Returns:
(结果map, 实际使用的降雨数据, 当前时间) (结果map, 经过默认值处理的条件, 实际使用的事件时间)
""" """
points = _fetch_points(point_ids, region_code) points = _fetch_points(point_ids, region_code)
if not points: if not points:
return {}, {}, datetime.now() return {}, {}, occurred_time or datetime.now()
# 使用传入的时间或当前时间作为查询时间
query_time = occurred_time or datetime.now()
model = get_rainfall_model() model = get_rainfall_model()
raw_results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration) raw_results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration, query_time=query_time)
result_map = _build_prediction_map(raw_results) result_map = _build_prediction_map(raw_results)
# 获取实际使用的降雨数据(如果未传递,模型会从数据库查询) # 获取实际使用的降雨数据(如果未传递,模型会从数据库查询)
@@ -69,22 +75,22 @@ def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
# 获取第一个点的降雨数据作为参考 # 获取第一个点的降雨数据作为参考
from app.repositories.dbn_repository import DbnRepository from app.repositories.dbn_repository import DbnRepository
first_point = points[0] first_point = points[0]
rain_data = DbnRepository.get_rainfall_data_with_duration(first_point['lon'], first_point['lat']) rain_data = DbnRepository.get_rainfall_data_with_duration(first_point['lon'], first_point['lat'], query_time)
if actual_rainfall is None: if actual_rainfall is None:
actual_rainfall = rain_data.get('accum_rain', 0.0) actual_rainfall = rain_data.get('accum_rain', 0.0)
if actual_duration is None: if actual_duration is None:
actual_duration = rain_data.get('duration_hours', 0) actual_duration = rain_data.get('duration_hours', 0)
# 构建经过默认值处理的条件用于保存 # 构建经过默认值处理的条件用于保存
now = datetime.now()
condition = { condition = {
"point_ids": point_ids, "point_ids": point_ids,
"region_code": region_code, "region_code": region_code,
"rainfall": actual_rainfall, "rainfall": actual_rainfall,
"duration": actual_duration "duration": actual_duration,
"occurred_time": query_time.isoformat() if hasattr(query_time, 'isoformat') else str(query_time)
} }
return result_map, condition, now return result_map, condition, query_time
@router.post("/update-monitoring-time", summary="更新降雨监测查询时间") @router.post("/update-monitoring-time", summary="更新降雨监测查询时间")
@@ -132,9 +138,9 @@ async def predict_rainfall(req: RainfallPredictRequest):
async with semaphore: async with semaphore:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
try: try:
result_map, condition, now = await loop.run_in_executor( result_map, condition, occurred_time = await loop.run_in_executor(
None, _predict_sync, req.point_ids, req.region_code, None, _predict_sync, req.point_ids, req.region_code,
req.rainfall, req.duration, req.operation_type req.rainfall, req.duration, req.operation_type, req.occurred_time
) )
except Exception as e: except Exception as e:
logger.error(f"暴雨预测失败: {e}", exc_info=True) logger.error(f"暴雨预测失败: {e}", exc_info=True)
@@ -147,7 +153,7 @@ async def predict_rainfall(req: RainfallPredictRequest):
record_id = dbn_repository.save_inference_result( record_id = dbn_repository.save_inference_result(
disaster_name=req.disaster_name, disaster_name=req.disaster_name,
event_type="rainfall", event_type="rainfall",
occurred_time=now, occurred_time=occurred_time,
operation_type=req.operation_type, operation_type=req.operation_type,
condition=condition, condition=condition,
result=result_map result=result_map
+2 -1
View File
@@ -20,6 +20,7 @@ class RainfallPredictRequest(BaseModel):
description="累计降雨量(mm),不传则从气象表自动获取") description="累计降雨量(mm),不传则从气象表自动获取")
duration: Optional[float] = Field(None, ge=0, duration: Optional[float] = Field(None, ge=0,
description="降雨持续时间(h),不传则从气象表自动获取") description="降雨持续时间(h),不传则从气象表自动获取")
occurred_time: Optional[datetime] = Field(None, description="事件发生时间,不传则为当前时间")
operation_type: str = Field("模拟", min_length=1, max_length=50, operation_type: str = Field("模拟", min_length=1, max_length=50,
description="操作类型(如 '模拟', '实时监测', '应急评估'") description="操作类型(如 '模拟', '实时监测', '应急评估'")
@@ -38,7 +39,7 @@ class EarthquakePredictRequest(BaseModel):
depth: float = Field(10.0, gt=0, le=700, description="震源深度(km),默认10km") depth: float = Field(10.0, gt=0, le=700, description="震源深度(km),默认10km")
epicenter_lon: float = Field(..., ge=-180, le=180, description="震中经度") epicenter_lon: float = Field(..., ge=-180, le=180, description="震中经度")
epicenter_lat: float = Field(..., ge=-90, le=90, description="震中纬度") epicenter_lat: float = Field(..., ge=-90, le=90, description="震中纬度")
occurred_time: datetime = Field(..., description="地震发生时间") occurred_time: Optional[datetime] = Field(None, description="地震发生时间,不传则为当前时间")
operation_type: str = Field("模拟", min_length=1, max_length=50, operation_type: str = Field("模拟", min_length=1, max_length=50,
description="操作类型(如 '模拟', '实时监测', '应急评估'") description="操作类型(如 '模拟', '实时监测', '应急评估'")