暴雨地震灾害链HTTP请求

This commit is contained in:
wzy-warehouse
2026-06-06 08:38:19 +08:00
parent 844fa7d719
commit eddbdaca1f
11 changed files with 551 additions and 38 deletions
+13
View File
@@ -0,0 +1,13 @@
"""
API 路由模块
"""
from fastapi import FastAPI
def register_routers(application: FastAPI):
"""注册所有路由"""
from app.api.rainfall import router as rainfall_router
from app.api.earthquake import router as earthquake_router
application.include_router(rainfall_router)
application.include_router(earthquake_router)
+92
View File
@@ -0,0 +1,92 @@
"""
地震灾害链预测接口
"""
import asyncio
from typing import List, Dict, Any, Optional
from fastapi import APIRouter, HTTPException
from app.schemas.api_schemas import EarthquakePredictRequest, PredictResponse, PredictionItem
from app.utils.api_deps import get_earthquake_model, get_prediction_semaphore
from app.repositories.dbn_repository import dbn_repository
from app.config.paths import get_logger
router = APIRouter(prefix="/earthquake", tags=["地震灾害链"])
logger = get_logger("api.earthquake")
SOURCE_TYPE_MAP = {1: "隐患点", 2: "风险点"}
LEVEL_MAP = {"": "", "": "", "较高": "较高", "": ""}
def _build_prediction_items(results: List[Dict[str, Any]]) -> List[PredictionItem]:
"""将模型原始结果转换为接口返回格式"""
items = []
for r in results:
probs = r.get("disaster_probabilities", {})
levels = r.get("disaster_levels", {})
if not probs:
continue
max_hazard = max(probs, key=probs.get)
items.append(PredictionItem(
id=r["point_id"],
type=SOURCE_TYPE_MAP.get(r.get("source_type"), "未知"),
probability=round(probs[max_hazard], 4),
level=LEVEL_MAP.get(levels.get(max_hazard, "none"), ""),
))
return items
def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]:
"""获取点位列表"""
if point_ids:
return dbn_repository.get_points_by_ids(point_ids)
return dbn_repository.get_all_points(region_code)
def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
magnitude: float, depth: float,
epicenter_lon: float, epicenter_lat: float) -> List[PredictionItem]:
"""同步执行地震预测(在线程池中运行)"""
points = _fetch_points(point_ids, region_code)
if not points:
return []
model = get_earthquake_model()
results = model.predict_multiple_points(
points,
magnitude=magnitude,
depth=depth,
epicenter_lon=epicenter_lon,
epicenter_lat=epicenter_lat,
)
return _build_prediction_items(results)
@router.post("/predict", response_model=PredictResponse, summary="地震灾害链预测")
async def predict_earthquake(req: EarthquakePredictRequest):
"""
根据震级、震源深度和震中位置,批量预测隐患点/风险点的次生灾害概率和等级。
- **point_ids**: 点位ID列表(可选,不传则查询所有点)
- **region_code**: 行政区划代码(可选,不传则不限区域)
- **magnitude**: 震级(Richter)
- **depth**: 震源深度(km),默认10km
- **epicenter_lon**: 震中经度
- **epicenter_lat**: 震中纬度
"""
semaphore = get_prediction_semaphore()
async with semaphore:
loop = asyncio.get_event_loop()
try:
items = await loop.run_in_executor(
None, _predict_sync, req.point_ids, req.region_code,
req.magnitude, req.depth, req.epicenter_lon, req.epicenter_lat
)
except Exception as e:
logger.error(f"地震预测失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"预测失败: {e}")
return PredictResponse(code=200, message="success", data=items)
+83
View File
@@ -0,0 +1,83 @@
"""
暴雨灾害链预测接口
"""
import asyncio
from typing import List, Dict, Any, Optional
from fastapi import APIRouter, HTTPException
from app.schemas.api_schemas import RainfallPredictRequest, PredictResponse, PredictionItem
from app.utils.api_deps import get_rainfall_model, get_prediction_semaphore
from app.repositories.dbn_repository import dbn_repository
from app.config.paths import get_logger
router = APIRouter(prefix="/rainfall", tags=["暴雨灾害链"])
logger = get_logger("api.rainfall")
SOURCE_TYPE_MAP = {1: "隐患点", 2: "风险点"}
LEVEL_MAP = {"": "", "": "", "较高": "较高", "": ""}
def _build_prediction_items(results: List[Dict[str, Any]]) -> List[PredictionItem]:
"""将模型原始结果转换为接口返回格式"""
items = []
for r in results:
probs = r.get("disaster_probabilities", {})
levels = r.get("disaster_levels", {})
if not probs:
continue
max_hazard = max(probs, key=probs.get)
items.append(PredictionItem(
id=r["point_id"],
type=SOURCE_TYPE_MAP.get(r.get("source_type"), "未知"),
probability=round(probs[max_hazard], 4),
level=LEVEL_MAP.get(levels.get(max_hazard, "none"), ""),
))
return items
def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]:
"""获取点位列表"""
if point_ids:
return dbn_repository.get_points_by_ids(point_ids)
return dbn_repository.get_all_points(region_code)
def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
rainfall: float, duration: float) -> List[PredictionItem]:
"""同步执行暴雨预测(在线程池中运行)"""
points = _fetch_points(point_ids, region_code)
if not points:
return []
model = get_rainfall_model()
results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration)
return _build_prediction_items(results)
@router.post("/predict", response_model=PredictResponse, summary="暴雨灾害链预测")
async def predict_rainfall(req: RainfallPredictRequest):
"""
根据降雨量和持续时间,批量预测隐患点/风险点的灾害概率和等级。
- **point_ids**: 点位ID列表(可选,不传则查询所有点)
- **region_code**: 行政区划代码(可选,不传则不限区域)
- **rainfall**: 累计降雨量(mm)
- **duration**: 降雨持续时间(h)
"""
semaphore = get_prediction_semaphore()
async with semaphore:
loop = asyncio.get_event_loop()
try:
items = await loop.run_in_executor(
None, _predict_sync, req.point_ids, req.region_code,
req.rainfall, req.duration
)
except Exception as e:
logger.error(f"暴雨预测失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"预测失败: {e}")
return PredictResponse(code=200, message="success", data=items)
+18
View File
@@ -51,12 +51,30 @@ class AppLauncher:
def start():
"""启动应用服务"""
import threading
from config import settings
from app.core.rainfall_manager import rainfall_manager
from app.utils.logger import get_logger
from app.utils.thread_pool_manager import block_main_thread, thread_pool_manager
logger = get_logger()
# 启动 FastAPI 服务(守护线程)
def run_api_server():
import uvicorn
from app.core.server import create_app
api_app = create_app()
uvicorn.run(
api_app,
host=getattr(settings, "API_HOST", "127.0.0.1"),
port=int(getattr(settings, "API_PORT", 8082)),
log_level="info",
)
api_thread = threading.Thread(target=run_api_server, daemon=True, name="api-server")
api_thread.start()
logger.info(f"FastAPI 服务已启动: http://{getattr(settings, 'API_HOST', '127.0.0.1')}:{getattr(settings, 'API_PORT', 8082)}")
# 启动降雨站点监测
logger.info("启动降雨站点监测服务...")
rainfall_manager.monitoring_rainfall_station_id('2025-09-16 20:00:00')
+55
View File
@@ -0,0 +1,55 @@
"""
FastAPI 服务创建与配置
"""
import time
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from app.utils.api_deps import get_rainfall_model, get_earthquake_model, is_model_loaded
from app.schemas.api_schemas import HealthResponse
from app.config.paths import get_logger
logger = get_logger("api")
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期:启动时预加载模型"""
logger.info("正在预加载DBN模型...")
get_rainfall_model()
get_earthquake_model()
logger.info("DBN模型预加载完成")
yield
logger.info("应用关闭")
def create_app() -> FastAPI:
"""创建 FastAPI 应用实例"""
application = FastAPI(
title="西安灾害链预测服务",
description="基于动态贝叶斯网络的暴雨/地震灾害链预测API",
version="1.0.0",
lifespan=lifespan,
)
@application.middleware("http")
async def log_requests(request: Request, call_next):
"""请求日志中间件"""
start = time.time()
response = await call_next(request)
elapsed = time.time() - start
logger.info(f"{request.method} {request.url.path} -> {response.status_code} ({elapsed:.3f}s)")
return response
# 注册路由
from app.api import register_routers
register_routers(application)
@application.get("/health", response_model=HealthResponse, tags=["系统"])
async def health_check():
"""健康检查"""
status = is_model_loaded()
return HealthResponse(status="ok", **status)
return application
+67 -14
View File
@@ -18,21 +18,20 @@ logger = get_logger("earthquake_dbn")
class EarthquakeDBN:
"""地震灾害链DBN模型"""
# 灾害概率→离散等级的阈值映射
# 灾害概率→等级的阈值映射
HAZARD_LEVEL_THRESHOLDS = [
(0.6, 'very_high'),
(0.4, 'high'),
(0.2, 'medium'),
(0.05, 'low'),
(0.0, 'none'),
(0.7, ''),
(0.5, '较高'),
(0.3, ''),
(0.0, ''),
]
def _probability_to_level(self, prob: float) -> str:
"""将连续概率映射到离散等级"""
"""将连续概率映射到风险等级:低(<30%) / 中(30-50%) / 较高(50-70%) / 高(70%+)"""
for threshold, level in self.HAZARD_LEVEL_THRESHOLDS:
if prob >= threshold:
return level
return 'none'
return ''
def __init__(self, config_dir: Optional[str] = None):
"""
@@ -114,18 +113,20 @@ class EarthquakeDBN:
}
@staticmethod
def estimate_seismic_intensity(magnitude: float, epicenter_distance_km: float) -> float:
def estimate_seismic_intensity(magnitude: float, epicenter_distance_km: float,
depth_km: float = 10.0) -> float:
"""
根据震级震中距估算地震烈度
使用中国地震烈度衰减关系
根据震级震中距和震源深度估算地震烈度
I = 0.923 + 1.621*M - 3.494*ln(R+10)
I = 0.923 + 1.621*M - 3.494*ln(R+10) - ln(H/10)
参考:GB 18306-2015 中国地震动参数区划图
深度修正:震源越深,地表烈度越低
Args:
magnitude: 震级(Richter
epicenter_distance_km: 震中距(km
depth_km: 震源深度(km),默认10km
Returns:
估算的地震烈度(中国烈度表数值)
@@ -135,6 +136,10 @@ class EarthquakeDBN:
intensity = 0.923 + 1.621 * magnitude - 3.494 * math.log(epicenter_distance_km + 10)
# 震源深度修正:以10km为基准,深度越大烈度衰减越多
depth_km = max(depth_km, 1.0)
intensity -= math.log(depth_km / 10.0)
# 限制在合理范围内
return max(1.0, min(12.0, intensity))
@@ -191,7 +196,8 @@ class EarthquakeDBN:
epicenter_distance: Optional[float] = None,
seismic_intensity: Optional[float] = None,
epicenter_lon: Optional[float] = None,
epicenter_lat: Optional[float] = None) -> Dict[str, Any]:
epicenter_lat: Optional[float] = None,
depth: float = 10.0) -> Dict[str, Any]:
"""
对单个点进行地震灾害预测
@@ -202,6 +208,7 @@ class EarthquakeDBN:
seismic_intensity: 地震烈度(中国烈度表),若未提供则自动估算
epicenter_lon: 震中经度(可选,用于计算震中距)
epicenter_lat: 震中纬度(可选,用于计算震中距)
depth: 震源深度(km),默认10.0
Returns:
预测结果
@@ -226,7 +233,7 @@ class EarthquakeDBN:
# 估算地震烈度(如果未直接提供)
if seismic_intensity is None:
seismic_intensity = self.estimate_seismic_intensity(magnitude, epicenter_distance)
seismic_intensity = self.estimate_seismic_intensity(magnitude, epicenter_distance, depth)
logger.info(f"估算地震烈度: {seismic_intensity:.1f}")
# 获取静态因子数据
@@ -379,6 +386,52 @@ class EarthquakeDBN:
return results
def predict_multiple_points(self, points: List[Dict[str, Any]],
magnitude: float = 6.0,
epicenter_distance: Optional[float] = None,
seismic_intensity: Optional[float] = None,
epicenter_lon: Optional[float] = None,
epicenter_lat: Optional[float] = None,
depth: float = 10.0) -> List[Dict[str, Any]]:
"""
对已获取的点列表进行地震灾害预测
Args:
points: 点信息列表(已从数据库获取)
magnitude: 地震震级
epicenter_distance: 震中距(km,可选)
seismic_intensity: 地震烈度(可选)
epicenter_lon: 震中经度(可选)
epicenter_lat: 震中纬度(可选)
depth: 震源深度(km),默认10.0
Returns:
预测结果列表
"""
results = []
for point in points:
try:
result = self.predict_single_point(
point,
magnitude=magnitude,
epicenter_distance=epicenter_distance,
seismic_intensity=seismic_intensity,
epicenter_lon=epicenter_lon,
epicenter_lat=epicenter_lat,
depth=depth
)
results.append(result)
except Exception as e:
logger.error(f"预测点 {point.get('id')} 失败: {e}")
results.append({
'point_id': point.get('id'),
'source_type': point.get('source_type'),
'lon': point.get('lon'),
'lat': point.get('lat'),
'error': str(e)
})
return results
def get_model_info(self) -> Dict[str, Any]:
"""获取模型信息"""
return {
+41 -8
View File
@@ -17,21 +17,20 @@ logger = get_logger("dbn")
class RainfallDBN:
"""暴雨灾害链DBN模型"""
# 灾害概率→离散等级的阈值映射
# 灾害概率→等级的阈值映射
HAZARD_LEVEL_THRESHOLDS = [
(0.6, 'very_high'),
(0.4, 'high'),
(0.2, 'medium'),
(0.05, 'low'),
(0.0, 'none'),
(0.7, ''),
(0.5, '较高'),
(0.3, ''),
(0.0, ''),
]
def _probability_to_level(self, prob: float) -> str:
"""将连续概率映射到离散等级"""
"""将连续概率映射到风险等级:低(<30%) / 中(30-50%) / 较高(50-70%) / 高(70%+)"""
for threshold, level in self.HAZARD_LEVEL_THRESHOLDS:
if prob >= threshold:
return level
return 'none'
return ''
def __init__(self, config_dir: Optional[str] = None):
"""
@@ -374,6 +373,40 @@ class RainfallDBN:
return results
def predict_multiple_points(self, points: List[Dict[str, Any]],
rainfall: Optional[float] = None,
duration: Optional[float] = None,
query_time: Optional[datetime] = None) -> List[Dict[str, Any]]:
"""
对已获取的点列表进行暴雨灾害预测
Args:
points: 点信息列表(已从数据库获取)
rainfall: 累计降雨量(可选)
duration: 持续时间(可选)
query_time: 查询时间(可选)
Returns:
预测结果列表
"""
results = []
for point in points:
try:
result = self.predict_single_point(
point, rainfall=rainfall, duration=duration, query_time=query_time
)
results.append(result)
except Exception as e:
logger.error(f"预测点 {point.get('id')} 失败: {e}")
results.append({
'point_id': point.get('id'),
'source_type': point.get('source_type'),
'lon': point.get('lon'),
'lat': point.get('lat'),
'error': str(e)
})
return results
def get_model_info(self) -> Dict[str, Any]:
"""
获取模型信息
+61 -15
View File
@@ -22,26 +22,35 @@ class DbnRepository:
获取所有隐患点和风险点(从 xian_risk_factors 表)
Args:
region_code: 行政区划代码(区县名称),可选
region_code: 行政区划代码('610104'),可选,匹配隐患点.county_code 和风险点.unit_code
Returns:
点列表,每个元素包含:id, source_id, source_type, lon, lat, static_factors
"""
sql = """
SELECT
id,
source_id,
source_type,
lon,
lat,
static_factors
FROM xian_risk_factors
WHERE is_delete = 0
"""
params = (region_code,) if region_code else None
if region_code:
sql += " AND county = %s"
# 通过源表的行政区划代码筛选
sql = """
SELECT rf.id, rf.source_id, rf.source_type, rf.lon, rf.lat, rf.static_factors
FROM xian_risk_factors rf
WHERE rf.is_delete = 0
AND (
(rf.source_type = 1 AND rf.source_id IN (
SELECT id FROM xian_hidden_danger_spots WHERE county_id = %s AND is_delete = 0
))
OR
(rf.source_type = 2 AND rf.source_id IN (
SELECT id FROM xian_risk_spots WHERE unit_code = %s AND is_delete = 0
))
)
"""
params = (region_code, region_code)
else:
sql = """
SELECT id, source_id, source_type, lon, lat, static_factors
FROM xian_risk_factors
WHERE is_delete = 0
"""
params = None
results = db_helper.execute_query(sql, params)
@@ -94,6 +103,43 @@ class DbnRepository:
'static_factors': result.get('static_factors') or {}
}
@staticmethod
def get_points_by_ids(point_ids: List[int]) -> List[Dict[str, Any]]:
"""
批量获取点信息
Args:
point_ids: 点ID列表
Returns:
点信息列表
"""
if not point_ids:
return []
placeholders = ','.join(['%s'] * len(point_ids))
sql = f"""
SELECT
rf.id,
rf.source_id,
rf.source_type,
rf.lon,
rf.lat,
rf.static_factors
FROM xian_risk_factors rf
WHERE rf.id IN ({placeholders}) AND rf.is_delete = 0
"""
results = db_helper.execute_query(sql, tuple(point_ids))
return [{
'id': row['id'],
'source_id': row['source_id'],
'source_type': row['source_type'],
'lon': float(row['lon']) if row['lon'] else None,
'lat': float(row['lat']) if row['lat'] else None,
'static_factors': row.get('static_factors') or {}
} for row in results]
@staticmethod
def get_static_factors(point_id: int) -> Dict[str, Any]:
"""
+59
View File
@@ -0,0 +1,59 @@
"""
API 请求/响应数据模型
"""
from typing import List, Optional
from pydantic import BaseModel, Field
# ============================================================
# 暴雨预测
# ============================================================
class RainfallPredictRequest(BaseModel):
"""暴雨灾害链预测请求"""
point_ids: Optional[List[int]] = Field(None, max_length=500,
description="点位ID列表,不传则查询所有点")
region_code: Optional[str] = Field(None, description="行政区划代码(如 '610104'),不传则不限区域")
rainfall: float = Field(..., ge=0, description="累计降雨量(mm)")
duration: float = Field(..., ge=0, description="降雨持续时间(h)")
# ============================================================
# 地震预测
# ============================================================
class EarthquakePredictRequest(BaseModel):
"""地震灾害链预测请求"""
point_ids: Optional[List[int]] = Field(None, max_length=500,
description="点位ID列表,不传则查询所有点")
region_code: Optional[str] = Field(None, description="行政区划代码(如 '610104'),不传则不限区域")
magnitude: float = Field(..., ge=0, le=10, description="震级(Richter)")
depth: float = Field(10.0, gt=0, le=700, description="震源深度(km),默认10km")
epicenter_lon: float = Field(..., ge=-180, le=180, description="震中经度")
epicenter_lat: float = Field(..., ge=-90, le=90, description="震中纬度")
# ============================================================
# 通用响应
# ============================================================
class PredictionItem(BaseModel):
"""单个点位预测结果"""
id: int = Field(..., description="点位ID")
type: str = Field(..., description="类型: 隐患点 / 风险点")
probability: float = Field(..., description="最大灾害概率")
level: str = Field(..., description="灾害等级: 低/中/较高/高")
class PredictResponse(BaseModel):
"""预测响应"""
code: int = Field(200, description="状态码")
message: str = Field("success", description="提示信息")
data: List[PredictionItem] = Field(default_factory=list, description="预测结果列表")
class HealthResponse(BaseModel):
"""健康检查响应"""
status: str = "ok"
rainfall_model_loaded: bool = False
earthquake_model_loaded: bool = False
+59
View File
@@ -0,0 +1,59 @@
"""
API 依赖注入
模型懒加载 + 并发控制
"""
import asyncio
from typing import Optional
from app.config.paths import get_logger
logger = get_logger("api")
# ============================================================
# 并发控制:限制同时进行的预测任务数,防止资源耗尽
# ============================================================
MAX_CONCURRENT_PREDICTIONS = 8
_prediction_semaphore: Optional[asyncio.Semaphore] = None
def get_prediction_semaphore() -> asyncio.Semaphore:
"""获取预测信号量(惰性初始化,兼容事件循环)"""
global _prediction_semaphore
if _prediction_semaphore is None:
_prediction_semaphore = asyncio.Semaphore(MAX_CONCURRENT_PREDICTIONS)
return _prediction_semaphore
# ============================================================
# 模型单例(启动时加载一次)
# ============================================================
_rainfall_model = None
_earthquake_model = None
def get_rainfall_model():
"""获取暴雨DBN模型单例"""
global _rainfall_model
if _rainfall_model is None:
from app.models.dbn.rainfall.rainfall_dbn import RainfallDBN
_rainfall_model = RainfallDBN()
logger.info("暴雨DBN模型加载完成")
return _rainfall_model
def get_earthquake_model():
"""获取地震DBN模型单例"""
global _earthquake_model
if _earthquake_model is None:
from app.models.dbn.earthquake.earthquake_dbn import EarthquakeDBN
_earthquake_model = EarthquakeDBN()
logger.info("地震DBN模型加载完成")
return _earthquake_model
def is_model_loaded() -> dict:
"""检查模型加载状态"""
return {
"rainfall_model_loaded": _rainfall_model is not None,
"earthquake_model_loaded": _earthquake_model is not None,
}
+2
View File
@@ -6,3 +6,5 @@ scipy == 1.17.1
matplotlib == 3.10.0
Pillow == 12.2.0
pyyaml == 6.0.2
fastapi == 0.136.3
uvicorn[standard] == 0.49.0