暴雨地震灾害链HTTP请求
This commit is contained in:
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
API 路由模块
|
||||
"""
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
||||
def register_routers(application: FastAPI):
|
||||
"""注册所有路由"""
|
||||
from app.api.rainfall import router as rainfall_router
|
||||
from app.api.earthquake import router as earthquake_router
|
||||
|
||||
application.include_router(rainfall_router)
|
||||
application.include_router(earthquake_router)
|
||||
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
地震灾害链预测接口
|
||||
"""
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.schemas.api_schemas import EarthquakePredictRequest, PredictResponse, PredictionItem
|
||||
from app.utils.api_deps import get_earthquake_model, get_prediction_semaphore
|
||||
from app.repositories.dbn_repository import dbn_repository
|
||||
from app.config.paths import get_logger
|
||||
|
||||
router = APIRouter(prefix="/earthquake", tags=["地震灾害链"])
|
||||
logger = get_logger("api.earthquake")
|
||||
|
||||
SOURCE_TYPE_MAP = {1: "隐患点", 2: "风险点"}
|
||||
LEVEL_MAP = {"低": "低", "中": "中", "较高": "较高", "高": "高"}
|
||||
|
||||
|
||||
def _build_prediction_items(results: List[Dict[str, Any]]) -> List[PredictionItem]:
|
||||
"""将模型原始结果转换为接口返回格式"""
|
||||
items = []
|
||||
for r in results:
|
||||
probs = r.get("disaster_probabilities", {})
|
||||
levels = r.get("disaster_levels", {})
|
||||
|
||||
if not probs:
|
||||
continue
|
||||
|
||||
max_hazard = max(probs, key=probs.get)
|
||||
items.append(PredictionItem(
|
||||
id=r["point_id"],
|
||||
type=SOURCE_TYPE_MAP.get(r.get("source_type"), "未知"),
|
||||
probability=round(probs[max_hazard], 4),
|
||||
level=LEVEL_MAP.get(levels.get(max_hazard, "none"), "无"),
|
||||
))
|
||||
return items
|
||||
|
||||
|
||||
def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]:
|
||||
"""获取点位列表"""
|
||||
if point_ids:
|
||||
return dbn_repository.get_points_by_ids(point_ids)
|
||||
return dbn_repository.get_all_points(region_code)
|
||||
|
||||
|
||||
def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
||||
magnitude: float, depth: float,
|
||||
epicenter_lon: float, epicenter_lat: float) -> List[PredictionItem]:
|
||||
"""同步执行地震预测(在线程池中运行)"""
|
||||
points = _fetch_points(point_ids, region_code)
|
||||
if not points:
|
||||
return []
|
||||
|
||||
model = get_earthquake_model()
|
||||
results = model.predict_multiple_points(
|
||||
points,
|
||||
magnitude=magnitude,
|
||||
depth=depth,
|
||||
epicenter_lon=epicenter_lon,
|
||||
epicenter_lat=epicenter_lat,
|
||||
)
|
||||
return _build_prediction_items(results)
|
||||
|
||||
|
||||
@router.post("/predict", response_model=PredictResponse, summary="地震灾害链预测")
|
||||
async def predict_earthquake(req: EarthquakePredictRequest):
|
||||
"""
|
||||
根据震级、震源深度和震中位置,批量预测隐患点/风险点的次生灾害概率和等级。
|
||||
|
||||
- **point_ids**: 点位ID列表(可选,不传则查询所有点)
|
||||
- **region_code**: 行政区划代码(可选,不传则不限区域)
|
||||
- **magnitude**: 震级(Richter)
|
||||
- **depth**: 震源深度(km),默认10km
|
||||
- **epicenter_lon**: 震中经度
|
||||
- **epicenter_lat**: 震中纬度
|
||||
"""
|
||||
semaphore = get_prediction_semaphore()
|
||||
|
||||
async with semaphore:
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
items = await loop.run_in_executor(
|
||||
None, _predict_sync, req.point_ids, req.region_code,
|
||||
req.magnitude, req.depth, req.epicenter_lon, req.epicenter_lat
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"地震预测失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"预测失败: {e}")
|
||||
|
||||
return PredictResponse(code=200, message="success", data=items)
|
||||
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
暴雨灾害链预测接口
|
||||
"""
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.schemas.api_schemas import RainfallPredictRequest, PredictResponse, PredictionItem
|
||||
from app.utils.api_deps import get_rainfall_model, get_prediction_semaphore
|
||||
from app.repositories.dbn_repository import dbn_repository
|
||||
from app.config.paths import get_logger
|
||||
|
||||
router = APIRouter(prefix="/rainfall", tags=["暴雨灾害链"])
|
||||
logger = get_logger("api.rainfall")
|
||||
|
||||
SOURCE_TYPE_MAP = {1: "隐患点", 2: "风险点"}
|
||||
LEVEL_MAP = {"低": "低", "中": "中", "较高": "较高", "高": "高"}
|
||||
|
||||
|
||||
def _build_prediction_items(results: List[Dict[str, Any]]) -> List[PredictionItem]:
|
||||
"""将模型原始结果转换为接口返回格式"""
|
||||
items = []
|
||||
for r in results:
|
||||
probs = r.get("disaster_probabilities", {})
|
||||
levels = r.get("disaster_levels", {})
|
||||
|
||||
if not probs:
|
||||
continue
|
||||
|
||||
max_hazard = max(probs, key=probs.get)
|
||||
items.append(PredictionItem(
|
||||
id=r["point_id"],
|
||||
type=SOURCE_TYPE_MAP.get(r.get("source_type"), "未知"),
|
||||
probability=round(probs[max_hazard], 4),
|
||||
level=LEVEL_MAP.get(levels.get(max_hazard, "none"), "无"),
|
||||
))
|
||||
return items
|
||||
|
||||
|
||||
def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) -> List[Dict[str, Any]]:
|
||||
"""获取点位列表"""
|
||||
if point_ids:
|
||||
return dbn_repository.get_points_by_ids(point_ids)
|
||||
return dbn_repository.get_all_points(region_code)
|
||||
|
||||
|
||||
def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
||||
rainfall: float, duration: float) -> List[PredictionItem]:
|
||||
"""同步执行暴雨预测(在线程池中运行)"""
|
||||
points = _fetch_points(point_ids, region_code)
|
||||
if not points:
|
||||
return []
|
||||
|
||||
model = get_rainfall_model()
|
||||
results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration)
|
||||
return _build_prediction_items(results)
|
||||
|
||||
|
||||
@router.post("/predict", response_model=PredictResponse, summary="暴雨灾害链预测")
|
||||
async def predict_rainfall(req: RainfallPredictRequest):
|
||||
"""
|
||||
根据降雨量和持续时间,批量预测隐患点/风险点的灾害概率和等级。
|
||||
|
||||
- **point_ids**: 点位ID列表(可选,不传则查询所有点)
|
||||
- **region_code**: 行政区划代码(可选,不传则不限区域)
|
||||
- **rainfall**: 累计降雨量(mm)
|
||||
- **duration**: 降雨持续时间(h)
|
||||
"""
|
||||
semaphore = get_prediction_semaphore()
|
||||
|
||||
async with semaphore:
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
items = await loop.run_in_executor(
|
||||
None, _predict_sync, req.point_ids, req.region_code,
|
||||
req.rainfall, req.duration
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"暴雨预测失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"预测失败: {e}")
|
||||
|
||||
return PredictResponse(code=200, message="success", data=items)
|
||||
@@ -51,12 +51,30 @@ class AppLauncher:
|
||||
|
||||
def start():
|
||||
"""启动应用服务"""
|
||||
import threading
|
||||
from config import settings
|
||||
from app.core.rainfall_manager import rainfall_manager
|
||||
from app.utils.logger import get_logger
|
||||
from app.utils.thread_pool_manager import block_main_thread, thread_pool_manager
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
# 启动 FastAPI 服务(守护线程)
|
||||
def run_api_server():
|
||||
import uvicorn
|
||||
from app.core.server import create_app
|
||||
api_app = create_app()
|
||||
uvicorn.run(
|
||||
api_app,
|
||||
host=getattr(settings, "API_HOST", "127.0.0.1"),
|
||||
port=int(getattr(settings, "API_PORT", 8082)),
|
||||
log_level="info",
|
||||
)
|
||||
|
||||
api_thread = threading.Thread(target=run_api_server, daemon=True, name="api-server")
|
||||
api_thread.start()
|
||||
logger.info(f"FastAPI 服务已启动: http://{getattr(settings, 'API_HOST', '127.0.0.1')}:{getattr(settings, 'API_PORT', 8082)}")
|
||||
|
||||
# 启动降雨站点监测
|
||||
logger.info("启动降雨站点监测服务...")
|
||||
rainfall_manager.monitoring_rainfall_station_id('2025-09-16 20:00:00')
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
FastAPI 服务创建与配置
|
||||
"""
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
|
||||
from app.utils.api_deps import get_rainfall_model, get_earthquake_model, is_model_loaded
|
||||
from app.schemas.api_schemas import HealthResponse
|
||||
from app.config.paths import get_logger
|
||||
|
||||
logger = get_logger("api")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期:启动时预加载模型"""
|
||||
logger.info("正在预加载DBN模型...")
|
||||
get_rainfall_model()
|
||||
get_earthquake_model()
|
||||
logger.info("DBN模型预加载完成")
|
||||
yield
|
||||
logger.info("应用关闭")
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""创建 FastAPI 应用实例"""
|
||||
application = FastAPI(
|
||||
title="西安灾害链预测服务",
|
||||
description="基于动态贝叶斯网络的暴雨/地震灾害链预测API",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
@application.middleware("http")
|
||||
async def log_requests(request: Request, call_next):
|
||||
"""请求日志中间件"""
|
||||
start = time.time()
|
||||
response = await call_next(request)
|
||||
elapsed = time.time() - start
|
||||
logger.info(f"{request.method} {request.url.path} -> {response.status_code} ({elapsed:.3f}s)")
|
||||
return response
|
||||
|
||||
# 注册路由
|
||||
from app.api import register_routers
|
||||
register_routers(application)
|
||||
|
||||
@application.get("/health", response_model=HealthResponse, tags=["系统"])
|
||||
async def health_check():
|
||||
"""健康检查"""
|
||||
status = is_model_loaded()
|
||||
return HealthResponse(status="ok", **status)
|
||||
|
||||
return application
|
||||
@@ -18,21 +18,20 @@ logger = get_logger("earthquake_dbn")
|
||||
class EarthquakeDBN:
|
||||
"""地震灾害链DBN模型"""
|
||||
|
||||
# 灾害概率→离散等级的阈值映射
|
||||
# 灾害概率→等级的阈值映射
|
||||
HAZARD_LEVEL_THRESHOLDS = [
|
||||
(0.6, 'very_high'),
|
||||
(0.4, 'high'),
|
||||
(0.2, 'medium'),
|
||||
(0.05, 'low'),
|
||||
(0.0, 'none'),
|
||||
(0.7, '高'),
|
||||
(0.5, '较高'),
|
||||
(0.3, '中'),
|
||||
(0.0, '低'),
|
||||
]
|
||||
|
||||
def _probability_to_level(self, prob: float) -> str:
|
||||
"""将连续概率映射到离散等级"""
|
||||
"""将连续概率映射到风险等级:低(<30%) / 中(30-50%) / 较高(50-70%) / 高(70%+)"""
|
||||
for threshold, level in self.HAZARD_LEVEL_THRESHOLDS:
|
||||
if prob >= threshold:
|
||||
return level
|
||||
return 'none'
|
||||
return '低'
|
||||
|
||||
def __init__(self, config_dir: Optional[str] = None):
|
||||
"""
|
||||
@@ -114,18 +113,20 @@ class EarthquakeDBN:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def estimate_seismic_intensity(magnitude: float, epicenter_distance_km: float) -> float:
|
||||
def estimate_seismic_intensity(magnitude: float, epicenter_distance_km: float,
|
||||
depth_km: float = 10.0) -> float:
|
||||
"""
|
||||
根据震级和震中距估算地震烈度
|
||||
使用中国地震烈度衰减关系
|
||||
根据震级、震中距和震源深度估算地震烈度
|
||||
|
||||
I = 0.923 + 1.621*M - 3.494*ln(R+10)
|
||||
I = 0.923 + 1.621*M - 3.494*ln(R+10) - ln(H/10)
|
||||
|
||||
参考:GB 18306-2015 中国地震动参数区划图
|
||||
深度修正:震源越深,地表烈度越低
|
||||
|
||||
Args:
|
||||
magnitude: 震级(Richter)
|
||||
epicenter_distance_km: 震中距(km)
|
||||
depth_km: 震源深度(km),默认10km
|
||||
|
||||
Returns:
|
||||
估算的地震烈度(中国烈度表数值)
|
||||
@@ -135,6 +136,10 @@ class EarthquakeDBN:
|
||||
|
||||
intensity = 0.923 + 1.621 * magnitude - 3.494 * math.log(epicenter_distance_km + 10)
|
||||
|
||||
# 震源深度修正:以10km为基准,深度越大烈度衰减越多
|
||||
depth_km = max(depth_km, 1.0)
|
||||
intensity -= math.log(depth_km / 10.0)
|
||||
|
||||
# 限制在合理范围内
|
||||
return max(1.0, min(12.0, intensity))
|
||||
|
||||
@@ -191,7 +196,8 @@ class EarthquakeDBN:
|
||||
epicenter_distance: Optional[float] = None,
|
||||
seismic_intensity: Optional[float] = None,
|
||||
epicenter_lon: Optional[float] = None,
|
||||
epicenter_lat: Optional[float] = None) -> Dict[str, Any]:
|
||||
epicenter_lat: Optional[float] = None,
|
||||
depth: float = 10.0) -> Dict[str, Any]:
|
||||
"""
|
||||
对单个点进行地震灾害预测
|
||||
|
||||
@@ -202,6 +208,7 @@ class EarthquakeDBN:
|
||||
seismic_intensity: 地震烈度(中国烈度表),若未提供则自动估算
|
||||
epicenter_lon: 震中经度(可选,用于计算震中距)
|
||||
epicenter_lat: 震中纬度(可选,用于计算震中距)
|
||||
depth: 震源深度(km),默认10.0
|
||||
|
||||
Returns:
|
||||
预测结果
|
||||
@@ -226,7 +233,7 @@ class EarthquakeDBN:
|
||||
|
||||
# 估算地震烈度(如果未直接提供)
|
||||
if seismic_intensity is None:
|
||||
seismic_intensity = self.estimate_seismic_intensity(magnitude, epicenter_distance)
|
||||
seismic_intensity = self.estimate_seismic_intensity(magnitude, epicenter_distance, depth)
|
||||
logger.info(f"估算地震烈度: {seismic_intensity:.1f}")
|
||||
|
||||
# 获取静态因子数据
|
||||
@@ -379,6 +386,52 @@ class EarthquakeDBN:
|
||||
|
||||
return results
|
||||
|
||||
def predict_multiple_points(self, points: List[Dict[str, Any]],
|
||||
magnitude: float = 6.0,
|
||||
epicenter_distance: Optional[float] = None,
|
||||
seismic_intensity: Optional[float] = None,
|
||||
epicenter_lon: Optional[float] = None,
|
||||
epicenter_lat: Optional[float] = None,
|
||||
depth: float = 10.0) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
对已获取的点列表进行地震灾害预测
|
||||
|
||||
Args:
|
||||
points: 点信息列表(已从数据库获取)
|
||||
magnitude: 地震震级
|
||||
epicenter_distance: 震中距(km,可选)
|
||||
seismic_intensity: 地震烈度(可选)
|
||||
epicenter_lon: 震中经度(可选)
|
||||
epicenter_lat: 震中纬度(可选)
|
||||
depth: 震源深度(km),默认10.0
|
||||
|
||||
Returns:
|
||||
预测结果列表
|
||||
"""
|
||||
results = []
|
||||
for point in points:
|
||||
try:
|
||||
result = self.predict_single_point(
|
||||
point,
|
||||
magnitude=magnitude,
|
||||
epicenter_distance=epicenter_distance,
|
||||
seismic_intensity=seismic_intensity,
|
||||
epicenter_lon=epicenter_lon,
|
||||
epicenter_lat=epicenter_lat,
|
||||
depth=depth
|
||||
)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"预测点 {point.get('id')} 失败: {e}")
|
||||
results.append({
|
||||
'point_id': point.get('id'),
|
||||
'source_type': point.get('source_type'),
|
||||
'lon': point.get('lon'),
|
||||
'lat': point.get('lat'),
|
||||
'error': str(e)
|
||||
})
|
||||
return results
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""获取模型信息"""
|
||||
return {
|
||||
|
||||
@@ -17,21 +17,20 @@ logger = get_logger("dbn")
|
||||
class RainfallDBN:
|
||||
"""暴雨灾害链DBN模型"""
|
||||
|
||||
# 灾害概率→离散等级的阈值映射
|
||||
# 灾害概率→等级的阈值映射
|
||||
HAZARD_LEVEL_THRESHOLDS = [
|
||||
(0.6, 'very_high'),
|
||||
(0.4, 'high'),
|
||||
(0.2, 'medium'),
|
||||
(0.05, 'low'),
|
||||
(0.0, 'none'),
|
||||
(0.7, '高'),
|
||||
(0.5, '较高'),
|
||||
(0.3, '中'),
|
||||
(0.0, '低'),
|
||||
]
|
||||
|
||||
def _probability_to_level(self, prob: float) -> str:
|
||||
"""将连续概率映射到离散等级"""
|
||||
"""将连续概率映射到风险等级:低(<30%) / 中(30-50%) / 较高(50-70%) / 高(70%+)"""
|
||||
for threshold, level in self.HAZARD_LEVEL_THRESHOLDS:
|
||||
if prob >= threshold:
|
||||
return level
|
||||
return 'none'
|
||||
return '低'
|
||||
|
||||
def __init__(self, config_dir: Optional[str] = None):
|
||||
"""
|
||||
@@ -374,6 +373,40 @@ class RainfallDBN:
|
||||
|
||||
return results
|
||||
|
||||
def predict_multiple_points(self, points: List[Dict[str, Any]],
|
||||
rainfall: Optional[float] = None,
|
||||
duration: Optional[float] = None,
|
||||
query_time: Optional[datetime] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
对已获取的点列表进行暴雨灾害预测
|
||||
|
||||
Args:
|
||||
points: 点信息列表(已从数据库获取)
|
||||
rainfall: 累计降雨量(可选)
|
||||
duration: 持续时间(可选)
|
||||
query_time: 查询时间(可选)
|
||||
|
||||
Returns:
|
||||
预测结果列表
|
||||
"""
|
||||
results = []
|
||||
for point in points:
|
||||
try:
|
||||
result = self.predict_single_point(
|
||||
point, rainfall=rainfall, duration=duration, query_time=query_time
|
||||
)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"预测点 {point.get('id')} 失败: {e}")
|
||||
results.append({
|
||||
'point_id': point.get('id'),
|
||||
'source_type': point.get('source_type'),
|
||||
'lon': point.get('lon'),
|
||||
'lat': point.get('lat'),
|
||||
'error': str(e)
|
||||
})
|
||||
return results
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取模型信息
|
||||
|
||||
@@ -22,26 +22,35 @@ class DbnRepository:
|
||||
获取所有隐患点和风险点(从 xian_risk_factors 表)
|
||||
|
||||
Args:
|
||||
region_code: 行政区划代码(区县名称),可选
|
||||
region_code: 行政区划代码(如 '610104'),可选,匹配隐患点.county_code 和风险点.unit_code
|
||||
|
||||
Returns:
|
||||
点列表,每个元素包含:id, source_id, source_type, lon, lat, static_factors
|
||||
"""
|
||||
sql = """
|
||||
SELECT
|
||||
id,
|
||||
source_id,
|
||||
source_type,
|
||||
lon,
|
||||
lat,
|
||||
static_factors
|
||||
FROM xian_risk_factors
|
||||
WHERE is_delete = 0
|
||||
"""
|
||||
params = (region_code,) if region_code else None
|
||||
|
||||
if region_code:
|
||||
sql += " AND county = %s"
|
||||
# 通过源表的行政区划代码筛选
|
||||
sql = """
|
||||
SELECT rf.id, rf.source_id, rf.source_type, rf.lon, rf.lat, rf.static_factors
|
||||
FROM xian_risk_factors rf
|
||||
WHERE rf.is_delete = 0
|
||||
AND (
|
||||
(rf.source_type = 1 AND rf.source_id IN (
|
||||
SELECT id FROM xian_hidden_danger_spots WHERE county_id = %s AND is_delete = 0
|
||||
))
|
||||
OR
|
||||
(rf.source_type = 2 AND rf.source_id IN (
|
||||
SELECT id FROM xian_risk_spots WHERE unit_code = %s AND is_delete = 0
|
||||
))
|
||||
)
|
||||
"""
|
||||
params = (region_code, region_code)
|
||||
else:
|
||||
sql = """
|
||||
SELECT id, source_id, source_type, lon, lat, static_factors
|
||||
FROM xian_risk_factors
|
||||
WHERE is_delete = 0
|
||||
"""
|
||||
params = None
|
||||
|
||||
results = db_helper.execute_query(sql, params)
|
||||
|
||||
@@ -94,6 +103,43 @@ class DbnRepository:
|
||||
'static_factors': result.get('static_factors') or {}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_points_by_ids(point_ids: List[int]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
批量获取点信息
|
||||
|
||||
Args:
|
||||
point_ids: 点ID列表
|
||||
|
||||
Returns:
|
||||
点信息列表
|
||||
"""
|
||||
if not point_ids:
|
||||
return []
|
||||
|
||||
placeholders = ','.join(['%s'] * len(point_ids))
|
||||
sql = f"""
|
||||
SELECT
|
||||
rf.id,
|
||||
rf.source_id,
|
||||
rf.source_type,
|
||||
rf.lon,
|
||||
rf.lat,
|
||||
rf.static_factors
|
||||
FROM xian_risk_factors rf
|
||||
WHERE rf.id IN ({placeholders}) AND rf.is_delete = 0
|
||||
"""
|
||||
results = db_helper.execute_query(sql, tuple(point_ids))
|
||||
|
||||
return [{
|
||||
'id': row['id'],
|
||||
'source_id': row['source_id'],
|
||||
'source_type': row['source_type'],
|
||||
'lon': float(row['lon']) if row['lon'] else None,
|
||||
'lat': float(row['lat']) if row['lat'] else None,
|
||||
'static_factors': row.get('static_factors') or {}
|
||||
} for row in results]
|
||||
|
||||
@staticmethod
|
||||
def get_static_factors(point_id: int) -> Dict[str, Any]:
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
API 请求/响应数据模型
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 暴雨预测
|
||||
# ============================================================
|
||||
|
||||
class RainfallPredictRequest(BaseModel):
|
||||
"""暴雨灾害链预测请求"""
|
||||
point_ids: Optional[List[int]] = Field(None, max_length=500,
|
||||
description="点位ID列表,不传则查询所有点")
|
||||
region_code: Optional[str] = Field(None, description="行政区划代码(如 '610104'),不传则不限区域")
|
||||
rainfall: float = Field(..., ge=0, description="累计降雨量(mm)")
|
||||
duration: float = Field(..., ge=0, description="降雨持续时间(h)")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 地震预测
|
||||
# ============================================================
|
||||
|
||||
class EarthquakePredictRequest(BaseModel):
|
||||
"""地震灾害链预测请求"""
|
||||
point_ids: Optional[List[int]] = Field(None, max_length=500,
|
||||
description="点位ID列表,不传则查询所有点")
|
||||
region_code: Optional[str] = Field(None, description="行政区划代码(如 '610104'),不传则不限区域")
|
||||
magnitude: float = Field(..., ge=0, le=10, description="震级(Richter)")
|
||||
depth: float = Field(10.0, gt=0, le=700, description="震源深度(km),默认10km")
|
||||
epicenter_lon: float = Field(..., ge=-180, le=180, description="震中经度")
|
||||
epicenter_lat: float = Field(..., ge=-90, le=90, description="震中纬度")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 通用响应
|
||||
# ============================================================
|
||||
|
||||
class PredictionItem(BaseModel):
|
||||
"""单个点位预测结果"""
|
||||
id: int = Field(..., description="点位ID")
|
||||
type: str = Field(..., description="类型: 隐患点 / 风险点")
|
||||
probability: float = Field(..., description="最大灾害概率")
|
||||
level: str = Field(..., description="灾害等级: 低/中/较高/高")
|
||||
|
||||
|
||||
class PredictResponse(BaseModel):
|
||||
"""预测响应"""
|
||||
code: int = Field(200, description="状态码")
|
||||
message: str = Field("success", description="提示信息")
|
||||
data: List[PredictionItem] = Field(default_factory=list, description="预测结果列表")
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""健康检查响应"""
|
||||
status: str = "ok"
|
||||
rainfall_model_loaded: bool = False
|
||||
earthquake_model_loaded: bool = False
|
||||
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
API 依赖注入
|
||||
模型懒加载 + 并发控制
|
||||
"""
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from app.config.paths import get_logger
|
||||
|
||||
logger = get_logger("api")
|
||||
|
||||
# ============================================================
|
||||
# 并发控制:限制同时进行的预测任务数,防止资源耗尽
|
||||
# ============================================================
|
||||
MAX_CONCURRENT_PREDICTIONS = 8
|
||||
_prediction_semaphore: Optional[asyncio.Semaphore] = None
|
||||
|
||||
|
||||
def get_prediction_semaphore() -> asyncio.Semaphore:
|
||||
"""获取预测信号量(惰性初始化,兼容事件循环)"""
|
||||
global _prediction_semaphore
|
||||
if _prediction_semaphore is None:
|
||||
_prediction_semaphore = asyncio.Semaphore(MAX_CONCURRENT_PREDICTIONS)
|
||||
return _prediction_semaphore
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 模型单例(启动时加载一次)
|
||||
# ============================================================
|
||||
_rainfall_model = None
|
||||
_earthquake_model = None
|
||||
|
||||
|
||||
def get_rainfall_model():
|
||||
"""获取暴雨DBN模型单例"""
|
||||
global _rainfall_model
|
||||
if _rainfall_model is None:
|
||||
from app.models.dbn.rainfall.rainfall_dbn import RainfallDBN
|
||||
_rainfall_model = RainfallDBN()
|
||||
logger.info("暴雨DBN模型加载完成")
|
||||
return _rainfall_model
|
||||
|
||||
|
||||
def get_earthquake_model():
|
||||
"""获取地震DBN模型单例"""
|
||||
global _earthquake_model
|
||||
if _earthquake_model is None:
|
||||
from app.models.dbn.earthquake.earthquake_dbn import EarthquakeDBN
|
||||
_earthquake_model = EarthquakeDBN()
|
||||
logger.info("地震DBN模型加载完成")
|
||||
return _earthquake_model
|
||||
|
||||
|
||||
def is_model_loaded() -> dict:
|
||||
"""检查模型加载状态"""
|
||||
return {
|
||||
"rainfall_model_loaded": _rainfall_model is not None,
|
||||
"earthquake_model_loaded": _earthquake_model is not None,
|
||||
}
|
||||
@@ -6,3 +6,5 @@ scipy == 1.17.1
|
||||
matplotlib == 3.10.0
|
||||
Pillow == 12.2.0
|
||||
pyyaml == 6.0.2
|
||||
fastapi == 0.136.3
|
||||
uvicorn[standard] == 0.49.0
|
||||
Reference in New Issue
Block a user