添加存库功能

This commit is contained in:
wzy-warehouse
2026-06-06 13:18:25 +08:00
parent cb2d8c2c54
commit 39b46b58fd
4 changed files with 144 additions and 17 deletions
+51 -7
View File
@@ -47,21 +47,40 @@ def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) ->
def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str], def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
magnitude: float, depth: float, magnitude: float, depth: float,
epicenter_lon: float, epicenter_lat: float) -> List[PredictionItem]: epicenter_lon: float, epicenter_lat: float) -> tuple:
"""同步执行地震预测(在线程池中运行)""" """
同步执行地震预测(在线程池中运行)
Returns:
(预测结果列表, 原始结果)
"""
points = _fetch_points(point_ids, region_code) points = _fetch_points(point_ids, region_code)
if not points: if not points:
return [] return [], []
model = get_earthquake_model() model = get_earthquake_model()
results = model.predict_multiple_points( raw_results = model.predict_multiple_points(
points, points,
magnitude=magnitude, magnitude=magnitude,
depth=depth, depth=depth,
epicenter_lon=epicenter_lon, epicenter_lon=epicenter_lon,
epicenter_lat=epicenter_lat, epicenter_lat=epicenter_lat,
) )
return _build_prediction_items(results) items = _build_prediction_items(raw_results)
save_results = [
{
"point_id": r.get("point_id"),
"source_type": r.get("source_type"),
"lon": r.get("lon"),
"lat": r.get("lat"),
"disaster_probabilities": r.get("disaster_probabilities", {}),
"disaster_levels": r.get("disaster_levels", {})
}
for r in raw_results
]
return items, save_results
@router.post("/predict", response_model=PredictResponse, summary="地震灾害链预测") @router.post("/predict", response_model=PredictResponse, summary="地震灾害链预测")
@@ -75,13 +94,15 @@ async def predict_earthquake(req: EarthquakePredictRequest):
- **depth**: 震源深度(km),默认10km - **depth**: 震源深度(km),默认10km
- **epicenter_lon**: 震中经度 - **epicenter_lon**: 震中经度
- **epicenter_lat**: 震中纬度 - **epicenter_lat**: 震中纬度
- **occurred_time**: 地震发生时间
- **operation_type**: 操作类型(如 '实时监测', '情景模拟', '应急评估'
""" """
semaphore = get_prediction_semaphore() semaphore = get_prediction_semaphore()
async with semaphore: async with semaphore:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
try: try:
items = await loop.run_in_executor( items, save_results = await loop.run_in_executor(
None, _predict_sync, req.point_ids, req.region_code, None, _predict_sync, req.point_ids, req.region_code,
req.magnitude, req.depth, req.epicenter_lon, req.epicenter_lat req.magnitude, req.depth, req.epicenter_lon, req.epicenter_lat
) )
@@ -89,4 +110,27 @@ async def predict_earthquake(req: EarthquakePredictRequest):
logger.error(f"地震预测失败: {e}", exc_info=True) logger.error(f"地震预测失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"预测失败: {e}") raise HTTPException(status_code=500, detail=f"预测失败: {e}")
return PredictResponse(code=200, message="success", data=items) # 保存推理结果
record_id = None
if save_results:
try:
condition = {
"point_ids": req.point_ids,
"region_code": req.region_code,
"magnitude": req.magnitude,
"depth": req.depth,
"epicenter_lon": req.epicenter_lon,
"epicenter_lat": req.epicenter_lat
}
record_id = dbn_repository.save_inference_result(
event_type="earthquake",
occurred_time=req.occurred_time,
operation_type=req.operation_type,
condition=condition,
result=save_results
)
logger.info(f"推理结果已保存,record_id={record_id}")
except Exception as e:
logger.error(f"保存推理结果失败: {e}", exc_info=True)
return PredictResponse(code=200, message="success", data=items, record_id=record_id)
+55 -10
View File
@@ -2,6 +2,7 @@
暴雨灾害链预测接口 暴雨灾害链预测接口
""" """
import asyncio import asyncio
from datetime import datetime
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
@@ -46,15 +47,43 @@ def _fetch_points(point_ids: Optional[List[int]], region_code: Optional[str]) ->
def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str], def _predict_sync(point_ids: Optional[List[int]], region_code: Optional[str],
rainfall: float, duration: float) -> List[PredictionItem]: rainfall: Optional[float], duration: Optional[float],
"""同步执行暴雨预测(在线程池中运行)""" operation_type: str) -> tuple:
"""
同步执行暴雨预测(在线程池中运行)
Returns:
(预测结果列表, 原始结果, 输入条件, 当前时间)
"""
points = _fetch_points(point_ids, region_code) points = _fetch_points(point_ids, region_code)
if not points: if not points:
return [] return [], [], {}, datetime.now()
model = get_rainfall_model() model = get_rainfall_model()
results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration) raw_results = model.predict_multiple_points(points, rainfall=rainfall, duration=duration)
return _build_prediction_items(results) items = _build_prediction_items(raw_results)
# 构建条件和结果用于保存
now = datetime.now()
condition = {
"point_ids": point_ids,
"region_code": region_code,
"rainfall": rainfall,
"duration": duration
}
save_results = [
{
"point_id": r.get("point_id"),
"source_type": r.get("source_type"),
"lon": r.get("lon"),
"lat": r.get("lat"),
"disaster_probabilities": r.get("disaster_probabilities", {}),
"disaster_levels": r.get("disaster_levels", {})
}
for r in raw_results
]
return items, save_results, condition, now
@router.post("/predict", response_model=PredictResponse, summary="暴雨灾害链预测") @router.post("/predict", response_model=PredictResponse, summary="暴雨灾害链预测")
@@ -64,20 +93,36 @@ async def predict_rainfall(req: RainfallPredictRequest):
- **point_ids**: 点位ID列表(可选,不传则查询所有点) - **point_ids**: 点位ID列表(可选,不传则查询所有点)
- **region_code**: 行政区划代码(可选,不传则不限区域) - **region_code**: 行政区划代码(可选,不传则不限区域)
- **rainfall**: 累计降雨量(mm) - **rainfall**: 累计降雨量(mm),不传则从气象表自动获取
- **duration**: 降雨持续时间(h) - **duration**: 降雨持续时间(h),不传则从气象表自动获取
- **operation_type**: 操作类型(如 '实时监测', '情景模拟', '应急评估'
""" """
semaphore = get_prediction_semaphore() semaphore = get_prediction_semaphore()
async with semaphore: async with semaphore:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
try: try:
items = await loop.run_in_executor( items, save_results, condition, now = await loop.run_in_executor(
None, _predict_sync, req.point_ids, req.region_code, None, _predict_sync, req.point_ids, req.region_code,
req.rainfall, req.duration req.rainfall, req.duration, req.operation_type
) )
except Exception as e: except Exception as e:
logger.error(f"暴雨预测失败: {e}", exc_info=True) logger.error(f"暴雨预测失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"预测失败: {e}") raise HTTPException(status_code=500, detail=f"预测失败: {e}")
return PredictResponse(code=200, message="success", data=items) # 保存推理结果
record_id = None
if save_results:
try:
record_id = dbn_repository.save_inference_result(
event_type="rainfall",
occurred_time=now,
operation_type=req.operation_type,
condition=condition,
result=save_results
)
logger.info(f"推理结果已保存,record_id={record_id}")
except Exception as e:
logger.error(f"保存推理结果失败: {e}", exc_info=True)
return PredictResponse(code=200, message="success", data=items, record_id=record_id)
+31
View File
@@ -684,6 +684,37 @@ class DbnRepository:
return float(result['aspect']) return float(result['aspect'])
return None return None
@staticmethod
def save_inference_result(event_type: str, occurred_time, operation_type: str,
condition: dict, result: list) -> int:
"""
保存推理结果到 inference_result 表
Args:
event_type: 事件类型('rainfall''earthquake'
occurred_time: 事件发生时间
operation_type: 操作类型
condition: 输入条件(JSON
result: 预测结果列表(JSON
Returns:
新插入记录的 ID
"""
import json
sql = """
INSERT INTO inference_result (event_type, occurred_time, operation_type, condition, result)
VALUES (%s, %s, %s, %s::jsonb, %s::jsonb)
RETURNING id
"""
row = db_helper.execute_query_one(sql, (
event_type,
occurred_time,
operation_type,
json.dumps(condition, ensure_ascii=False),
json.dumps(result, ensure_ascii=False)
))
return row['id'] if row else 0
# 创建全局实例 # 创建全局实例
dbn_repository = DbnRepository() dbn_repository = DbnRepository()
+7
View File
@@ -1,6 +1,7 @@
""" """
API 请求/响应数据模型 API 请求/响应数据模型
""" """
from datetime import datetime
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -18,6 +19,8 @@ class RainfallPredictRequest(BaseModel):
description="累计降雨量(mm),不传则从气象表自动获取") description="累计降雨量(mm),不传则从气象表自动获取")
duration: Optional[float] = Field(None, ge=0, duration: Optional[float] = Field(None, ge=0,
description="降雨持续时间(h),不传则从气象表自动获取") description="降雨持续时间(h),不传则从气象表自动获取")
operation_type: str = Field("模拟", min_length=1, max_length=50,
description="操作类型(如 '模拟', '实时监测', '应急评估'")
# ============================================================ # ============================================================
@@ -33,6 +36,9 @@ class EarthquakePredictRequest(BaseModel):
depth: float = Field(10.0, gt=0, le=700, description="震源深度(km),默认10km") depth: float = Field(10.0, gt=0, le=700, description="震源深度(km),默认10km")
epicenter_lon: float = Field(..., ge=-180, le=180, description="震中经度") epicenter_lon: float = Field(..., ge=-180, le=180, description="震中经度")
epicenter_lat: float = Field(..., ge=-90, le=90, description="震中纬度") epicenter_lat: float = Field(..., ge=-90, le=90, description="震中纬度")
occurred_time: datetime = Field(..., description="地震发生时间")
operation_type: str = Field("模拟", min_length=1, max_length=50,
description="操作类型(如 '模拟', '实时监测', '应急评估'")
# ============================================================ # ============================================================
@@ -52,6 +58,7 @@ class PredictResponse(BaseModel):
code: int = Field(200, description="状态码") code: int = Field(200, description="状态码")
message: str = Field("success", description="提示信息") message: str = Field("success", description="提示信息")
data: List[PredictionItem] = Field(default_factory=list, description="预测结果列表") data: List[PredictionItem] = Field(default_factory=list, description="预测结果列表")
record_id: Optional[int] = Field(None, description="推理结果记录ID")
class HealthResponse(BaseModel): class HealthResponse(BaseModel):