增加概率阈值
This commit is contained in:
@@ -34,8 +34,10 @@ def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]:
|
||||
return result_map
|
||||
|
||||
|
||||
def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
||||
def _build_prediction_map_with_location(results: List[Dict[str, Any]], threshold: float = 50.0) -> Dict[str, Dict[str, Any]]:
|
||||
"""将模型原始结果转换为返回格式: {id_type: {probability, lon, lat}}"""
|
||||
from config import settings
|
||||
threshold = getattr(settings, 'PREDICT_PROBABILITY_THRESHOLD', threshold)
|
||||
result_map = {}
|
||||
for r in results:
|
||||
probs = r.get("disaster_probabilities", {})
|
||||
@@ -45,9 +47,13 @@ def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[s
|
||||
source_id = r["source_id"]
|
||||
source_type = r.get("source_type")
|
||||
max_hazard = max(probs, key=probs.get)
|
||||
prob_value = round(probs[max_hazard] * 100, 2)
|
||||
# 低于阈值不返回
|
||||
if prob_value < threshold:
|
||||
continue
|
||||
key = f"{source_id}_{source_type}"
|
||||
result_map[key] = {
|
||||
"probability": round(probs[max_hazard] * 100, 2),
|
||||
"probability": prob_value,
|
||||
"lon": r.get("lon"),
|
||||
"lat": r.get("lat")
|
||||
}
|
||||
|
||||
+8
-2
@@ -37,8 +37,10 @@ def _build_prediction_map(results: List[Dict[str, Any]]) -> Dict[str, float]:
|
||||
return result_map
|
||||
|
||||
|
||||
def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
||||
def _build_prediction_map_with_location(results: List[Dict[str, Any]], threshold: float = 50.0) -> Dict[str, Dict[str, Any]]:
|
||||
"""将模型原始结果转换为返回格式: {id_type: {probability, lon, lat}}"""
|
||||
from config import settings
|
||||
threshold = getattr(settings, 'PREDICT_PROBABILITY_THRESHOLD', threshold)
|
||||
result_map = {}
|
||||
for r in results:
|
||||
probs = r.get("disaster_probabilities", {})
|
||||
@@ -48,9 +50,13 @@ def _build_prediction_map_with_location(results: List[Dict[str, Any]]) -> Dict[s
|
||||
source_id = r["source_id"]
|
||||
source_type = r.get("source_type")
|
||||
max_hazard = max(probs, key=probs.get)
|
||||
prob_value = round(probs[max_hazard] * 100, 2)
|
||||
# 低于阈值不返回
|
||||
if prob_value < threshold:
|
||||
continue
|
||||
key = f"{source_id}_{source_type}"
|
||||
result_map[key] = {
|
||||
"probability": round(probs[max_hazard] * 100, 2),
|
||||
"probability": prob_value,
|
||||
"lon": r.get("lon"),
|
||||
"lat": r.get("lat")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user