增加概率阈值
This commit is contained in:
@@ -34,8 +34,10 @@ def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]:
|
|||||||
return result_map
|
return result_map
|
||||||
|
|
||||||
|
|
||||||
def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
def _build_prediction_map_with_location(results: List[Dict[str, Any]], threshold: float = 50.0) -> Dict[str, Dict[str, Any]]:
|
||||||
"""将模型原始结果转换为返回格式: {id_type: {probability, lon, lat}}"""
|
"""将模型原始结果转换为返回格式: {id_type: {probability, lon, lat}}"""
|
||||||
|
from config import settings
|
||||||
|
threshold = getattr(settings, 'PREDICT_PROBABILITY_THRESHOLD', threshold)
|
||||||
result_map = {}
|
result_map = {}
|
||||||
for r in results:
|
for r in results:
|
||||||
probs = r.get("disaster_probabilities", {})
|
probs = r.get("disaster_probabilities", {})
|
||||||
@@ -45,9 +47,13 @@ def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[s
|
|||||||
source_id = r["source_id"]
|
source_id = r["source_id"]
|
||||||
source_type = r.get("source_type")
|
source_type = r.get("source_type")
|
||||||
max_hazard = max(probs, key=probs.get)
|
max_hazard = max(probs, key=probs.get)
|
||||||
|
prob_value = round(probs[max_hazard] * 100, 2)
|
||||||
|
# 低于阈值不返回
|
||||||
|
if prob_value < threshold:
|
||||||
|
continue
|
||||||
key = f"{source_id}_{source_type}"
|
key = f"{source_id}_{source_type}"
|
||||||
result_map[key] = {
|
result_map[key] = {
|
||||||
"probability": round(probs[max_hazard] * 100, 2),
|
"probability": prob_value,
|
||||||
"lon": r.get("lon"),
|
"lon": r.get("lon"),
|
||||||
"lat": r.get("lat")
|
"lat": r.get("lat")
|
||||||
}
|
}
|
||||||
|
|||||||
+8
-2
@@ -37,8 +37,10 @@ def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]:
|
|||||||
return result_map
|
return result_map
|
||||||
|
|
||||||
|
|
||||||
def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
def _build_prediction_map_with_location(results: List[Dict[str, Any]], threshold: float = 50.0) -> Dict[str, Dict[str, Any]]:
|
||||||
"""将模型原始结果转换为返回格式: {id_type: {probability, lon, lat}}"""
|
"""将模型原始结果转换为返回格式: {id_type: {probability, lon, lat}}"""
|
||||||
|
from config import settings
|
||||||
|
threshold = getattr(settings, 'PREDICT_PROBABILITY_THRESHOLD', threshold)
|
||||||
result_map = {}
|
result_map = {}
|
||||||
for r in results:
|
for r in results:
|
||||||
probs = r.get("disaster_probabilities", {})
|
probs = r.get("disaster_probabilities", {})
|
||||||
@@ -48,9 +50,13 @@ def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[s
|
|||||||
source_id = r["source_id"]
|
source_id = r["source_id"]
|
||||||
source_type = r.get("source_type")
|
source_type = r.get("source_type")
|
||||||
max_hazard = max(probs, key=probs.get)
|
max_hazard = max(probs, key=probs.get)
|
||||||
|
prob_value = round(probs[max_hazard] * 100, 2)
|
||||||
|
# 低于阈值不返回
|
||||||
|
if prob_value < threshold:
|
||||||
|
continue
|
||||||
key = f"{source_id}_{source_type}"
|
key = f"{source_id}_{source_type}"
|
||||||
result_map[key] = {
|
result_map[key] = {
|
||||||
"probability": round(probs[max_hazard] * 100, 2),
|
"probability": prob_value,
|
||||||
"lon": r.get("lon"),
|
"lon": r.get("lon"),
|
||||||
"lat": r.get("lat")
|
"lat": r.get("lat")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ RAIN_STATION_GRID_DIR = "/xian/rainfall/grid/images/:id"
|
|||||||
REDIS_RAIN_STATION_GRID_KEY = "xian:rainfall:rain_station_grid"
|
REDIS_RAIN_STATION_GRID_KEY = "xian:rainfall:rain_station_grid"
|
||||||
# 雨量站存储标识符的redis的key
|
# 雨量站存储标识符的redis的key
|
||||||
REDIS_RAIN_STATION_IDENTIFIER_KEY = "xian:rainfall:rain_station_identifier"
|
REDIS_RAIN_STATION_IDENTIFIER_KEY = "xian:rainfall:rain_station_identifier"
|
||||||
|
# 预测结果概率阈值(低于此值不返回给前端)
|
||||||
|
PREDICT_PROBABILITY_THRESHOLD = 50
|
||||||
|
|
||||||
# 开发环境
|
# 开发环境
|
||||||
[development]
|
[development]
|
||||||
|
|||||||
Reference in New Issue
Block a user