diff --git a/app/config/dbn/discretization.yaml b/app/config/dbn/discretization.yaml index 7e429b3..8237553 100644 --- a/app/config/dbn/discretization.yaml +++ b/app/config/dbn/discretization.yaml @@ -203,6 +203,7 @@ pipe_density: description: "供水管网密度" unit: "m/m²" # 数据: [0.0, 0.07], 约80%为0.0,90%分位数0.013,95%分位数0.023 - # 分位数: [0.0, 0.013, 0.023, 0.065] - bins: [0.0, 0.013, 0.023, 0.065] + # 分箱策略:0单独一类,其余3等分(分位数分箱) + # 分位数(非零): [0.013, 0.023, 0.065] + bins: [0.0, 0.001, 0.013, 0.023, 0.065] labels: [none, low, medium, high] diff --git a/app/models/dbn/rainfall/rainfall_dbn.py b/app/models/dbn/rainfall/rainfall_dbn.py index f367d33..47666b5 100644 --- a/app/models/dbn/rainfall/rainfall_dbn.py +++ b/app/models/dbn/rainfall/rainfall_dbn.py @@ -214,7 +214,8 @@ class RainfallDBN: def predict_single_point(self, point: Dict[str, Any], rainfall: Optional[float] = None, duration: Optional[float] = None, - query_time: Optional[datetime] = None) -> Dict[str, Any]: + query_time: Optional[datetime] = None, + rainfall_data_override: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """ 对单个点进行预测 @@ -223,6 +224,7 @@ class RainfallDBN: rainfall: 累计降雨量(可选) duration: 持续时间(可选) query_time: 查询时间(可选) + rainfall_data_override: 预取的降雨数据(批量模式下避免重复查询) Returns: 预测结果 @@ -242,6 +244,8 @@ class RainfallDBN: 'duration_hours': duration, 'rain_intensity': rain_intensity } + elif rainfall_data_override is not None: + rainfall_data = rainfall_data_override else: rainfall_data = DbnRepository.get_rainfall_data_with_duration(lon, lat, query_time) @@ -389,11 +393,24 @@ class RainfallDBN: Returns: 预测结果列表 """ + # 批量预取降雨数据(避免逐点N+1查询) + batch_rainfall: Optional[Dict[str, Dict[str, Any]]] = None + if (rainfall is None or duration is None) and points: + batch_points = [ + {'id': p.get('id'), 'lon': p.get('lon'), 'lat': p.get('lat')} + for p in points + ] + batch_rainfall = DbnRepository.get_rainfall_data_batch(batch_points, query_time) + logger.info(f"批量预取降雨数据完成,覆盖 {len(batch_rainfall)} 个点") + results = [] for point in points: try: + point_id = point.get('id') + override = batch_rainfall.get(point_id) if batch_rainfall else None result = self.predict_single_point( - point, rainfall=rainfall, duration=duration, query_time=query_time + point, rainfall=rainfall, duration=duration, + query_time=query_time, rainfall_data_override=override ) results.append(result) except Exception as e: diff --git a/app/repositories/dbn_repository.py b/app/repositories/dbn_repository.py index c25c64d..a1fa0ea 100644 --- a/app/repositories/dbn_repository.py +++ b/app/repositories/dbn_repository.py @@ -314,6 +314,134 @@ class DbnRepository: 'rain_intensity': rain_intensity } + # ---- 批量降雨查询(性能优化) ---- + + _cached_stations: Optional[List[Dict[str, Any]]] = None + + @classmethod + def _ensure_stations_cached(cls) -> List[Dict[str, Any]]: + """一次性加载所有气象站点坐标到内存(188个站点,约2KB)""" + if cls._cached_stations is not None: + return cls._cached_stations + sql = "SELECT DISTINCT lon, lat FROM xian_meteorology" + cls._cached_stations = db_helper.execute_query(sql) + logger.info(f"已缓存 {len(cls._cached_stations)} 个气象站点坐标") + return cls._cached_stations + + @staticmethod + def _haversine_distance(lon1: float, lat1: float, lon2: float, lat2: float) -> float: + """Haversine公式计算两点间距离(米)""" + R = 6371000 + phi1, phi2 = math.radians(lat1), math.radians(lat2) + dphi = math.radians(lat2 - lat1) + dlam = math.radians(lon2 - lon1) + a = math.sin(dphi / 2) ** 2 + math.cos(phi1) * math.cos(phi2) * math.sin(dlam / 2) ** 2 + return R * 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a)) + + @classmethod + def _find_nearest_station(cls, lon: float, lat: float) -> Optional[Dict[str, Any]]: + """在缓存的站点中找最近的一个(纯Python,微秒级)""" + stations = cls._ensure_stations_cached() + if not stations: + return None + best = None + best_dist = float('inf') + for s in stations: + d = cls._haversine_distance(lon, lat, s['lon'], s['lat']) + if d < best_dist: + best_dist = d + best = s + if best_dist > 50000: + return None + return {'lon': best['lon'], 'lat': best['lat'], 'dist': best_dist} + + @classmethod + def get_rainfall_data_batch(cls, points: List[Dict[str, Any]], + query_time: Optional[datetime] = None) -> Dict[str, Dict[str, Any]]: + """ + 批量获取多个点的降雨数据(2次DB查询替代 N×2次) + + Args: + points: 预测点列表,每个含 {'id': str, 'lon': float, 'lat': float} + query_time: 查询时间 + + Returns: + {point_id: {accum_rain, duration_hours, rain_intensity}} + """ + if query_time is None: + query_time = datetime.now() + + # 结果模板(无数据时的默认值) + default = {'accum_rain': 0.0, 'duration_hours': 0, 'rain_intensity': 0.0} + result: Dict[str, Dict[str, Any]] = {} + + # 1. 为每个点找最近站点(纯Python,瞬间完成) + station_to_points: Dict[tuple, List[str]] = {} + for p in points: + station = cls._find_nearest_station(p['lon'], p['lat']) + if station is None: + result[p['id']] = default.copy() + continue + key = (station['lon'], station['lat']) + station_to_points.setdefault(key, []).append(p['id']) + + if not station_to_points: + return result + + # 2. 一次批量查所有站点的72小时降雨数据 + station_keys = list(station_to_points.keys()) + placeholders = ', '.join(['(%s, %s)'] * len(station_keys)) + params: List[Any] = [] + for slon, slat in station_keys: + params.extend([slon, slat]) + params.extend([query_time, query_time]) + + # noinspection SqlNoDataSourceInspection + sql = f""" + SELECT lon, lat, datetime, CAST(rainfall_1h AS DOUBLE PRECISION) as rainfall + FROM xian_meteorology + WHERE (lon, lat) IN ({placeholders}) + AND datetime BETWEEN + CAST(EXTRACT(EPOCH FROM (%s::timestamp - INTERVAL '72 hours')) AS BIGINT) + AND CAST(EXTRACT(EPOCH FROM %s::timestamp) AS BIGINT) + ORDER BY lon, lat, datetime DESC + """ + rows = db_helper.execute_query(sql, tuple(params)) + + # 3. 按站点分组,计算累计降雨量和持续时间 + from itertools import groupby + station_rainfall: Dict[tuple, Dict[str, Any]] = {} + for (slon, slat), group in groupby(rows, key=lambda r: (r['lon'], r['lat'])): + accum_rain = 0.0 + duration_hours = 0 + consecutive_no_rain = 0 + for row in group: + rainfall = float(row['rainfall']) if row['rainfall'] else 0.0 + if rainfall > 0: + accum_rain += rainfall + duration_hours += 1 + consecutive_no_rain = 0 + else: + consecutive_no_rain += 1 + if consecutive_no_rain >= 3: + break + if accum_rain > 0: + duration_hours += 1 + intensity = accum_rain / duration_hours if duration_hours > 0 else 0.0 + station_rainfall[(slon, slat)] = { + 'accum_rain': accum_rain, + 'duration_hours': duration_hours, + 'rain_intensity': intensity + } + + # 4. 分发给各预测点 + for (slon, slat), point_ids in station_to_points.items(): + rain_data = station_rainfall.get((slon, slat), default) + for pid in point_ids: + result[pid] = rain_data + + return result + # ==================== 空间查询 ==================== @staticmethod diff --git a/app/schemas/api_schemas.py b/app/schemas/api_schemas.py index 74d2ca6..618b4e6 100644 --- a/app/schemas/api_schemas.py +++ b/app/schemas/api_schemas.py @@ -14,8 +14,10 @@ class RainfallPredictRequest(BaseModel): point_ids: Optional[List[int]] = Field(None, max_length=500, description="点位ID列表,不传则查询所有点") region_code: Optional[str] = Field(None, description="行政区划代码(如 '610104'),不传则不限区域") - rainfall: float = Field(..., ge=0, description="累计降雨量(mm)") - duration: float = Field(..., ge=0, description="降雨持续时间(h)") + rainfall: Optional[float] = Field(None, ge=0, + description="累计降雨量(mm),不传则从气象表自动获取") + duration: Optional[float] = Field(None, ge=0, + description="降雨持续时间(h),不传则从气象表自动获取") # ============================================================