增加概率阈值

This commit is contained in:
wzy-warehouse
2026-06-14 16:50:03 +08:00
parent 75046c99c8
commit 615a563369
3 changed files with 18 additions and 4 deletions
+8 -2
View File
@@ -34,8 +34,10 @@ def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]:
return result_map return result_map
def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: def _build_prediction_map_with_location(results: List[Dict[str, Any]], threshold: float = 50.0) -> Dict[str, Dict[str, Any]]:
"""将模型原始结果转换为返回格式: {id_type: {probability, lon, lat}}""" """将模型原始结果转换为返回格式: {id_type: {probability, lon, lat}}"""
from config import settings
threshold = getattr(settings, 'PREDICT_PROBABILITY_THRESHOLD', threshold)
result_map = {} result_map = {}
for r in results: for r in results:
probs = r.get("disaster_probabilities", {}) probs = r.get("disaster_probabilities", {})
@@ -45,9 +47,13 @@ def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[s
source_id = r["source_id"] source_id = r["source_id"]
source_type = r.get("source_type") source_type = r.get("source_type")
max_hazard = max(probs, key=probs.get) max_hazard = max(probs, key=probs.get)
prob_value = round(probs[max_hazard] * 100, 2)
# 低于阈值不返回
if prob_value < threshold:
continue
key = f"{source_id}_{source_type}" key = f"{source_id}_{source_type}"
result_map[key] = { result_map[key] = {
"probability": round(probs[max_hazard] * 100, 2), "probability": prob_value,
"lon": r.get("lon"), "lon": r.get("lon"),
"lat": r.get("lat") "lat": r.get("lat")
} }
+8 -2
View File
@@ -37,8 +37,10 @@ def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]:
return result_map return result_map
def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: def _build_prediction_map_with_location(results: List[Dict[str, Any]], threshold: float = 50.0) -> Dict[str, Dict[str, Any]]:
"""将模型原始结果转换为返回格式: {id_type: {probability, lon, lat}}""" """将模型原始结果转换为返回格式: {id_type: {probability, lon, lat}}"""
from config import settings
threshold = getattr(settings, 'PREDICT_PROBABILITY_THRESHOLD', threshold)
result_map = {} result_map = {}
for r in results: for r in results:
probs = r.get("disaster_probabilities", {}) probs = r.get("disaster_probabilities", {})
@@ -48,9 +50,13 @@ def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[s
source_id = r["source_id"] source_id = r["source_id"]
source_type = r.get("source_type") source_type = r.get("source_type")
max_hazard = max(probs, key=probs.get) max_hazard = max(probs, key=probs.get)
prob_value = round(probs[max_hazard] * 100, 2)
# 低于阈值不返回
if prob_value < threshold:
continue
key = f"{source_id}_{source_type}" key = f"{source_id}_{source_type}"
result_map[key] = { result_map[key] = {
"probability": round(probs[max_hazard] * 100, 2), "probability": prob_value,
"lon": r.get("lon"), "lon": r.get("lon"),
"lat": r.get("lat") "lat": r.get("lat")
} }
+2
View File
@@ -8,6 +8,8 @@ RAIN_STATION_GRID_DIR = "/xian/rainfall/grid/images/:id"
REDIS_RAIN_STATION_GRID_KEY = "xian:rainfall:rain_station_grid" REDIS_RAIN_STATION_GRID_KEY = "xian:rainfall:rain_station_grid"
# 雨量站存储标识符的redis的key # 雨量站存储标识符的redis的key
REDIS_RAIN_STATION_IDENTIFIER_KEY = "xian:rainfall:rain_station_identifier" REDIS_RAIN_STATION_IDENTIFIER_KEY = "xian:rainfall:rain_station_identifier"
# 预测结果概率阈值(低于此值不返回给前端)
PREDICT_PROBABILITY_THRESHOLD = 50
# 开发环境 # 开发环境
[development] [development]