初始化代码

This commit is contained in:
wzy-warehouse
2026-05-08 15:42:32 +08:00
parent 7d18effcfe
commit 4ef23fec7c
26 changed files with 140 additions and 2263 deletions
-22
View File
@@ -1,22 +0,0 @@
# 开发环境配置
ENVIRONMENT=development
# 数据库配置
DB_HOST=47.92.216.173
DB_PORT=7654
DB_USER=postgres
DB_PASSWORD=zhangsan
DB_NAME=xian_new
# FastAPI配置
API_HOST=127.0.0.1
API_PORT=8082
# 应用配置
APP_NAME=西安项目算法
APP_VERSION=1.0.0
DEBUG=True
# 日志配置
LOG_LEVEL=DEBUG
LOG_DIR=logs
-22
View File
@@ -1,22 +0,0 @@
# 生产环境配置
ENVIRONMENT=production
# 数据库配置
DB_HOST=10.22.245.138
DB_PORT=54321
DB_USER=zaihailian
DB_PASSWORD=XAYJ@gis2603
DB_NAME=xianDC
# FastAPI配置
API_HOST=127.0.0.1
API_PORT=8081
# 应用配置
APP_NAME=西安项目算法
APP_VERSION=1.0.0
DEBUG=False
# 日志配置
LOG_LEVEL=WARNING
LOG_DIR=logs
+3
View File
@@ -51,3 +51,6 @@ htmlcov/
# Jupyter Notebook
.ipynb_checkpoints
# Ignore dynaconf secret files
.secrets.*
-173
View File
@@ -1,173 +0,0 @@
"""
降雨数据API Controller - 路由层
"""
from fastapi import APIRouter, HTTPException, Query
from datetime import datetime
from typing import Optional
from app.services.rainfall_service import RainfallService
from app.schemas.rainfall import (
RainfallGridRequest,
RainfallGridResponse,
StationsResponse
)
from app.utils.logger import setup_logging
logger = setup_logging()
router = APIRouter(
prefix="/rainfall",
tags=["降雨数据"],
responses={404: {"description": "Not found"}}
)
# 创建服务实例
rainfall_service = RainfallService()
@router.post("/grid", response_model=RainfallGridResponse, summary="获取降雨栅格数据")
async def get_rainfall_grid(request: RainfallGridRequest):
"""
获取指定时间的降雨栅格数据
使用反距离权重插值(IDW)方法,将站点降雨数据插值为连续栅格,
返回适合Cesium渲染的GeoJSON格式数据。
Args:
request: 包含时间、分辨率和持续时间的请求
Returns:
GeoJSON格式的栅格数据
"""
try:
# 解析时间,如果未提供则使用当前时间
now = datetime.now()
query_time = datetime.fromisoformat(request.time) if request.time else now
# 验证duration参数
if request.duration not in [12, 24]:
raise ValueError("duration参数必须为12或24")
# 调用服务层生成栅格(自动查询前12小时或24小时数据)
geojson_data = rainfall_service.generate_rainfall_grid(
query_time=query_time,
resolution=request.resolution,
duration=request.duration
)
if not geojson_data:
return RainfallGridResponse(
code=404,
message="未找到降雨数据",
data=None
)
return RainfallGridResponse(
code=200,
message="降雨栅格数据生成成功",
data=geojson_data
)
except ValueError as e:
logger.error(f"时间格式错误: {e}")
raise HTTPException(status_code=400, detail=f"时间格式错误: {str(e)}")
except Exception as e:
logger.error(f"生成降雨栅格失败: {e}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"生成降雨栅格失败: {str(e)}")
@router.get("/stations", response_model=StationsResponse, summary="获取雨量站点数据")
async def get_rainfall_stations(
time: str = Query(..., description="查询时间 ISO格式(自动查询前12小时或24小时数据)"),
duration: int = Query(12, description="持续时间(小时),可选12或24", ge=12, le=24)
):
"""
获取指定时间的雨量站点原始数据
Args:
time: 查询时间
duration: 持续时间(12或24小时)
Returns:
站点列表,包含经纬度和降雨量
"""
try:
query_time = datetime.fromisoformat(time)
# 调用服务层获取站点数据(自动查询前12小时或24小时数据)
stations = rainfall_service.get_stations_data(
query_time=query_time,
duration=duration
)
return StationsResponse(
code=200,
message="查询成功",
data=stations
)
except ValueError as e:
logger.error(f"时间格式错误: {e}")
raise HTTPException(status_code=400, detail=f"时间格式错误: {str(e)}")
except Exception as e:
logger.error(f"查询站点数据失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/point", summary="查询指定点位的降雨量")
async def get_rainfall_at_point(
longitude: float,
latitude: float,
time: Optional[str] = None,
duration: int = Query(12, description="持续时间(小时),可选12或24", ge=12, le=24)
):
"""
查询指定经纬度位置的降雨量
Args:
longitude: 经度
latitude: 纬度
time: 查询时间(可选,默认当前时间,自动查询前12小时或24小时数据)
duration: 持续时间(12或24小时)
Returns:
该点位的降雨量信息
"""
try:
from app.services.rainfall_service import RainfallService
# 解析时间
now = datetime.now()
query_time = datetime.fromisoformat(time) if time else now
# 验证duration参数
if duration not in [12, 24]:
raise ValueError("duration参数必须为12或24")
# 调用服务层查询(自动查询前12小时或24小时数据)
service = RainfallService()
rainfall_info = service.get_rainfall_at_point(
longitude=longitude,
latitude=latitude,
query_time=query_time,
duration=duration
)
if not rainfall_info:
return {
"code": 200,
"message": "未找到该点位的降雨数据",
"data": None
}
return {
"code": 200,
"message": "查询成功",
"data": rainfall_info
}
except Exception as e:
logger.error(f"查询点位降雨量失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
-48
View File
@@ -1,48 +0,0 @@
"""
基础配置类
"""
from pydantic_settings import BaseSettings
from enum import Enum
class EnvironmentEnum(str, Enum):
"""环境枚举"""
DEVELOPMENT = "development"
PRODUCTION = "production"
class BaseConfig(BaseSettings):
"""基础配置类"""
# 应用基本信息
APP_NAME: str = "西安项目算法"
APP_VERSION: str = "1.0.0"
ENVIRONMENT: EnvironmentEnum = EnvironmentEnum.DEVELOPMENT
# 调试模式
DEBUG: bool = True
# API配置
API_HOST: str = "127.0.0.1" # 默认只监听本地
API_PORT: int = 8000
# CORS配置(默认只允许localhost
CORS_ORIGINS: list = ["http://localhost", "http://127.0.0.1"]
# 日志配置
LOG_LEVEL: str = "INFO"
LOG_DIR: str = "logs"
class Config:
env_file = ".env.development" # 默认使用开发环境配置
case_sensitive = True
@property
def is_development(self) -> bool:
"""是否为开发环境"""
return self.ENVIRONMENT == EnvironmentEnum.DEVELOPMENT
@property
def is_production(self) -> bool:
"""是否为生产环境"""
return self.ENVIRONMENT == EnvironmentEnum.PRODUCTION
-37
View File
@@ -1,37 +0,0 @@
"""
数据库配置
"""
from .base_config import BaseConfig
class DatabaseConfig(BaseConfig):
"""数据库配置类"""
# PostgreSQL配置
DB_HOST: str = "localhost"
DB_PORT: int = 5432
DB_USER: str = "postgres"
DB_PASSWORD: str = "postgres"
DB_NAME: str = "test_db"
# 连接池配置
DB_POOL_SIZE: int = 10
DB_MAX_OVERFLOW: int = 20
DB_POOL_TIMEOUT: int = 30
DB_POOL_RECYCLE: int = 3600
@property
def DATABASE_URL(self) -> str:
"""构建数据库连接URL"""
return (
f"postgresql://{self.DB_USER}:{self.DB_PASSWORD}"
f"@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
)
@property
def ASYNC_DATABASE_URL(self) -> str:
"""构建异步数据库连接URL"""
return (
f"postgresql+asyncpg://{self.DB_USER}:{self.DB_PASSWORD}"
f"@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
)
-22
View File
@@ -1,22 +0,0 @@
"""
开发环境配置
"""
from .database_config import DatabaseConfig
from .base_config import EnvironmentEnum
class DevelopmentConfig(DatabaseConfig):
"""开发环境配置 - 只覆盖与生产不同的配置"""
ENVIRONMENT: EnvironmentEnum = EnvironmentEnum.DEVELOPMENT
# 调试和日志
DEBUG: bool = True
LOG_LEVEL: str = "DEBUG"
# 开发特性
RELOAD: bool = True
class Config:
env_file = ".env.development"
case_sensitive = True
-26
View File
@@ -1,26 +0,0 @@
"""
生产环境配置
"""
from .database_config import DatabaseConfig
from .base_config import EnvironmentEnum
class ProductionConfig(DatabaseConfig):
"""生产环境配置 - 只覆盖性能和安全相关配置"""
ENVIRONMENT: EnvironmentEnum = EnvironmentEnum.PRODUCTION
# 调试和日志
DEBUG: bool = False
LOG_LEVEL: str = "WARNING"
# 生产特性
RELOAD: bool = False
# 数据库连接池(生产环境需要更大)
DB_POOL_SIZE: int = 20
DB_MAX_OVERFLOW: int = 40
class Config:
env_file = ".env.production"
case_sensitive = True
-65
View File
@@ -1,65 +0,0 @@
"""
配置加载器 - 根据环境自动加载对应配置
"""
import os
from typing import Type
from .base_config import BaseConfig, EnvironmentEnum
from .development import DevelopmentConfig
from .production import ProductionConfig
def get_config_class(environment: str = None) -> Type[BaseConfig]:
"""根据环境获取配置类
Args:
environment: 环境名称 (development/production)
Returns:
对应的配置类
"""
if environment is None:
environment = os.getenv("ENVIRONMENT", "development")
config_map = {
EnvironmentEnum.DEVELOPMENT: DevelopmentConfig,
EnvironmentEnum.PRODUCTION: ProductionConfig,
}
try:
env_enum = EnvironmentEnum(environment)
return config_map[env_enum]
except ValueError:
print(f"警告: 未知环境 '{environment}',使用默认开发环境配置")
return DevelopmentConfig
def load_config(environment: str = None) -> BaseConfig:
"""加载配置
Args:
environment: 环境名称
Returns:
配置实例
"""
config_class = get_config_class(environment)
return config_class()
# 全局配置实例(延迟加载)
_config_instance = None
def get_settings() -> BaseConfig:
"""获取全局配置实例(单例模式)"""
global _config_instance
if _config_instance is None:
environment = os.getenv("ENVIRONMENT", None)
_config_instance = load_config(environment)
return _config_instance
def reload_config(environment: str = None):
"""重新加载配置"""
global _config_instance
_config_instance = load_config(environment)
-255
View File
@@ -1,255 +0,0 @@
"""
数据库连接管理 - 使用SQLAlchemy 2.0
"""
import logging
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker, Session, DeclarativeBase
from sqlalchemy.pool import QueuePool
from typing import List, Dict, Any
from contextlib import contextmanager
from app.config.settings import get_settings
from app.utils.logger import setup_logging
# 初始化日志
logger = setup_logging()
# 关闭SQLAlchemy引擎日志,只保留关键信息
logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING)
logging.getLogger('sqlalchemy.pool').setLevel(logging.WARNING)
# 获取配置
settings = get_settings()
# 创建数据库引擎
engine = create_engine(
settings.DATABASE_URL,
poolclass=QueuePool,
pool_size=settings.DB_POOL_SIZE,
max_overflow=settings.DB_MAX_OVERFLOW,
pool_timeout=settings.DB_POOL_TIMEOUT,
pool_recycle=settings.DB_POOL_RECYCLE,
pool_pre_ping=True,
echo=False, # 关闭SQL语句打印
connect_args={
"options": "-c client_encoding=UTF8"
}
)
# 创建会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 声明基类
Base = DeclarativeBase()
class DatabaseManager:
"""数据库管理器 - 提供通用的CRUD操作"""
def __init__(self):
self.engine = engine
self.SessionLocal = SessionLocal
@contextmanager
def get_session(self) -> Session:
"""获取数据库会话的上下文管理器"""
session = self.SessionLocal()
try:
yield session
session.commit()
except Exception as e:
session.rollback()
logger.error(f"数据库操作失败: {e}")
raise
finally:
session.close()
def init_db(self):
"""初始化数据库,创建所有表"""
try:
Base.metadata.create_all(bind=self.engine)
logger.info("数据库表创建成功")
except Exception as e:
logger.error(f"数据库表创建失败: {e}")
raise
def test_connection(self) -> bool:
"""测试数据库连接"""
try:
with self.get_session() as session:
session.execute(text("SELECT 1"))
logger.info("数据库连接测试成功")
return True
except Exception as e:
logger.error(f"数据库连接测试失败: {e}")
return False
def execute_raw_sql(self, sql: str, params: dict = None) -> List[Dict[str, Any]]:
"""执行原生SQL查询
Args:
sql: SQL语句
params: 参数字典
Returns:
查询结果列表
"""
with self.get_session() as session:
result = session.execute(text(sql), params or {})
if result.returns_rows:
columns = result.keys()
return [dict(zip(columns, row)) for row in result.fetchall()]
return []
def insert(self, table_name: str, data: Dict[str, Any]) -> int:
"""插入单条记录
Args:
table_name: 表名
data: 要插入的数据字典
Returns:
插入的行数
"""
columns = ", ".join(data.keys())
placeholders = ", ".join([f":{key}" for key in data.keys()])
sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
with self.get_session() as session:
result = session.execute(text(sql), data)
return result.rowcount
def insert_many(self, table_name: str, data_list: List[Dict[str, Any]]) -> int:
"""批量插入记录
Args:
table_name: 表名
data_list: 数据字典列表
Returns:
插入的行数
"""
if not data_list:
return 0
columns = ", ".join(data_list[0].keys())
placeholders = ", ".join([f":{key}" for key in data_list[0].keys()])
sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
with self.get_session() as session:
result = session.execute(text(sql), data_list)
return result.rowcount
def select(
self,
table_name: str,
conditions: Dict[str, Any] = None,
columns: List[str] = None,
limit: int = None,
offset: int = None,
order_by: str = None
) -> List[Dict[str, Any]]:
"""查询记录
Args:
table_name: 表名
conditions: 查询条件字典
columns: 要查询的列列表,None表示查询所有列
limit: 限制返回行数
offset: 偏移量
order_by: 排序字段
Returns:
查询结果列表
"""
col_str = ", ".join(columns) if columns else "*"
sql = f"SELECT {col_str} FROM {table_name}"
params = {}
if conditions:
where_clauses = []
for key, value in conditions.items():
where_clauses.append(f"{key} = :{key}")
params[key] = value
sql += " WHERE " + " AND ".join(where_clauses)
if order_by:
sql += f" ORDER BY {order_by}"
if limit:
sql += f" LIMIT :limit"
params["limit"] = limit
if offset:
sql += f" OFFSET :offset"
params["offset"] = offset
return self.execute_raw_sql(sql, params)
def update(self, table_name: str, data: Dict[str, Any],
conditions: Dict[str, Any]) -> int:
"""更新记录
Args:
table_name: 表名
data: 要更新的数据字典
conditions: 更新条件字典
Returns:
更新的行数
"""
set_clauses = [f"{key} = :{key}" for key in data.keys()]
where_clauses = [f"{key} = :cond_{key}" for key in conditions.keys()]
sql = f"UPDATE {table_name} SET {', '.join(set_clauses)} WHERE {' AND '.join(where_clauses)}"
# 合并参数,避免键名冲突
params = {**data, **{f"cond_{key}": value for key, value in conditions.items()}}
with self.get_session() as session:
result = session.execute(text(sql), params)
return result.rowcount
def delete(self, table_name: str, conditions: Dict[str, Any]) -> int:
"""删除记录
Args:
table_name: 表名
conditions: 删除条件字典
Returns:
删除的行数
"""
where_clauses = [f"{key} = :{key}" for key in conditions.keys()]
sql = f"DELETE FROM {table_name} WHERE {' AND '.join(where_clauses)}"
with self.get_session() as session:
result = session.execute(text(sql), conditions)
return result.rowcount
def count(self, table_name: str, conditions: Dict[str, Any] = None) -> int:
"""统计记录数
Args:
table_name: 表名
conditions: 统计条件字典
Returns:
记录数
"""
sql = f"SELECT COUNT(*) as count FROM {table_name}"
params = {}
if conditions:
where_clauses = []
for key, value in conditions.items():
where_clauses.append(f"{key} = :{key}")
params[key] = value
sql += " WHERE " + " AND ".join(where_clauses)
result = self.execute_raw_sql(sql, params)
return result[0]["count"] if result else 0
# 创建全局数据库管理器实例
db_manager = DatabaseManager()
-110
View File
@@ -1,110 +0,0 @@
"""
FastAPI应用主文件
"""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.config.settings import get_settings
from app.core.database import db_manager
from app.utils.logger import setup_logging
from app.api.rainfall import router as rainfall_router
# 初始化日志
logger = setup_logging()
# 获取配置
settings = get_settings()
def create_application() -> FastAPI:
"""创建FastAPI应用实例"""
application = FastAPI(
title=settings.APP_NAME,
version=settings.APP_VERSION,
debug=settings.DEBUG,
description="基于FastAPI的现代化Web应用框架"
)
# 配置CORS
application.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS if hasattr(settings, 'CORS_ORIGINS') else ["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
return application
# 创建应用实例
app = create_application()
# ==================== 启动和关闭事件 ====================
@app.on_event("startup")
async def startup_event():
"""应用启动时执行"""
logger.info(f"正在启动 {settings.APP_NAME} v{settings.APP_VERSION}")
logger.info(f"环境: {settings.ENVIRONMENT}")
logger.info(f"数据库连接: {settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}")
# 测试数据库连接
if db_manager.test_connection():
logger.info("数据库连接成功")
else:
logger.warning("数据库连接失败,请检查配置")
@app.on_event("shutdown")
async def shutdown_event():
"""应用关闭时执行"""
logger.info("应用正在关闭...")
# ==================== 根路径和健康检查 ====================
@app.get("/", tags=["基础"])
async def root():
"""根路径 - 欢迎信息"""
return {
"app": settings.APP_NAME,
"version": settings.APP_VERSION,
"environment": settings.ENVIRONMENT.value,
"status": "running"
}
@app.get("/health", tags=["基础"])
async def health_check():
"""健康检查接口"""
try:
# 检查数据库连接
db_manager.execute_raw_sql("SELECT 1")
db_status = "connected"
except Exception as e:
db_status = f"error: {str(e)}"
return {
"status": "healthy",
"database": db_status,
"app_version": settings.APP_VERSION,
"environment": settings.ENVIRONMENT.value
}
# ==================== 注册路由 ====================
app.include_router(rainfall_router)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app",
host=settings.API_HOST,
port=settings.API_PORT,
reload=settings.RELOAD if hasattr(settings, 'RELOAD') else settings.DEBUG
)
-55
View File
@@ -1,55 +0,0 @@
"""
降雨数据Repository - 数据访问层
"""
from typing import List, Dict, Any
from datetime import datetime
from app.core.database import db_manager
from app.utils.logger import setup_logging
logger = setup_logging()
class RainfallRepository:
"""降雨数据仓储类"""
@staticmethod
def query_stations_rainfall(
query_time: datetime,
duration: int = 12
) -> List[Dict[str, Any]]:
"""
查询指定时间的站点降雨数据(自动查询前12小时或24小时)
Args:
query_time: 查询时间
duration: 持续时间(12或24小时)
Returns:
站点降雨数据列表
"""
sql = f"""
SELECT
lon,
lat,
SUM(rainfall_1h::numeric) AS rainfall
FROM xian_meteorology
WHERE datetime BETWEEN (
to_char(timestamp :query_time - interval '{duration} hours', 'YYYYMMDDHH24MISS')
)::bigint AND (
to_char(timestamp :query_time, 'YYYYMMDDHH24MISS')
)::bigint
GROUP BY lon, lat
"""
params = {
"query_time": query_time
}
try:
result = db_manager.execute_raw_sql(sql, params)
logger.info(f"查询到 {len(result)} 个站点数据({duration}小时)")
return result
except Exception as e:
logger.error(f"查询站点降雨数据失败: {e}")
raise
-64
View File
@@ -1,64 +0,0 @@
"""
降雨数据相关的Pydantic Schemas
"""
from pydantic import BaseModel, Field
from typing import Optional, List
class RainfallGridRequest(BaseModel):
"""降雨栅格请求模型"""
time: Optional[str] = Field(
None,
alias="time",
description="查询时间 ISO格式,默认为当前时间(自动查询前12小时或24小时数据)",
example="2024-01-01T12:00:00"
)
resolution: float = Field(
0.01,
alias="resolution",
description="栅格分辨率(度)",
gt=0,
le=0.1
)
duration: int = Field(
12,
alias="duration",
description="持续时间(小时),可选12或24",
ge=12,
le=24
)
class Config:
populate_by_name = True # 允许同时使用字段名和别名
class StationData(BaseModel):
"""站点数据模型"""
lon: float
lat: float
rainfall: float
class GridMetadata(BaseModel):
"""栅格元数据"""
start_time: str
end_time: str
district_id: int
resolution: float
station_count: int
grid_size: List[int]
bounds: dict
class RainfallGridResponse(BaseModel):
"""降雨栅格响应模型 - 符合前端 ApiResponse 结构"""
code: int = Field(200, description="状态码")
message: str = Field(..., description="响应消息")
data: Optional[dict] = Field(None, description="响应数据")
class StationsResponse(BaseModel):
"""站点数据响应模型 - 符合前端 ApiResponse 结构"""
code: int = Field(200, description="状态码")
message: str = Field(..., description="响应消息")
data: List[StationData] = Field(default_factory=list, description="站点数据列表")
-605
View File
@@ -1,605 +0,0 @@
"""
降雨数据Service - 业务逻辑层
"""
import numpy as np
from scipy.spatial import Delaunay, ConvexHull
from typing import List, Dict, Any, Tuple, Optional
from datetime import datetime
from app.repositories.rainfall_repository import RainfallRepository
from app.utils.logger import setup_logging
logger = setup_logging()
class InterpolationService:
"""插值服务类"""
@staticmethod
def _create_buffer_points(
points_array: np.ndarray
) -> np.ndarray:
"""
创建缓冲点:在原始站点外围生成虚拟点以扩展插值区域
Args:
points_array: 原始站点坐标数组
Returns:
缓冲点坐标数组
"""
# 计算站点分布的中心
center = np.mean(points_array, axis=0)
# 计算站点到中心的最大距离
distances_from_center = np.sqrt(np.sum((points_array - center) ** 2, axis=1))
np.max(distances_from_center)
# 在站点外围生成缓冲点(沿着各个方向扩展)
buffer_points = []
num_angles = 360 # 每隔1度生成一个缓冲点
for angle_deg in range(0, 360, 360 // num_angles):
angle_rad = np.radians(angle_deg)
# 在凸包边界外扩展
for scale in [1.05, 1.1, 1.15]:
# 找到该方向上最远的站点
direction = np.array([np.cos(angle_rad), np.sin(angle_rad)])
projections = points_array @ direction
max_idx = np.argmax(projections)
# 在该方向上扩展
base_point = points_array[max_idx]
buffer_point = center + (base_point - center) * scale
buffer_points.append(buffer_point)
return np.array(buffer_points)
@staticmethod
def gaussian_smoothing(
grid_data: np.ndarray,
sigma: float = 1.5
) -> np.ndarray:
"""
高斯平滑滤波,减少边缘突变
Args:
grid_data: 栅格数据
sigma: 高斯核标准差
Returns:
平滑后的栅格数据
"""
from scipy.ndimage import gaussian_filter
# 只对有效数据进行平滑
valid_mask = ~np.isnan(grid_data)
if not np.any(valid_mask):
return grid_data
# 填充NaN值以便平滑
filled_data = grid_data.copy()
mean_val = np.nanmean(grid_data)
filled_data[~valid_mask] = mean_val
# 应用高斯滤波
smoothed = gaussian_filter(filled_data, sigma=sigma)
# 恢复原始NaN区域
result = np.where(valid_mask, smoothed, np.nan)
return result
@staticmethod
def calculate_adaptive_max_distance(
points_array: np.ndarray,
base_distance: float = 0.3,
min_distance: float = 0.15,
max_distance: float = 0.5
) -> float:
"""
根据站点密度自适应计算最大影响距离
Args:
points_array: 站点坐标数组
base_distance: 基础距离
min_distance: 最小距离
max_distance: 最大距离
Returns:
自适应的最大影响距离
"""
if len(points_array) < 3:
return base_distance
# 计算站点间的平均距离
from scipy.spatial import distance_matrix
dist_matrix = distance_matrix(points_array, points_array)
# 排除对角线(自身距离为0
np.fill_diagonal(dist_matrix, np.inf)
avg_distance = np.mean(np.min(dist_matrix, axis=1))
# 根据平均距离调整max_distance
adaptive_distance = avg_distance * 3 # 约3倍平均站点间距
# 限制在合理范围内
return np.clip(adaptive_distance, min_distance, max_distance)
@staticmethod
def inverse_distance_weighting(
points: List[Tuple[float, float]],
values: List[float],
grid_lon: np.ndarray,
grid_lat: np.ndarray,
power: float = 2.0,
max_distance: float = None,
use_adaptive_distance: bool = True,
apply_smoothing: bool = True,
smoothing_sigma: float = 1.0
) -> np.ndarray:
"""
反距离权重插值 (IDW) - 优化版本
改进:
1. 高斯核衰减替代简单幂律
2. 自适应距离阈值
3. 边缘渐变处理
4. 高斯平滑减少突变
Args:
points: 已知点坐标 [(lon, lat), ...]
values: 已知点的值 [rainfall, ...]
grid_lon: 网格经度数组
grid_lat: 网格纬度数组
power: 距离幂次(基础值)
max_distance: 最大影响距离(度),None则自适应计算
use_adaptive_distance: 是否使用自适应距离
apply_smoothing: 是否应用平滑
smoothing_sigma: 平滑强度
Returns:
插值后的栅格数据,无效区域为 NaN
"""
points_array = np.array(points)
values_array = np.array(values)
# 创建网格
lon_grid, lat_grid = np.meshgrid(grid_lon, grid_lat)
result = np.full_like(lon_grid, np.nan)
# 自适应计算最大距离
if use_adaptive_distance or max_distance is None:
actual_max_distance = InterpolationService.calculate_adaptive_max_distance(
points_array
)
if max_distance is not None:
actual_max_distance = min(actual_max_distance, max_distance)
else:
actual_max_distance = max_distance
logger.info(f"使用最大影响距离: {actual_max_distance:.3f}")
# 计算站点的凸包(带边缘缓冲)
hull_mask = None
confidence_mask = None # 置信度掩码
if len(points_array) >= 3:
try:
# 创建缓冲站点:在原始站点外围添加虚拟点
buffer_points = InterpolationService._create_buffer_points(
points_array
)
# 合并原始站点和缓冲站点
all_points = np.vstack([points_array, buffer_points])
# 计算凸包
hull = ConvexHull(all_points)
hull_points = all_points[hull.vertices]
tri = Delaunay(hull_points)
# 向量化判断所有网格点是否在凸包内
grid_points = np.column_stack([lon_grid.ravel(), lat_grid.ravel()])
hull_indices = tri.find_simplex(grid_points)
hull_mask = hull_indices >= 0
hull_mask = hull_mask.reshape(lon_grid.shape)
# 计算置信度:基于到最近站点的距离
# 在凸包内但远离站点的区域降低置信度
from scipy.spatial import distance_matrix
grid_valid = grid_points[hull_mask.ravel()]
if len(grid_valid) > 0:
dist_to_stations = distance_matrix(grid_valid, points_array)
min_distances = np.min(dist_to_stations, axis=1)
# 创建置信度掩码(距离越远,置信度越低)
confidence = np.ones(len(grid_points))
confidence[hull_mask.ravel()] = np.exp(-min_distances / actual_max_distance)
confidence_mask = confidence.reshape(lon_grid.shape)
else:
confidence_mask = np.ones_like(lon_grid)
except Exception as e:
logger.warning(f"凸包计算失败: {e},使用全区域插值")
hull_mask = np.ones_like(lon_grid, dtype=bool)
confidence_mask = np.ones_like(lon_grid)
else:
hull_mask = np.ones_like(lon_grid, dtype=bool)
confidence_mask = np.ones_like(lon_grid)
# 向量化计算所有网格点到所有站点的距离
lon_diff = lon_grid[:, :, np.newaxis] - points_array[np.newaxis, np.newaxis, :, 0]
lat_diff = lat_grid[:, :, np.newaxis] - points_array[np.newaxis, np.newaxis, :, 1]
distances = np.sqrt(lon_diff**2 + lat_diff**2)
# 过滤超出最大距离的站点
valid_mask = distances <= actual_max_distance
# 对于每个网格点,检查是否有有效站点
has_valid_stations = np.any(valid_mask, axis=2)
# 合并凸包掩码和有效站点掩码
final_mask = hull_mask & has_valid_stations
# 避免除零
distances = np.where(valid_mask, distances, np.inf)
distances = np.maximum(distances, 1e-10)
# 优化的权重计算:结合幂律和高斯衰减
# 近处使用幂律,远处使用高斯衰减使过渡更平滑
power_weights = 1.0 / (distances ** power)
gaussian_weights = np.exp(-0.5 * (distances / (actual_max_distance * 0.5)) ** 2)
# 混合权重:距离越远,高斯权重占比越大
distance_ratio = distances / actual_max_distance
mix_factor = np.clip(distance_ratio, 0, 1)
weights = (1 - mix_factor) * power_weights + mix_factor * gaussian_weights
weights = np.where(valid_mask, weights, 0)
# 加权平均
weighted_sum = np.sum(weights * values_array[np.newaxis, np.newaxis, :], axis=2)
weight_total = np.sum(weights, axis=2)
# 计算基础插值结果(使用 errstate 忽略预期的除零警告,np.where 已安全过滤)
with np.errstate(divide='ignore', invalid='ignore'):
result = np.where(
final_mask & (weight_total > 0),
weighted_sum / weight_total,
np.nan
)
# 应用置信度调整:边缘区域向邻近值渐变
if confidence_mask is not None:
# 计算全局平均降雨量作为边缘区域的基准
valid_rainfall = result[final_mask]
if len(valid_rainfall) > 0:
mean_rainfall = np.mean(valid_rainfall)
# 边缘区域向平均值渐变
result = np.where(
final_mask,
result,
np.nan
)
# 根据置信度调整结果,低置信度区域向均值靠拢
adjusted_result = result * confidence_mask + mean_rainfall * (1 - confidence_mask)
result = np.where(final_mask, adjusted_result, np.nan)
# 应用高斯平滑减少边缘突变
if apply_smoothing:
result = InterpolationService.gaussian_smoothing(result, sigma=smoothing_sigma)
return result
@staticmethod
def get_rainfall_color(rainfall: float, duration: int = 12) -> str:
"""
根据降雨量获取颜色(按照国标)
Args:
rainfall: 降雨量(mm)
duration: 持续时间(12或24小时)
Returns:
颜色字符串 "rgba(r,g,b,a)"
"""
# 国标降雨等级颜色映射
if rainfall < 0.1:
return "rgba(200,200,200,0)" # 透明 - 微量降雨(零星小雨)
elif rainfall < 5 if duration == 12 else 9.9:
return "rgba(0,0,255,0.4)" # 浅蓝 - 小雨
elif rainfall < 15 if duration == 12 else 25:
return "rgba(0,255,255,0.5)" # 青色 - 中雨
elif rainfall < 30 if duration == 12 else 50:
return "rgba(0,255,0,0.6)" # 绿色 - 大雨
elif rainfall < 70 if duration == 12 else 100:
return "rgba(255,255,0,0.7)" # 黄色 - 暴雨
elif rainfall < 140 if duration == 12 else 250:
return "rgba(255,165,0,0.8)" # 橙色 - 大暴雨
else:
return "rgba(255,0,0,0.9)" # 红色 - 特大暴雨
class GeoJSONService:
"""GeoJSON生成服务"""
@staticmethod
def create_feature_collection(
grid_metadata: Dict[str, Any],
rainfall_array: np.ndarray,
grid_lon: np.ndarray,
grid_lat: np.ndarray,
duration: int = 12
) -> Dict[str, Any]:
"""
创建GeoJSON FeatureCollection用于Cesium渲染
Args:
grid_metadata: 栅格元数据
rainfall_array: 降雨量数组
grid_lon: 经度网格
grid_lat: 纬度网格
duration: 持续时间(12或24小时)
Returns:
GeoJSON格式的FeatureCollection
"""
features = []
# 将栅格数据转换为矩形要素
for i in range(len(grid_lat) - 1):
for j in range(len(grid_lon) - 1):
rainfall_value = float(rainfall_array[i, j])
# 跳过无数据的区域
if np.isnan(rainfall_value) or rainfall_value < 0:
continue
# 创建矩形多边形
lon_min = float(grid_lon[j])
lon_max = float(grid_lon[j + 1])
lat_min = float(grid_lat[i])
lat_max = float(grid_lat[i + 1])
feature = {
"type": "Feature",
"geometry": {
"type": "Polygon",
"coordinates": [[
[lon_min, lat_min],
[lon_max, lat_min],
[lon_max, lat_max],
[lon_min, lat_max],
[lon_min, lat_min]
]]
},
"properties": {
"rainfall": round(rainfall_value, 2),
"level": RainfallService._get_rainfall_level(rainfall_value, duration),
"color": InterpolationService.get_rainfall_color(rainfall_value, duration)
}
}
features.append(feature)
return {
"type": "FeatureCollection",
"features": features,
"metadata": {
"resolution": grid_metadata['resolution'],
"grid_size": [len(grid_lon), len(grid_lat)],
"bounds": {
"min_lon": float(grid_lon.min()),
"max_lon": float(grid_lon.max()),
"min_lat": float(grid_lat.min()),
"max_lat": float(grid_lat.max())
}
}
}
class RainfallService:
"""降雨数据业务服务类"""
def __init__(self):
self.repository = RainfallRepository()
self.interpolation_service = InterpolationService()
self.geojson_service = GeoJSONService()
def get_stations_data(
self,
query_time: datetime,
duration: int = 12
) -> List[Dict[str, Any]]:
"""
获取站点降雨数据
Args:
query_time: 查询时间(自动查询前12小时或24小时数据)
duration: 持续时间(12或24小时)
Returns:
站点数据列表
"""
return self.repository.query_stations_rainfall(query_time, duration)
def generate_rainfall_grid(
self,
query_time: datetime,
resolution: float = 0.01,
duration: int = 12
) -> Dict[str, Any]:
"""
生成降雨栅格数据
Args:
query_time: 查询时间(自动查询前12小时或24小时数据)
resolution: 栅格分辨率
duration: 持续时间(12或24小时)
Returns:
GeoJSON格式的栅格数据
"""
logger.info(f"查询降雨数据: {query_time}, 持续时间: {duration}小时")
# 查询站点数据(自动查询前12小时或24小时数据)
stations_data = self.get_stations_data(query_time, duration)
if not stations_data:
return None
# 提取站点坐标和降雨量(过滤空值)
valid_stations = [row for row in stations_data if row['rainfall'] is not None]
if not valid_stations:
logger.warning("所有站点的降雨量数据均为空")
return None
points = [(row['lon'], row['lat']) for row in valid_stations]
values = [float(row['rainfall']) for row in valid_stations]
# 确定栅格范围
lon_min, lon_max = 107, 110
lat_min, lat_max = 33, 35
# 创建栅格网格
num_lon = int((lon_max - lon_min) / resolution) + 1
num_lat = int((lat_max - lat_min) / resolution) + 1
grid_lon = np.linspace(lon_min, lon_max, num_lon)
grid_lat = np.linspace(lat_min, lat_max, num_lat)
logger.info(f"生成栅格: {num_lon}x{num_lat}, 分辨率: {resolution}")
# 执行IDW插值(优化版本:自适应距离、混合权重、平滑处理)
rainfall_grid = self.interpolation_service.inverse_distance_weighting(
points=points,
values=values,
grid_lon=grid_lon,
grid_lat=grid_lat,
power=2.0,
max_distance=0.35, # 最大影响距离0.35度(约35公里)
use_adaptive_distance=True, # 启用自适应距离
apply_smoothing=True, # 启用平滑处理
smoothing_sigma=1.2 # 平滑强度
)
# 创建栅格元数据
grid_metadata = {
"query_time": query_time.isoformat(),
"resolution": resolution,
"station_count": len(stations_data),
"grid_size": [num_lon, num_lat]
}
# 转换为GeoJSON格式
geojson_data = self.geojson_service.create_feature_collection(
grid_metadata, rainfall_grid, grid_lon, grid_lat, duration
)
logger.info("降雨栅格数据生成成功")
return geojson_data
def get_rainfall_at_point(
self,
longitude: float,
latitude: float,
query_time: datetime,
duration: int = 12
) -> Optional[Dict[str, Any]]:
"""
查询指定点位的降雨量(使用IDW插值)
Args:
longitude: 经度
latitude: 纬度
query_time: 查询时间(自动查询前12小时或24小时数据)
duration: 持续时间(12或24小时)
Returns:
点位降雨量信息
"""
# 获取站点数据(自动查询前12小时或24小时数据)
stations_data = self.get_stations_data(query_time, duration)
if not stations_data:
return None
# 提取站点坐标和降雨量
points = [(row['lon'], row['lat']) for row in stations_data]
values = [float(row['rainfall']) for row in stations_data]
# 使用IDW插值计算该点的降雨量
target_point = np.array([[longitude, latitude]])
points_array = np.array(points)
# 计算距离
distances = np.sqrt(
(points_array[:, 0] - longitude) ** 2 +
(points_array[:, 1] - latitude) ** 2
)
# 避免除零
min_dist = 1e-10
distances = np.maximum(distances, min_dist)
# IDW公式
power = 2.0
weights = 1.0 / (distances ** power)
rainfall_value = np.sum(weights * values) / np.sum(weights)
# 返回结果
return {
"longitude": longitude,
"latitude": latitude,
"rainfall": round(float(rainfall_value), 2),
"level": self._get_rainfall_level(rainfall_value, duration),
"color": InterpolationService.get_rainfall_color(rainfall_value, duration),
"station_count": len(stations_data),
"query_time": query_time.isoformat(),
"duration": duration
}
@staticmethod
def _get_rainfall_level(rainfall: float, duration: int = 12) -> str:
"""
获取降雨等级(按照国标)
Args:
rainfall: 降雨量(mm)
duration: 持续时间(12或24小时)
Returns:
降雨等级字符串
"""
if duration == 12:
# 12小时降雨等级标准
if rainfall < 0.1:
return "微量降雨"
elif rainfall < 5.0:
return "小雨"
elif rainfall < 15.0:
return "中雨"
elif rainfall < 30.0:
return "大雨"
elif rainfall < 70.0:
return "暴雨"
elif rainfall < 140.0:
return "大暴雨"
else:
return "特大暴雨"
else: # 24小时
# 24小时降雨等级标准
if rainfall < 0.1:
return "微量降雨"
elif rainfall < 10.0:
return "小雨"
elif rainfall < 25.0:
return "中雨"
elif rainfall < 50.0:
return "大雨"
elif rainfall < 100.0:
return "暴雨"
elif rainfall < 250.0:
return "大暴雨"
else:
return "特大暴雨"
+66 -81
View File
@@ -1,104 +1,89 @@
"""
日志配置工具
日志工具
支持按天分割、自动清理过期日志
"""
import logging
import sys
from pathlib import Path
from logging.handlers import RotatingFileHandler, TimedRotatingFileHandler
from datetime import datetime
from logging.handlers import TimedRotatingFileHandler
from datetime import datetime, timedelta
class LoggerConfig:
"""日志配置类"""
def __init__(self, log_dir: str = "logs", log_level: str = "INFO"):
self.log_dir = Path(log_dir)
self.log_level = getattr(logging, log_level.upper(), logging.INFO)
self.ensure_log_directory()
def ensure_log_directory(self):
"""确保日志目录存在"""
if not self.log_dir.exists():
self.log_dir.mkdir(parents=True, exist_ok=True)
def setup_logger(
self,
name: str = "app",
log_file: str = None,
max_bytes: int = 10 * 1024 * 1024, # 10MB
backup_count: int = 5,
use_timed_rotation: bool = False
) -> logging.Logger:
"""配置日志记录器
class LoggerManager:
"""日志管理器"""
_loggers = {}
@classmethod
def get_logger(cls, name: str = "algorithm", log_dir: str = "logs") -> logging.Logger:
"""
获取日志记录器
Args:
name: 日志记录器名称
log_file: 日志文件名(None则使用时间戳)
max_bytes: 单个日志文件最大大小
backup_count: 保留的备份文件数量
use_timed_rotation: 是否使用按时间轮转
name: 日志名称
log_dir: 日志目录
Returns:
配置好的Logger实例
logging.Logger 实例
"""
if name in cls._loggers:
return cls._loggers[name]
# 创建日志目录
log_path = Path(log_dir)
log_path.mkdir(parents=True, exist_ok=True)
# 创建 logger
logger = logging.getLogger(name)
logger.setLevel(self.log_level)
# 避免重复添加handler
logger.setLevel(logging.DEBUG)
# 避免重复添加 handler
if logger.handlers:
cls._loggers[name] = logger
return logger
# 创建formatter
# 日志格式
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
'%(asctime)s [%(threadName)s] %(levelname)-5s %(name)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# 控制台handler
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(self.log_level)
# 控制台 Handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# 文件handler
if log_file is None:
timestamp = datetime.now().strftime("%Y%m%d")
log_file = f"app_{timestamp}.log"
log_path = self.log_dir / log_file
if use_timed_rotation:
# 按时间轮转(每天)
file_handler = TimedRotatingFileHandler(
log_path,
when='midnight',
interval=1,
backupCount=backup_count,
encoding='utf-8'
)
else:
# 按大小轮转
file_handler = RotatingFileHandler(
log_path,
maxBytes=max_bytes,
backupCount=backup_count,
encoding='utf-8'
)
file_handler.setLevel(self.log_level)
# 文件 Handler - 按天分割
log_file = log_path / f"{name}.log"
file_handler = TimedRotatingFileHandler(
filename=str(log_file),
when='midnight',
interval=1,
backupCount=7,
encoding='utf-8'
)
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)
# 设置日志文件命名格式
file_handler.suffix = "%Y-%m-%d.log"
logger.addHandler(file_handler)
cls._loggers[name] = logger
return logger
# 便捷函数
def get_logger(name: str = "algorithm", log_dir: str = "logs") -> logging.Logger:
"""
获取日志记录器的便捷函数
@staticmethod
def get_logger(name: str = "app") -> logging.Logger:
"""获取logger实例"""
return logging.getLogger(name)
# 全局日志配置
def setup_logging(log_dir: str = "logs", log_level: str = "INFO") -> logging.Logger:
"""设置全局日志"""
logger_config = LoggerConfig(log_dir, log_level)
return logger_config.setup_logger()
Args:
name: 日志名称
log_dir: 日志目录
Returns:
logging.Logger 实例
"""
return LoggerManager.get_logger(name, log_dir)
-137
View File
@@ -1,137 +0,0 @@
"""
平台检测工具 - 兼容Windows和Linux
"""
import sys
import os
import platform
from pathlib import Path
from typing import Tuple
class PlatformDetector:
"""平台检测器"""
def __init__(self):
self.os_name = os.name
self.platform_system = platform.system()
self.platform_release = platform.release()
self.platform_version = platform.version()
self.python_version = sys.version_info
self.is_windows = self.platform_system == "Windows"
self.is_linux = self.platform_system == "Linux"
self.is_macos = self.platform_system == "Darwin"
def get_platform_info(self) -> dict:
"""获取平台信息"""
return {
"os_name": self.os_name,
"platform_system": self.platform_system,
"platform_release": self.platform_release,
"platform_version": self.platform_version,
"python_version": f"{self.python_version.major}.{self.python_version.minor}.{self.python_version.micro}",
"is_windows": self.is_windows,
"is_linux": self.is_linux,
"is_macos": self.is_macos,
}
def check_python_version(self, min_version: Tuple[int, int] = (3, 13)) -> bool:
"""检查Python版本是否满足要求
Args:
min_version: 最低版本要求 (major, minor)
Returns:
是否满足版本要求
"""
current = (self.python_version.major, self.python_version.minor)
return current >= min_version
def get_path_separator(self) -> str:
"""获取路径分隔符"""
return "\\" if self.is_windows else "/"
def normalize_path(self, path: str) -> str:
"""标准化路径"""
return Path(path).as_posix()
def get_project_root(self) -> Path:
"""获取项目根目录"""
return Path(__file__).parent.parent.parent
def ensure_directory(self, path: Path, create: bool = True) -> bool:
"""确保目录存在
Args:
path: 目录路径
create: 是否自动创建
Returns:
目录是否存在
"""
if path.exists():
return True
if create:
try:
path.mkdir(parents=True, exist_ok=True)
return True
except Exception as e:
print(f"创建目录失败 {path}: {e}")
return False
return False
def get_env_file_path(self) -> Path:
"""获取环境配置文件路径"""
return self.get_project_root() / ".env"
def format_command_for_platform(self, command: str) -> str:
"""根据平台格式化命令
Args:
command: 原始命令
Returns:
适合当前平台的命令
"""
if self.is_windows:
# Windows特定命令转换
if command.startswith("ls "):
return command.replace("ls ", "dir ")
elif command.startswith("cat "):
return command.replace("cat ", "type ")
elif command.startswith("rm "):
return command.replace("rm ", "del ")
elif command.startswith("cp "):
return command.replace("cp ", "copy ")
elif command.startswith("mv "):
return command.replace("mv ", "move ")
else:
# Linux/Mac特定命令转换
if command.startswith("dir "):
return command.replace("dir ", "ls ")
elif command.startswith("type "):
return command.replace("type ", "cat ")
elif command.startswith("del "):
return command.replace("del ", "rm ")
elif command.startswith("copy "):
return command.replace("copy ", "cp ")
elif command.startswith("move "):
return command.replace("move ", "mv ")
return command
def print_platform_banner(self):
"""打印平台信息横幅"""
info = self.get_platform_info()
print("=" * 60)
print(" 系统信息")
print("=" * 60)
print(f" 操作系统: {info['platform_system']} {info['platform_release']}")
print(f" Python版本: {info['python_version']}")
print(f" 项目根目录: {self.get_project_root()}")
print("=" * 60)
# 全局平台检测器实例
platform_detector = PlatformDetector()
+11
View File
@@ -0,0 +1,11 @@
from dynaconf import Dynaconf
settings = Dynaconf(
# 配置文件
settings_files=['settings.toml', '.secrets.toml'],
# 环境变量
environments=True,
# 环境切换变量
env_switcher="ENV_FOR_DYNACONF",
)
+1 -8
View File
@@ -1,8 +1 @@
fastapi==0.136.1
uvicorn==0.46.0
sqlalchemy==2.0.49
psycopg2-binary==2.9.12
pydantic==2.13.3
pydantic-settings==2.14.0
numpy==2.4.4
scipy==1.17.1
dynaconf == 3.2.13
-74
View File
@@ -1,74 +0,0 @@
@echo off
REM Windows startup script - Development environment (background mode)
echo ========================================
echo Starting Development Environment
echo ========================================
REM Change to project root directory
cd /d "%~dp0.."
REM Verify we are in the correct directory
if not exist "start.py" (
echo Error: Cannot find start.py in current directory: %CD%
echo Please ensure you are running this script from the project directory
pause
exit /b 1
)
echo Current directory: %CD%
REM Check and create virtual environment if not exists
if not exist ".venv\Scripts\activate.bat" (
echo Virtual environment not found, creating...
python -m venv .venv
if errorlevel 1 (
echo Error: Failed to create virtual environment
echo Please ensure Python is installed and accessible
pause
exit /b 1
)
echo Virtual environment created successfully
)
REM Activate virtual environment
call .venv\Scripts\activate.bat
if errorlevel 1 (
echo Error: Failed to activate virtual environment
pause
exit /b 1
)
echo Virtual environment activated
REM Upgrade pip and install dependencies
echo Checking dependencies...
.venv\Scripts\python.exe -m pip install --upgrade pip -q
.venv\Scripts\python.exe -m pip install -r requirements.txt -q
if errorlevel 1 (
echo Error: Failed to install dependencies
pause
exit /b 1
)
echo Dependencies installed successfully
REM Create logs directory if not exists
if not exist "logs" mkdir logs
REM Start application in background using start command
REM Create a temporary startup script with clean environment
set TEMP_START_SCRIPT=%TEMP%\start_xian_app_%RANDOM%.bat
(
echo @echo off
echo cd /d %CD%
echo set ENVIRONMENT=development
echo .venv\Scripts\python.exe start.py
) > "%TEMP_START_SCRIPT%"
start "Xian Algorithm Dev" cmd /k "title Xian Algorithm Dev && call "%TEMP_START_SCRIPT%""
echo.
echo Application started in background
echo To view logs, check: logs\app_*.log
echo To stop the application, run: scripts\stop.bat
echo ========================================
REM Keep the window open briefly to show any immediate errors
timeout /t 3 /nobreak >nul
-52
View File
@@ -1,52 +0,0 @@
#!/bin/bash
# Linux/Mac startup script - Development environment (background mode)
echo "========================================"
echo " Starting Development Environment"
echo "========================================"
# Change to project root directory
cd "$(dirname "$0")/.." || exit
# Check and create virtual environment if not exists
if [ ! -f ".venv/bin/activate" ]; then
echo "Virtual environment not found, creating..."
python3 -m venv .venv
if [ $? -ne 0 ]; then
echo "Error: Failed to create virtual environment"
echo "Please ensure Python 3 is installed and accessible"
exit 1
fi
echo "Virtual environment created successfully"
fi
# Activate virtual environment
source .venv/bin/activate
echo "Virtual environment activated"
# Upgrade pip and install dependencies
echo "Checking dependencies..."
pip install --upgrade pip -q
pip install -r requirements.txt -q
if [ $? -ne 0 ]; then
echo "Error: Failed to install dependencies"
exit 1
fi
echo "Dependencies installed successfully"
# Set environment variable
export ENVIRONMENT=development
# Create logs directory if not exists
mkdir -p logs
# Start application in background
nohup python start.py > logs/app_dev.log 2>&1 &
APP_PID=$!
echo $APP_PID > scripts/app_dev.pid
echo ""
echo "Application started in background (PID: $APP_PID)"
echo "To view logs: tail -f logs/app_dev.log"
echo "To stop the application, run: bash scripts/stop.sh"
echo "========================================"
-52
View File
@@ -1,52 +0,0 @@
@echo off
REM Windows startup script - Production environment (background mode)
echo ========================================
echo Starting Production Environment
echo ========================================
REM Change to project root directory
cd /d "%~dp0.."
REM Check and create virtual environment if not exists
if not exist ".venv\Scripts\activate.bat" (
echo Virtual environment not found, creating...
python -m venv .venv
if errorlevel 1 (
echo Error: Failed to create virtual environment
echo Please ensure Python is installed and accessible
pause
exit /b 1
)
echo Virtual environment created successfully
)
REM Activate virtual environment
call .venv\Scripts\activate.bat
echo Virtual environment activated
REM Upgrade pip and install dependencies
echo Checking dependencies...
.venv\Scripts\python.exe -m pip install --upgrade pip -q
.venv\Scripts\python.exe -m pip install -r requirements.txt -q
if errorlevel 1 (
echo Error: Failed to install dependencies
pause
exit /b 1
)
echo Dependencies installed successfully
REM Set environment variable
set ENVIRONMENT=production
REM Create logs directory if not exists
if not exist "logs" mkdir logs
REM Start application in background using start command
start "Xian Algorithm Prod" cmd /c "title Xian Algorithm Prod && .venv\Scripts\python.exe start.py"
echo.
echo Application started in background
echo To view logs, check: logs\app_*.log
echo To stop the application, run: scripts\stop.bat
echo ========================================
-52
View File
@@ -1,52 +0,0 @@
#!/bin/bash
# Linux/Mac startup script - Production environment (background mode)
echo "========================================"
echo " Starting Production Environment"
echo "========================================"
# Change to project root directory
cd "$(dirname "$0")/.." || exit
# Check and create virtual environment if not exists
if [ ! -f ".venv/bin/activate" ]; then
echo "Virtual environment not found, creating..."
python3 -m venv .venv
if [ $? -ne 0 ]; then
echo "Error: Failed to create virtual environment"
echo "Please ensure Python 3 is installed and accessible"
exit 1
fi
echo "Virtual environment created successfully"
fi
# Activate virtual environment
source .venv/bin/activate
echo "Virtual environment activated"
# Upgrade pip and install dependencies
echo "Checking dependencies..."
pip install --upgrade pip -q
pip install -r requirements.txt -q
if [ $? -ne 0 ]; then
echo "Error: Failed to install dependencies"
exit 1
fi
echo "Dependencies installed successfully"
# Set environment variable
export ENVIRONMENT=production
# Create logs directory if not exists
mkdir -p logs
# Start application in background
nohup python start.py > logs/app_prod.log 2>&1 &
APP_PID=$!
echo $APP_PID > scripts/app_prod.pid
echo ""
echo "Application started in background (PID: $APP_PID)"
echo "To view logs: tail -f logs/app_prod.log"
echo "To stop the application, run: bash scripts/stop.sh"
echo "========================================"
-29
View File
@@ -1,29 +0,0 @@
@echo off
REM Windows stop script - Stop the application
echo ========================================
echo Stopping Application
echo ========================================
REM Change to project root directory
cd /d "%~dp0.."
REM Find and kill python processes running start.py
echo Searching for running application...
tasklist /FI "IMAGENAME eq python.exe" /FO CSV /NH 2>nul | findstr /I "python.exe" >nul
if %errorlevel% equ 0 (
echo Found running Python processes, stopping...
taskkill /F /FI "WINDOWTITLE eq Xian Algorithm Dev" 2>nul
taskkill /F /FI "WINDOWTITLE eq Xian Algorithm Prod" 2>nul
REM If title-based kill didn't work, try killing all python.exe running start.py
for /f "tokens=2 delims=," %%a in ('tasklist /FI "IMAGENAME eq python.exe" /FO CSV /NH 2^>nul') do (
set "PID=%%~a"
wmic process where "ProcessId=!PID! and CommandLine like '%%start.py%%'" delete >nul 2>&1
)
echo Application stopped successfully
) else (
echo No running application found
)
echo ========================================
-67
View File
@@ -1,67 +0,0 @@
#!/bin/bash
# Linux/Mac stop script - Stop the application
echo "========================================"
echo " Stopping Application"
echo "========================================"
# Change to project root directory
cd "$(dirname "$0")/.." || exit
# Function to stop a process by PID file
stop_process() {
local pid_file=$1
local env_name=$2
if [ -f "$pid_file" ]; then
local pid=$(cat "$pid_file")
if kill -0 "$pid" 2>/dev/null; then
echo "Stopping $env_name (PID: $pid)..."
kill "$pid"
# Wait for process to stop
local count=0
while kill -0 "$pid" 2>/dev/null && [ $count -lt 10 ]; do
sleep 1
count=$((count + 1))
done
# Force kill if still running
if kill -0 "$pid" 2>/dev/null; then
echo "Force stopping $env_name..."
kill -9 "$pid"
fi
rm -f "$pid_file"
echo "$env_name stopped"
else
echo "$env_name not running (stale PID file removed)"
rm -f "$pid_file"
fi
else
echo "No $env_name PID file found"
fi
}
# Stop development environment
stop_process "scripts/app_dev.pid" "Development environment"
# Stop production environment
stop_process "scripts/app_prod.pid" "Production environment"
# Also try to find any remaining python processes running start.py
REMAINING_PIDS=$(ps aux | grep "[p]ython.*start.py" | awk '{print $2}')
if [ -n "$REMAINING_PIDS" ]; then
echo "Found additional Python processes, stopping..."
echo "$REMAINING_PIDS" | xargs kill 2>/dev/null
sleep 2
# Force kill if still running
REMAINING_PIDS=$(ps aux | grep "[p]ython.*start.py" | awk '{print $2}')
if [ -n "$REMAINING_PIDS" ]; then
echo "$REMAINING_PIDS" | xargs kill -9 2>/dev/null
fi
fi
echo "========================================"
echo "All applications stopped"
+34
View File
@@ -0,0 +1,34 @@
# 公共配置
[default]
APP_NAME = "西安项目算法服务"
LOG_DIR = "logs"
# 开发环境
[development]
DEBUG = true
# 数据库
DB_HOST = "47.92.216.173"
DB_PORT = 7654
DB_USER = "postgres"
DB_PASSWORD = "zhangsan"
DB_NAME = "xian_new"
# FastAPI配置
API_HOST = "127.0.0.1"
API_PORT = 8082
# 日志配置
LOG_LEVEL = "DEBUG"
# 生产环境
[production]
DEBUG = false
# 数据库配置
DB_HOST = "10.22.245.138"
DB_PORT = 54321
DB_USER = "zaihailian"
DB_PASSWORD = "XAYJ@gis2603"
DB_NAME = "xianDC"
# FastAPI配置
API_HOST = "127.0.0.1"
API_PORT = 8081
# 日志配置
LOG_LEVEL = "WARNING"
+25 -207
View File
@@ -1,230 +1,48 @@
"""
项目启动脚本 - 支持多环境和跨平台
使用 Dynaconf 进行环境隔离配置
"""
import sys
import subprocess
import os
import platform
from pathlib import Path
from app.utils.logger import get_logger
# 添加项目根目录到Python路径
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
logger = get_logger()
def check_platform():
"""测平台信息"""
print("=" * 60)
print(" 系统信息检测")
print("=" * 60)
from app.utils.platform_utils import platform_detector
platform_detector.print_platform_banner()
return platform_detector
def check_environment():
"""查系统和Python版本"""
# 识别操作系统
os_name = platform.system()
print(f"当前操作系统: {os_name}")
if os_name not in ['Windows', 'Linux']:
print(f"警告: 未测试的操作系统 {os_name},可能存在问题")
def check_python_version():
"""检查Python版本是否为3.13或更高"""
print("\n" + "=" * 60)
print(" Python版本检查")
print("=" * 60)
from app.utils.platform_utils import platform_detector
current_version = sys.version_info
print(f"当前Python版本: {current_version.major}.{current_version.minor}.{current_version.micro}")
if not platform_detector.check_python_version((3, 13)):
print("\n❌ 错误: Python版本过低!")
print(f" 当前版本: {current_version.major}.{current_version.minor}.{current_version.micro}")
print(" 要求版本: 3.13 或更高")
print("\n请升级到Python 3.13或更高版本:")
print(" 下载地址: https://www.python.org/downloads/")
print("=" * 60)
sys.exit(1)
print(f"✅ Python版本检查通过: {current_version.major}.{current_version.minor}.{current_version.micro}")
print("=" * 60)
return True
# 检查Python版本
python_version = platform.python_version()
print(f"当前Python版本: {python_version}")
# 解析版本号
major, minor, *_ = map(int, python_version.split('.'))
def get_environment():
"""获取运行环境"""
print("\n" + "=" * 60)
print(" 环境配置")
print("=" * 60)
environment = os.getenv("ENVIRONMENT", "development")
print(f"当前环境: {environment}")
if environment not in ["development", "production"]:
print(f"⚠️ 警告: 未知环境 '{environment}',使用默认开发环境")
environment = "development"
print("=" * 60)
return environment
def install_dependencies():
"""检查并安装依赖包"""
print("\n" + "=" * 60)
print(" 依赖包检查")
print("=" * 60)
requirements_file = "requirements.txt"
if not os.path.exists(requirements_file):
print(f"\n❌ 错误: 找不到依赖文件 {requirements_file}")
sys.exit(1)
# 读取已安装的包
try:
result = subprocess.run(
[sys.executable, "-m", "pip", "list", "--format=freeze"],
capture_output=True,
text=True,
check=True
)
installed_packages = {line.split("==")[0].lower() for line in result.stdout.strip().split("\n") if line}
except subprocess.CalledProcessError as e:
print(f"\n❌ 检查已安装包失败: {e}")
sys.exit(1)
# 读取requirements.txt中的包
with open(requirements_file, "r", encoding="utf-8") as f:
required_packages = []
for line in f:
line = line.strip()
if line and not line.startswith("#"):
package_name = line.split("==")[0].lower()
required_packages.append(line)
# 检查缺失的包
missing_packages = []
for package_line in required_packages:
package_name = package_line.split("==")[0].lower()
if package_name not in installed_packages:
missing_packages.append(package_line)
if not missing_packages:
print("\n✅ 所有依赖包已安装,无需重复安装")
print("=" * 60)
if major == 3 and minor == 13:
print("✓ Python版本符合要求 (3.13)")
return True
print(f"\n发现 {len(missing_packages)} 个未安装的依赖包:")
for package in missing_packages:
print(f" - {package}")
print("\n正在安装依赖包...")
try:
subprocess.run(
[sys.executable, "-m", "pip", "install", "-r", requirements_file],
check=True
)
print("\n✅ 依赖包安装成功")
print("=" * 60)
return True
except subprocess.CalledProcessError as e:
print(f"\n❌ 依赖包安装失败: {e}")
print("=" * 60)
sys.exit(1)
def initialize_database():
"""初始化数据库连接"""
print("\n" + "=" * 60)
print(" 数据库初始化")
print("=" * 60)
try:
from app.core.database import db_manager
from app.config.settings import get_settings
settings = get_settings()
print(f"数据库地址: {settings.DB_HOST}:{settings.DB_PORT}")
print(f"数据库名称: {settings.DB_NAME}")
print(f"连接池大小: {settings.DB_POOL_SIZE}")
# 测试数据库连接
if db_manager.test_connection():
print("✅ 数据库连接成功")
else:
print("⚠️ 数据库连接失败,请检查配置")
return False
print("=" * 60)
return True
except Exception as e:
print(f"\n⚠️ 数据库初始化警告: {e}")
print(" 应用将继续启动,但数据库功能可能不可用")
print("=" * 60)
else:
print(f"✗ Python版本不符合要求!")
print(f" 当前版本: {python_version}")
print(f" 要求版本: 3.13.x")
print(f"\n请使用 Python 3.13 版本运行此项目")
print(f"下载地址: https://www.python.org/downloads/")
return False
def start_application(environment: str):
"""启动FastAPI应用"""
print("\n" + "=" * 60)
print(" 启动FastAPI应用")
print("=" * 60)
try:
from app.config.settings import get_settings
import uvicorn
settings = get_settings()
print(f"应用名称: {settings.APP_NAME}")
print(f"应用版本: {settings.APP_VERSION}")
print(f"运行环境: {settings.ENVIRONMENT.value}")
print(f"监听地址: {settings.API_HOST}:{settings.API_PORT}")
print(f"调试模式: {'开启' if settings.DEBUG else '关闭'}")
print(f"自动重载: {'开启' if hasattr(settings, 'RELOAD') and settings.RELOAD else '关闭'}")
print("\n🚀 应用启动中...\n")
print("=" * 60)
# 启动uvicorn服务器
uvicorn.run(
"app.main:app",
host=settings.API_HOST,
port=settings.API_PORT,
reload=settings.RELOAD if hasattr(settings, 'RELOAD') else settings.DEBUG,
log_level=settings.LOG_LEVEL.lower()
)
except KeyboardInterrupt:
print("\n\n应用已停止")
except Exception as e:
print(f"\n❌ 应用启动失败: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
def main():
"""主函数"""
print("\n" + "=" * 60)
print(" Xian Algorithm New - 应用启动程序")
print("=" * 60)
# 检测平台信息
check_platform()
# 检查Python版本
check_python_version()
# 获取环境配置
environment = get_environment()
# 检查并安装依赖
install_dependencies()
# 初始化数据库
initialize_database()
# 启动应用
start_application(environment)
check_environment()
if __name__ == "__main__":