优化参数

This commit is contained in:
wzy-warehouse
2026-06-06 11:29:08 +08:00
parent 9c3b0575d2
commit cb2d8c2c54
4 changed files with 154 additions and 6 deletions
+19 -2
View File
@@ -214,7 +214,8 @@ class RainfallDBN:
def predict_single_point(self, point: Dict[str, Any],
rainfall: Optional[float] = None,
duration: Optional[float] = None,
query_time: Optional[datetime] = None) -> Dict[str, Any]:
query_time: Optional[datetime] = None,
rainfall_data_override: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
对单个点进行预测
@@ -223,6 +224,7 @@ class RainfallDBN:
rainfall: 累计降雨量(可选)
duration: 持续时间(可选)
query_time: 查询时间(可选)
rainfall_data_override: 预取的降雨数据(批量模式下避免重复查询)
Returns:
预测结果
@@ -242,6 +244,8 @@ class RainfallDBN:
'duration_hours': duration,
'rain_intensity': rain_intensity
}
elif rainfall_data_override is not None:
rainfall_data = rainfall_data_override
else:
rainfall_data = DbnRepository.get_rainfall_data_with_duration(lon, lat, query_time)
@@ -389,11 +393,24 @@ class RainfallDBN:
Returns:
预测结果列表
"""
# 批量预取降雨数据(避免逐点N+1查询)
batch_rainfall: Optional[Dict[str, Dict[str, Any]]] = None
if (rainfall is None or duration is None) and points:
batch_points = [
{'id': p.get('id'), 'lon': p.get('lon'), 'lat': p.get('lat')}
for p in points
]
batch_rainfall = DbnRepository.get_rainfall_data_batch(batch_points, query_time)
logger.info(f"批量预取降雨数据完成,覆盖 {len(batch_rainfall)} 个点")
results = []
for point in points:
try:
point_id = point.get('id')
override = batch_rainfall.get(point_id) if batch_rainfall else None
result = self.predict_single_point(
point, rainfall=rainfall, duration=duration, query_time=query_time
point, rainfall=rainfall, duration=duration,
query_time=query_time, rainfall_data_override=override
)
results.append(result)
except Exception as e: