暴雨地震灾害链HTTP请求
This commit is contained in:
@@ -0,0 +1,13 @@
|
|||||||
|
"""
|
||||||
|
API 路由模块
|
||||||
|
"""
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
|
||||||
|
def register_routers(application: FastAPI):
|
||||||
|
"""注册所有路由"""
|
||||||
|
from app.api.rainfall import router as rainfall_router
|
||||||
|
from app.api.earthquake import router as earthquake_router
|
||||||
|
|
||||||
|
application.include_router(rainfall_router)
|
||||||
|
application.include_router(earthquake_router)
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
"""
|
||||||
|
地震灾害链预测接口
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
|
||||||
|
from app.schemas.api_schemas import EarthquakePredictRequest, PredictResponse, PredictionItem
|
||||||
|
from app.utils.api_deps import get_earthquake_model, get_prediction_semaphore
|
||||||
|
from app.repositories.dbn_repository import dbn_repository
|
||||||
|
from app.config.paths import get_logger
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/earthquake", tags=["地震灾害链"])
|
||||||
|
logger = get_logger("api.earthquake")
|
||||||
|
|
||||||
|
SOURCE_TYPE_MAP = {1: "隐患点", 2: "风险点"}
|
||||||
|
LEVEL_MAP = {"低": "低", "中": "中", "较高": "较高", "高": "高"}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_prediction_items(results: List[Dict[str, Any]]) -> List[PredictionItem]:
|
||||||
|
"""将模型原始结果转换为接口返回格式"""
|
||||||
|
items = []
|
||||||
|
for r in results:
|
||||||
|
probs = r.get("disaster_probabilities", {})
|
||||||
|
levels = r.get("disaster_levels", {})
|
||||||
|
|
||||||
|
if not probs:
|
||||||
|
continue
|
||||||
|
|
||||||
|
max_hazard = max(probs, key=probs.get)
|
||||||
|
items.append(PredictionItem(
|
||||||
|
id=r["point_id"],
|
||||||
|
type=SOURCE_TYPE_MAP.get(r.get("source_type"), "未知"),
|
||||||
|
probability=round(probs[max_hazard], 4),
|
||||||
|
level=LEVEL_MAP.get(levels.get(max_hazard, "none"), "无"),
|
||||||
|
))
|
||||||
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]:
|
||||||
|
"""获取点位列表"""
|
||||||
|
if point_ids:
|
||||||
|
return dbn_repository.get_points_by_ids(point_ids)
|
||||||
|
return dbn_repository.get_all_points(region_code)
|
||||||
|
|
||||||
|
|
||||||
|
def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
||||||
|
magnitude: float, depth: float,
|
||||||
|
epicenter_lon: float, epicenter_lat: float) -> List[PredictionItem]:
|
||||||
|
"""同步执行地震预测(在线程池中运行)"""
|
||||||
|
points = _fetch_points(point_ids, region_code)
|
||||||
|
if not points:
|
||||||
|
return []
|
||||||
|
|
||||||
|
model = get_earthquake_model()
|
||||||
|
results = model.predict_multiple_points(
|
||||||
|
points,
|
||||||
|
magnitude=magnitude,
|
||||||
|
depth=depth,
|
||||||
|
epicenter_lon=epicenter_lon,
|
||||||
|
epicenter_lat=epicenter_lat,
|
||||||
|
)
|
||||||
|
return _build_prediction_items(results)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/predict", response_model=PredictResponse, summary="地震灾害链预测")
|
||||||
|
async def predict_earthquake(req: EarthquakePredictRequest):
|
||||||
|
"""
|
||||||
|
根据震级、震源深度和震中位置,批量预测隐患点/风险点的次生灾害概率和等级。
|
||||||
|
|
||||||
|
- **point_ids**: 点位ID列表(可选,不传则查询所有点)
|
||||||
|
- **region_code**: 行政区划代码(可选,不传则不限区域)
|
||||||
|
- **magnitude**: 震级(Richter)
|
||||||
|
- **depth**: 震源深度(km),默认10km
|
||||||
|
- **epicenter_lon**: 震中经度
|
||||||
|
- **epicenter_lat**: 震中纬度
|
||||||
|
"""
|
||||||
|
semaphore = get_prediction_semaphore()
|
||||||
|
|
||||||
|
async with semaphore:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
try:
|
||||||
|
items = await loop.run_in_executor(
|
||||||
|
None, _predict_sync, req.point_ids, req.region_code,
|
||||||
|
req.magnitude, req.depth, req.epicenter_lon, req.epicenter_lat
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"地震预测失败: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"预测失败: {e}")
|
||||||
|
|
||||||
|
return PredictResponse(code=200, message="success", data=items)
|
||||||
@@ -0,0 +1,83 @@
|
|||||||
|
"""
|
||||||
|
暴雨灾害链预测接口
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
|
||||||
|
from app.schemas.api_schemas import RainfallPredictRequest, PredictResponse, PredictionItem
|
||||||
|
from app.utils.api_deps import get_rainfall_model, get_prediction_semaphore
|
||||||
|
from app.repositories.dbn_repository import dbn_repository
|
||||||
|
from app.config.paths import get_logger
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/rainfall", tags=["暴雨灾害链"])
|
||||||
|
logger = get_logger("api.rainfall")
|
||||||
|
|
||||||
|
SOURCE_TYPE_MAP = {1: "隐患点", 2: "风险点"}
|
||||||
|
LEVEL_MAP = {"低": "低", "中": "中", "较高": "较高", "高": "高"}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_prediction_items(results: List[Dict[str, Any]]) -> List[PredictionItem]:
|
||||||
|
"""将模型原始结果转换为接口返回格式"""
|
||||||
|
items = []
|
||||||
|
for r in results:
|
||||||
|
probs = r.get("disaster_probabilities", {})
|
||||||
|
levels = r.get("disaster_levels", {})
|
||||||
|
|
||||||
|
if not probs:
|
||||||
|
continue
|
||||||
|
|
||||||
|
max_hazard = max(probs, key=probs.get)
|
||||||
|
items.append(PredictionItem(
|
||||||
|
id=r["point_id"],
|
||||||
|
type=SOURCE_TYPE_MAP.get(r.get("source_type"), "未知"),
|
||||||
|
probability=round(probs[max_hazard], 4),
|
||||||
|
level=LEVEL_MAP.get(levels.get(max_hazard, "none"), "无"),
|
||||||
|
))
|
||||||
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]:
|
||||||
|
"""获取点位列表"""
|
||||||
|
if point_ids:
|
||||||
|
return dbn_repository.get_points_by_ids(point_ids)
|
||||||
|
return dbn_repository.get_all_points(region_code)
|
||||||
|
|
||||||
|
|
||||||
|
def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
||||||
|
rainfall: float, duration: float) -> List[PredictionItem]:
|
||||||
|
"""同步执行暴雨预测(在线程池中运行)"""
|
||||||
|
points = _fetch_points(point_ids, region_code)
|
||||||
|
if not points:
|
||||||
|
return []
|
||||||
|
|
||||||
|
model = get_rainfall_model()
|
||||||
|
results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration)
|
||||||
|
return _build_prediction_items(results)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/predict", response_model=PredictResponse, summary="暴雨灾害链预测")
|
||||||
|
async def predict_rainfall(req: RainfallPredictRequest):
|
||||||
|
"""
|
||||||
|
根据降雨量和持续时间,批量预测隐患点/风险点的灾害概率和等级。
|
||||||
|
|
||||||
|
- **point_ids**: 点位ID列表(可选,不传则查询所有点)
|
||||||
|
- **region_code**: 行政区划代码(可选,不传则不限区域)
|
||||||
|
- **rainfall**: 累计降雨量(mm)
|
||||||
|
- **duration**: 降雨持续时间(h)
|
||||||
|
"""
|
||||||
|
semaphore = get_prediction_semaphore()
|
||||||
|
|
||||||
|
async with semaphore:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
try:
|
||||||
|
items = await loop.run_in_executor(
|
||||||
|
None, _predict_sync, req.point_ids, req.region_code,
|
||||||
|
req.rainfall, req.duration
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"暴雨预测失败: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"预测失败: {e}")
|
||||||
|
|
||||||
|
return PredictResponse(code=200, message="success", data=items)
|
||||||
@@ -51,12 +51,30 @@ class AppLauncher:
|
|||||||
|
|
||||||
def start():
|
def start():
|
||||||
"""启动应用服务"""
|
"""启动应用服务"""
|
||||||
|
import threading
|
||||||
|
from config import settings
|
||||||
from app.core.rainfall_manager import rainfall_manager
|
from app.core.rainfall_manager import rainfall_manager
|
||||||
from app.utils.logger import get_logger
|
from app.utils.logger import get_logger
|
||||||
from app.utils.thread_pool_manager import block_main_thread, thread_pool_manager
|
from app.utils.thread_pool_manager import block_main_thread, thread_pool_manager
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
# 启动 FastAPI 服务(守护线程)
|
||||||
|
def run_api_server():
|
||||||
|
import uvicorn
|
||||||
|
from app.core.server import create_app
|
||||||
|
api_app = create_app()
|
||||||
|
uvicorn.run(
|
||||||
|
api_app,
|
||||||
|
host=getattr(settings, "API_HOST", "127.0.0.1"),
|
||||||
|
port=int(getattr(settings, "API_PORT", 8082)),
|
||||||
|
log_level="info",
|
||||||
|
)
|
||||||
|
|
||||||
|
api_thread = threading.Thread(target=run_api_server, daemon=True, name="api-server")
|
||||||
|
api_thread.start()
|
||||||
|
logger.info(f"FastAPI 服务已启动: http://{getattr(settings, 'API_HOST', '127.0.0.1')}:{getattr(settings, 'API_PORT', 8082)}")
|
||||||
|
|
||||||
# 启动降雨站点监测
|
# 启动降雨站点监测
|
||||||
logger.info("启动降雨站点监测服务...")
|
logger.info("启动降雨站点监测服务...")
|
||||||
rainfall_manager.monitoring_rainfall_station_id('2025-09-16 20:00:00')
|
rainfall_manager.monitoring_rainfall_station_id('2025-09-16 20:00:00')
|
||||||
|
|||||||
@@ -0,0 +1,55 @@
|
|||||||
|
"""
|
||||||
|
FastAPI 服务创建与配置
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
|
||||||
|
from app.utils.api_deps import get_rainfall_model, get_earthquake_model, is_model_loaded
|
||||||
|
from app.schemas.api_schemas import HealthResponse
|
||||||
|
from app.config.paths import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("api")
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""应用生命周期:启动时预加载模型"""
|
||||||
|
logger.info("正在预加载DBN模型...")
|
||||||
|
get_rainfall_model()
|
||||||
|
get_earthquake_model()
|
||||||
|
logger.info("DBN模型预加载完成")
|
||||||
|
yield
|
||||||
|
logger.info("应用关闭")
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
"""创建 FastAPI 应用实例"""
|
||||||
|
application = FastAPI(
|
||||||
|
title="西安灾害链预测服务",
|
||||||
|
description="基于动态贝叶斯网络的暴雨/地震灾害链预测API",
|
||||||
|
version="1.0.0",
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
@application.middleware("http")
|
||||||
|
async def log_requests(request: Request, call_next):
|
||||||
|
"""请求日志中间件"""
|
||||||
|
start = time.time()
|
||||||
|
response = await call_next(request)
|
||||||
|
elapsed = time.time() - start
|
||||||
|
logger.info(f"{request.method} {request.url.path} -> {response.status_code} ({elapsed:.3f}s)")
|
||||||
|
return response
|
||||||
|
|
||||||
|
# 注册路由
|
||||||
|
from app.api import register_routers
|
||||||
|
register_routers(application)
|
||||||
|
|
||||||
|
@application.get("/health", response_model=HealthResponse, tags=["系统"])
|
||||||
|
async def health_check():
|
||||||
|
"""健康检查"""
|
||||||
|
status = is_model_loaded()
|
||||||
|
return HealthResponse(status="ok", **status)
|
||||||
|
|
||||||
|
return application
|
||||||
@@ -18,21 +18,20 @@ logger = get_logger("earthquake_dbn")
|
|||||||
class EarthquakeDBN:
|
class EarthquakeDBN:
|
||||||
"""地震灾害链DBN模型"""
|
"""地震灾害链DBN模型"""
|
||||||
|
|
||||||
# 灾害概率→离散等级的阈值映射
|
# 灾害概率→等级的阈值映射
|
||||||
HAZARD_LEVEL_THRESHOLDS = [
|
HAZARD_LEVEL_THRESHOLDS = [
|
||||||
(0.6, 'very_high'),
|
(0.7, '高'),
|
||||||
(0.4, 'high'),
|
(0.5, '较高'),
|
||||||
(0.2, 'medium'),
|
(0.3, '中'),
|
||||||
(0.05, 'low'),
|
(0.0, '低'),
|
||||||
(0.0, 'none'),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def _probability_to_level(self, prob: float) -> str:
|
def _probability_to_level(self, prob: float) -> str:
|
||||||
"""将连续概率映射到离散等级"""
|
"""将连续概率映射到风险等级:低(<30%) / 中(30-50%) / 较高(50-70%) / 高(70%+)"""
|
||||||
for threshold, level in self.HAZARD_LEVEL_THRESHOLDS:
|
for threshold, level in self.HAZARD_LEVEL_THRESHOLDS:
|
||||||
if prob >= threshold:
|
if prob >= threshold:
|
||||||
return level
|
return level
|
||||||
return 'none'
|
return '低'
|
||||||
|
|
||||||
def __init__(self, config_dir: Optional[str] = None):
|
def __init__(self, config_dir: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
@@ -114,18 +113,20 @@ class EarthquakeDBN:
|
|||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def estimate_seismic_intensity(magnitude: float, epicenter_distance_km: float) -> float:
|
def estimate_seismic_intensity(magnitude: float, epicenter_distance_km: float,
|
||||||
|
depth_km: float = 10.0) -> float:
|
||||||
"""
|
"""
|
||||||
根据震级和震中距估算地震烈度
|
根据震级、震中距和震源深度估算地震烈度
|
||||||
使用中国地震烈度衰减关系
|
|
||||||
|
|
||||||
I = 0.923 + 1.621*M - 3.494*ln(R+10)
|
I = 0.923 + 1.621*M - 3.494*ln(R+10) - ln(H/10)
|
||||||
|
|
||||||
参考:GB 18306-2015 中国地震动参数区划图
|
参考:GB 18306-2015 中国地震动参数区划图
|
||||||
|
深度修正:震源越深,地表烈度越低
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
magnitude: 震级(Richter)
|
magnitude: 震级(Richter)
|
||||||
epicenter_distance_km: 震中距(km)
|
epicenter_distance_km: 震中距(km)
|
||||||
|
depth_km: 震源深度(km),默认10km
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
估算的地震烈度(中国烈度表数值)
|
估算的地震烈度(中国烈度表数值)
|
||||||
@@ -135,6 +136,10 @@ class EarthquakeDBN:
|
|||||||
|
|
||||||
intensity = 0.923 + 1.621 * magnitude - 3.494 * math.log(epicenter_distance_km + 10)
|
intensity = 0.923 + 1.621 * magnitude - 3.494 * math.log(epicenter_distance_km + 10)
|
||||||
|
|
||||||
|
# 震源深度修正:以10km为基准,深度越大烈度衰减越多
|
||||||
|
depth_km = max(depth_km, 1.0)
|
||||||
|
intensity -= math.log(depth_km / 10.0)
|
||||||
|
|
||||||
# 限制在合理范围内
|
# 限制在合理范围内
|
||||||
return max(1.0, min(12.0, intensity))
|
return max(1.0, min(12.0, intensity))
|
||||||
|
|
||||||
@@ -191,7 +196,8 @@ class EarthquakeDBN:
|
|||||||
epicenter_distance: Optional[float] = None,
|
epicenter_distance: Optional[float] = None,
|
||||||
seismic_intensity: Optional[float] = None,
|
seismic_intensity: Optional[float] = None,
|
||||||
epicenter_lon: Optional[float] = None,
|
epicenter_lon: Optional[float] = None,
|
||||||
epicenter_lat: Optional[float] = None) -> Dict[str, Any]:
|
epicenter_lat: Optional[float] = None,
|
||||||
|
depth: float = 10.0) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
对单个点进行地震灾害预测
|
对单个点进行地震灾害预测
|
||||||
|
|
||||||
@@ -202,6 +208,7 @@ class EarthquakeDBN:
|
|||||||
seismic_intensity: 地震烈度(中国烈度表),若未提供则自动估算
|
seismic_intensity: 地震烈度(中国烈度表),若未提供则自动估算
|
||||||
epicenter_lon: 震中经度(可选,用于计算震中距)
|
epicenter_lon: 震中经度(可选,用于计算震中距)
|
||||||
epicenter_lat: 震中纬度(可选,用于计算震中距)
|
epicenter_lat: 震中纬度(可选,用于计算震中距)
|
||||||
|
depth: 震源深度(km),默认10.0
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
预测结果
|
预测结果
|
||||||
@@ -226,7 +233,7 @@ class EarthquakeDBN:
|
|||||||
|
|
||||||
# 估算地震烈度(如果未直接提供)
|
# 估算地震烈度(如果未直接提供)
|
||||||
if seismic_intensity is None:
|
if seismic_intensity is None:
|
||||||
seismic_intensity = self.estimate_seismic_intensity(magnitude, epicenter_distance)
|
seismic_intensity = self.estimate_seismic_intensity(magnitude, epicenter_distance, depth)
|
||||||
logger.info(f"估算地震烈度: {seismic_intensity:.1f}")
|
logger.info(f"估算地震烈度: {seismic_intensity:.1f}")
|
||||||
|
|
||||||
# 获取静态因子数据
|
# 获取静态因子数据
|
||||||
@@ -379,6 +386,52 @@ class EarthquakeDBN:
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def predict_multiple_points(self, points: List[Dict[str, Any]],
|
||||||
|
magnitude: float = 6.0,
|
||||||
|
epicenter_distance: Optional[float] = None,
|
||||||
|
seismic_intensity: Optional[float] = None,
|
||||||
|
epicenter_lon: Optional[float] = None,
|
||||||
|
epicenter_lat: Optional[float] = None,
|
||||||
|
depth: float = 10.0) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
对已获取的点列表进行地震灾害预测
|
||||||
|
|
||||||
|
Args:
|
||||||
|
points: 点信息列表(已从数据库获取)
|
||||||
|
magnitude: 地震震级
|
||||||
|
epicenter_distance: 震中距(km,可选)
|
||||||
|
seismic_intensity: 地震烈度(可选)
|
||||||
|
epicenter_lon: 震中经度(可选)
|
||||||
|
epicenter_lat: 震中纬度(可选)
|
||||||
|
depth: 震源深度(km),默认10.0
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
预测结果列表
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for point in points:
|
||||||
|
try:
|
||||||
|
result = self.predict_single_point(
|
||||||
|
point,
|
||||||
|
magnitude=magnitude,
|
||||||
|
epicenter_distance=epicenter_distance,
|
||||||
|
seismic_intensity=seismic_intensity,
|
||||||
|
epicenter_lon=epicenter_lon,
|
||||||
|
epicenter_lat=epicenter_lat,
|
||||||
|
depth=depth
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"预测点 {point.get('id')} 失败: {e}")
|
||||||
|
results.append({
|
||||||
|
'point_id': point.get('id'),
|
||||||
|
'source_type': point.get('source_type'),
|
||||||
|
'lon': point.get('lon'),
|
||||||
|
'lat': point.get('lat'),
|
||||||
|
'error': str(e)
|
||||||
|
})
|
||||||
|
return results
|
||||||
|
|
||||||
def get_model_info(self) -> Dict[str, Any]:
|
def get_model_info(self) -> Dict[str, Any]:
|
||||||
"""获取模型信息"""
|
"""获取模型信息"""
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -17,21 +17,20 @@ logger = get_logger("dbn")
|
|||||||
class RainfallDBN:
|
class RainfallDBN:
|
||||||
"""暴雨灾害链DBN模型"""
|
"""暴雨灾害链DBN模型"""
|
||||||
|
|
||||||
# 灾害概率→离散等级的阈值映射
|
# 灾害概率→等级的阈值映射
|
||||||
HAZARD_LEVEL_THRESHOLDS = [
|
HAZARD_LEVEL_THRESHOLDS = [
|
||||||
(0.6, 'very_high'),
|
(0.7, '高'),
|
||||||
(0.4, 'high'),
|
(0.5, '较高'),
|
||||||
(0.2, 'medium'),
|
(0.3, '中'),
|
||||||
(0.05, 'low'),
|
(0.0, '低'),
|
||||||
(0.0, 'none'),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def _probability_to_level(self, prob: float) -> str:
|
def _probability_to_level(self, prob: float) -> str:
|
||||||
"""将连续概率映射到离散等级"""
|
"""将连续概率映射到风险等级:低(<30%) / 中(30-50%) / 较高(50-70%) / 高(70%+)"""
|
||||||
for threshold, level in self.HAZARD_LEVEL_THRESHOLDS:
|
for threshold, level in self.HAZARD_LEVEL_THRESHOLDS:
|
||||||
if prob >= threshold:
|
if prob >= threshold:
|
||||||
return level
|
return level
|
||||||
return 'none'
|
return '低'
|
||||||
|
|
||||||
def __init__(self, config_dir: Optional[str] = None):
|
def __init__(self, config_dir: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
@@ -374,6 +373,40 @@ class RainfallDBN:
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def predict_multiple_points(self, points: List[Dict[str, Any]],
|
||||||
|
rainfall: Optional[float] = None,
|
||||||
|
duration: Optional[float] = None,
|
||||||
|
query_time: Optional[datetime] = None) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
对已获取的点列表进行暴雨灾害预测
|
||||||
|
|
||||||
|
Args:
|
||||||
|
points: 点信息列表(已从数据库获取)
|
||||||
|
rainfall: 累计降雨量(可选)
|
||||||
|
duration: 持续时间(可选)
|
||||||
|
query_time: 查询时间(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
预测结果列表
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for point in points:
|
||||||
|
try:
|
||||||
|
result = self.predict_single_point(
|
||||||
|
point, rainfall=rainfall, duration=duration, query_time=query_time
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"预测点 {point.get('id')} 失败: {e}")
|
||||||
|
results.append({
|
||||||
|
'point_id': point.get('id'),
|
||||||
|
'source_type': point.get('source_type'),
|
||||||
|
'lon': point.get('lon'),
|
||||||
|
'lat': point.get('lat'),
|
||||||
|
'error': str(e)
|
||||||
|
})
|
||||||
|
return results
|
||||||
|
|
||||||
def get_model_info(self) -> Dict[str, Any]:
|
def get_model_info(self) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
获取模型信息
|
获取模型信息
|
||||||
|
|||||||
@@ -22,26 +22,35 @@ class DbnRepository:
|
|||||||
获取所有隐患点和风险点(从 xian_risk_factors 表)
|
获取所有隐患点和风险点(从 xian_risk_factors 表)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
region_code: 行政区划代码(区县名称),可选
|
region_code: 行政区划代码(如 '610104'),可选,匹配隐患点.county_code 和风险点.unit_code
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
点列表,每个元素包含:id, source_id, source_type, lon, lat, static_factors
|
点列表,每个元素包含:id, source_id, source_type, lon, lat, static_factors
|
||||||
"""
|
"""
|
||||||
sql = """
|
|
||||||
SELECT
|
|
||||||
id,
|
|
||||||
source_id,
|
|
||||||
source_type,
|
|
||||||
lon,
|
|
||||||
lat,
|
|
||||||
static_factors
|
|
||||||
FROM xian_risk_factors
|
|
||||||
WHERE is_delete = 0
|
|
||||||
"""
|
|
||||||
params = (region_code,) if region_code else None
|
|
||||||
|
|
||||||
if region_code:
|
if region_code:
|
||||||
sql += " AND county = %s"
|
# 通过源表的行政区划代码筛选
|
||||||
|
sql = """
|
||||||
|
SELECT rf.id, rf.source_id, rf.source_type, rf.lon, rf.lat, rf.static_factors
|
||||||
|
FROM xian_risk_factors rf
|
||||||
|
WHERE rf.is_delete = 0
|
||||||
|
AND (
|
||||||
|
(rf.source_type = 1 AND rf.source_id IN (
|
||||||
|
SELECT id FROM xian_hidden_danger_spots WHERE county_id = %s AND is_delete = 0
|
||||||
|
))
|
||||||
|
OR
|
||||||
|
(rf.source_type = 2 AND rf.source_id IN (
|
||||||
|
SELECT id FROM xian_risk_spots WHERE unit_code = %s AND is_delete = 0
|
||||||
|
))
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
params = (region_code, region_code)
|
||||||
|
else:
|
||||||
|
sql = """
|
||||||
|
SELECT id, source_id, source_type, lon, lat, static_factors
|
||||||
|
FROM xian_risk_factors
|
||||||
|
WHERE is_delete = 0
|
||||||
|
"""
|
||||||
|
params = None
|
||||||
|
|
||||||
results = db_helper.execute_query(sql, params)
|
results = db_helper.execute_query(sql, params)
|
||||||
|
|
||||||
@@ -94,6 +103,43 @@ class DbnRepository:
|
|||||||
'static_factors': result.get('static_factors') or {}
|
'static_factors': result.get('static_factors') or {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_points_by_ids(point_ids: List[int]) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
批量获取点信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
point_ids: 点ID列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
点信息列表
|
||||||
|
"""
|
||||||
|
if not point_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
placeholders = ','.join(['%s'] * len(point_ids))
|
||||||
|
sql = f"""
|
||||||
|
SELECT
|
||||||
|
rf.id,
|
||||||
|
rf.source_id,
|
||||||
|
rf.source_type,
|
||||||
|
rf.lon,
|
||||||
|
rf.lat,
|
||||||
|
rf.static_factors
|
||||||
|
FROM xian_risk_factors rf
|
||||||
|
WHERE rf.id IN ({placeholders}) AND rf.is_delete = 0
|
||||||
|
"""
|
||||||
|
results = db_helper.execute_query(sql, tuple(point_ids))
|
||||||
|
|
||||||
|
return [{
|
||||||
|
'id': row['id'],
|
||||||
|
'source_id': row['source_id'],
|
||||||
|
'source_type': row['source_type'],
|
||||||
|
'lon': float(row['lon']) if row['lon'] else None,
|
||||||
|
'lat': float(row['lat']) if row['lat'] else None,
|
||||||
|
'static_factors': row.get('static_factors') or {}
|
||||||
|
} for row in results]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_static_factors(point_id: int) -> Dict[str, Any]:
|
def get_static_factors(point_id: int) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -0,0 +1,59 @@
|
|||||||
|
"""
|
||||||
|
API 请求/响应数据模型
|
||||||
|
"""
|
||||||
|
from typing import List, Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 暴雨预测
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
class RainfallPredictRequest(BaseModel):
|
||||||
|
"""暴雨灾害链预测请求"""
|
||||||
|
point_ids: Optional[List[int]] = Field(None, max_length=500,
|
||||||
|
description="点位ID列表,不传则查询所有点")
|
||||||
|
region_code: Optional[str] = Field(None, description="行政区划代码(如 '610104'),不传则不限区域")
|
||||||
|
rainfall: float = Field(..., ge=0, description="累计降雨量(mm)")
|
||||||
|
duration: float = Field(..., ge=0, description="降雨持续时间(h)")
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 地震预测
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
class EarthquakePredictRequest(BaseModel):
|
||||||
|
"""地震灾害链预测请求"""
|
||||||
|
point_ids: Optional[List[int]] = Field(None, max_length=500,
|
||||||
|
description="点位ID列表,不传则查询所有点")
|
||||||
|
region_code: Optional[str] = Field(None, description="行政区划代码(如 '610104'),不传则不限区域")
|
||||||
|
magnitude: float = Field(..., ge=0, le=10, description="震级(Richter)")
|
||||||
|
depth: float = Field(10.0, gt=0, le=700, description="震源深度(km),默认10km")
|
||||||
|
epicenter_lon: float = Field(..., ge=-180, le=180, description="震中经度")
|
||||||
|
epicenter_lat: float = Field(..., ge=-90, le=90, description="震中纬度")
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 通用响应
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
class PredictionItem(BaseModel):
|
||||||
|
"""单个点位预测结果"""
|
||||||
|
id: int = Field(..., description="点位ID")
|
||||||
|
type: str = Field(..., description="类型: 隐患点 / 风险点")
|
||||||
|
probability: float = Field(..., description="最大灾害概率")
|
||||||
|
level: str = Field(..., description="灾害等级: 低/中/较高/高")
|
||||||
|
|
||||||
|
|
||||||
|
class PredictResponse(BaseModel):
|
||||||
|
"""预测响应"""
|
||||||
|
code: int = Field(200, description="状态码")
|
||||||
|
message: str = Field("success", description="提示信息")
|
||||||
|
data: List[PredictionItem] = Field(default_factory=list, description="预测结果列表")
|
||||||
|
|
||||||
|
|
||||||
|
class HealthResponse(BaseModel):
|
||||||
|
"""健康检查响应"""
|
||||||
|
status: str = "ok"
|
||||||
|
rainfall_model_loaded: bool = False
|
||||||
|
earthquake_model_loaded: bool = False
|
||||||
@@ -0,0 +1,59 @@
|
|||||||
|
"""
|
||||||
|
API 依赖注入
|
||||||
|
模型懒加载 + 并发控制
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from app.config.paths import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("api")
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 并发控制:限制同时进行的预测任务数,防止资源耗尽
|
||||||
|
# ============================================================
|
||||||
|
MAX_CONCURRENT_PREDICTIONS = 8
|
||||||
|
_prediction_semaphore: Optional[asyncio.Semaphore] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_prediction_semaphore() -> asyncio.Semaphore:
|
||||||
|
"""获取预测信号量(惰性初始化,兼容事件循环)"""
|
||||||
|
global _prediction_semaphore
|
||||||
|
if _prediction_semaphore is None:
|
||||||
|
_prediction_semaphore = asyncio.Semaphore(MAX_CONCURRENT_PREDICTIONS)
|
||||||
|
return _prediction_semaphore
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 模型单例(启动时加载一次)
|
||||||
|
# ============================================================
|
||||||
|
_rainfall_model = None
|
||||||
|
_earthquake_model = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_rainfall_model():
|
||||||
|
"""获取暴雨DBN模型单例"""
|
||||||
|
global _rainfall_model
|
||||||
|
if _rainfall_model is None:
|
||||||
|
from app.models.dbn.rainfall.rainfall_dbn import RainfallDBN
|
||||||
|
_rainfall_model = RainfallDBN()
|
||||||
|
logger.info("暴雨DBN模型加载完成")
|
||||||
|
return _rainfall_model
|
||||||
|
|
||||||
|
|
||||||
|
def get_earthquake_model():
|
||||||
|
"""获取地震DBN模型单例"""
|
||||||
|
global _earthquake_model
|
||||||
|
if _earthquake_model is None:
|
||||||
|
from app.models.dbn.earthquake.earthquake_dbn import EarthquakeDBN
|
||||||
|
_earthquake_model = EarthquakeDBN()
|
||||||
|
logger.info("地震DBN模型加载完成")
|
||||||
|
return _earthquake_model
|
||||||
|
|
||||||
|
|
||||||
|
def is_model_loaded() -> dict:
|
||||||
|
"""检查模型加载状态"""
|
||||||
|
return {
|
||||||
|
"rainfall_model_loaded": _rainfall_model is not None,
|
||||||
|
"earthquake_model_loaded": _earthquake_model is not None,
|
||||||
|
}
|
||||||
@@ -6,3 +6,5 @@ scipy == 1.17.1
|
|||||||
matplotlib == 3.10.0
|
matplotlib == 3.10.0
|
||||||
Pillow == 12.2.0
|
Pillow == 12.2.0
|
||||||
pyyaml == 6.0.2
|
pyyaml == 6.0.2
|
||||||
|
fastapi == 0.136.3
|
||||||
|
uvicorn[standard] == 0.49.0
|
||||||
Reference in New Issue
Block a user