Files
xian_algorithm_new/app/models/dbn/rainfall/rainfall_dbn.py
T

395 lines
13 KiB
Python
Raw Normal View History

"""
暴雨灾害链DBN模型
实现贝叶斯网络推理,预测5类灾害概率
"""
import os
import yaml
from typing import Optional, List, Dict, Any
from datetime import datetime
from app.utils.discretizer import discretizer
from app.repositories.dbn_repository import DbnRepository
from app.config.paths import DBN_CONFIG_DIR, get_logger
logger = get_logger("dbn")
class RainfallDBN:
"""暴雨灾害链DBN模型"""
# 灾害概率→离散等级的阈值映射
HAZARD_LEVEL_THRESHOLDS = [
(0.6, 'very_high'),
(0.4, 'high'),
(0.2, 'medium'),
(0.05, 'low'),
(0.0, 'none'),
]
def _probability_to_level(self, prob: float) -> str:
"""将连续概率映射到离散等级"""
for threshold, level in self.HAZARD_LEVEL_THRESHOLDS:
if prob >= threshold:
return level
return 'none'
def __init__(self, config_dir: Optional[str] = None):
"""
初始化DBN模型
Args:
config_dir: 配置文件目录
"""
if config_dir is None:
config_dir = str(DBN_CONFIG_DIR)
self.config_dir = config_dir
self.graph_config = self._load_graph_config()
self.cpt_config = self._load_cpt_config()
# 构建贝叶斯网络结构
self._build_network()
def _load_graph_config(self) -> Dict[str, Any]:
"""加载图结构配置"""
config_path = os.path.join(self.config_dir, 'rainfall_dbn_graph.yaml')
if not os.path.exists(config_path):
logger.error(f"图结构配置文件不存在: {config_path}")
return {}
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
return config
def _load_cpt_config(self) -> Dict[str, Any]:
"""加载CPT配置"""
config_path = os.path.join(self.config_dir, 'rainfall_cpt_params.yaml')
if not os.path.exists(config_path):
logger.error(f"CPT配置文件不存在: {config_path}")
return {}
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
return config
def _build_network(self):
"""构建贝叶斯网络结构"""
# 获取节点列表
self.trigger_nodes = self.graph_config.get('layers', {}).get('trigger', [])
self.environment_nodes = self.graph_config.get('layers', {}).get('environment', [])
self.hazard_nodes = self.graph_config.get('layers', {}).get('hazard', [])
# 获取所有节点
self.all_nodes = self.trigger_nodes + self.environment_nodes + self.hazard_nodes
# 获取边关系
self.edges = self.graph_config.get('edges', [])
# 获取节点状态
self.node_states = self.graph_config.get('node_states', {})
# 构建父子关系
self.children = {node: [] for node in self.all_nodes}
self.parents = {node: [] for node in self.all_nodes}
for parent, child in self.edges:
if parent in self.all_nodes and child in self.all_nodes:
self.children[parent].append(child)
self.parents[child].append(parent)
# 构建CPT表
self._build_cpt_tables()
def _build_cpt_tables(self):
"""构建条件概率表"""
self.cpt_tables = {}
for node in self.all_nodes:
if node in self.cpt_config:
self.cpt_tables[node] = self.cpt_config[node]
else:
# 如果没有配置,使用均匀分布
states = self.node_states.get(node, ['no', 'yes'])
if len(states) == 2:
# 二值节点
self.cpt_tables[node] = {
'type': 'prior',
'probabilities': [0.5, 0.5]
}
else:
# 多值节点
prob = 1.0 / len(states)
self.cpt_tables[node] = {
'type': 'prior',
'probabilities': [prob] * len(states)
}
def _get_node_probability(self, node: str, evidence: Dict[str, str]) -> List[float]:
"""
获取节点的概率分布
Args:
node: 节点名称
evidence: 证据字典
Returns:
概率分布列表
"""
cpt = self.cpt_tables.get(node)
if not cpt:
states = self.node_states.get(node, ['no', 'yes'])
return [1.0 / len(states)] * len(states)
# 如果是先验概率
if cpt.get('type') == 'prior':
return cpt.get('probabilities', [0.5, 0.5])
# 如果是条件概率
if cpt.get('type') == 'conditional':
return self._evaluate_conditional_probability(node, cpt, evidence)
return [0.5, 0.5]
def _evaluate_conditional_probability(self, node: str, cpt: Dict[str, Any],
evidence: Dict[str, str]) -> List[float]:
"""
评估条件概率
Args:
node: 节点名称
cpt: CPT配置
evidence: 证据字典
Returns:
概率分布列表
"""
states = self.node_states.get(node, ['no', 'yes'])
default_prob = cpt.get('default_probability', 0.05)
# 检查规则
rules = cpt.get('rules', [])
for rule in rules:
condition = rule.get('condition', {})
probability = rule.get('probability', default_prob)
# 检查是否满足条件
if self._check_condition(condition, evidence):
# 返回 [P(no), P(yes)]
return [1.0 - probability, probability]
# 如果没有匹配的规则,返回默认概率
return [1.0 - default_prob, default_prob]
def _check_condition(self, condition: Dict[str, Any], evidence: Dict[str, str]) -> bool:
"""
检查条件是否满足
Args:
condition: 条件字典
evidence: 证据字典
Returns:
是否满足
"""
for node, required_states in condition.items():
if node not in evidence:
return False
evidence_state = evidence[node]
# 如果required_states是列表,检查是否在列表中
if isinstance(required_states, list):
if evidence_state not in required_states:
return False
else:
# 如果是单个值,检查是否相等
if evidence_state != required_states:
return False
return True
def predict_single_point(self, point: Dict[str, Any],
rainfall: Optional[float] = None,
duration: Optional[float] = None,
query_time: Optional[datetime] = None) -> Dict[str, Any]:
"""
对单个点进行预测
Args:
point: 点信息(包含 static_factors 字段)
rainfall: 累计降雨量(可选)
duration: 持续时间(可选)
query_time: 查询时间(可选)
Returns:
预测结果
"""
point_id = point.get('id')
lon = point.get('lon')
lat = point.get('lat')
source_type = point.get('source_type')
logger.info(f"预测点 ID={point_id}, source_type={source_type}")
# 获取降雨数据
if rainfall is not None and duration is not None:
rain_intensity = rainfall / duration if duration > 0 else 0.0
rainfall_data = {
'accum_rain': rainfall,
'duration_hours': duration,
'rain_intensity': rain_intensity
}
else:
rainfall_data = DbnRepository.get_rainfall_data_with_duration(lon, lat, query_time)
# 获取静态因子数据(从 point 的 static_factors 字段)
raw_factors = point.get('static_factors', {})
static_factors = {
'elevation': raw_factors.get('dem_value'),
'slope': raw_factors.get('slope_value'),
'aspect': raw_factors.get('aspect_value'),
'soil_type': raw_factors.get('soil_type'),
'lithology': raw_factors.get('lithology'),
'landuse': raw_factors.get('landuse'),
'terrain': raw_factors.get('landform'),
'impervious': raw_factors.get('impervious_surface'),
'ndvi': raw_factors.get('vegetation_index'),
'sand_content': raw_factors.get('soil_sand'),
'ph': raw_factors.get('soil_ph'),
'soil_moisture': raw_factors.get('soil_moisture'),
'organic_carbon': raw_factors.get('organic_carbon'),
'dist_to_river': raw_factors.get('river_distance'),
'dist_to_fault': raw_factors.get('fault_distance'),
'pipe_density': raw_factors.get('pipe_density')
}
# 合并所有因子
all_factors = {
'rain_intensity': rainfall_data.get('rain_intensity', 0.0),
'duration': rainfall_data.get('duration_hours', 0),
'accum_rain': rainfall_data.get('accum_rain', 0.0),
**static_factors
}
# 离散化
evidence = discretizer.discretize_all_factors(all_factors)
# 运行推理
hazard_results = self._run_inference(evidence)
# 构造输出
result = {
'point_id': point_id,
'source_type': source_type,
'lon': lon,
'lat': lat,
'disaster_probabilities': {
h: r['probability'] for h, r in hazard_results.items()
},
'disaster_levels': {
h: r['level'] for h, r in hazard_results.items()
}
}
return result
def _run_inference(self, evidence: Dict[str, str]) -> Dict[str, Any]:
"""
运行贝叶斯推理
Args:
evidence: 证据字典
Returns:
灾害概率字典,每个值包含 probability 和 level
"""
hazard_probabilities = {}
for hazard_node in self.hazard_nodes:
# 获取灾害节点的概率
prob_dist = self._get_node_probability(hazard_node, evidence)
# 取发生概率(第二个状态)
if len(prob_dist) >= 2:
prob = prob_dist[1]
else:
prob = 0.0
hazard_probabilities[hazard_node] = {
'probability': round(prob, 4),
'level': self._probability_to_level(prob)
}
return hazard_probabilities
def predict(self, region_code: Optional[str] = None,
rainfall: Optional[float] = None,
duration: Optional[float] = None,
timestamp: Optional[datetime] = None) -> List[Dict[str, Any]]:
"""
预测灾害概率
Args:
region_code: 行政区划代码(可选)
rainfall: 累计降雨量(可选,全局值)
duration: 持续时间(可选,全局值)
timestamp: 时间(可选)
Returns:
预测结果列表
"""
# 1. 获取点列表
points = DbnRepository.get_all_points(region_code)
if not points:
logger.warning(f"没有找到点数据,region_code={region_code}")
return []
logger.info(f"共找到 {len(points)} 个点")
# 2. 对每个点进行预测
results = []
for point in points:
try:
result = self.predict_single_point(
point,
rainfall=rainfall,
duration=duration,
query_time=timestamp
)
results.append(result)
except Exception as e:
logger.error(f"预测点 {point.get('id')} 失败: {e}")
results.append({
'point_id': point.get('id'),
'source_type': point.get('source_type'),
'lon': point.get('lon'),
'lat': point.get('lat'),
'error': str(e)
})
return results
def get_model_info(self) -> Dict[str, Any]:
"""
获取模型信息
Returns:
模型信息字典
"""
return {
'trigger_nodes': self.trigger_nodes,
'environment_nodes': self.environment_nodes,
'hazard_nodes': self.hazard_nodes,
'edges': self.edges,
'node_states': self.node_states
}
# 创建全局实例
rainfall_dbn = RainfallDBN()