diff --git a/app/api/earthquake.py b/app/api/earthquake.py index 4f49070..691cb9d 100644 --- a/app/api/earthquake.py +++ b/app/api/earthquake.py @@ -34,8 +34,10 @@ def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]: return result_map -def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: +def _build_prediction_map_with_location(results: List[Dict[str, Any]], threshold: float = 50.0) -> Dict[str, Dict[str, Any]]: """将模型原始结果转换为返回格式: {id_type: {probability, lon, lat}}""" + from config import settings + threshold = getattr(settings, 'PREDICT_PROBABILITY_THRESHOLD', threshold) result_map = {} for r in results: probs = r.get("disaster_probabilities", {}) @@ -45,9 +47,13 @@ def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[s source_id = r["source_id"] source_type = r.get("source_type") max_hazard = max(probs, key=probs.get) + prob_value = round(probs[max_hazard] * 100, 2) + # 低于阈值不返回 + if prob_value < threshold: + continue key = f"{source_id}_{source_type}" result_map[key] = { - "probability": round(probs[max_hazard] * 100, 2), + "probability": prob_value, "lon": r.get("lon"), "lat": r.get("lat") } diff --git a/app/api/rainfall.py b/app/api/rainfall.py index 0f00df2..768dec7 100644 --- a/app/api/rainfall.py +++ b/app/api/rainfall.py @@ -37,8 +37,10 @@ def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]: return result_map -def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: +def _build_prediction_map_with_location(results: List[Dict[str, Any]], threshold: float = 50.0) -> Dict[str, Dict[str, Any]]: """将模型原始结果转换为返回格式: {id_type: {probability, lon, lat}}""" + from config import settings + threshold = getattr(settings, 'PREDICT_PROBABILITY_THRESHOLD', threshold) result_map = {} for r in results: probs = r.get("disaster_probabilities", {}) @@ -48,9 +50,13 @@ def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[s source_id = r["source_id"] source_type = r.get("source_type") max_hazard = max(probs, key=probs.get) + prob_value = round(probs[max_hazard] * 100, 2) + # 低于阈值不返回 + if prob_value < threshold: + continue key = f"{source_id}_{source_type}" result_map[key] = { - "probability": round(probs[max_hazard] * 100, 2), + "probability": prob_value, "lon": r.get("lon"), "lat": r.get("lat") } diff --git a/settings.toml b/settings.toml index d6e9ac1..83ba84d 100644 --- a/settings.toml +++ b/settings.toml @@ -8,6 +8,8 @@ RAIN_STATION_GRID_DIR = "/xian/rainfall/grid/images/:id" REDIS_RAIN_STATION_GRID_KEY = "xian:rainfall:rain_station_grid" # 雨量站存储标识符的redis的key REDIS_RAIN_STATION_IDENTIFIER_KEY = "xian:rainfall:rain_station_identifier" +# 预测结果概率阈值(低于此值不返回给前端) +PREDICT_PROBABILITY_THRESHOLD = 50 # 开发环境 [development]