From b269e3282d6c659a5e4eeff1bc9719bf86ed9e11 Mon Sep 17 00:00:00 2001 From: wzy-warehouse <18135009705@163.com> Date: Sun, 14 Jun 2026 14:19:22 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=88=86=E8=BE=A8=E7=8E=87?= =?UTF-8?q?=E8=A7=A6=E5=8F=91=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/services/rainfall_grid_service.py | 320 ++++++++++++++++---------- 1 file changed, 203 insertions(+), 117 deletions(-) diff --git a/app/services/rainfall_grid_service.py b/app/services/rainfall_grid_service.py index 016ecc3..554df9c 100644 --- a/app/services/rainfall_grid_service.py +++ b/app/services/rainfall_grid_service.py @@ -1,5 +1,5 @@ """ -降雨栅格服务 +降雨栅格服务 - 内存优化版本 负责降雨插值、边缘优化、PNG生成等业务逻辑 """ import os @@ -13,11 +13,11 @@ from app.utils.logger import get_logger class RainfallGridService: """降雨栅格服务""" - + def __init__(self): """初始化服务""" self.logger = get_logger() - + # 国标12小时累计降雨量等级和颜色映射 self.rainfall_levels = { 'levels': [0, 0.1, 5, 15, 30, 70, 140], @@ -32,7 +32,7 @@ class RainfallGridService: ], 'labels': ['无雨', '小雨', '中雨', '大雨', '暴雨', '大暴雨', '特大暴雨'] } - + # 西安地区大致边界(用于栅格范围) self.xian_bounds = { 'min_lon': 107, @@ -40,10 +40,10 @@ class RainfallGridService: 'min_lat': 33, 'max_lat': 35, } - + # 栅格分辨率(度) - self.grid_resolution = 0.01 # 约1km - + self.grid_resolution = 0.001 + def _create_buffer_points(self, points_array) -> 'np.ndarray': """ 创建缓冲点:在原始站点外围生成虚拟点以扩展插值区域 @@ -58,11 +58,11 @@ class RainfallGridService: # 计算站点分布的中心 center = np.mean(points_array, axis=0) - + # 在站点外围生成缓冲点(沿着各个方向扩展) buffer_points = [] num_angles = 120 # 每隔3度生成一个缓冲点 - + for angle_deg in range(0, 360, 360 // num_angles): angle_rad = np.radians(angle_deg) # 在凸包边界外扩展 @@ -71,14 +71,14 @@ class RainfallGridService: direction = np.array([np.cos(angle_rad), np.sin(angle_rad)]) projections = points_array @ direction max_idx = np.argmax(projections) - + # 在该方向上扩展 base_point = points_array[max_idx] buffer_point = center + (base_point - center) * scale buffer_points.append(buffer_point) - + return np.array(buffer_points) - + def _calculate_adaptive_max_distance( self, points_array, @@ -116,19 +116,16 @@ class RainfallGridService: # 限制在合理范围内 return float(np.clip(adaptive_distance, min_distance, max_distance)) - + def interpolate_rainfall(self, station_data: List[Dict[str, Any]]) -> Dict[str, Any]: """ - 使用优化的反距离权重法(IDW)进行降雨插值 - - 注意:station_data 现在包含 'rainfall'(累计降雨量)和 'duration_hours'(持续时间) - 与DBN推演使用相同的降雨量计算逻辑(72小时回溯 + 3小时无雨截断) - - 改进: - 1. 高斯核衰减替代简单幂律 - 2. 自适应距离阈值 - 3. 边缘渐变处理 - 4. 高斯平滑减少突变 + 使用优化的反距离权重法(IDW)进行降雨插值(内存优化版本) + + 内存优化: + 1. 使用float32代替float64(内存减半) + 2. 分块处理距离计算 + 3. 提前过滤无效站点 + 4. 减少中间数组 Args: station_data: 站点数据列表,格式: @@ -141,32 +138,34 @@ class RainfallGridService: 插值结果字典 """ import numpy as np - from scipy.spatial import Delaunay, ConvexHull, distance_matrix + from scipy.spatial import Delaunay, ConvexHull from scipy.ndimage import gaussian_filter # 提取站点坐标和降雨量 - points_array = np.array([[s['lon'], s['lat']] for s in station_data]) - values_array = np.array([s['rainfall'] for s in station_data]) - + points_array = np.array([[s['lon'], s['lat']] for s in station_data], dtype=np.float32) + values_array = np.array([s['rainfall'] for s in station_data], dtype=np.float32) + # 创建栅格网格 lon_range = np.arange( self.xian_bounds['min_lon'], self.xian_bounds['max_lon'], - self.grid_resolution + self.grid_resolution, + dtype=np.float32 ) lat_range = np.arange( self.xian_bounds['min_lat'], self.xian_bounds['max_lat'], - self.grid_resolution + self.grid_resolution, + dtype=np.float32 ) - + grid_lon, grid_lat = np.meshgrid(lon_range, lat_range) - result = np.full_like(grid_lon, np.nan) - + result = np.full_like(grid_lon, np.nan, dtype=np.float32) + # 自适应计算最大距离 actual_max_distance = self._calculate_adaptive_max_distance(points_array) self.logger.info(f"使用最大影响距离: {actual_max_distance:.3f} 度") - + # 计算站点的凸包(带边缘缓冲) hull_mask = None confidence_mask = None @@ -174,84 +173,171 @@ class RainfallGridService: try: # 创建缓冲站点:在原始站点外围添加虚拟点 buffer_points = self._create_buffer_points(points_array) - + # 合并原始站点和缓冲站点 all_points = np.vstack([points_array, buffer_points]) - + # 计算凸包 hull = ConvexHull(all_points) hull_points = all_points[hull.vertices] tri = Delaunay(hull_points) - + # 向量化判断所有网格点是否在凸包内 grid_points = np.column_stack([grid_lon.ravel(), grid_lat.ravel()]) - hull_indices = tri.find_simplex(grid_points) - hull_mask = hull_indices >= 0 + simplex_indices = tri.find_simplex(grid_points) + hull_mask = simplex_indices >= 0 hull_mask = hull_mask.reshape(grid_lon.shape) - - # 计算置信度:基于到最近站点的距离 + + # 计算置信度:基于到最近站点的距离(分块处理) grid_valid = grid_points[hull_mask.ravel()] if len(grid_valid) > 0: - dist_to_stations = distance_matrix(grid_valid, points_array) - min_distances = np.min(dist_to_stations, axis=1) + # 分块计算距离,避免内存溢出 + chunk_size = 100000 # 每次处理10万点 + n_valid = len(grid_valid) + min_distances = np.zeros(n_valid, dtype=np.float32) + for i in range(0, n_valid, chunk_size): + chunk_end = min(i + chunk_size, n_valid) + chunk_points = grid_valid[i:chunk_end] + + # 计算当前块到所有站点的距离 + lon_diff = chunk_points[:, 0:1] - points_array[np.newaxis, :, 0] + lat_diff = chunk_points[:, 1:2] - points_array[np.newaxis, :, 1] + distances = np.sqrt(lon_diff**2 + lat_diff**2) + + # 记录最小距离 + min_distances[i:chunk_end] = np.min(distances, axis=1) + + # 释放临时数组 + del lon_diff, lat_diff, distances + # 创建置信度掩码(距离越远,置信度越低) - confidence = np.ones(len(grid_points)) + confidence = np.ones(len(grid_points), dtype=np.float32) confidence[hull_mask.ravel()] = np.exp(-min_distances / actual_max_distance) confidence_mask = confidence.reshape(grid_lon.shape) else: - confidence_mask = np.ones_like(grid_lon) - + confidence_mask = np.ones_like(grid_lon, dtype=np.float32) + except Exception as e: self.logger.warning(f"凸包计算失败: {e},使用全区域插值") hull_mask = np.ones_like(grid_lon, dtype=bool) - confidence_mask = np.ones_like(grid_lon) + confidence_mask = np.ones_like(grid_lon, dtype=np.float32) else: hull_mask = np.ones_like(grid_lon, dtype=bool) - confidence_mask = np.ones_like(grid_lon) + confidence_mask = np.ones_like(grid_lon, dtype=np.float32) + + # 获取凸包内网格点坐标 + grid_points = np.column_stack([grid_lon.ravel(), grid_lat.ravel()]) + if hull_mask is not None: + # 只计算凸包内网格点到站点的距离 + hull_point_indices = np.where(hull_mask.ravel())[0] + grid_points_hull = grid_points[hull_point_indices] + n_hull_points = len(grid_points_hull) + self.logger.info(f"凸包内网格点数量: {n_hull_points}, 总网格点: {grid_lon.size}") + else: + # 如果凸包掩码不可用,使用所有网格点 + grid_points_hull = grid_points + hull_point_indices = np.arange(len(grid_points)) + n_hull_points = len(grid_points_hull) + + # 分块计算凸包内网格点到所有站点的距离 + chunk_size = 50000 # 每次处理5万点 + result_hull = np.full(n_hull_points, np.nan, dtype=np.float32) - # 向量化计算所有网格点到所有站点的距离 - lon_diff = grid_lon[:, :, np.newaxis] - points_array[np.newaxis, np.newaxis, :, 0] - lat_diff = grid_lat[:, :, np.newaxis] - points_array[np.newaxis, np.newaxis, :, 1] - distances = np.sqrt(lon_diff**2 + lat_diff**2) + # 用于记录最后一个块的has_valid_stations_chunk + last_has_valid_stations_chunk = None - # 过滤超出最大距离的站点 - valid_mask = distances <= actual_max_distance + for i in range(0, n_hull_points, chunk_size): + chunk_end = min(i + chunk_size, n_hull_points) + chunk_points = grid_points_hull[i:chunk_end] + chunk_size_actual = chunk_end - i + + # 计算当前块到所有站点的距离 + lon_diff_chunk = chunk_points[:, 0:1] - points_array[np.newaxis, :, 0] + lat_diff_chunk = chunk_points[:, 1:2] - points_array[np.newaxis, :, 1] + distances_chunk = np.sqrt(lon_diff_chunk**2 + lat_diff_chunk**2) + + # 过滤超出最大距离的站点 + valid_mask_chunk = distances_chunk <= actual_max_distance + + # 对于每个网格点,检查是否有有效站点 + has_valid_stations_chunk = np.any(valid_mask_chunk, axis=1) + + # 避免除零 + distances_chunk = np.where(valid_mask_chunk, distances_chunk, np.inf) + distances_chunk = np.maximum(distances_chunk, 1e-10) + + # 优化的权重计算:结合幂律和高斯衰减 + power = 2.0 + power_weights_chunk = 1.0 / (distances_chunk ** power) + gaussian_weights_chunk = np.exp(-0.5 * (distances_chunk / (actual_max_distance * 0.5)) ** 2) + + # 混合权重:距离越远,高斯权重占比越大 + distance_ratio_chunk = distances_chunk / actual_max_distance + mix_factor_chunk = np.clip(distance_ratio_chunk, 0, 1) + weights_chunk = (1 - mix_factor_chunk) * power_weights_chunk + mix_factor_chunk * gaussian_weights_chunk + + weights_chunk = np.where(valid_mask_chunk, weights_chunk, 0) + + # 加权平均 + weighted_sum_chunk = np.sum(weights_chunk * values_array[np.newaxis, :], axis=1) + weight_total_chunk = np.sum(weights_chunk, axis=1) + + # 计算当前块的插值结果 + with np.errstate(divide='ignore', invalid='ignore'): + chunk_result = np.where( + has_valid_stations_chunk & (weight_total_chunk > 0), + weighted_sum_chunk / weight_total_chunk, + np.nan + ) + + # 存储结果 + result_hull[i:chunk_end] = chunk_result + + # 记录最后一个块的has_valid_stations_chunk + last_has_valid_stations_chunk = has_valid_stations_chunk + + # 释放临时数组 + del lon_diff_chunk, lat_diff_chunk, distances_chunk, valid_mask_chunk + del power_weights_chunk, gaussian_weights_chunk, weights_chunk + del weighted_sum_chunk, weight_total_chunk, chunk_result + + # 将凸包内点的结果映射回完整网格 + result = np.full_like(grid_lon, np.nan, dtype=np.float32) + result.ravel()[hull_point_indices] = result_hull + + # 构建完整网格的有效掩码(凸包内且有有效站点) + # 注意:这里需要重新计算所有凸包内点的有效站点掩码 + # 由于分块处理,我们需要重新计算完整掩码 + has_valid_stations_full = np.zeros_like(grid_lon, dtype=bool) - # 对于每个网格点,检查是否有有效站点 - has_valid_stations = np.any(valid_mask, axis=2) - - # 合并凸包掩码和有效站点掩码 - final_mask = hull_mask & has_valid_stations - - # 避免除零 - distances = np.where(valid_mask, distances, np.inf) - distances = np.maximum(distances, 1e-10) - - # 优化的权重计算:结合幂律和高斯衰减 - power = 2.0 - power_weights = 1.0 / (distances ** power) - gaussian_weights = np.exp(-0.5 * (distances / (actual_max_distance * 0.5)) ** 2) - - # 混合权重:距离越远,高斯权重占比越大 - distance_ratio = distances / actual_max_distance - mix_factor = np.clip(distance_ratio, 0, 1) - weights = (1 - mix_factor) * power_weights + mix_factor * gaussian_weights - - weights = np.where(valid_mask, weights, 0) - - # 加权平均 - weighted_sum = np.sum(weights * values_array[np.newaxis, np.newaxis, :], axis=2) - weight_total = np.sum(weights, axis=2) - - # 计算基础插值结果 - with np.errstate(divide='ignore', invalid='ignore'): - result = np.where( - final_mask & (weight_total > 0), - weighted_sum / weight_total, - np.nan - ) + # 重新计算所有凸包内点的有效站点掩码(内存优化:分块计算) + if n_hull_points > 0: + # 分块计算有效站点掩码 + chunk_size_mask = 100000 # 每次处理10万点 + for i in range(0, n_hull_points, chunk_size_mask): + chunk_end = min(i + chunk_size_mask, n_hull_points) + chunk_points = grid_points_hull[i:chunk_end] + + # 计算当前块到所有站点的距离 + lon_diff_chunk = chunk_points[:, 0:1] - points_array[np.newaxis, :, 0] + lat_diff_chunk = chunk_points[:, 1:2] - points_array[np.newaxis, :, 1] + distances_chunk = np.sqrt(lon_diff_chunk**2 + lat_diff_chunk**2) + + # 过滤超出最大距离的站点 + valid_mask_chunk = distances_chunk <= actual_max_distance + + # 对于每个网格点,检查是否有有效站点 + has_valid_chunk = np.any(valid_mask_chunk, axis=1) + + # 存储结果 + has_valid_stations_full.ravel()[hull_point_indices[i:chunk_end]] = has_valid_chunk + + # 释放临时数组 + del lon_diff_chunk, lat_diff_chunk, distances_chunk, valid_mask_chunk, has_valid_chunk + final_mask = hull_mask & has_valid_stations_full + # 应用置信度调整:边缘区域向邻近值渐变 if confidence_mask is not None: valid_rainfall = result[final_mask] @@ -260,13 +346,13 @@ class RainfallGridService: # 根据置信度调整结果,低置信度区域向均值靠拢 adjusted_result = result * confidence_mask + mean_rainfall * (1 - confidence_mask) result = np.where(final_mask, adjusted_result, np.nan) - + # 应用高斯平滑减少边缘突变 result = gaussian_filter(result, sigma=1.0) - + # 处理NaN值 result = np.nan_to_num(result, nan=0.0) - + return { 'grid_values': result, 'grid_lon': grid_lon, @@ -274,23 +360,23 @@ class RainfallGridService: 'lon_range': lon_range, 'lat_range': lat_range, } - - def optimize_edges(self, grid_data: Dict[str, Any], + + def optimize_edges(self, grid_data: Dict[str, Any], station_data: List[Dict[str, Any]]) -> Dict[str, Any]: """ 优化栅格边缘(已在插值时处理,此方法保留用于向后兼容) - + Args: grid_data: 插值结果 station_data: 站点数据 - + Returns: 优化后的栅格数据 """ # 由于interpolate_rainfall已经包含了边缘优化和平滑处理 # 这里不再重复处理,直接返回 return grid_data - + def save_rainfall_grid_png(self, grid_data: Dict[str, Any], max_id: int) -> Optional[str]: """ 将降雨栅格保存为PNG图片(背景透明) @@ -310,8 +396,8 @@ class RainfallGridService: try: grid_values = grid_data['grid_values'] - lon_range = grid_data['lon_range'] - lat_range = grid_data['lat_range'] + lon_range = grid_data['grid_lon'] + lat_range = grid_data['grid_lat'] # 创建自定义颜色映射 levels = self.rainfall_levels['levels'] @@ -325,7 +411,7 @@ class RainfallGridService: # 创建图形(设置dpi确保不拉伸) fig, ax = plt.subplots(1, 1, figsize=(10, 10), dpi=100) - + # 绘制栅格 im = ax.pcolormesh( lon_range, @@ -335,56 +421,56 @@ class RainfallGridService: norm=norm, shading='auto' ) - + # 设置透明背景 fig.patch.set_alpha(0) ax.patch.set_alpha(0) - + # 移除坐标轴 ax.set_axis_off() - + # 调整布局,去除白边 plt.tight_layout(pad=0) - + # 构建文件路径 file_store_dir = settings.FILE_STORE_DIR grid_dir_template = settings.RAIN_STATION_GRID_DIR - + # 替换:id为实际的max_id grid_dir = grid_dir_template.replace(':id', str(max_id)) - + # 完整路径 full_dir = os.path.join(file_store_dir, grid_dir.lstrip('/')) - + # 创建目录 os.makedirs(full_dir, exist_ok=True) - + # 保存PNG(使用PIL确保透明度) png_path = os.path.join(full_dir, 'grid.png') - + # 先保存到缓冲区 buf = BytesIO() plt.savefig(buf, format='png', transparent=True, bbox_inches='tight', pad_inches=0) buf.seek(0) - + # 使用PIL打开并重新保存,确保透明度正确 img = Image.open(buf) img.save(png_path, 'PNG') - + buf.close() plt.close(fig) - + # 返回相对路径(相对于FILE_STORE_DIR),统一使用正斜杠 relative_path = os.path.join(grid_dir, 'grid.png').replace('\\', '/') saved_path = png_path.replace('\\', '/') self.logger.info(f"PNG图片已保存: {saved_path}") return relative_path - + except Exception as e: self.logger.error(f"保存PNG图片失败: {e}", exc_info=True) return None - + def store_to_redis(self, png_path: str, max_id: int, query_time, station_data: List[Dict[str, Any]]): """ @@ -402,13 +488,13 @@ class RainfallGridService: try: redis_key = settings.REDIS_RAIN_STATION_GRID_KEY redis_identifier_key = settings.REDIS_RAIN_STATION_IDENTIFIER_KEY - + # 处理query_time,可能是datetime对象或字符串 if isinstance(query_time, datetime): query_time_str = query_time.isoformat() else: query_time_str = str(query_time) - + # 构建辅助前端定位的信息 grid_info = { 'id': max_id, @@ -429,16 +515,16 @@ class RainfallGridService: 'height': int((self.xian_bounds['max_lat'] - self.xian_bounds['min_lat']) / self.grid_resolution), } } - + # 存储到Redis redis_helper.set(redis_key, json.dumps(grid_info)) redis_helper.set(redis_identifier_key, max_id) - + self.logger.info(f"栅格信息已存储到Redis,key: {redis_key}, id: {max_id}") - + except Exception as e: self.logger.error(f"存储到Redis失败: {e}", exc_info=True) # 创建全局实例 -rainfall_grid_service = RainfallGridService() +rainfall_grid_service = RainfallGridService() \ No newline at end of file