初始化代码
This commit is contained in:
@@ -1,22 +0,0 @@
|
||||
# 开发环境配置
|
||||
ENVIRONMENT=development
|
||||
|
||||
# 数据库配置
|
||||
DB_HOST=47.92.216.173
|
||||
DB_PORT=7654
|
||||
DB_USER=postgres
|
||||
DB_PASSWORD=zhangsan
|
||||
DB_NAME=xian_new
|
||||
|
||||
# FastAPI配置
|
||||
API_HOST=127.0.0.1
|
||||
API_PORT=8082
|
||||
|
||||
# 应用配置
|
||||
APP_NAME=西安项目算法
|
||||
APP_VERSION=1.0.0
|
||||
DEBUG=True
|
||||
|
||||
# 日志配置
|
||||
LOG_LEVEL=DEBUG
|
||||
LOG_DIR=logs
|
||||
@@ -1,22 +0,0 @@
|
||||
# 生产环境配置
|
||||
ENVIRONMENT=production
|
||||
|
||||
# 数据库配置
|
||||
DB_HOST=10.22.245.138
|
||||
DB_PORT=54321
|
||||
DB_USER=zaihailian
|
||||
DB_PASSWORD=XAYJ@gis2603
|
||||
DB_NAME=xianDC
|
||||
|
||||
# FastAPI配置
|
||||
API_HOST=127.0.0.1
|
||||
API_PORT=8081
|
||||
|
||||
# 应用配置
|
||||
APP_NAME=西安项目算法
|
||||
APP_VERSION=1.0.0
|
||||
DEBUG=False
|
||||
|
||||
# 日志配置
|
||||
LOG_LEVEL=WARNING
|
||||
LOG_DIR=logs
|
||||
@@ -51,3 +51,6 @@ htmlcov/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# Ignore dynaconf secret files
|
||||
.secrets.*
|
||||
|
||||
@@ -1,173 +0,0 @@
|
||||
"""
|
||||
降雨数据API Controller - 路由层
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from app.services.rainfall_service import RainfallService
|
||||
from app.schemas.rainfall import (
|
||||
RainfallGridRequest,
|
||||
RainfallGridResponse,
|
||||
StationsResponse
|
||||
)
|
||||
from app.utils.logger import setup_logging
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/rainfall",
|
||||
tags=["降雨数据"],
|
||||
responses={404: {"description": "Not found"}}
|
||||
)
|
||||
|
||||
# 创建服务实例
|
||||
rainfall_service = RainfallService()
|
||||
|
||||
|
||||
@router.post("/grid", response_model=RainfallGridResponse, summary="获取降雨栅格数据")
|
||||
async def get_rainfall_grid(request: RainfallGridRequest):
|
||||
"""
|
||||
获取指定时间的降雨栅格数据
|
||||
|
||||
使用反距离权重插值(IDW)方法,将站点降雨数据插值为连续栅格,
|
||||
返回适合Cesium渲染的GeoJSON格式数据。
|
||||
|
||||
Args:
|
||||
request: 包含时间、分辨率和持续时间的请求
|
||||
|
||||
Returns:
|
||||
GeoJSON格式的栅格数据
|
||||
"""
|
||||
try:
|
||||
# 解析时间,如果未提供则使用当前时间
|
||||
now = datetime.now()
|
||||
query_time = datetime.fromisoformat(request.time) if request.time else now
|
||||
|
||||
# 验证duration参数
|
||||
if request.duration not in [12, 24]:
|
||||
raise ValueError("duration参数必须为12或24")
|
||||
|
||||
# 调用服务层生成栅格(自动查询前12小时或24小时数据)
|
||||
geojson_data = rainfall_service.generate_rainfall_grid(
|
||||
query_time=query_time,
|
||||
resolution=request.resolution,
|
||||
duration=request.duration
|
||||
)
|
||||
|
||||
if not geojson_data:
|
||||
return RainfallGridResponse(
|
||||
code=404,
|
||||
message="未找到降雨数据",
|
||||
data=None
|
||||
)
|
||||
|
||||
return RainfallGridResponse(
|
||||
code=200,
|
||||
message="降雨栅格数据生成成功",
|
||||
data=geojson_data
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"时间格式错误: {e}")
|
||||
raise HTTPException(status_code=400, detail=f"时间格式错误: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"生成降雨栅格失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=f"生成降雨栅格失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/stations", response_model=StationsResponse, summary="获取雨量站点数据")
|
||||
async def get_rainfall_stations(
|
||||
time: str = Query(..., description="查询时间 ISO格式(自动查询前12小时或24小时数据)"),
|
||||
duration: int = Query(12, description="持续时间(小时),可选12或24", ge=12, le=24)
|
||||
):
|
||||
"""
|
||||
获取指定时间的雨量站点原始数据
|
||||
|
||||
Args:
|
||||
time: 查询时间
|
||||
duration: 持续时间(12或24小时)
|
||||
|
||||
Returns:
|
||||
站点列表,包含经纬度和降雨量
|
||||
"""
|
||||
try:
|
||||
query_time = datetime.fromisoformat(time)
|
||||
|
||||
# 调用服务层获取站点数据(自动查询前12小时或24小时数据)
|
||||
stations = rainfall_service.get_stations_data(
|
||||
query_time=query_time,
|
||||
duration=duration
|
||||
)
|
||||
|
||||
return StationsResponse(
|
||||
code=200,
|
||||
message="查询成功",
|
||||
data=stations
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"时间格式错误: {e}")
|
||||
raise HTTPException(status_code=400, detail=f"时间格式错误: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"查询站点数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/point", summary="查询指定点位的降雨量")
|
||||
async def get_rainfall_at_point(
|
||||
longitude: float,
|
||||
latitude: float,
|
||||
time: Optional[str] = None,
|
||||
duration: int = Query(12, description="持续时间(小时),可选12或24", ge=12, le=24)
|
||||
):
|
||||
"""
|
||||
查询指定经纬度位置的降雨量
|
||||
|
||||
Args:
|
||||
longitude: 经度
|
||||
latitude: 纬度
|
||||
time: 查询时间(可选,默认当前时间,自动查询前12小时或24小时数据)
|
||||
duration: 持续时间(12或24小时)
|
||||
|
||||
Returns:
|
||||
该点位的降雨量信息
|
||||
"""
|
||||
try:
|
||||
from app.services.rainfall_service import RainfallService
|
||||
|
||||
# 解析时间
|
||||
now = datetime.now()
|
||||
query_time = datetime.fromisoformat(time) if time else now
|
||||
|
||||
# 验证duration参数
|
||||
if duration not in [12, 24]:
|
||||
raise ValueError("duration参数必须为12或24")
|
||||
|
||||
# 调用服务层查询(自动查询前12小时或24小时数据)
|
||||
service = RainfallService()
|
||||
rainfall_info = service.get_rainfall_at_point(
|
||||
longitude=longitude,
|
||||
latitude=latitude,
|
||||
query_time=query_time,
|
||||
duration=duration
|
||||
)
|
||||
|
||||
if not rainfall_info:
|
||||
return {
|
||||
"code": 200,
|
||||
"message": "未找到该点位的降雨数据",
|
||||
"data": None
|
||||
}
|
||||
|
||||
return {
|
||||
"code": 200,
|
||||
"message": "查询成功",
|
||||
"data": rainfall_info
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询点位降雨量失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -1,48 +0,0 @@
|
||||
"""
|
||||
基础配置类
|
||||
"""
|
||||
from pydantic_settings import BaseSettings
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class EnvironmentEnum(str, Enum):
|
||||
"""环境枚举"""
|
||||
DEVELOPMENT = "development"
|
||||
PRODUCTION = "production"
|
||||
|
||||
|
||||
class BaseConfig(BaseSettings):
|
||||
"""基础配置类"""
|
||||
|
||||
# 应用基本信息
|
||||
APP_NAME: str = "西安项目算法"
|
||||
APP_VERSION: str = "1.0.0"
|
||||
ENVIRONMENT: EnvironmentEnum = EnvironmentEnum.DEVELOPMENT
|
||||
|
||||
# 调试模式
|
||||
DEBUG: bool = True
|
||||
|
||||
# API配置
|
||||
API_HOST: str = "127.0.0.1" # 默认只监听本地
|
||||
API_PORT: int = 8000
|
||||
|
||||
# CORS配置(默认只允许localhost)
|
||||
CORS_ORIGINS: list = ["http://localhost", "http://127.0.0.1"]
|
||||
|
||||
# 日志配置
|
||||
LOG_LEVEL: str = "INFO"
|
||||
LOG_DIR: str = "logs"
|
||||
|
||||
class Config:
|
||||
env_file = ".env.development" # 默认使用开发环境配置
|
||||
case_sensitive = True
|
||||
|
||||
@property
|
||||
def is_development(self) -> bool:
|
||||
"""是否为开发环境"""
|
||||
return self.ENVIRONMENT == EnvironmentEnum.DEVELOPMENT
|
||||
|
||||
@property
|
||||
def is_production(self) -> bool:
|
||||
"""是否为生产环境"""
|
||||
return self.ENVIRONMENT == EnvironmentEnum.PRODUCTION
|
||||
@@ -1,37 +0,0 @@
|
||||
"""
|
||||
数据库配置
|
||||
"""
|
||||
from .base_config import BaseConfig
|
||||
|
||||
|
||||
class DatabaseConfig(BaseConfig):
|
||||
"""数据库配置类"""
|
||||
|
||||
# PostgreSQL配置
|
||||
DB_HOST: str = "localhost"
|
||||
DB_PORT: int = 5432
|
||||
DB_USER: str = "postgres"
|
||||
DB_PASSWORD: str = "postgres"
|
||||
DB_NAME: str = "test_db"
|
||||
|
||||
# 连接池配置
|
||||
DB_POOL_SIZE: int = 10
|
||||
DB_MAX_OVERFLOW: int = 20
|
||||
DB_POOL_TIMEOUT: int = 30
|
||||
DB_POOL_RECYCLE: int = 3600
|
||||
|
||||
@property
|
||||
def DATABASE_URL(self) -> str:
|
||||
"""构建数据库连接URL"""
|
||||
return (
|
||||
f"postgresql://{self.DB_USER}:{self.DB_PASSWORD}"
|
||||
f"@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
|
||||
)
|
||||
|
||||
@property
|
||||
def ASYNC_DATABASE_URL(self) -> str:
|
||||
"""构建异步数据库连接URL"""
|
||||
return (
|
||||
f"postgresql+asyncpg://{self.DB_USER}:{self.DB_PASSWORD}"
|
||||
f"@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
|
||||
)
|
||||
@@ -1,22 +0,0 @@
|
||||
"""
|
||||
开发环境配置
|
||||
"""
|
||||
from .database_config import DatabaseConfig
|
||||
from .base_config import EnvironmentEnum
|
||||
|
||||
|
||||
class DevelopmentConfig(DatabaseConfig):
|
||||
"""开发环境配置 - 只覆盖与生产不同的配置"""
|
||||
|
||||
ENVIRONMENT: EnvironmentEnum = EnvironmentEnum.DEVELOPMENT
|
||||
|
||||
# 调试和日志
|
||||
DEBUG: bool = True
|
||||
LOG_LEVEL: str = "DEBUG"
|
||||
|
||||
# 开发特性
|
||||
RELOAD: bool = True
|
||||
|
||||
class Config:
|
||||
env_file = ".env.development"
|
||||
case_sensitive = True
|
||||
@@ -1,26 +0,0 @@
|
||||
"""
|
||||
生产环境配置
|
||||
"""
|
||||
from .database_config import DatabaseConfig
|
||||
from .base_config import EnvironmentEnum
|
||||
|
||||
|
||||
class ProductionConfig(DatabaseConfig):
|
||||
"""生产环境配置 - 只覆盖性能和安全相关配置"""
|
||||
|
||||
ENVIRONMENT: EnvironmentEnum = EnvironmentEnum.PRODUCTION
|
||||
|
||||
# 调试和日志
|
||||
DEBUG: bool = False
|
||||
LOG_LEVEL: str = "WARNING"
|
||||
|
||||
# 生产特性
|
||||
RELOAD: bool = False
|
||||
|
||||
# 数据库连接池(生产环境需要更大)
|
||||
DB_POOL_SIZE: int = 20
|
||||
DB_MAX_OVERFLOW: int = 40
|
||||
|
||||
class Config:
|
||||
env_file = ".env.production"
|
||||
case_sensitive = True
|
||||
@@ -1,65 +0,0 @@
|
||||
"""
|
||||
配置加载器 - 根据环境自动加载对应配置
|
||||
"""
|
||||
import os
|
||||
from typing import Type
|
||||
from .base_config import BaseConfig, EnvironmentEnum
|
||||
from .development import DevelopmentConfig
|
||||
from .production import ProductionConfig
|
||||
|
||||
|
||||
def get_config_class(environment: str = None) -> Type[BaseConfig]:
|
||||
"""根据环境获取配置类
|
||||
|
||||
Args:
|
||||
environment: 环境名称 (development/production)
|
||||
|
||||
Returns:
|
||||
对应的配置类
|
||||
"""
|
||||
if environment is None:
|
||||
environment = os.getenv("ENVIRONMENT", "development")
|
||||
|
||||
config_map = {
|
||||
EnvironmentEnum.DEVELOPMENT: DevelopmentConfig,
|
||||
EnvironmentEnum.PRODUCTION: ProductionConfig,
|
||||
}
|
||||
|
||||
try:
|
||||
env_enum = EnvironmentEnum(environment)
|
||||
return config_map[env_enum]
|
||||
except ValueError:
|
||||
print(f"警告: 未知环境 '{environment}',使用默认开发环境配置")
|
||||
return DevelopmentConfig
|
||||
|
||||
|
||||
def load_config(environment: str = None) -> BaseConfig:
|
||||
"""加载配置
|
||||
|
||||
Args:
|
||||
environment: 环境名称
|
||||
|
||||
Returns:
|
||||
配置实例
|
||||
"""
|
||||
config_class = get_config_class(environment)
|
||||
return config_class()
|
||||
|
||||
|
||||
# 全局配置实例(延迟加载)
|
||||
_config_instance = None
|
||||
|
||||
|
||||
def get_settings() -> BaseConfig:
|
||||
"""获取全局配置实例(单例模式)"""
|
||||
global _config_instance
|
||||
if _config_instance is None:
|
||||
environment = os.getenv("ENVIRONMENT", None)
|
||||
_config_instance = load_config(environment)
|
||||
return _config_instance
|
||||
|
||||
|
||||
def reload_config(environment: str = None):
|
||||
"""重新加载配置"""
|
||||
global _config_instance
|
||||
_config_instance = load_config(environment)
|
||||
@@ -1,255 +0,0 @@
|
||||
"""
|
||||
数据库连接管理 - 使用SQLAlchemy 2.0
|
||||
"""
|
||||
import logging
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.orm import sessionmaker, Session, DeclarativeBase
|
||||
from sqlalchemy.pool import QueuePool
|
||||
from typing import List, Dict, Any
|
||||
from contextlib import contextmanager
|
||||
|
||||
from app.config.settings import get_settings
|
||||
from app.utils.logger import setup_logging
|
||||
|
||||
# 初始化日志
|
||||
logger = setup_logging()
|
||||
|
||||
# 关闭SQLAlchemy引擎日志,只保留关键信息
|
||||
logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING)
|
||||
logging.getLogger('sqlalchemy.pool').setLevel(logging.WARNING)
|
||||
|
||||
# 获取配置
|
||||
settings = get_settings()
|
||||
|
||||
# 创建数据库引擎
|
||||
engine = create_engine(
|
||||
settings.DATABASE_URL,
|
||||
poolclass=QueuePool,
|
||||
pool_size=settings.DB_POOL_SIZE,
|
||||
max_overflow=settings.DB_MAX_OVERFLOW,
|
||||
pool_timeout=settings.DB_POOL_TIMEOUT,
|
||||
pool_recycle=settings.DB_POOL_RECYCLE,
|
||||
pool_pre_ping=True,
|
||||
echo=False, # 关闭SQL语句打印
|
||||
connect_args={
|
||||
"options": "-c client_encoding=UTF8"
|
||||
}
|
||||
)
|
||||
|
||||
# 创建会话工厂
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
# 声明基类
|
||||
Base = DeclarativeBase()
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""数据库管理器 - 提供通用的CRUD操作"""
|
||||
|
||||
def __init__(self):
|
||||
self.engine = engine
|
||||
self.SessionLocal = SessionLocal
|
||||
|
||||
@contextmanager
|
||||
def get_session(self) -> Session:
|
||||
"""获取数据库会话的上下文管理器"""
|
||||
session = self.SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"数据库操作失败: {e}")
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def init_db(self):
|
||||
"""初始化数据库,创建所有表"""
|
||||
try:
|
||||
Base.metadata.create_all(bind=self.engine)
|
||||
logger.info("数据库表创建成功")
|
||||
except Exception as e:
|
||||
logger.error(f"数据库表创建失败: {e}")
|
||||
raise
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""测试数据库连接"""
|
||||
try:
|
||||
with self.get_session() as session:
|
||||
session.execute(text("SELECT 1"))
|
||||
logger.info("数据库连接测试成功")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接测试失败: {e}")
|
||||
return False
|
||||
|
||||
def execute_raw_sql(self, sql: str, params: dict = None) -> List[Dict[str, Any]]:
|
||||
"""执行原生SQL查询
|
||||
|
||||
Args:
|
||||
sql: SQL语句
|
||||
params: 参数字典
|
||||
|
||||
Returns:
|
||||
查询结果列表
|
||||
"""
|
||||
with self.get_session() as session:
|
||||
result = session.execute(text(sql), params or {})
|
||||
if result.returns_rows:
|
||||
columns = result.keys()
|
||||
return [dict(zip(columns, row)) for row in result.fetchall()]
|
||||
return []
|
||||
|
||||
def insert(self, table_name: str, data: Dict[str, Any]) -> int:
|
||||
"""插入单条记录
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
data: 要插入的数据字典
|
||||
|
||||
Returns:
|
||||
插入的行数
|
||||
"""
|
||||
columns = ", ".join(data.keys())
|
||||
placeholders = ", ".join([f":{key}" for key in data.keys()])
|
||||
sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
|
||||
|
||||
with self.get_session() as session:
|
||||
result = session.execute(text(sql), data)
|
||||
return result.rowcount
|
||||
|
||||
def insert_many(self, table_name: str, data_list: List[Dict[str, Any]]) -> int:
|
||||
"""批量插入记录
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
data_list: 数据字典列表
|
||||
|
||||
Returns:
|
||||
插入的行数
|
||||
"""
|
||||
if not data_list:
|
||||
return 0
|
||||
|
||||
columns = ", ".join(data_list[0].keys())
|
||||
placeholders = ", ".join([f":{key}" for key in data_list[0].keys()])
|
||||
sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
|
||||
|
||||
with self.get_session() as session:
|
||||
result = session.execute(text(sql), data_list)
|
||||
return result.rowcount
|
||||
|
||||
def select(
|
||||
self,
|
||||
table_name: str,
|
||||
conditions: Dict[str, Any] = None,
|
||||
columns: List[str] = None,
|
||||
limit: int = None,
|
||||
offset: int = None,
|
||||
order_by: str = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""查询记录
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
conditions: 查询条件字典
|
||||
columns: 要查询的列列表,None表示查询所有列
|
||||
limit: 限制返回行数
|
||||
offset: 偏移量
|
||||
order_by: 排序字段
|
||||
|
||||
Returns:
|
||||
查询结果列表
|
||||
"""
|
||||
col_str = ", ".join(columns) if columns else "*"
|
||||
sql = f"SELECT {col_str} FROM {table_name}"
|
||||
|
||||
params = {}
|
||||
if conditions:
|
||||
where_clauses = []
|
||||
for key, value in conditions.items():
|
||||
where_clauses.append(f"{key} = :{key}")
|
||||
params[key] = value
|
||||
sql += " WHERE " + " AND ".join(where_clauses)
|
||||
|
||||
if order_by:
|
||||
sql += f" ORDER BY {order_by}"
|
||||
|
||||
if limit:
|
||||
sql += f" LIMIT :limit"
|
||||
params["limit"] = limit
|
||||
|
||||
if offset:
|
||||
sql += f" OFFSET :offset"
|
||||
params["offset"] = offset
|
||||
|
||||
return self.execute_raw_sql(sql, params)
|
||||
|
||||
def update(self, table_name: str, data: Dict[str, Any],
|
||||
conditions: Dict[str, Any]) -> int:
|
||||
"""更新记录
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
data: 要更新的数据字典
|
||||
conditions: 更新条件字典
|
||||
|
||||
Returns:
|
||||
更新的行数
|
||||
"""
|
||||
set_clauses = [f"{key} = :{key}" for key in data.keys()]
|
||||
where_clauses = [f"{key} = :cond_{key}" for key in conditions.keys()]
|
||||
|
||||
sql = f"UPDATE {table_name} SET {', '.join(set_clauses)} WHERE {' AND '.join(where_clauses)}"
|
||||
|
||||
# 合并参数,避免键名冲突
|
||||
params = {**data, **{f"cond_{key}": value for key, value in conditions.items()}}
|
||||
|
||||
with self.get_session() as session:
|
||||
result = session.execute(text(sql), params)
|
||||
return result.rowcount
|
||||
|
||||
def delete(self, table_name: str, conditions: Dict[str, Any]) -> int:
|
||||
"""删除记录
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
conditions: 删除条件字典
|
||||
|
||||
Returns:
|
||||
删除的行数
|
||||
"""
|
||||
where_clauses = [f"{key} = :{key}" for key in conditions.keys()]
|
||||
sql = f"DELETE FROM {table_name} WHERE {' AND '.join(where_clauses)}"
|
||||
|
||||
with self.get_session() as session:
|
||||
result = session.execute(text(sql), conditions)
|
||||
return result.rowcount
|
||||
|
||||
def count(self, table_name: str, conditions: Dict[str, Any] = None) -> int:
|
||||
"""统计记录数
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
conditions: 统计条件字典
|
||||
|
||||
Returns:
|
||||
记录数
|
||||
"""
|
||||
sql = f"SELECT COUNT(*) as count FROM {table_name}"
|
||||
params = {}
|
||||
|
||||
if conditions:
|
||||
where_clauses = []
|
||||
for key, value in conditions.items():
|
||||
where_clauses.append(f"{key} = :{key}")
|
||||
params[key] = value
|
||||
sql += " WHERE " + " AND ".join(where_clauses)
|
||||
|
||||
result = self.execute_raw_sql(sql, params)
|
||||
return result[0]["count"] if result else 0
|
||||
|
||||
|
||||
# 创建全局数据库管理器实例
|
||||
db_manager = DatabaseManager()
|
||||
-110
@@ -1,110 +0,0 @@
|
||||
"""
|
||||
FastAPI应用主文件
|
||||
"""
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.config.settings import get_settings
|
||||
from app.core.database import db_manager
|
||||
from app.utils.logger import setup_logging
|
||||
from app.api.rainfall import router as rainfall_router
|
||||
|
||||
# 初始化日志
|
||||
logger = setup_logging()
|
||||
|
||||
# 获取配置
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
def create_application() -> FastAPI:
|
||||
"""创建FastAPI应用实例"""
|
||||
|
||||
application = FastAPI(
|
||||
title=settings.APP_NAME,
|
||||
version=settings.APP_VERSION,
|
||||
debug=settings.DEBUG,
|
||||
description="基于FastAPI的现代化Web应用框架"
|
||||
)
|
||||
|
||||
# 配置CORS
|
||||
application.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.CORS_ORIGINS if hasattr(settings, 'CORS_ORIGINS') else ["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
return application
|
||||
|
||||
|
||||
# 创建应用实例
|
||||
app = create_application()
|
||||
|
||||
|
||||
# ==================== 启动和关闭事件 ====================
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""应用启动时执行"""
|
||||
logger.info(f"正在启动 {settings.APP_NAME} v{settings.APP_VERSION}")
|
||||
logger.info(f"环境: {settings.ENVIRONMENT}")
|
||||
logger.info(f"数据库连接: {settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}")
|
||||
|
||||
# 测试数据库连接
|
||||
if db_manager.test_connection():
|
||||
logger.info("数据库连接成功")
|
||||
else:
|
||||
logger.warning("数据库连接失败,请检查配置")
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""应用关闭时执行"""
|
||||
logger.info("应用正在关闭...")
|
||||
|
||||
|
||||
# ==================== 根路径和健康检查 ====================
|
||||
|
||||
@app.get("/", tags=["基础"])
|
||||
async def root():
|
||||
"""根路径 - 欢迎信息"""
|
||||
return {
|
||||
"app": settings.APP_NAME,
|
||||
"version": settings.APP_VERSION,
|
||||
"environment": settings.ENVIRONMENT.value,
|
||||
"status": "running"
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health", tags=["基础"])
|
||||
async def health_check():
|
||||
"""健康检查接口"""
|
||||
try:
|
||||
# 检查数据库连接
|
||||
db_manager.execute_raw_sql("SELECT 1")
|
||||
db_status = "connected"
|
||||
except Exception as e:
|
||||
db_status = f"error: {str(e)}"
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"database": db_status,
|
||||
"app_version": settings.APP_VERSION,
|
||||
"environment": settings.ENVIRONMENT.value
|
||||
}
|
||||
|
||||
|
||||
# ==================== 注册路由 ====================
|
||||
|
||||
app.include_router(rainfall_router)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.API_HOST,
|
||||
port=settings.API_PORT,
|
||||
reload=settings.RELOAD if hasattr(settings, 'RELOAD') else settings.DEBUG
|
||||
)
|
||||
@@ -1,55 +0,0 @@
|
||||
"""
|
||||
降雨数据Repository - 数据访问层
|
||||
"""
|
||||
from typing import List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.database import db_manager
|
||||
from app.utils.logger import setup_logging
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
|
||||
class RainfallRepository:
|
||||
"""降雨数据仓储类"""
|
||||
|
||||
@staticmethod
|
||||
def query_stations_rainfall(
|
||||
query_time: datetime,
|
||||
duration: int = 12
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
查询指定时间的站点降雨数据(自动查询前12小时或24小时)
|
||||
|
||||
Args:
|
||||
query_time: 查询时间
|
||||
duration: 持续时间(12或24小时)
|
||||
|
||||
Returns:
|
||||
站点降雨数据列表
|
||||
"""
|
||||
sql = f"""
|
||||
SELECT
|
||||
lon,
|
||||
lat,
|
||||
SUM(rainfall_1h::numeric) AS rainfall
|
||||
FROM xian_meteorology
|
||||
WHERE datetime BETWEEN (
|
||||
to_char(timestamp :query_time - interval '{duration} hours', 'YYYYMMDDHH24MISS')
|
||||
)::bigint AND (
|
||||
to_char(timestamp :query_time, 'YYYYMMDDHH24MISS')
|
||||
)::bigint
|
||||
GROUP BY lon, lat
|
||||
"""
|
||||
|
||||
params = {
|
||||
"query_time": query_time
|
||||
}
|
||||
|
||||
try:
|
||||
result = db_manager.execute_raw_sql(sql, params)
|
||||
logger.info(f"查询到 {len(result)} 个站点数据({duration}小时)")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"查询站点降雨数据失败: {e}")
|
||||
raise
|
||||
@@ -1,64 +0,0 @@
|
||||
"""
|
||||
降雨数据相关的Pydantic Schemas
|
||||
"""
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
class RainfallGridRequest(BaseModel):
|
||||
"""降雨栅格请求模型"""
|
||||
time: Optional[str] = Field(
|
||||
None,
|
||||
alias="time",
|
||||
description="查询时间 ISO格式,默认为当前时间(自动查询前12小时或24小时数据)",
|
||||
example="2024-01-01T12:00:00"
|
||||
)
|
||||
resolution: float = Field(
|
||||
0.01,
|
||||
alias="resolution",
|
||||
description="栅格分辨率(度)",
|
||||
gt=0,
|
||||
le=0.1
|
||||
)
|
||||
duration: int = Field(
|
||||
12,
|
||||
alias="duration",
|
||||
description="持续时间(小时),可选12或24",
|
||||
ge=12,
|
||||
le=24
|
||||
)
|
||||
|
||||
class Config:
|
||||
populate_by_name = True # 允许同时使用字段名和别名
|
||||
|
||||
|
||||
class StationData(BaseModel):
|
||||
"""站点数据模型"""
|
||||
lon: float
|
||||
lat: float
|
||||
rainfall: float
|
||||
|
||||
|
||||
class GridMetadata(BaseModel):
|
||||
"""栅格元数据"""
|
||||
start_time: str
|
||||
end_time: str
|
||||
district_id: int
|
||||
resolution: float
|
||||
station_count: int
|
||||
grid_size: List[int]
|
||||
bounds: dict
|
||||
|
||||
|
||||
class RainfallGridResponse(BaseModel):
|
||||
"""降雨栅格响应模型 - 符合前端 ApiResponse 结构"""
|
||||
code: int = Field(200, description="状态码")
|
||||
message: str = Field(..., description="响应消息")
|
||||
data: Optional[dict] = Field(None, description="响应数据")
|
||||
|
||||
|
||||
class StationsResponse(BaseModel):
|
||||
"""站点数据响应模型 - 符合前端 ApiResponse 结构"""
|
||||
code: int = Field(200, description="状态码")
|
||||
message: str = Field(..., description="响应消息")
|
||||
data: List[StationData] = Field(default_factory=list, description="站点数据列表")
|
||||
@@ -1,605 +0,0 @@
|
||||
"""
|
||||
降雨数据Service - 业务逻辑层
|
||||
"""
|
||||
import numpy as np
|
||||
from scipy.spatial import Delaunay, ConvexHull
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from app.repositories.rainfall_repository import RainfallRepository
|
||||
from app.utils.logger import setup_logging
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
|
||||
class InterpolationService:
|
||||
"""插值服务类"""
|
||||
|
||||
@staticmethod
|
||||
def _create_buffer_points(
|
||||
points_array: np.ndarray
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
创建缓冲点:在原始站点外围生成虚拟点以扩展插值区域
|
||||
|
||||
Args:
|
||||
points_array: 原始站点坐标数组
|
||||
|
||||
Returns:
|
||||
缓冲点坐标数组
|
||||
"""
|
||||
# 计算站点分布的中心
|
||||
center = np.mean(points_array, axis=0)
|
||||
|
||||
# 计算站点到中心的最大距离
|
||||
distances_from_center = np.sqrt(np.sum((points_array - center) ** 2, axis=1))
|
||||
np.max(distances_from_center)
|
||||
|
||||
# 在站点外围生成缓冲点(沿着各个方向扩展)
|
||||
buffer_points = []
|
||||
num_angles = 360 # 每隔1度生成一个缓冲点
|
||||
|
||||
for angle_deg in range(0, 360, 360 // num_angles):
|
||||
angle_rad = np.radians(angle_deg)
|
||||
# 在凸包边界外扩展
|
||||
for scale in [1.05, 1.1, 1.15]:
|
||||
# 找到该方向上最远的站点
|
||||
direction = np.array([np.cos(angle_rad), np.sin(angle_rad)])
|
||||
projections = points_array @ direction
|
||||
max_idx = np.argmax(projections)
|
||||
|
||||
# 在该方向上扩展
|
||||
base_point = points_array[max_idx]
|
||||
buffer_point = center + (base_point - center) * scale
|
||||
buffer_points.append(buffer_point)
|
||||
|
||||
return np.array(buffer_points)
|
||||
|
||||
@staticmethod
|
||||
def gaussian_smoothing(
|
||||
grid_data: np.ndarray,
|
||||
sigma: float = 1.5
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
高斯平滑滤波,减少边缘突变
|
||||
|
||||
Args:
|
||||
grid_data: 栅格数据
|
||||
sigma: 高斯核标准差
|
||||
|
||||
Returns:
|
||||
平滑后的栅格数据
|
||||
"""
|
||||
from scipy.ndimage import gaussian_filter
|
||||
|
||||
# 只对有效数据进行平滑
|
||||
valid_mask = ~np.isnan(grid_data)
|
||||
if not np.any(valid_mask):
|
||||
return grid_data
|
||||
|
||||
# 填充NaN值以便平滑
|
||||
filled_data = grid_data.copy()
|
||||
mean_val = np.nanmean(grid_data)
|
||||
filled_data[~valid_mask] = mean_val
|
||||
|
||||
# 应用高斯滤波
|
||||
smoothed = gaussian_filter(filled_data, sigma=sigma)
|
||||
|
||||
# 恢复原始NaN区域
|
||||
result = np.where(valid_mask, smoothed, np.nan)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def calculate_adaptive_max_distance(
|
||||
points_array: np.ndarray,
|
||||
base_distance: float = 0.3,
|
||||
min_distance: float = 0.15,
|
||||
max_distance: float = 0.5
|
||||
) -> float:
|
||||
"""
|
||||
根据站点密度自适应计算最大影响距离
|
||||
|
||||
Args:
|
||||
points_array: 站点坐标数组
|
||||
base_distance: 基础距离
|
||||
min_distance: 最小距离
|
||||
max_distance: 最大距离
|
||||
|
||||
Returns:
|
||||
自适应的最大影响距离
|
||||
"""
|
||||
if len(points_array) < 3:
|
||||
return base_distance
|
||||
|
||||
# 计算站点间的平均距离
|
||||
from scipy.spatial import distance_matrix
|
||||
dist_matrix = distance_matrix(points_array, points_array)
|
||||
|
||||
# 排除对角线(自身距离为0)
|
||||
np.fill_diagonal(dist_matrix, np.inf)
|
||||
avg_distance = np.mean(np.min(dist_matrix, axis=1))
|
||||
|
||||
# 根据平均距离调整max_distance
|
||||
adaptive_distance = avg_distance * 3 # 约3倍平均站点间距
|
||||
|
||||
# 限制在合理范围内
|
||||
return np.clip(adaptive_distance, min_distance, max_distance)
|
||||
|
||||
@staticmethod
|
||||
def inverse_distance_weighting(
|
||||
points: List[Tuple[float, float]],
|
||||
values: List[float],
|
||||
grid_lon: np.ndarray,
|
||||
grid_lat: np.ndarray,
|
||||
power: float = 2.0,
|
||||
max_distance: float = None,
|
||||
use_adaptive_distance: bool = True,
|
||||
apply_smoothing: bool = True,
|
||||
smoothing_sigma: float = 1.0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
反距离权重插值 (IDW) - 优化版本
|
||||
改进:
|
||||
1. 高斯核衰减替代简单幂律
|
||||
2. 自适应距离阈值
|
||||
3. 边缘渐变处理
|
||||
4. 高斯平滑减少突变
|
||||
|
||||
Args:
|
||||
points: 已知点坐标 [(lon, lat), ...]
|
||||
values: 已知点的值 [rainfall, ...]
|
||||
grid_lon: 网格经度数组
|
||||
grid_lat: 网格纬度数组
|
||||
power: 距离幂次(基础值)
|
||||
max_distance: 最大影响距离(度),None则自适应计算
|
||||
use_adaptive_distance: 是否使用自适应距离
|
||||
apply_smoothing: 是否应用平滑
|
||||
smoothing_sigma: 平滑强度
|
||||
|
||||
Returns:
|
||||
插值后的栅格数据,无效区域为 NaN
|
||||
"""
|
||||
points_array = np.array(points)
|
||||
values_array = np.array(values)
|
||||
|
||||
# 创建网格
|
||||
lon_grid, lat_grid = np.meshgrid(grid_lon, grid_lat)
|
||||
result = np.full_like(lon_grid, np.nan)
|
||||
|
||||
# 自适应计算最大距离
|
||||
if use_adaptive_distance or max_distance is None:
|
||||
actual_max_distance = InterpolationService.calculate_adaptive_max_distance(
|
||||
points_array
|
||||
)
|
||||
if max_distance is not None:
|
||||
actual_max_distance = min(actual_max_distance, max_distance)
|
||||
else:
|
||||
actual_max_distance = max_distance
|
||||
|
||||
logger.info(f"使用最大影响距离: {actual_max_distance:.3f} 度")
|
||||
|
||||
# 计算站点的凸包(带边缘缓冲)
|
||||
hull_mask = None
|
||||
confidence_mask = None # 置信度掩码
|
||||
if len(points_array) >= 3:
|
||||
try:
|
||||
# 创建缓冲站点:在原始站点外围添加虚拟点
|
||||
buffer_points = InterpolationService._create_buffer_points(
|
||||
points_array
|
||||
)
|
||||
|
||||
# 合并原始站点和缓冲站点
|
||||
all_points = np.vstack([points_array, buffer_points])
|
||||
|
||||
# 计算凸包
|
||||
hull = ConvexHull(all_points)
|
||||
hull_points = all_points[hull.vertices]
|
||||
tri = Delaunay(hull_points)
|
||||
|
||||
# 向量化判断所有网格点是否在凸包内
|
||||
grid_points = np.column_stack([lon_grid.ravel(), lat_grid.ravel()])
|
||||
hull_indices = tri.find_simplex(grid_points)
|
||||
hull_mask = hull_indices >= 0
|
||||
hull_mask = hull_mask.reshape(lon_grid.shape)
|
||||
|
||||
# 计算置信度:基于到最近站点的距离
|
||||
# 在凸包内但远离站点的区域降低置信度
|
||||
from scipy.spatial import distance_matrix
|
||||
grid_valid = grid_points[hull_mask.ravel()]
|
||||
if len(grid_valid) > 0:
|
||||
dist_to_stations = distance_matrix(grid_valid, points_array)
|
||||
min_distances = np.min(dist_to_stations, axis=1)
|
||||
|
||||
# 创建置信度掩码(距离越远,置信度越低)
|
||||
confidence = np.ones(len(grid_points))
|
||||
confidence[hull_mask.ravel()] = np.exp(-min_distances / actual_max_distance)
|
||||
confidence_mask = confidence.reshape(lon_grid.shape)
|
||||
else:
|
||||
confidence_mask = np.ones_like(lon_grid)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"凸包计算失败: {e},使用全区域插值")
|
||||
hull_mask = np.ones_like(lon_grid, dtype=bool)
|
||||
confidence_mask = np.ones_like(lon_grid)
|
||||
else:
|
||||
hull_mask = np.ones_like(lon_grid, dtype=bool)
|
||||
confidence_mask = np.ones_like(lon_grid)
|
||||
|
||||
# 向量化计算所有网格点到所有站点的距离
|
||||
lon_diff = lon_grid[:, :, np.newaxis] - points_array[np.newaxis, np.newaxis, :, 0]
|
||||
lat_diff = lat_grid[:, :, np.newaxis] - points_array[np.newaxis, np.newaxis, :, 1]
|
||||
distances = np.sqrt(lon_diff**2 + lat_diff**2)
|
||||
|
||||
# 过滤超出最大距离的站点
|
||||
valid_mask = distances <= actual_max_distance
|
||||
|
||||
# 对于每个网格点,检查是否有有效站点
|
||||
has_valid_stations = np.any(valid_mask, axis=2)
|
||||
|
||||
# 合并凸包掩码和有效站点掩码
|
||||
final_mask = hull_mask & has_valid_stations
|
||||
|
||||
# 避免除零
|
||||
distances = np.where(valid_mask, distances, np.inf)
|
||||
distances = np.maximum(distances, 1e-10)
|
||||
|
||||
# 优化的权重计算:结合幂律和高斯衰减
|
||||
# 近处使用幂律,远处使用高斯衰减使过渡更平滑
|
||||
power_weights = 1.0 / (distances ** power)
|
||||
gaussian_weights = np.exp(-0.5 * (distances / (actual_max_distance * 0.5)) ** 2)
|
||||
|
||||
# 混合权重:距离越远,高斯权重占比越大
|
||||
distance_ratio = distances / actual_max_distance
|
||||
mix_factor = np.clip(distance_ratio, 0, 1)
|
||||
weights = (1 - mix_factor) * power_weights + mix_factor * gaussian_weights
|
||||
|
||||
weights = np.where(valid_mask, weights, 0)
|
||||
|
||||
# 加权平均
|
||||
weighted_sum = np.sum(weights * values_array[np.newaxis, np.newaxis, :], axis=2)
|
||||
weight_total = np.sum(weights, axis=2)
|
||||
|
||||
# 计算基础插值结果(使用 errstate 忽略预期的除零警告,np.where 已安全过滤)
|
||||
with np.errstate(divide='ignore', invalid='ignore'):
|
||||
result = np.where(
|
||||
final_mask & (weight_total > 0),
|
||||
weighted_sum / weight_total,
|
||||
np.nan
|
||||
)
|
||||
|
||||
# 应用置信度调整:边缘区域向邻近值渐变
|
||||
if confidence_mask is not None:
|
||||
# 计算全局平均降雨量作为边缘区域的基准
|
||||
valid_rainfall = result[final_mask]
|
||||
if len(valid_rainfall) > 0:
|
||||
mean_rainfall = np.mean(valid_rainfall)
|
||||
# 边缘区域向平均值渐变
|
||||
result = np.where(
|
||||
final_mask,
|
||||
result,
|
||||
np.nan
|
||||
)
|
||||
# 根据置信度调整结果,低置信度区域向均值靠拢
|
||||
adjusted_result = result * confidence_mask + mean_rainfall * (1 - confidence_mask)
|
||||
result = np.where(final_mask, adjusted_result, np.nan)
|
||||
|
||||
# 应用高斯平滑减少边缘突变
|
||||
if apply_smoothing:
|
||||
result = InterpolationService.gaussian_smoothing(result, sigma=smoothing_sigma)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_rainfall_color(rainfall: float, duration: int = 12) -> str:
|
||||
"""
|
||||
根据降雨量获取颜色(按照国标)
|
||||
|
||||
Args:
|
||||
rainfall: 降雨量(mm)
|
||||
duration: 持续时间(12或24小时)
|
||||
|
||||
Returns:
|
||||
颜色字符串 "rgba(r,g,b,a)"
|
||||
"""
|
||||
# 国标降雨等级颜色映射
|
||||
if rainfall < 0.1:
|
||||
return "rgba(200,200,200,0)" # 透明 - 微量降雨(零星小雨)
|
||||
elif rainfall < 5 if duration == 12 else 9.9:
|
||||
return "rgba(0,0,255,0.4)" # 浅蓝 - 小雨
|
||||
elif rainfall < 15 if duration == 12 else 25:
|
||||
return "rgba(0,255,255,0.5)" # 青色 - 中雨
|
||||
elif rainfall < 30 if duration == 12 else 50:
|
||||
return "rgba(0,255,0,0.6)" # 绿色 - 大雨
|
||||
elif rainfall < 70 if duration == 12 else 100:
|
||||
return "rgba(255,255,0,0.7)" # 黄色 - 暴雨
|
||||
elif rainfall < 140 if duration == 12 else 250:
|
||||
return "rgba(255,165,0,0.8)" # 橙色 - 大暴雨
|
||||
else:
|
||||
return "rgba(255,0,0,0.9)" # 红色 - 特大暴雨
|
||||
|
||||
|
||||
class GeoJSONService:
|
||||
"""GeoJSON生成服务"""
|
||||
|
||||
@staticmethod
|
||||
def create_feature_collection(
|
||||
grid_metadata: Dict[str, Any],
|
||||
rainfall_array: np.ndarray,
|
||||
grid_lon: np.ndarray,
|
||||
grid_lat: np.ndarray,
|
||||
duration: int = 12
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
创建GeoJSON FeatureCollection用于Cesium渲染
|
||||
|
||||
Args:
|
||||
grid_metadata: 栅格元数据
|
||||
rainfall_array: 降雨量数组
|
||||
grid_lon: 经度网格
|
||||
grid_lat: 纬度网格
|
||||
duration: 持续时间(12或24小时)
|
||||
|
||||
Returns:
|
||||
GeoJSON格式的FeatureCollection
|
||||
"""
|
||||
features = []
|
||||
|
||||
# 将栅格数据转换为矩形要素
|
||||
for i in range(len(grid_lat) - 1):
|
||||
for j in range(len(grid_lon) - 1):
|
||||
rainfall_value = float(rainfall_array[i, j])
|
||||
|
||||
# 跳过无数据的区域
|
||||
if np.isnan(rainfall_value) or rainfall_value < 0:
|
||||
continue
|
||||
|
||||
# 创建矩形多边形
|
||||
lon_min = float(grid_lon[j])
|
||||
lon_max = float(grid_lon[j + 1])
|
||||
lat_min = float(grid_lat[i])
|
||||
lat_max = float(grid_lat[i + 1])
|
||||
|
||||
feature = {
|
||||
"type": "Feature",
|
||||
"geometry": {
|
||||
"type": "Polygon",
|
||||
"coordinates": [[
|
||||
[lon_min, lat_min],
|
||||
[lon_max, lat_min],
|
||||
[lon_max, lat_max],
|
||||
[lon_min, lat_max],
|
||||
[lon_min, lat_min]
|
||||
]]
|
||||
},
|
||||
"properties": {
|
||||
"rainfall": round(rainfall_value, 2),
|
||||
"level": RainfallService._get_rainfall_level(rainfall_value, duration),
|
||||
"color": InterpolationService.get_rainfall_color(rainfall_value, duration)
|
||||
}
|
||||
}
|
||||
features.append(feature)
|
||||
|
||||
return {
|
||||
"type": "FeatureCollection",
|
||||
"features": features,
|
||||
"metadata": {
|
||||
"resolution": grid_metadata['resolution'],
|
||||
"grid_size": [len(grid_lon), len(grid_lat)],
|
||||
"bounds": {
|
||||
"min_lon": float(grid_lon.min()),
|
||||
"max_lon": float(grid_lon.max()),
|
||||
"min_lat": float(grid_lat.min()),
|
||||
"max_lat": float(grid_lat.max())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class RainfallService:
|
||||
"""降雨数据业务服务类"""
|
||||
|
||||
def __init__(self):
|
||||
self.repository = RainfallRepository()
|
||||
self.interpolation_service = InterpolationService()
|
||||
self.geojson_service = GeoJSONService()
|
||||
|
||||
def get_stations_data(
|
||||
self,
|
||||
query_time: datetime,
|
||||
duration: int = 12
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取站点降雨数据
|
||||
|
||||
Args:
|
||||
query_time: 查询时间(自动查询前12小时或24小时数据)
|
||||
duration: 持续时间(12或24小时)
|
||||
|
||||
Returns:
|
||||
站点数据列表
|
||||
"""
|
||||
return self.repository.query_stations_rainfall(query_time, duration)
|
||||
|
||||
def generate_rainfall_grid(
|
||||
self,
|
||||
query_time: datetime,
|
||||
resolution: float = 0.01,
|
||||
duration: int = 12
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
生成降雨栅格数据
|
||||
|
||||
Args:
|
||||
query_time: 查询时间(自动查询前12小时或24小时数据)
|
||||
resolution: 栅格分辨率
|
||||
duration: 持续时间(12或24小时)
|
||||
|
||||
Returns:
|
||||
GeoJSON格式的栅格数据
|
||||
"""
|
||||
logger.info(f"查询降雨数据: {query_time}, 持续时间: {duration}小时")
|
||||
|
||||
# 查询站点数据(自动查询前12小时或24小时数据)
|
||||
stations_data = self.get_stations_data(query_time, duration)
|
||||
|
||||
if not stations_data:
|
||||
return None
|
||||
|
||||
# 提取站点坐标和降雨量(过滤空值)
|
||||
valid_stations = [row for row in stations_data if row['rainfall'] is not None]
|
||||
|
||||
if not valid_stations:
|
||||
logger.warning("所有站点的降雨量数据均为空")
|
||||
return None
|
||||
|
||||
points = [(row['lon'], row['lat']) for row in valid_stations]
|
||||
values = [float(row['rainfall']) for row in valid_stations]
|
||||
|
||||
# 确定栅格范围
|
||||
lon_min, lon_max = 107, 110
|
||||
lat_min, lat_max = 33, 35
|
||||
|
||||
# 创建栅格网格
|
||||
num_lon = int((lon_max - lon_min) / resolution) + 1
|
||||
num_lat = int((lat_max - lat_min) / resolution) + 1
|
||||
|
||||
grid_lon = np.linspace(lon_min, lon_max, num_lon)
|
||||
grid_lat = np.linspace(lat_min, lat_max, num_lat)
|
||||
|
||||
logger.info(f"生成栅格: {num_lon}x{num_lat}, 分辨率: {resolution}")
|
||||
|
||||
# 执行IDW插值(优化版本:自适应距离、混合权重、平滑处理)
|
||||
rainfall_grid = self.interpolation_service.inverse_distance_weighting(
|
||||
points=points,
|
||||
values=values,
|
||||
grid_lon=grid_lon,
|
||||
grid_lat=grid_lat,
|
||||
power=2.0,
|
||||
max_distance=0.35, # 最大影响距离0.35度(约35公里)
|
||||
use_adaptive_distance=True, # 启用自适应距离
|
||||
apply_smoothing=True, # 启用平滑处理
|
||||
smoothing_sigma=1.2 # 平滑强度
|
||||
)
|
||||
|
||||
# 创建栅格元数据
|
||||
grid_metadata = {
|
||||
"query_time": query_time.isoformat(),
|
||||
"resolution": resolution,
|
||||
"station_count": len(stations_data),
|
||||
"grid_size": [num_lon, num_lat]
|
||||
}
|
||||
|
||||
# 转换为GeoJSON格式
|
||||
geojson_data = self.geojson_service.create_feature_collection(
|
||||
grid_metadata, rainfall_grid, grid_lon, grid_lat, duration
|
||||
)
|
||||
|
||||
logger.info("降雨栅格数据生成成功")
|
||||
|
||||
return geojson_data
|
||||
|
||||
def get_rainfall_at_point(
|
||||
self,
|
||||
longitude: float,
|
||||
latitude: float,
|
||||
query_time: datetime,
|
||||
duration: int = 12
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
查询指定点位的降雨量(使用IDW插值)
|
||||
|
||||
Args:
|
||||
longitude: 经度
|
||||
latitude: 纬度
|
||||
query_time: 查询时间(自动查询前12小时或24小时数据)
|
||||
duration: 持续时间(12或24小时)
|
||||
|
||||
Returns:
|
||||
点位降雨量信息
|
||||
"""
|
||||
# 获取站点数据(自动查询前12小时或24小时数据)
|
||||
stations_data = self.get_stations_data(query_time, duration)
|
||||
|
||||
if not stations_data:
|
||||
return None
|
||||
|
||||
# 提取站点坐标和降雨量
|
||||
points = [(row['lon'], row['lat']) for row in stations_data]
|
||||
values = [float(row['rainfall']) for row in stations_data]
|
||||
|
||||
# 使用IDW插值计算该点的降雨量
|
||||
target_point = np.array([[longitude, latitude]])
|
||||
points_array = np.array(points)
|
||||
|
||||
# 计算距离
|
||||
distances = np.sqrt(
|
||||
(points_array[:, 0] - longitude) ** 2 +
|
||||
(points_array[:, 1] - latitude) ** 2
|
||||
)
|
||||
|
||||
# 避免除零
|
||||
min_dist = 1e-10
|
||||
distances = np.maximum(distances, min_dist)
|
||||
|
||||
# IDW公式
|
||||
power = 2.0
|
||||
weights = 1.0 / (distances ** power)
|
||||
rainfall_value = np.sum(weights * values) / np.sum(weights)
|
||||
|
||||
# 返回结果
|
||||
return {
|
||||
"longitude": longitude,
|
||||
"latitude": latitude,
|
||||
"rainfall": round(float(rainfall_value), 2),
|
||||
"level": self._get_rainfall_level(rainfall_value, duration),
|
||||
"color": InterpolationService.get_rainfall_color(rainfall_value, duration),
|
||||
"station_count": len(stations_data),
|
||||
"query_time": query_time.isoformat(),
|
||||
"duration": duration
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_rainfall_level(rainfall: float, duration: int = 12) -> str:
|
||||
"""
|
||||
获取降雨等级(按照国标)
|
||||
|
||||
Args:
|
||||
rainfall: 降雨量(mm)
|
||||
duration: 持续时间(12或24小时)
|
||||
|
||||
Returns:
|
||||
降雨等级字符串
|
||||
"""
|
||||
if duration == 12:
|
||||
# 12小时降雨等级标准
|
||||
if rainfall < 0.1:
|
||||
return "微量降雨"
|
||||
elif rainfall < 5.0:
|
||||
return "小雨"
|
||||
elif rainfall < 15.0:
|
||||
return "中雨"
|
||||
elif rainfall < 30.0:
|
||||
return "大雨"
|
||||
elif rainfall < 70.0:
|
||||
return "暴雨"
|
||||
elif rainfall < 140.0:
|
||||
return "大暴雨"
|
||||
else:
|
||||
return "特大暴雨"
|
||||
else: # 24小时
|
||||
# 24小时降雨等级标准
|
||||
if rainfall < 0.1:
|
||||
return "微量降雨"
|
||||
elif rainfall < 10.0:
|
||||
return "小雨"
|
||||
elif rainfall < 25.0:
|
||||
return "中雨"
|
||||
elif rainfall < 50.0:
|
||||
return "大雨"
|
||||
elif rainfall < 100.0:
|
||||
return "暴雨"
|
||||
elif rainfall < 250.0:
|
||||
return "大暴雨"
|
||||
else:
|
||||
return "特大暴雨"
|
||||
+66
-81
@@ -1,104 +1,89 @@
|
||||
"""
|
||||
日志配置工具
|
||||
日志工具类
|
||||
支持按天分割、自动清理过期日志
|
||||
"""
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from logging.handlers import RotatingFileHandler, TimedRotatingFileHandler
|
||||
from datetime import datetime
|
||||
from logging.handlers import TimedRotatingFileHandler
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
class LoggerConfig:
|
||||
"""日志配置类"""
|
||||
|
||||
def __init__(self, log_dir: str = "logs", log_level: str = "INFO"):
|
||||
self.log_dir = Path(log_dir)
|
||||
self.log_level = getattr(logging, log_level.upper(), logging.INFO)
|
||||
self.ensure_log_directory()
|
||||
|
||||
def ensure_log_directory(self):
|
||||
"""确保日志目录存在"""
|
||||
if not self.log_dir.exists():
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def setup_logger(
|
||||
self,
|
||||
name: str = "app",
|
||||
log_file: str = None,
|
||||
max_bytes: int = 10 * 1024 * 1024, # 10MB
|
||||
backup_count: int = 5,
|
||||
use_timed_rotation: bool = False
|
||||
) -> logging.Logger:
|
||||
"""配置日志记录器
|
||||
class LoggerManager:
|
||||
"""日志管理器"""
|
||||
|
||||
_loggers = {}
|
||||
|
||||
@classmethod
|
||||
def get_logger(cls, name: str = "algorithm", log_dir: str = "logs") -> logging.Logger:
|
||||
"""
|
||||
获取日志记录器
|
||||
|
||||
Args:
|
||||
name: 日志记录器名称
|
||||
log_file: 日志文件名(None则使用时间戳)
|
||||
max_bytes: 单个日志文件最大大小
|
||||
backup_count: 保留的备份文件数量
|
||||
use_timed_rotation: 是否使用按时间轮转
|
||||
name: 日志名称
|
||||
log_dir: 日志目录
|
||||
|
||||
Returns:
|
||||
配置好的Logger实例
|
||||
logging.Logger 实例
|
||||
"""
|
||||
if name in cls._loggers:
|
||||
return cls._loggers[name]
|
||||
|
||||
# 创建日志目录
|
||||
log_path = Path(log_dir)
|
||||
log_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 创建 logger
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(self.log_level)
|
||||
|
||||
# 避免重复添加handler
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# 避免重复添加 handler
|
||||
if logger.handlers:
|
||||
cls._loggers[name] = logger
|
||||
return logger
|
||||
|
||||
# 创建formatter
|
||||
|
||||
# 日志格式
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
'%(asctime)s [%(threadName)s] %(levelname)-5s %(name)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
|
||||
# 控制台handler
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(self.log_level)
|
||||
|
||||
# 控制台 Handler
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# 文件handler
|
||||
if log_file is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d")
|
||||
log_file = f"app_{timestamp}.log"
|
||||
|
||||
log_path = self.log_dir / log_file
|
||||
|
||||
if use_timed_rotation:
|
||||
# 按时间轮转(每天)
|
||||
file_handler = TimedRotatingFileHandler(
|
||||
log_path,
|
||||
when='midnight',
|
||||
interval=1,
|
||||
backupCount=backup_count,
|
||||
encoding='utf-8'
|
||||
)
|
||||
else:
|
||||
# 按大小轮转
|
||||
file_handler = RotatingFileHandler(
|
||||
log_path,
|
||||
maxBytes=max_bytes,
|
||||
backupCount=backup_count,
|
||||
encoding='utf-8'
|
||||
)
|
||||
|
||||
file_handler.setLevel(self.log_level)
|
||||
|
||||
# 文件 Handler - 按天分割
|
||||
log_file = log_path / f"{name}.log"
|
||||
file_handler = TimedRotatingFileHandler(
|
||||
filename=str(log_file),
|
||||
when='midnight',
|
||||
interval=1,
|
||||
backupCount=7,
|
||||
encoding='utf-8'
|
||||
)
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
file_handler.setFormatter(formatter)
|
||||
|
||||
# 设置日志文件命名格式
|
||||
file_handler.suffix = "%Y-%m-%d.log"
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
|
||||
cls._loggers[name] = logger
|
||||
return logger
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def get_logger(name: str = "algorithm", log_dir: str = "logs") -> logging.Logger:
|
||||
"""
|
||||
获取日志记录器的便捷函数
|
||||
|
||||
@staticmethod
|
||||
def get_logger(name: str = "app") -> logging.Logger:
|
||||
"""获取logger实例"""
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
# 全局日志配置
|
||||
def setup_logging(log_dir: str = "logs", log_level: str = "INFO") -> logging.Logger:
|
||||
"""设置全局日志"""
|
||||
logger_config = LoggerConfig(log_dir, log_level)
|
||||
return logger_config.setup_logger()
|
||||
Args:
|
||||
name: 日志名称
|
||||
log_dir: 日志目录
|
||||
|
||||
Returns:
|
||||
logging.Logger 实例
|
||||
"""
|
||||
return LoggerManager.get_logger(name, log_dir)
|
||||
|
||||
@@ -1,137 +0,0 @@
|
||||
"""
|
||||
平台检测工具 - 兼容Windows和Linux
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class PlatformDetector:
|
||||
"""平台检测器"""
|
||||
|
||||
def __init__(self):
|
||||
self.os_name = os.name
|
||||
self.platform_system = platform.system()
|
||||
self.platform_release = platform.release()
|
||||
self.platform_version = platform.version()
|
||||
self.python_version = sys.version_info
|
||||
self.is_windows = self.platform_system == "Windows"
|
||||
self.is_linux = self.platform_system == "Linux"
|
||||
self.is_macos = self.platform_system == "Darwin"
|
||||
|
||||
def get_platform_info(self) -> dict:
|
||||
"""获取平台信息"""
|
||||
return {
|
||||
"os_name": self.os_name,
|
||||
"platform_system": self.platform_system,
|
||||
"platform_release": self.platform_release,
|
||||
"platform_version": self.platform_version,
|
||||
"python_version": f"{self.python_version.major}.{self.python_version.minor}.{self.python_version.micro}",
|
||||
"is_windows": self.is_windows,
|
||||
"is_linux": self.is_linux,
|
||||
"is_macos": self.is_macos,
|
||||
}
|
||||
|
||||
def check_python_version(self, min_version: Tuple[int, int] = (3, 13)) -> bool:
|
||||
"""检查Python版本是否满足要求
|
||||
|
||||
Args:
|
||||
min_version: 最低版本要求 (major, minor)
|
||||
|
||||
Returns:
|
||||
是否满足版本要求
|
||||
"""
|
||||
current = (self.python_version.major, self.python_version.minor)
|
||||
return current >= min_version
|
||||
|
||||
def get_path_separator(self) -> str:
|
||||
"""获取路径分隔符"""
|
||||
return "\\" if self.is_windows else "/"
|
||||
|
||||
def normalize_path(self, path: str) -> str:
|
||||
"""标准化路径"""
|
||||
return Path(path).as_posix()
|
||||
|
||||
def get_project_root(self) -> Path:
|
||||
"""获取项目根目录"""
|
||||
return Path(__file__).parent.parent.parent
|
||||
|
||||
def ensure_directory(self, path: Path, create: bool = True) -> bool:
|
||||
"""确保目录存在
|
||||
|
||||
Args:
|
||||
path: 目录路径
|
||||
create: 是否自动创建
|
||||
|
||||
Returns:
|
||||
目录是否存在
|
||||
"""
|
||||
if path.exists():
|
||||
return True
|
||||
|
||||
if create:
|
||||
try:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"创建目录失败 {path}: {e}")
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
def get_env_file_path(self) -> Path:
|
||||
"""获取环境配置文件路径"""
|
||||
return self.get_project_root() / ".env"
|
||||
|
||||
def format_command_for_platform(self, command: str) -> str:
|
||||
"""根据平台格式化命令
|
||||
|
||||
Args:
|
||||
command: 原始命令
|
||||
|
||||
Returns:
|
||||
适合当前平台的命令
|
||||
"""
|
||||
if self.is_windows:
|
||||
# Windows特定命令转换
|
||||
if command.startswith("ls "):
|
||||
return command.replace("ls ", "dir ")
|
||||
elif command.startswith("cat "):
|
||||
return command.replace("cat ", "type ")
|
||||
elif command.startswith("rm "):
|
||||
return command.replace("rm ", "del ")
|
||||
elif command.startswith("cp "):
|
||||
return command.replace("cp ", "copy ")
|
||||
elif command.startswith("mv "):
|
||||
return command.replace("mv ", "move ")
|
||||
else:
|
||||
# Linux/Mac特定命令转换
|
||||
if command.startswith("dir "):
|
||||
return command.replace("dir ", "ls ")
|
||||
elif command.startswith("type "):
|
||||
return command.replace("type ", "cat ")
|
||||
elif command.startswith("del "):
|
||||
return command.replace("del ", "rm ")
|
||||
elif command.startswith("copy "):
|
||||
return command.replace("copy ", "cp ")
|
||||
elif command.startswith("move "):
|
||||
return command.replace("move ", "mv ")
|
||||
|
||||
return command
|
||||
|
||||
def print_platform_banner(self):
|
||||
"""打印平台信息横幅"""
|
||||
info = self.get_platform_info()
|
||||
print("=" * 60)
|
||||
print(" 系统信息")
|
||||
print("=" * 60)
|
||||
print(f" 操作系统: {info['platform_system']} {info['platform_release']}")
|
||||
print(f" Python版本: {info['python_version']}")
|
||||
print(f" 项目根目录: {self.get_project_root()}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
# 全局平台检测器实例
|
||||
platform_detector = PlatformDetector()
|
||||
@@ -0,0 +1,11 @@
|
||||
|
||||
from dynaconf import Dynaconf
|
||||
|
||||
settings = Dynaconf(
|
||||
# 配置文件
|
||||
settings_files=['settings.toml', '.secrets.toml'],
|
||||
# 环境变量
|
||||
environments=True,
|
||||
# 环境切换变量
|
||||
env_switcher="ENV_FOR_DYNACONF",
|
||||
)
|
||||
+1
-8
@@ -1,8 +1 @@
|
||||
fastapi==0.136.1
|
||||
uvicorn==0.46.0
|
||||
sqlalchemy==2.0.49
|
||||
psycopg2-binary==2.9.12
|
||||
pydantic==2.13.3
|
||||
pydantic-settings==2.14.0
|
||||
numpy==2.4.4
|
||||
scipy==1.17.1
|
||||
dynaconf == 3.2.13
|
||||
@@ -1,74 +0,0 @@
|
||||
@echo off
|
||||
REM Windows startup script - Development environment (background mode)
|
||||
echo ========================================
|
||||
echo Starting Development Environment
|
||||
echo ========================================
|
||||
|
||||
REM Change to project root directory
|
||||
cd /d "%~dp0.."
|
||||
|
||||
REM Verify we are in the correct directory
|
||||
if not exist "start.py" (
|
||||
echo Error: Cannot find start.py in current directory: %CD%
|
||||
echo Please ensure you are running this script from the project directory
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
echo Current directory: %CD%
|
||||
|
||||
REM Check and create virtual environment if not exists
|
||||
if not exist ".venv\Scripts\activate.bat" (
|
||||
echo Virtual environment not found, creating...
|
||||
python -m venv .venv
|
||||
if errorlevel 1 (
|
||||
echo Error: Failed to create virtual environment
|
||||
echo Please ensure Python is installed and accessible
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
echo Virtual environment created successfully
|
||||
)
|
||||
|
||||
REM Activate virtual environment
|
||||
call .venv\Scripts\activate.bat
|
||||
if errorlevel 1 (
|
||||
echo Error: Failed to activate virtual environment
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
echo Virtual environment activated
|
||||
|
||||
REM Upgrade pip and install dependencies
|
||||
echo Checking dependencies...
|
||||
.venv\Scripts\python.exe -m pip install --upgrade pip -q
|
||||
.venv\Scripts\python.exe -m pip install -r requirements.txt -q
|
||||
if errorlevel 1 (
|
||||
echo Error: Failed to install dependencies
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
echo Dependencies installed successfully
|
||||
|
||||
REM Create logs directory if not exists
|
||||
if not exist "logs" mkdir logs
|
||||
|
||||
REM Start application in background using start command
|
||||
REM Create a temporary startup script with clean environment
|
||||
set TEMP_START_SCRIPT=%TEMP%\start_xian_app_%RANDOM%.bat
|
||||
(
|
||||
echo @echo off
|
||||
echo cd /d %CD%
|
||||
echo set ENVIRONMENT=development
|
||||
echo .venv\Scripts\python.exe start.py
|
||||
) > "%TEMP_START_SCRIPT%"
|
||||
|
||||
start "Xian Algorithm Dev" cmd /k "title Xian Algorithm Dev && call "%TEMP_START_SCRIPT%""
|
||||
|
||||
echo.
|
||||
echo Application started in background
|
||||
echo To view logs, check: logs\app_*.log
|
||||
echo To stop the application, run: scripts\stop.bat
|
||||
echo ========================================
|
||||
|
||||
REM Keep the window open briefly to show any immediate errors
|
||||
timeout /t 3 /nobreak >nul
|
||||
@@ -1,52 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Linux/Mac startup script - Development environment (background mode)
|
||||
|
||||
echo "========================================"
|
||||
echo " Starting Development Environment"
|
||||
echo "========================================"
|
||||
|
||||
# Change to project root directory
|
||||
cd "$(dirname "$0")/.." || exit
|
||||
|
||||
# Check and create virtual environment if not exists
|
||||
if [ ! -f ".venv/bin/activate" ]; then
|
||||
echo "Virtual environment not found, creating..."
|
||||
python3 -m venv .venv
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: Failed to create virtual environment"
|
||||
echo "Please ensure Python 3 is installed and accessible"
|
||||
exit 1
|
||||
fi
|
||||
echo "Virtual environment created successfully"
|
||||
fi
|
||||
|
||||
# Activate virtual environment
|
||||
source .venv/bin/activate
|
||||
echo "Virtual environment activated"
|
||||
|
||||
# Upgrade pip and install dependencies
|
||||
echo "Checking dependencies..."
|
||||
pip install --upgrade pip -q
|
||||
pip install -r requirements.txt -q
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: Failed to install dependencies"
|
||||
exit 1
|
||||
fi
|
||||
echo "Dependencies installed successfully"
|
||||
|
||||
# Set environment variable
|
||||
export ENVIRONMENT=development
|
||||
|
||||
# Create logs directory if not exists
|
||||
mkdir -p logs
|
||||
|
||||
# Start application in background
|
||||
nohup python start.py > logs/app_dev.log 2>&1 &
|
||||
APP_PID=$!
|
||||
|
||||
echo $APP_PID > scripts/app_dev.pid
|
||||
echo ""
|
||||
echo "Application started in background (PID: $APP_PID)"
|
||||
echo "To view logs: tail -f logs/app_dev.log"
|
||||
echo "To stop the application, run: bash scripts/stop.sh"
|
||||
echo "========================================"
|
||||
@@ -1,52 +0,0 @@
|
||||
@echo off
|
||||
REM Windows startup script - Production environment (background mode)
|
||||
echo ========================================
|
||||
echo Starting Production Environment
|
||||
echo ========================================
|
||||
|
||||
REM Change to project root directory
|
||||
cd /d "%~dp0.."
|
||||
|
||||
REM Check and create virtual environment if not exists
|
||||
if not exist ".venv\Scripts\activate.bat" (
|
||||
echo Virtual environment not found, creating...
|
||||
python -m venv .venv
|
||||
if errorlevel 1 (
|
||||
echo Error: Failed to create virtual environment
|
||||
echo Please ensure Python is installed and accessible
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
echo Virtual environment created successfully
|
||||
)
|
||||
|
||||
REM Activate virtual environment
|
||||
call .venv\Scripts\activate.bat
|
||||
echo Virtual environment activated
|
||||
|
||||
REM Upgrade pip and install dependencies
|
||||
echo Checking dependencies...
|
||||
.venv\Scripts\python.exe -m pip install --upgrade pip -q
|
||||
.venv\Scripts\python.exe -m pip install -r requirements.txt -q
|
||||
if errorlevel 1 (
|
||||
echo Error: Failed to install dependencies
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
echo Dependencies installed successfully
|
||||
|
||||
REM Set environment variable
|
||||
set ENVIRONMENT=production
|
||||
|
||||
REM Create logs directory if not exists
|
||||
if not exist "logs" mkdir logs
|
||||
|
||||
REM Start application in background using start command
|
||||
start "Xian Algorithm Prod" cmd /c "title Xian Algorithm Prod && .venv\Scripts\python.exe start.py"
|
||||
|
||||
echo.
|
||||
echo Application started in background
|
||||
echo To view logs, check: logs\app_*.log
|
||||
echo To stop the application, run: scripts\stop.bat
|
||||
echo ========================================
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Linux/Mac startup script - Production environment (background mode)
|
||||
|
||||
echo "========================================"
|
||||
echo " Starting Production Environment"
|
||||
echo "========================================"
|
||||
|
||||
# Change to project root directory
|
||||
cd "$(dirname "$0")/.." || exit
|
||||
|
||||
# Check and create virtual environment if not exists
|
||||
if [ ! -f ".venv/bin/activate" ]; then
|
||||
echo "Virtual environment not found, creating..."
|
||||
python3 -m venv .venv
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: Failed to create virtual environment"
|
||||
echo "Please ensure Python 3 is installed and accessible"
|
||||
exit 1
|
||||
fi
|
||||
echo "Virtual environment created successfully"
|
||||
fi
|
||||
|
||||
# Activate virtual environment
|
||||
source .venv/bin/activate
|
||||
echo "Virtual environment activated"
|
||||
|
||||
# Upgrade pip and install dependencies
|
||||
echo "Checking dependencies..."
|
||||
pip install --upgrade pip -q
|
||||
pip install -r requirements.txt -q
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: Failed to install dependencies"
|
||||
exit 1
|
||||
fi
|
||||
echo "Dependencies installed successfully"
|
||||
|
||||
# Set environment variable
|
||||
export ENVIRONMENT=production
|
||||
|
||||
# Create logs directory if not exists
|
||||
mkdir -p logs
|
||||
|
||||
# Start application in background
|
||||
nohup python start.py > logs/app_prod.log 2>&1 &
|
||||
APP_PID=$!
|
||||
|
||||
echo $APP_PID > scripts/app_prod.pid
|
||||
echo ""
|
||||
echo "Application started in background (PID: $APP_PID)"
|
||||
echo "To view logs: tail -f logs/app_prod.log"
|
||||
echo "To stop the application, run: bash scripts/stop.sh"
|
||||
echo "========================================"
|
||||
@@ -1,29 +0,0 @@
|
||||
@echo off
|
||||
REM Windows stop script - Stop the application
|
||||
echo ========================================
|
||||
echo Stopping Application
|
||||
echo ========================================
|
||||
|
||||
REM Change to project root directory
|
||||
cd /d "%~dp0.."
|
||||
|
||||
REM Find and kill python processes running start.py
|
||||
echo Searching for running application...
|
||||
tasklist /FI "IMAGENAME eq python.exe" /FO CSV /NH 2>nul | findstr /I "python.exe" >nul
|
||||
if %errorlevel% equ 0 (
|
||||
echo Found running Python processes, stopping...
|
||||
taskkill /F /FI "WINDOWTITLE eq Xian Algorithm Dev" 2>nul
|
||||
taskkill /F /FI "WINDOWTITLE eq Xian Algorithm Prod" 2>nul
|
||||
|
||||
REM If title-based kill didn't work, try killing all python.exe running start.py
|
||||
for /f "tokens=2 delims=," %%a in ('tasklist /FI "IMAGENAME eq python.exe" /FO CSV /NH 2^>nul') do (
|
||||
set "PID=%%~a"
|
||||
wmic process where "ProcessId=!PID! and CommandLine like '%%start.py%%'" delete >nul 2>&1
|
||||
)
|
||||
|
||||
echo Application stopped successfully
|
||||
) else (
|
||||
echo No running application found
|
||||
)
|
||||
|
||||
echo ========================================
|
||||
@@ -1,67 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Linux/Mac stop script - Stop the application
|
||||
|
||||
echo "========================================"
|
||||
echo " Stopping Application"
|
||||
echo "========================================"
|
||||
|
||||
# Change to project root directory
|
||||
cd "$(dirname "$0")/.." || exit
|
||||
|
||||
# Function to stop a process by PID file
|
||||
stop_process() {
|
||||
local pid_file=$1
|
||||
local env_name=$2
|
||||
|
||||
if [ -f "$pid_file" ]; then
|
||||
local pid=$(cat "$pid_file")
|
||||
if kill -0 "$pid" 2>/dev/null; then
|
||||
echo "Stopping $env_name (PID: $pid)..."
|
||||
kill "$pid"
|
||||
|
||||
# Wait for process to stop
|
||||
local count=0
|
||||
while kill -0 "$pid" 2>/dev/null && [ $count -lt 10 ]; do
|
||||
sleep 1
|
||||
count=$((count + 1))
|
||||
done
|
||||
|
||||
# Force kill if still running
|
||||
if kill -0 "$pid" 2>/dev/null; then
|
||||
echo "Force stopping $env_name..."
|
||||
kill -9 "$pid"
|
||||
fi
|
||||
|
||||
rm -f "$pid_file"
|
||||
echo "$env_name stopped"
|
||||
else
|
||||
echo "$env_name not running (stale PID file removed)"
|
||||
rm -f "$pid_file"
|
||||
fi
|
||||
else
|
||||
echo "No $env_name PID file found"
|
||||
fi
|
||||
}
|
||||
|
||||
# Stop development environment
|
||||
stop_process "scripts/app_dev.pid" "Development environment"
|
||||
|
||||
# Stop production environment
|
||||
stop_process "scripts/app_prod.pid" "Production environment"
|
||||
|
||||
# Also try to find any remaining python processes running start.py
|
||||
REMAINING_PIDS=$(ps aux | grep "[p]ython.*start.py" | awk '{print $2}')
|
||||
if [ -n "$REMAINING_PIDS" ]; then
|
||||
echo "Found additional Python processes, stopping..."
|
||||
echo "$REMAINING_PIDS" | xargs kill 2>/dev/null
|
||||
sleep 2
|
||||
|
||||
# Force kill if still running
|
||||
REMAINING_PIDS=$(ps aux | grep "[p]ython.*start.py" | awk '{print $2}')
|
||||
if [ -n "$REMAINING_PIDS" ]; then
|
||||
echo "$REMAINING_PIDS" | xargs kill -9 2>/dev/null
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "========================================"
|
||||
echo "All applications stopped"
|
||||
@@ -0,0 +1,34 @@
|
||||
# 公共配置
|
||||
[default]
|
||||
APP_NAME = "西安项目算法服务"
|
||||
LOG_DIR = "logs"
|
||||
|
||||
# 开发环境
|
||||
[development]
|
||||
DEBUG = true
|
||||
# 数据库
|
||||
DB_HOST = "47.92.216.173"
|
||||
DB_PORT = 7654
|
||||
DB_USER = "postgres"
|
||||
DB_PASSWORD = "zhangsan"
|
||||
DB_NAME = "xian_new"
|
||||
# FastAPI配置
|
||||
API_HOST = "127.0.0.1"
|
||||
API_PORT = 8082
|
||||
# 日志配置
|
||||
LOG_LEVEL = "DEBUG"
|
||||
|
||||
# 生产环境
|
||||
[production]
|
||||
DEBUG = false
|
||||
# 数据库配置
|
||||
DB_HOST = "10.22.245.138"
|
||||
DB_PORT = 54321
|
||||
DB_USER = "zaihailian"
|
||||
DB_PASSWORD = "XAYJ@gis2603"
|
||||
DB_NAME = "xianDC"
|
||||
# FastAPI配置
|
||||
API_HOST = "127.0.0.1"
|
||||
API_PORT = 8081
|
||||
# 日志配置
|
||||
LOG_LEVEL = "WARNING"
|
||||
@@ -1,230 +1,48 @@
|
||||
"""
|
||||
项目启动脚本 - 支持多环境和跨平台
|
||||
使用 Dynaconf 进行环境隔离配置
|
||||
"""
|
||||
import sys
|
||||
import subprocess
|
||||
import os
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def check_platform():
|
||||
"""检测平台信息"""
|
||||
print("=" * 60)
|
||||
print(" 系统信息检测")
|
||||
print("=" * 60)
|
||||
|
||||
from app.utils.platform_utils import platform_detector
|
||||
platform_detector.print_platform_banner()
|
||||
|
||||
return platform_detector
|
||||
def check_environment():
|
||||
"""检查系统和Python版本"""
|
||||
# 识别操作系统
|
||||
os_name = platform.system()
|
||||
print(f"当前操作系统: {os_name}")
|
||||
|
||||
if os_name not in ['Windows', 'Linux']:
|
||||
print(f"警告: 未测试的操作系统 {os_name},可能存在问题")
|
||||
|
||||
def check_python_version():
|
||||
"""检查Python版本是否为3.13或更高"""
|
||||
print("\n" + "=" * 60)
|
||||
print(" Python版本检查")
|
||||
print("=" * 60)
|
||||
|
||||
from app.utils.platform_utils import platform_detector
|
||||
|
||||
current_version = sys.version_info
|
||||
print(f"当前Python版本: {current_version.major}.{current_version.minor}.{current_version.micro}")
|
||||
|
||||
if not platform_detector.check_python_version((3, 13)):
|
||||
print("\n❌ 错误: Python版本过低!")
|
||||
print(f" 当前版本: {current_version.major}.{current_version.minor}.{current_version.micro}")
|
||||
print(" 要求版本: 3.13 或更高")
|
||||
print("\n请升级到Python 3.13或更高版本:")
|
||||
print(" 下载地址: https://www.python.org/downloads/")
|
||||
print("=" * 60)
|
||||
sys.exit(1)
|
||||
|
||||
print(f"✅ Python版本检查通过: {current_version.major}.{current_version.minor}.{current_version.micro}")
|
||||
print("=" * 60)
|
||||
return True
|
||||
# 检查Python版本
|
||||
python_version = platform.python_version()
|
||||
print(f"当前Python版本: {python_version}")
|
||||
|
||||
# 解析版本号
|
||||
major, minor, *_ = map(int, python_version.split('.'))
|
||||
|
||||
def get_environment():
|
||||
"""获取运行环境"""
|
||||
print("\n" + "=" * 60)
|
||||
print(" 环境配置")
|
||||
print("=" * 60)
|
||||
|
||||
environment = os.getenv("ENVIRONMENT", "development")
|
||||
print(f"当前环境: {environment}")
|
||||
|
||||
if environment not in ["development", "production"]:
|
||||
print(f"⚠️ 警告: 未知环境 '{environment}',使用默认开发环境")
|
||||
environment = "development"
|
||||
|
||||
print("=" * 60)
|
||||
return environment
|
||||
|
||||
|
||||
def install_dependencies():
|
||||
"""检查并安装依赖包"""
|
||||
print("\n" + "=" * 60)
|
||||
print(" 依赖包检查")
|
||||
print("=" * 60)
|
||||
|
||||
requirements_file = "requirements.txt"
|
||||
|
||||
if not os.path.exists(requirements_file):
|
||||
print(f"\n❌ 错误: 找不到依赖文件 {requirements_file}")
|
||||
sys.exit(1)
|
||||
|
||||
# 读取已安装的包
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "pip", "list", "--format=freeze"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True
|
||||
)
|
||||
installed_packages = {line.split("==")[0].lower() for line in result.stdout.strip().split("\n") if line}
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"\n❌ 检查已安装包失败: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# 读取requirements.txt中的包
|
||||
with open(requirements_file, "r", encoding="utf-8") as f:
|
||||
required_packages = []
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
package_name = line.split("==")[0].lower()
|
||||
required_packages.append(line)
|
||||
|
||||
# 检查缺失的包
|
||||
missing_packages = []
|
||||
for package_line in required_packages:
|
||||
package_name = package_line.split("==")[0].lower()
|
||||
if package_name not in installed_packages:
|
||||
missing_packages.append(package_line)
|
||||
|
||||
if not missing_packages:
|
||||
print("\n✅ 所有依赖包已安装,无需重复安装")
|
||||
print("=" * 60)
|
||||
if major == 3 and minor == 13:
|
||||
print("✓ Python版本符合要求 (3.13)")
|
||||
return True
|
||||
|
||||
print(f"\n发现 {len(missing_packages)} 个未安装的依赖包:")
|
||||
for package in missing_packages:
|
||||
print(f" - {package}")
|
||||
|
||||
print("\n正在安装依赖包...")
|
||||
try:
|
||||
subprocess.run(
|
||||
[sys.executable, "-m", "pip", "install", "-r", requirements_file],
|
||||
check=True
|
||||
)
|
||||
print("\n✅ 依赖包安装成功")
|
||||
print("=" * 60)
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"\n❌ 依赖包安装失败: {e}")
|
||||
print("=" * 60)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def initialize_database():
|
||||
"""初始化数据库连接"""
|
||||
print("\n" + "=" * 60)
|
||||
print(" 数据库初始化")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
from app.core.database import db_manager
|
||||
from app.config.settings import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
print(f"数据库地址: {settings.DB_HOST}:{settings.DB_PORT}")
|
||||
print(f"数据库名称: {settings.DB_NAME}")
|
||||
print(f"连接池大小: {settings.DB_POOL_SIZE}")
|
||||
|
||||
# 测试数据库连接
|
||||
if db_manager.test_connection():
|
||||
print("✅ 数据库连接成功")
|
||||
else:
|
||||
print("⚠️ 数据库连接失败,请检查配置")
|
||||
return False
|
||||
|
||||
print("=" * 60)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n⚠️ 数据库初始化警告: {e}")
|
||||
print(" 应用将继续启动,但数据库功能可能不可用")
|
||||
print("=" * 60)
|
||||
else:
|
||||
print(f"✗ Python版本不符合要求!")
|
||||
print(f" 当前版本: {python_version}")
|
||||
print(f" 要求版本: 3.13.x")
|
||||
print(f"\n请使用 Python 3.13 版本运行此项目")
|
||||
print(f"下载地址: https://www.python.org/downloads/")
|
||||
return False
|
||||
|
||||
|
||||
def start_application(environment: str):
|
||||
"""启动FastAPI应用"""
|
||||
print("\n" + "=" * 60)
|
||||
print(" 启动FastAPI应用")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
from app.config.settings import get_settings
|
||||
import uvicorn
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
print(f"应用名称: {settings.APP_NAME}")
|
||||
print(f"应用版本: {settings.APP_VERSION}")
|
||||
print(f"运行环境: {settings.ENVIRONMENT.value}")
|
||||
print(f"监听地址: {settings.API_HOST}:{settings.API_PORT}")
|
||||
print(f"调试模式: {'开启' if settings.DEBUG else '关闭'}")
|
||||
print(f"自动重载: {'开启' if hasattr(settings, 'RELOAD') and settings.RELOAD else '关闭'}")
|
||||
print("\n🚀 应用启动中...\n")
|
||||
print("=" * 60)
|
||||
|
||||
# 启动uvicorn服务器
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.API_HOST,
|
||||
port=settings.API_PORT,
|
||||
reload=settings.RELOAD if hasattr(settings, 'RELOAD') else settings.DEBUG,
|
||||
log_level=settings.LOG_LEVEL.lower()
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n应用已停止")
|
||||
except Exception as e:
|
||||
print(f"\n❌ 应用启动失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("\n" + "=" * 60)
|
||||
print(" Xian Algorithm New - 应用启动程序")
|
||||
print("=" * 60)
|
||||
|
||||
# 检测平台信息
|
||||
check_platform()
|
||||
|
||||
# 检查Python版本
|
||||
check_python_version()
|
||||
|
||||
# 获取环境配置
|
||||
environment = get_environment()
|
||||
|
||||
# 检查并安装依赖
|
||||
install_dependencies()
|
||||
|
||||
# 初始化数据库
|
||||
initialize_database()
|
||||
|
||||
# 启动应用
|
||||
start_application(environment)
|
||||
check_environment()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user