添加存库功能
This commit is contained in:
+51
-7
@@ -47,21 +47,40 @@ def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) ->
|
|||||||
|
|
||||||
def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
||||||
magnitude: float, depth: float,
|
magnitude: float, depth: float,
|
||||||
epicenter_lon: float, epicenter_lat: float) -> List[PredictionItem]:
|
epicenter_lon: float, epicenter_lat: float) -> tuple:
|
||||||
"""同步执行地震预测(在线程池中运行)"""
|
"""
|
||||||
|
同步执行地震预测(在线程池中运行)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(预测结果列表, 原始结果)
|
||||||
|
"""
|
||||||
points = _fetch_points(point_ids, region_code)
|
points = _fetch_points(point_ids, region_code)
|
||||||
if not points:
|
if not points:
|
||||||
return []
|
return [], []
|
||||||
|
|
||||||
model = get_earthquake_model()
|
model = get_earthquake_model()
|
||||||
results = model.predict_multiple_points(
|
raw_results = model.predict_multiple_points(
|
||||||
points,
|
points,
|
||||||
magnitude=magnitude,
|
magnitude=magnitude,
|
||||||
depth=depth,
|
depth=depth,
|
||||||
epicenter_lon=epicenter_lon,
|
epicenter_lon=epicenter_lon,
|
||||||
epicenter_lat=epicenter_lat,
|
epicenter_lat=epicenter_lat,
|
||||||
)
|
)
|
||||||
return _build_prediction_items(results)
|
items = _build_prediction_items(raw_results)
|
||||||
|
|
||||||
|
save_results = [
|
||||||
|
{
|
||||||
|
"point_id": r.get("point_id"),
|
||||||
|
"source_type": r.get("source_type"),
|
||||||
|
"lon": r.get("lon"),
|
||||||
|
"lat": r.get("lat"),
|
||||||
|
"disaster_probabilities": r.get("disaster_probabilities", {}),
|
||||||
|
"disaster_levels": r.get("disaster_levels", {})
|
||||||
|
}
|
||||||
|
for r in raw_results
|
||||||
|
]
|
||||||
|
|
||||||
|
return items, save_results
|
||||||
|
|
||||||
|
|
||||||
@router.post("/predict", response_model=PredictResponse, summary="地震灾害链预测")
|
@router.post("/predict", response_model=PredictResponse, summary="地震灾害链预测")
|
||||||
@@ -75,13 +94,15 @@ async def predict_earthquake(req: EarthquakePredictRequest):
|
|||||||
- **depth**: 震源深度(km),默认10km
|
- **depth**: 震源深度(km),默认10km
|
||||||
- **epicenter_lon**: 震中经度
|
- **epicenter_lon**: 震中经度
|
||||||
- **epicenter_lat**: 震中纬度
|
- **epicenter_lat**: 震中纬度
|
||||||
|
- **occurred_time**: 地震发生时间
|
||||||
|
- **operation_type**: 操作类型(如 '实时监测', '情景模拟', '应急评估')
|
||||||
"""
|
"""
|
||||||
semaphore = get_prediction_semaphore()
|
semaphore = get_prediction_semaphore()
|
||||||
|
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
try:
|
try:
|
||||||
items = await loop.run_in_executor(
|
items, save_results = await loop.run_in_executor(
|
||||||
None, _predict_sync, req.point_ids, req.region_code,
|
None, _predict_sync, req.point_ids, req.region_code,
|
||||||
req.magnitude, req.depth, req.epicenter_lon, req.epicenter_lat
|
req.magnitude, req.depth, req.epicenter_lon, req.epicenter_lat
|
||||||
)
|
)
|
||||||
@@ -89,4 +110,27 @@ async def predict_earthquake(req: EarthquakePredictRequest):
|
|||||||
logger.error(f"地震预测失败: {e}", exc_info=True)
|
logger.error(f"地震预测失败: {e}", exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail=f"预测失败: {e}")
|
raise HTTPException(status_code=500, detail=f"预测失败: {e}")
|
||||||
|
|
||||||
return PredictResponse(code=200, message="success", data=items)
|
# 保存推理结果
|
||||||
|
record_id = None
|
||||||
|
if save_results:
|
||||||
|
try:
|
||||||
|
condition = {
|
||||||
|
"point_ids": req.point_ids,
|
||||||
|
"region_code": req.region_code,
|
||||||
|
"magnitude": req.magnitude,
|
||||||
|
"depth": req.depth,
|
||||||
|
"epicenter_lon": req.epicenter_lon,
|
||||||
|
"epicenter_lat": req.epicenter_lat
|
||||||
|
}
|
||||||
|
record_id = dbn_repository.save_inference_result(
|
||||||
|
event_type="earthquake",
|
||||||
|
occurred_time=req.occurred_time,
|
||||||
|
operation_type=req.operation_type,
|
||||||
|
condition=condition,
|
||||||
|
result=save_results
|
||||||
|
)
|
||||||
|
logger.info(f"推理结果已保存,record_id={record_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存推理结果失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
return PredictResponse(code=200, message="success", data=items, record_id=record_id)
|
||||||
|
|||||||
+55
-10
@@ -2,6 +2,7 @@
|
|||||||
暴雨灾害链预测接口
|
暴雨灾害链预测接口
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from datetime import datetime
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
@@ -46,15 +47,43 @@ def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) ->
|
|||||||
|
|
||||||
|
|
||||||
def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
|
||||||
rainfall: float, duration: float) -> List[PredictionItem]:
|
rainfall: Optional[float], duration: Optional[float],
|
||||||
"""同步执行暴雨预测(在线程池中运行)"""
|
operation_type: str) -> tuple:
|
||||||
|
"""
|
||||||
|
同步执行暴雨预测(在线程池中运行)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(预测结果列表, 原始结果, 输入条件, 当前时间)
|
||||||
|
"""
|
||||||
points = _fetch_points(point_ids, region_code)
|
points = _fetch_points(point_ids, region_code)
|
||||||
if not points:
|
if not points:
|
||||||
return []
|
return [], [], {}, datetime.now()
|
||||||
|
|
||||||
model = get_rainfall_model()
|
model = get_rainfall_model()
|
||||||
results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration)
|
raw_results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration)
|
||||||
return _build_prediction_items(results)
|
items = _build_prediction_items(raw_results)
|
||||||
|
|
||||||
|
# 构建条件和结果用于保存
|
||||||
|
now = datetime.now()
|
||||||
|
condition = {
|
||||||
|
"point_ids": point_ids,
|
||||||
|
"region_code": region_code,
|
||||||
|
"rainfall": rainfall,
|
||||||
|
"duration": duration
|
||||||
|
}
|
||||||
|
save_results = [
|
||||||
|
{
|
||||||
|
"point_id": r.get("point_id"),
|
||||||
|
"source_type": r.get("source_type"),
|
||||||
|
"lon": r.get("lon"),
|
||||||
|
"lat": r.get("lat"),
|
||||||
|
"disaster_probabilities": r.get("disaster_probabilities", {}),
|
||||||
|
"disaster_levels": r.get("disaster_levels", {})
|
||||||
|
}
|
||||||
|
for r in raw_results
|
||||||
|
]
|
||||||
|
|
||||||
|
return items, save_results, condition, now
|
||||||
|
|
||||||
|
|
||||||
@router.post("/predict", response_model=PredictResponse, summary="暴雨灾害链预测")
|
@router.post("/predict", response_model=PredictResponse, summary="暴雨灾害链预测")
|
||||||
@@ -64,20 +93,36 @@ async def predict_rainfall(req: RainfallPredictRequest):
|
|||||||
|
|
||||||
- **point_ids**: 点位ID列表(可选,不传则查询所有点)
|
- **point_ids**: 点位ID列表(可选,不传则查询所有点)
|
||||||
- **region_code**: 行政区划代码(可选,不传则不限区域)
|
- **region_code**: 行政区划代码(可选,不传则不限区域)
|
||||||
- **rainfall**: 累计降雨量(mm)
|
- **rainfall**: 累计降雨量(mm),不传则从气象表自动获取
|
||||||
- **duration**: 降雨持续时间(h)
|
- **duration**: 降雨持续时间(h),不传则从气象表自动获取
|
||||||
|
- **operation_type**: 操作类型(如 '实时监测', '情景模拟', '应急评估')
|
||||||
"""
|
"""
|
||||||
semaphore = get_prediction_semaphore()
|
semaphore = get_prediction_semaphore()
|
||||||
|
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
try:
|
try:
|
||||||
items = await loop.run_in_executor(
|
items, save_results, condition, now = await loop.run_in_executor(
|
||||||
None, _predict_sync, req.point_ids, req.region_code,
|
None, _predict_sync, req.point_ids, req.region_code,
|
||||||
req.rainfall, req.duration
|
req.rainfall, req.duration, req.operation_type
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"暴雨预测失败: {e}", exc_info=True)
|
logger.error(f"暴雨预测失败: {e}", exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail=f"预测失败: {e}")
|
raise HTTPException(status_code=500, detail=f"预测失败: {e}")
|
||||||
|
|
||||||
return PredictResponse(code=200, message="success", data=items)
|
# 保存推理结果
|
||||||
|
record_id = None
|
||||||
|
if save_results:
|
||||||
|
try:
|
||||||
|
record_id = dbn_repository.save_inference_result(
|
||||||
|
event_type="rainfall",
|
||||||
|
occurred_time=now,
|
||||||
|
operation_type=req.operation_type,
|
||||||
|
condition=condition,
|
||||||
|
result=save_results
|
||||||
|
)
|
||||||
|
logger.info(f"推理结果已保存,record_id={record_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存推理结果失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
return PredictResponse(code=200, message="success", data=items, record_id=record_id)
|
||||||
|
|||||||
@@ -684,6 +684,37 @@ class DbnRepository:
|
|||||||
return float(result['aspect'])
|
return float(result['aspect'])
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_inference_result(event_type: str, occurred_time, operation_type: str,
|
||||||
|
condition: dict, result: list) -> int:
|
||||||
|
"""
|
||||||
|
保存推理结果到 inference_result 表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_type: 事件类型('rainfall' 或 'earthquake')
|
||||||
|
occurred_time: 事件发生时间
|
||||||
|
operation_type: 操作类型
|
||||||
|
condition: 输入条件(JSON)
|
||||||
|
result: 预测结果列表(JSON)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
新插入记录的 ID
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
sql = """
|
||||||
|
INSERT INTO inference_result (event_type, occurred_time, operation_type, condition, result)
|
||||||
|
VALUES (%s, %s, %s, %s::jsonb, %s::jsonb)
|
||||||
|
RETURNING id
|
||||||
|
"""
|
||||||
|
row = db_helper.execute_query_one(sql, (
|
||||||
|
event_type,
|
||||||
|
occurred_time,
|
||||||
|
operation_type,
|
||||||
|
json.dumps(condition, ensure_ascii=False),
|
||||||
|
json.dumps(result, ensure_ascii=False)
|
||||||
|
))
|
||||||
|
return row['id'] if row else 0
|
||||||
|
|
||||||
|
|
||||||
# 创建全局实例
|
# 创建全局实例
|
||||||
dbn_repository = DbnRepository()
|
dbn_repository = DbnRepository()
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
API 请求/响应数据模型
|
API 请求/响应数据模型
|
||||||
"""
|
"""
|
||||||
|
from datetime import datetime
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -18,6 +19,8 @@ class RainfallPredictRequest(BaseModel):
|
|||||||
description="累计降雨量(mm),不传则从气象表自动获取")
|
description="累计降雨量(mm),不传则从气象表自动获取")
|
||||||
duration: Optional[float] = Field(None, ge=0,
|
duration: Optional[float] = Field(None, ge=0,
|
||||||
description="降雨持续时间(h),不传则从气象表自动获取")
|
description="降雨持续时间(h),不传则从气象表自动获取")
|
||||||
|
operation_type: str = Field("模拟", min_length=1, max_length=50,
|
||||||
|
description="操作类型(如 '模拟', '实时监测', '应急评估')")
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
@@ -33,6 +36,9 @@ class EarthquakePredictRequest(BaseModel):
|
|||||||
depth: float = Field(10.0, gt=0, le=700, description="震源深度(km),默认10km")
|
depth: float = Field(10.0, gt=0, le=700, description="震源深度(km),默认10km")
|
||||||
epicenter_lon: float = Field(..., ge=-180, le=180, description="震中经度")
|
epicenter_lon: float = Field(..., ge=-180, le=180, description="震中经度")
|
||||||
epicenter_lat: float = Field(..., ge=-90, le=90, description="震中纬度")
|
epicenter_lat: float = Field(..., ge=-90, le=90, description="震中纬度")
|
||||||
|
occurred_time: datetime = Field(..., description="地震发生时间")
|
||||||
|
operation_type: str = Field("模拟", min_length=1, max_length=50,
|
||||||
|
description="操作类型(如 '模拟', '实时监测', '应急评估')")
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
@@ -52,6 +58,7 @@ class PredictResponse(BaseModel):
|
|||||||
code: int = Field(200, description="状态码")
|
code: int = Field(200, description="状态码")
|
||||||
message: str = Field("success", description="提示信息")
|
message: str = Field("success", description="提示信息")
|
||||||
data: List[PredictionItem] = Field(default_factory=list, description="预测结果列表")
|
data: List[PredictionItem] = Field(default_factory=list, description="预测结果列表")
|
||||||
|
record_id: Optional[int] = Field(None, description="推理结果记录ID")
|
||||||
|
|
||||||
|
|
||||||
class HealthResponse(BaseModel):
|
class HealthResponse(BaseModel):
|
||||||
|
|||||||
Reference in New Issue
Block a user