diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5975cfe --- /dev/null +++ b/.gitignore @@ -0,0 +1,57 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# 虚拟环境 +.venv/ +venv/ +ENV/ +env/ + +# 环境配置文件(包含敏感信息) +.env.development +.env.production + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# 操作系统 +.DS_Store +Thumbs.db + +# 日志 +*.log +logs/ + +# PID文件 +scripts/*.pid + +# 测试和覆盖率 +.pytest_cache/ +.coverage +htmlcov/ + +# Jupyter Notebook +.ipynb_checkpoints diff --git a/README.md b/README.md index abdc6e7..531fa94 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,220 @@ -# xian_algorithm_new -西安项目算法服务器 +# Xian Algorithm New + +基于 FastAPI + PostgreSQL 的现代化 Python Web 应用框架。 + +## 特性 + +- ✅ 模块化架构(API/Config/Core/Utils) +- ✅ 多环境配置(开发/生产) +- ✅ 跨平台支持(Windows/Linux/Mac) +- ✅ 完整的数据库 CRUD 封装 +- ✅ 自动依赖管理 +- ✅ 自动生成 API 文档 +- ✅ 降雨栅格插值服务(IDW算法) + +## 快速开始 + +### 1. 环境要求 + +- Python 3.13+ +- PostgreSQL + +### 2. 配置 + +根据环境选择配置文件(无需复制,直接使用): + +- **开发环境**:`.env.development` +- **生产环境**:`.env.production` + +编辑对应的配置文件,修改数据库信息: + +```env +# .env.development 或 .env.production +DB_HOST=localhost +DB_PORT=5432 +DB_USER=postgres +DB_PASSWORD=your_password +DB_NAME=test_db +API_HOST=127.0.0.1 # 仅监听本地请求 +``` + +### 3. 启动 + +**后台运行(推荐):** + +```bash +# Windows - 开发环境 +scripts\start_dev.bat + +# Windows - 生产环境 +scripts\start_prod.bat + +# Linux/Mac - 开发环境 +bash scripts/start_dev.sh + +# Linux/Mac - 生产环境 +bash scripts/start_prod.sh +``` + +**前台运行(调试用):** + +```bash +python start.py +``` + +### 4. 停止 + +```bash +# Windows +scripts\stop.bat + +# Linux/Mac +bash scripts/stop.sh +``` + +### 5. 访问 + +- API 文档: http://localhost:8000/docs +- 健康检查: http://localhost:8000/health + +## 项目结构 + +``` +xian_algorithm_new/ +├── app/ +│ ├── api/ # API 路由 +│ ├── config/ # 配置管理 +│ ├── core/ # 核心功能 +│ ├── models/ # 数据模型 +│ ├── services/ # 业务逻辑 +│ └── utils/ # 工具函数 +├── scripts/ # 启动脚本 +├── logs/ # 日志目录 +├── tests/ # 测试目录 +├── start.py # 启动入口 +└── requirements.txt # 依赖包 +``` + +## 配置说明 + +配置采用三层结构(优先级从高到低): + +1. **.env 文件** - 用户自定义配置(数据库地址、密码等) +2. **环境配置类** - 开发/生产环境的差异化配置(日志级别、连接池等) +3. **基础配置类** - 通用默认值 + +只需修改 `.env` 文件即可覆盖大部分配置。 + +## 常用操作 + +### API接口 + +#### 1. 获取降雨栅格数据 + +```bash +curl -X POST "http://localhost:8000/rainfall/grid" \ + -H "Content-Type: application/json" \ + -d '{ + "start_time": "2024-01-01T00:00:00", + "end_time": "2024-01-01T12:00:00", + "district_id": 1, + "resolution": 0.01 + }' +``` + +#### 2. 获取雨量站点数据 + +```bash +curl "http://localhost:8000/rainfall/stations?start_time=2024-01-01T00:00:00&end_time=2024-01-01T12:00:00" +``` + +### 切换环境 + +通过启动脚本自动选择对应的配置文件: + +```bash +# Windows - 开发环境(后台) +scripts\start_dev.bat + +# Windows - 生产环境(后台) +scripts\start_prod.bat + +# Linux/Mac - 开发环境(后台) +bash scripts/start_dev.sh + +# Linux/Mac - 生产环境(后台) +bash scripts/start_prod.sh +``` + +**停止应用:** + +```bash +# Windows +scripts\stop.bat + +# Linux/Mac +bash scripts/stop.sh +``` + +或者手动设置环境变量: + +```bash +# Windows PowerShell +$env:ENVIRONMENT="production" +python start.py + +# Windows CMD +set ENVIRONMENT=production +python start.py + +# Linux/Mac +export ENVIRONMENT=production +python start.py +``` + +### 数据库操作 + +```python +from app.core.database import db_manager + +# 插入 +db_manager.insert("users", {"name": "张三", "email": "test@example.com"}) + +# 查询 +users = db_manager.select("users", conditions={"age": 25}) + +# 更新 +db_manager.update("users", {"name": "李四"}, {"id": 1}) + +# 删除 +db_manager.delete("users", {"id": 1}) +``` + +### 添加新 API + +1. 在 `app/api/` 创建路由文件 +2. 在 `app/main.py` 注册路由 + +```python +from app.api import new_module +app.include_router(new_module.router) +``` + +## 技术栈 + +- FastAPI 0.109.0 +- SQLAlchemy 2.0.25 +- PostgreSQL (psycopg2-binary) +- Pydantic 2.5.3 +- Uvicorn 0.27.0 + +## 注意事项 + +- ⚠️ 不要将 `.env.development` 和 `.env.production` 文件提交到 Git +- ⚠️ 生产环境务必修改默认密码 +- ⚠️ 定期清理 `logs/` 目录下的日志文件 +- ⚠️ Linux/Mac 下首次运行需给脚本添加执行权限:`chmod +x scripts/*.sh` + +## 许可证 + +MIT diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..da09317 --- /dev/null +++ b/app/__init__.py @@ -0,0 +1,3 @@ +""" +App package +""" diff --git a/app/api/__init__.py b/app/api/__init__.py new file mode 100644 index 0000000..5e3fe99 --- /dev/null +++ b/app/api/__init__.py @@ -0,0 +1,3 @@ +""" +API routes package +""" diff --git a/app/api/rainfall.py b/app/api/rainfall.py new file mode 100644 index 0000000..7b9dfdf --- /dev/null +++ b/app/api/rainfall.py @@ -0,0 +1,158 @@ +""" +降雨数据API Controller - 路由层 +""" +from fastapi import APIRouter, HTTPException, Query +from datetime import datetime +from typing import Optional + +from app.services.rainfall_service import RainfallService +from app.schemas.rainfall import ( + RainfallGridRequest, + RainfallGridResponse, + StationsResponse +) +from app.utils.logger import setup_logging + +logger = setup_logging() + +router = APIRouter( + prefix="/rainfall", + tags=["降雨数据"], + responses={404: {"description": "Not found"}} +) + +# 创建服务实例 +rainfall_service = RainfallService() + + +@router.post("/grid", response_model=RainfallGridResponse, summary="获取降雨栅格数据") +async def get_rainfall_grid(request: RainfallGridRequest): + """ + 获取指定时间的降雨栅格数据 + + 使用反距离权重插值(IDW)方法,将站点降雨数据插值为连续栅格, + 返回适合Cesium渲染的GeoJSON格式数据。 + + Args: + request: 包含时间和分辨率的请求 + + Returns: + GeoJSON格式的栅格数据 + """ + try: + # 解析时间,如果未提供则使用当前时间 + now = datetime.now() + query_time = datetime.fromisoformat(request.time) if request.time else now + + # 调用服务层生成栅格(自动查询前12小时数据) + geojson_data = rainfall_service.generate_rainfall_grid( + query_time=query_time, + resolution=request.resolution + ) + + if not geojson_data: + return RainfallGridResponse( + code=404, + message="未找到降雨数据", + data=None + ) + + return RainfallGridResponse( + code=200, + message="降雨栅格数据生成成功", + data=geojson_data + ) + + except ValueError as e: + logger.error(f"时间格式错误: {e}") + raise HTTPException(status_code=400, detail=f"时间格式错误: {str(e)}") + except Exception as e: + logger.error(f"生成降雨栅格失败: {e}") + import traceback + traceback.print_exc() + raise HTTPException(status_code=500, detail=f"生成降雨栅格失败: {str(e)}") + + +@router.get("/stations", response_model=StationsResponse, summary="获取雨量站点数据") +async def get_rainfall_stations( + time: str = Query(..., description="查询时间 ISO格式(自动查询前12小时数据)") +): + """ + 获取指定时间的雨量站点原始数据 + + Args: + time: 查询时间 + + Returns: + 站点列表,包含经纬度和降雨量 + """ + try: + query_time = datetime.fromisoformat(time) + + # 调用服务层获取站点数据(自动查询前12小时数据) + stations = rainfall_service.get_stations_data( + query_time=query_time + ) + + return StationsResponse( + code=200, + message="查询成功", + data=stations + ) + + except ValueError as e: + logger.error(f"时间格式错误: {e}") + raise HTTPException(status_code=400, detail=f"时间格式错误: {str(e)}") + except Exception as e: + logger.error(f"查询站点数据失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/point", summary="查询指定点位的降雨量") +async def get_rainfall_at_point( + longitude: float, + latitude: float, + time: Optional[str] = None +): + """ + 查询指定经纬度位置的降雨量 + + Args: + longitude: 经度 + latitude: 纬度 + time: 查询时间(可选,默认当前时间,自动查询前12小时数据) + + Returns: + 该点位的降雨量信息 + """ + try: + from app.services.rainfall_service import RainfallService + + # 解析时间 + now = datetime.now() + query_time = datetime.fromisoformat(time) if time else now + + # 调用服务层查询(自动查询前12小时数据) + service = RainfallService() + rainfall_info = service.get_rainfall_at_point( + longitude=longitude, + latitude=latitude, + query_time=query_time + ) + + if not rainfall_info: + return { + "code": 404, + "message": "未找到该点位的降雨数据", + "data": None + } + + return { + "code": 200, + "message": "查询成功", + "data": rainfall_info + } + + except Exception as e: + logger.error(f"查询点位降雨量失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/app/config/__init__.py b/app/config/__init__.py new file mode 100644 index 0000000..4d059a0 --- /dev/null +++ b/app/config/__init__.py @@ -0,0 +1,3 @@ +""" +Configuration package +""" diff --git a/app/config/base_config.py b/app/config/base_config.py new file mode 100644 index 0000000..dae44e4 --- /dev/null +++ b/app/config/base_config.py @@ -0,0 +1,48 @@ +""" +基础配置类 +""" +from pydantic_settings import BaseSettings +from enum import Enum + + +class EnvironmentEnum(str, Enum): + """环境枚举""" + DEVELOPMENT = "development" + PRODUCTION = "production" + + +class BaseConfig(BaseSettings): + """基础配置类""" + + # 应用基本信息 + APP_NAME: str = "西安项目算法" + APP_VERSION: str = "1.0.0" + ENVIRONMENT: EnvironmentEnum = EnvironmentEnum.DEVELOPMENT + + # 调试模式 + DEBUG: bool = True + + # API配置 + API_HOST: str = "127.0.0.1" # 默认只监听本地 + API_PORT: int = 8000 + + # CORS配置(默认只允许localhost) + CORS_ORIGINS: list = ["http://localhost", "http://127.0.0.1"] + + # 日志配置 + LOG_LEVEL: str = "INFO" + LOG_DIR: str = "logs" + + class Config: + env_file = ".env.development" # 默认使用开发环境配置 + case_sensitive = True + + @property + def is_development(self) -> bool: + """是否为开发环境""" + return self.ENVIRONMENT == EnvironmentEnum.DEVELOPMENT + + @property + def is_production(self) -> bool: + """是否为生产环境""" + return self.ENVIRONMENT == EnvironmentEnum.PRODUCTION diff --git a/app/config/database_config.py b/app/config/database_config.py new file mode 100644 index 0000000..4b2ec1e --- /dev/null +++ b/app/config/database_config.py @@ -0,0 +1,37 @@ +""" +数据库配置 +""" +from .base_config import BaseConfig + + +class DatabaseConfig(BaseConfig): + """数据库配置类""" + + # PostgreSQL配置 + DB_HOST: str = "localhost" + DB_PORT: int = 5432 + DB_USER: str = "postgres" + DB_PASSWORD: str = "postgres" + DB_NAME: str = "test_db" + + # 连接池配置 + DB_POOL_SIZE: int = 10 + DB_MAX_OVERFLOW: int = 20 + DB_POOL_TIMEOUT: int = 30 + DB_POOL_RECYCLE: int = 3600 + + @property + def DATABASE_URL(self) -> str: + """构建数据库连接URL""" + return ( + f"postgresql://{self.DB_USER}:{self.DB_PASSWORD}" + f"@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}" + ) + + @property + def ASYNC_DATABASE_URL(self) -> str: + """构建异步数据库连接URL""" + return ( + f"postgresql+asyncpg://{self.DB_USER}:{self.DB_PASSWORD}" + f"@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}" + ) diff --git a/app/config/development.py b/app/config/development.py new file mode 100644 index 0000000..ee7b31e --- /dev/null +++ b/app/config/development.py @@ -0,0 +1,22 @@ +""" +开发环境配置 +""" +from .database_config import DatabaseConfig +from .base_config import EnvironmentEnum + + +class DevelopmentConfig(DatabaseConfig): + """开发环境配置 - 只覆盖与生产不同的配置""" + + ENVIRONMENT: EnvironmentEnum = EnvironmentEnum.DEVELOPMENT + + # 调试和日志 + DEBUG: bool = True + LOG_LEVEL: str = "DEBUG" + + # 开发特性 + RELOAD: bool = True + + class Config: + env_file = ".env.development" + case_sensitive = True diff --git a/app/config/production.py b/app/config/production.py new file mode 100644 index 0000000..f348eb8 --- /dev/null +++ b/app/config/production.py @@ -0,0 +1,26 @@ +""" +生产环境配置 +""" +from .database_config import DatabaseConfig +from .base_config import EnvironmentEnum + + +class ProductionConfig(DatabaseConfig): + """生产环境配置 - 只覆盖性能和安全相关配置""" + + ENVIRONMENT: EnvironmentEnum = EnvironmentEnum.PRODUCTION + + # 调试和日志 + DEBUG: bool = False + LOG_LEVEL: str = "WARNING" + + # 生产特性 + RELOAD: bool = False + + # 数据库连接池(生产环境需要更大) + DB_POOL_SIZE: int = 20 + DB_MAX_OVERFLOW: int = 40 + + class Config: + env_file = ".env.production" + case_sensitive = True diff --git a/app/config/settings.py b/app/config/settings.py new file mode 100644 index 0000000..4b06f26 --- /dev/null +++ b/app/config/settings.py @@ -0,0 +1,65 @@ +""" +配置加载器 - 根据环境自动加载对应配置 +""" +import os +from typing import Type +from .base_config import BaseConfig, EnvironmentEnum +from .development import DevelopmentConfig +from .production import ProductionConfig + + +def get_config_class(environment: str = None) -> Type[BaseConfig]: + """根据环境获取配置类 + + Args: + environment: 环境名称 (development/production) + + Returns: + 对应的配置类 + """ + if environment is None: + environment = os.getenv("ENVIRONMENT", "development") + + config_map = { + EnvironmentEnum.DEVELOPMENT: DevelopmentConfig, + EnvironmentEnum.PRODUCTION: ProductionConfig, + } + + try: + env_enum = EnvironmentEnum(environment) + return config_map[env_enum] + except ValueError: + print(f"警告: 未知环境 '{environment}',使用默认开发环境配置") + return DevelopmentConfig + + +def load_config(environment: str = None) -> BaseConfig: + """加载配置 + + Args: + environment: 环境名称 + + Returns: + 配置实例 + """ + config_class = get_config_class(environment) + return config_class() + + +# 全局配置实例(延迟加载) +_config_instance = None + + +def get_settings() -> BaseConfig: + """获取全局配置实例(单例模式)""" + global _config_instance + if _config_instance is None: + environment = os.getenv("ENVIRONMENT", None) + _config_instance = load_config(environment) + return _config_instance + + +def reload_config(environment: str = None): + """重新加载配置""" + global _config_instance + _config_instance = load_config(environment) diff --git a/app/core/__init__.py b/app/core/__init__.py new file mode 100644 index 0000000..5452f80 --- /dev/null +++ b/app/core/__init__.py @@ -0,0 +1,3 @@ +""" +Core functionality package +""" diff --git a/app/core/database.py b/app/core/database.py new file mode 100644 index 0000000..cdb98d0 --- /dev/null +++ b/app/core/database.py @@ -0,0 +1,255 @@ +""" +数据库连接管理 - 使用SQLAlchemy 2.0 +""" +import logging +from sqlalchemy import create_engine, text +from sqlalchemy.orm import sessionmaker, Session, DeclarativeBase +from sqlalchemy.pool import QueuePool +from typing import List, Dict, Any +from contextlib import contextmanager + +from app.config.settings import get_settings +from app.utils.logger import setup_logging + +# 初始化日志 +logger = setup_logging() + +# 关闭SQLAlchemy引擎日志,只保留关键信息 +logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) +logging.getLogger('sqlalchemy.pool').setLevel(logging.WARNING) + +# 获取配置 +settings = get_settings() + +# 创建数据库引擎 +engine = create_engine( + settings.DATABASE_URL, + poolclass=QueuePool, + pool_size=settings.DB_POOL_SIZE, + max_overflow=settings.DB_MAX_OVERFLOW, + pool_timeout=settings.DB_POOL_TIMEOUT, + pool_recycle=settings.DB_POOL_RECYCLE, + pool_pre_ping=True, + echo=False, # 关闭SQL语句打印 + connect_args={ + "options": "-c client_encoding=UTF8" + } +) + +# 创建会话工厂 +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +# 声明基类 +Base = DeclarativeBase() + + +class DatabaseManager: + """数据库管理器 - 提供通用的CRUD操作""" + + def __init__(self): + self.engine = engine + self.SessionLocal = SessionLocal + + @contextmanager + def get_session(self) -> Session: + """获取数据库会话的上下文管理器""" + session = self.SessionLocal() + try: + yield session + session.commit() + except Exception as e: + session.rollback() + logger.error(f"数据库操作失败: {e}") + raise + finally: + session.close() + + def init_db(self): + """初始化数据库,创建所有表""" + try: + Base.metadata.create_all(bind=self.engine) + logger.info("数据库表创建成功") + except Exception as e: + logger.error(f"数据库表创建失败: {e}") + raise + + def test_connection(self) -> bool: + """测试数据库连接""" + try: + with self.get_session() as session: + session.execute(text("SELECT 1")) + logger.info("数据库连接测试成功") + return True + except Exception as e: + logger.error(f"数据库连接测试失败: {e}") + return False + + def execute_raw_sql(self, sql: str, params: dict = None) -> List[Dict[str, Any]]: + """执行原生SQL查询 + + Args: + sql: SQL语句 + params: 参数字典 + + Returns: + 查询结果列表 + """ + with self.get_session() as session: + result = session.execute(text(sql), params or {}) + if result.returns_rows: + columns = result.keys() + return [dict(zip(columns, row)) for row in result.fetchall()] + return [] + + def insert(self, table_name: str, data: Dict[str, Any]) -> int: + """插入单条记录 + + Args: + table_name: 表名 + data: 要插入的数据字典 + + Returns: + 插入的行数 + """ + columns = ", ".join(data.keys()) + placeholders = ", ".join([f":{key}" for key in data.keys()]) + sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})" + + with self.get_session() as session: + result = session.execute(text(sql), data) + return result.rowcount + + def insert_many(self, table_name: str, data_list: List[Dict[str, Any]]) -> int: + """批量插入记录 + + Args: + table_name: 表名 + data_list: 数据字典列表 + + Returns: + 插入的行数 + """ + if not data_list: + return 0 + + columns = ", ".join(data_list[0].keys()) + placeholders = ", ".join([f":{key}" for key in data_list[0].keys()]) + sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})" + + with self.get_session() as session: + result = session.execute(text(sql), data_list) + return result.rowcount + + def select( + self, + table_name: str, + conditions: Dict[str, Any] = None, + columns: List[str] = None, + limit: int = None, + offset: int = None, + order_by: str = None + ) -> List[Dict[str, Any]]: + """查询记录 + + Args: + table_name: 表名 + conditions: 查询条件字典 + columns: 要查询的列列表,None表示查询所有列 + limit: 限制返回行数 + offset: 偏移量 + order_by: 排序字段 + + Returns: + 查询结果列表 + """ + col_str = ", ".join(columns) if columns else "*" + sql = f"SELECT {col_str} FROM {table_name}" + + params = {} + if conditions: + where_clauses = [] + for key, value in conditions.items(): + where_clauses.append(f"{key} = :{key}") + params[key] = value + sql += " WHERE " + " AND ".join(where_clauses) + + if order_by: + sql += f" ORDER BY {order_by}" + + if limit: + sql += f" LIMIT :limit" + params["limit"] = limit + + if offset: + sql += f" OFFSET :offset" + params["offset"] = offset + + return self.execute_raw_sql(sql, params) + + def update(self, table_name: str, data: Dict[str, Any], + conditions: Dict[str, Any]) -> int: + """更新记录 + + Args: + table_name: 表名 + data: 要更新的数据字典 + conditions: 更新条件字典 + + Returns: + 更新的行数 + """ + set_clauses = [f"{key} = :{key}" for key in data.keys()] + where_clauses = [f"{key} = :cond_{key}" for key in conditions.keys()] + + sql = f"UPDATE {table_name} SET {', '.join(set_clauses)} WHERE {' AND '.join(where_clauses)}" + + # 合并参数,避免键名冲突 + params = {**data, **{f"cond_{key}": value for key, value in conditions.items()}} + + with self.get_session() as session: + result = session.execute(text(sql), params) + return result.rowcount + + def delete(self, table_name: str, conditions: Dict[str, Any]) -> int: + """删除记录 + + Args: + table_name: 表名 + conditions: 删除条件字典 + + Returns: + 删除的行数 + """ + where_clauses = [f"{key} = :{key}" for key in conditions.keys()] + sql = f"DELETE FROM {table_name} WHERE {' AND '.join(where_clauses)}" + + with self.get_session() as session: + result = session.execute(text(sql), conditions) + return result.rowcount + + def count(self, table_name: str, conditions: Dict[str, Any] = None) -> int: + """统计记录数 + + Args: + table_name: 表名 + conditions: 统计条件字典 + + Returns: + 记录数 + """ + sql = f"SELECT COUNT(*) as count FROM {table_name}" + params = {} + + if conditions: + where_clauses = [] + for key, value in conditions.items(): + where_clauses.append(f"{key} = :{key}") + params[key] = value + sql += " WHERE " + " AND ".join(where_clauses) + + result = self.execute_raw_sql(sql, params) + return result[0]["count"] if result else 0 + + +# 创建全局数据库管理器实例 +db_manager = DatabaseManager() diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..3a7ce3b --- /dev/null +++ b/app/main.py @@ -0,0 +1,110 @@ +""" +FastAPI应用主文件 +""" +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from app.config.settings import get_settings +from app.core.database import db_manager +from app.utils.logger import setup_logging +from app.api.rainfall import router as rainfall_router + +# 初始化日志 +logger = setup_logging() + +# 获取配置 +settings = get_settings() + + +def create_application() -> FastAPI: + """创建FastAPI应用实例""" + + application = FastAPI( + title=settings.APP_NAME, + version=settings.APP_VERSION, + debug=settings.DEBUG, + description="基于FastAPI的现代化Web应用框架" + ) + + # 配置CORS + application.add_middleware( + CORSMiddleware, + allow_origins=settings.CORS_ORIGINS if hasattr(settings, 'CORS_ORIGINS') else ["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + return application + + +# 创建应用实例 +app = create_application() + + +# ==================== 启动和关闭事件 ==================== + +@app.on_event("startup") +async def startup_event(): + """应用启动时执行""" + logger.info(f"正在启动 {settings.APP_NAME} v{settings.APP_VERSION}") + logger.info(f"环境: {settings.ENVIRONMENT}") + logger.info(f"数据库连接: {settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}") + + # 测试数据库连接 + if db_manager.test_connection(): + logger.info("数据库连接成功") + else: + logger.warning("数据库连接失败,请检查配置") + + +@app.on_event("shutdown") +async def shutdown_event(): + """应用关闭时执行""" + logger.info("应用正在关闭...") + + +# ==================== 根路径和健康检查 ==================== + +@app.get("/", tags=["基础"]) +async def root(): + """根路径 - 欢迎信息""" + return { + "app": settings.APP_NAME, + "version": settings.APP_VERSION, + "environment": settings.ENVIRONMENT.value, + "status": "running" + } + + +@app.get("/health", tags=["基础"]) +async def health_check(): + """健康检查接口""" + try: + # 检查数据库连接 + db_manager.execute_raw_sql("SELECT 1") + db_status = "connected" + except Exception as e: + db_status = f"error: {str(e)}" + + return { + "status": "healthy", + "database": db_status, + "app_version": settings.APP_VERSION, + "environment": settings.ENVIRONMENT.value + } + + +# ==================== 注册路由 ==================== + +app.include_router(rainfall_router) + + +if __name__ == "__main__": + import uvicorn + uvicorn.run( + "app.main:app", + host=settings.API_HOST, + port=settings.API_PORT, + reload=settings.RELOAD if hasattr(settings, 'RELOAD') else settings.DEBUG + ) diff --git a/app/models/__init__.py b/app/models/__init__.py new file mode 100644 index 0000000..0e08b02 --- /dev/null +++ b/app/models/__init__.py @@ -0,0 +1,3 @@ +""" +Database models package +""" diff --git a/app/repositories/__init__.py b/app/repositories/__init__.py new file mode 100644 index 0000000..a92695f --- /dev/null +++ b/app/repositories/__init__.py @@ -0,0 +1,3 @@ +""" +Repositories package - 数据访问层 +""" diff --git a/app/repositories/rainfall_repository.py b/app/repositories/rainfall_repository.py new file mode 100644 index 0000000..aabd483 --- /dev/null +++ b/app/repositories/rainfall_repository.py @@ -0,0 +1,54 @@ +""" +降雨数据Repository - 数据访问层 +""" +from typing import List, Dict, Any +from datetime import datetime + +from app.core.database import db_manager +from app.utils.logger import setup_logging + +logger = setup_logging() + + +class RainfallRepository: + """降雨数据仓储类""" + + @staticmethod + def query_stations_rainfall( + query_time: datetime + ) -> List[Dict[str, Any]]: + """ + 查询指定时间的站点降雨数据(自动查询前12小时) + + Args: + query_time: 查询时间 + + Returns: + 站点降雨数据列表 + """ + sql = """ + SELECT + m.lon, + m.lat, + SUM(m.rainfall_1h::numeric) AS rainfall + FROM xian_meteorology m + WHERE m.datetime BETWEEN ( + to_char(timestamp :query_time - interval '12 hours', 'YYYYMMDDHH24MISS') + )::bigint AND ( + to_char(timestamp :query_time, 'YYYYMMDDHH24MISS') + )::bigint + GROUP BY m.lon, m.lat + ORDER BY rainfall DESC + """ + + params = { + "query_time": query_time + } + + try: + result = db_manager.execute_raw_sql(sql, params) + logger.info(f"查询到 {len(result)} 个站点数据") + return result + except Exception as e: + logger.error(f"查询站点降雨数据失败: {e}") + raise diff --git a/app/schemas/__init__.py b/app/schemas/__init__.py new file mode 100644 index 0000000..7759195 --- /dev/null +++ b/app/schemas/__init__.py @@ -0,0 +1,3 @@ +""" +Schemas package - Pydantic数据模型 +""" diff --git a/app/schemas/rainfall.py b/app/schemas/rainfall.py new file mode 100644 index 0000000..eae9f8f --- /dev/null +++ b/app/schemas/rainfall.py @@ -0,0 +1,57 @@ +""" +降雨数据相关的Pydantic Schemas +""" +from pydantic import BaseModel, Field +from typing import Optional, List + + +class RainfallGridRequest(BaseModel): + """降雨栅格请求模型""" + time: Optional[str] = Field( + None, + alias="time", + description="查询时间 ISO格式,默认为当前时间(自动查询前12小时数据)", + example="2024-01-01T12:00:00" + ) + resolution: float = Field( + 0.01, + alias="resolution", + description="栅格分辨率(度)", + gt=0, + le=0.1 + ) + + class Config: + populate_by_name = True # 允许同时使用字段名和别名 + + +class StationData(BaseModel): + """站点数据模型""" + lon: float + lat: float + rainfall: float + + +class GridMetadata(BaseModel): + """栅格元数据""" + start_time: str + end_time: str + district_id: int + resolution: float + station_count: int + grid_size: List[int] + bounds: dict + + +class RainfallGridResponse(BaseModel): + """降雨栅格响应模型 - 符合前端 ApiResponse 结构""" + code: int = Field(200, description="状态码") + message: str = Field(..., description="响应消息") + data: Optional[dict] = Field(None, description="响应数据") + + +class StationsResponse(BaseModel): + """站点数据响应模型 - 符合前端 ApiResponse 结构""" + code: int = Field(200, description="状态码") + message: str = Field(..., description="响应消息") + data: List[StationData] = Field(default_factory=list, description="站点数据列表") diff --git a/app/services/__init__.py b/app/services/__init__.py new file mode 100644 index 0000000..3b44cdf --- /dev/null +++ b/app/services/__init__.py @@ -0,0 +1,3 @@ +""" +Business logic services package +""" diff --git a/app/services/rainfall_service.py b/app/services/rainfall_service.py new file mode 100644 index 0000000..9b343dc --- /dev/null +++ b/app/services/rainfall_service.py @@ -0,0 +1,381 @@ +""" +降雨数据Service - 业务逻辑层 +""" +import numpy as np +from scipy.spatial import Delaunay, ConvexHull +from typing import List, Dict, Any, Tuple, Optional +from datetime import datetime + +from app.repositories.rainfall_repository import RainfallRepository +from app.utils.logger import setup_logging + +logger = setup_logging() + + +class InterpolationService: + """插值服务类""" + + @staticmethod + def inverse_distance_weighting( + points: List[Tuple[float, float]], + values: List[float], + grid_lon: np.ndarray, + grid_lat: np.ndarray, + power: float = 2.0, + max_distance: float = 0.5, + edge_buffer: float = 0.15 + ) -> np.ndarray: + """ + 反距离权重插值 (IDW) - 向量化优化版本 + + Args: + points: 已知点坐标 [(lon, lat), ...] + values: 已知点的值 [rainfall, ...] + grid_lon: 网格经度数组 + grid_lat: 网格纬度数组 + power: 距离幂次 + max_distance: 最大影响距离(度),超出此距离的点不参与插值 + edge_buffer: 边缘缓冲距离,站点外围扩展此距离再计算凸包 + + Returns: + 插值后的栅格数据,无效区域为 NaN + """ + points_array = np.array(points) + values_array = np.array(values) + + # 创建网格 + lon_grid, lat_grid = np.meshgrid(grid_lon, grid_lat) + result = np.full_like(lon_grid, np.nan) + + # 计算站点的凸包(带边缘缓冲) + hull_mask = None + if len(points_array) >= 3: + try: + # 创建缓冲站点:在原始站点外围添加虚拟点 + buffer_points = InterpolationService._create_buffer_points( + points_array, + buffer_distance=edge_buffer + ) + + # 合并原始站点和缓冲站点 + all_points = np.vstack([points_array, buffer_points]) + + # 计算凸包 + hull = ConvexHull(all_points) + hull_points = all_points[hull.vertices] + tri = Delaunay(hull_points) + + # 向量化判断所有网格点是否在凸包内 + grid_points = np.column_stack([lon_grid.ravel(), lat_grid.ravel()]) + hull_mask = tri.find_simplex(grid_points) >= 0 + hull_mask = hull_mask.reshape(lon_grid.shape) + except: + hull_mask = np.ones_like(lon_grid, dtype=bool) + else: + hull_mask = np.ones_like(lon_grid, dtype=bool) + + # 向量化计算所有网格点到所有站点的距离 + # grid_lon shape: (num_lat, num_lon) + # points_array[:, 0] shape: (num_stations,) + # 使用广播机制 + lon_diff = lon_grid[:, :, np.newaxis] - points_array[np.newaxis, np.newaxis, :, 0] + lat_diff = lat_grid[:, :, np.newaxis] - points_array[np.newaxis, np.newaxis, :, 1] + distances = np.sqrt(lon_diff**2 + lat_diff**2) + + # 过滤超出最大距离的站点 + valid_mask = distances <= max_distance + + # 对于每个网格点,检查是否有有效站点 + has_valid_stations = np.any(valid_mask, axis=2) + + # 合并凸包掩码和有效站点掩码 + final_mask = hull_mask & has_valid_stations + + # 避免除零 + distances = np.where(valid_mask, distances, np.inf) + distances = np.maximum(distances, 1e-10) + + # IDW权重计算 + weights = 1.0 / (distances ** power) + weights = np.where(valid_mask, weights, 0) + + # 加权平均 + weighted_sum = np.sum(weights * values_array[np.newaxis, np.newaxis, :], axis=2) + weight_total = np.sum(weights, axis=2) + + # 计算最终结果 + result = np.where( + final_mask & (weight_total > 0), + weighted_sum / weight_total, + np.nan + ) + + return result + + @staticmethod + def get_rainfall_color(rainfall: float) -> str: + """ + 根据降雨量获取颜色(蓝色渐变) + + Args: + rainfall: 降雨量(mm) + + Returns: + 颜色字符串 "rgba(r,g,b,a)" + """ + if rainfall < 0.1: + return "rgba(200,200,200,0)" # 透明 - 无雨 + elif rainfall < 10: + return "rgba(173,216,230,0.5)" # 浅蓝 - 小雨 + elif rainfall < 25: + return "rgba(100,149,237,0.6)" # 矢车菊蓝 - 中雨 + elif rainfall < 50: + return "rgba(30,144,255,0.7)" # 道奇蓝 - 大雨 + elif rainfall < 100: + return "rgba(0,0,205,0.8)" # 中蓝 - 暴雨 + else: + return "rgba(0,0,139,0.9)" # 深蓝 - 大暴雨 + + +class GeoJSONService: + """GeoJSON生成服务""" + + @staticmethod + def create_feature_collection( + grid_metadata: Dict[str, Any], + rainfall_array: np.ndarray, + grid_lon: np.ndarray, + grid_lat: np.ndarray + ) -> Dict[str, Any]: + """ + 创建GeoJSON FeatureCollection用于Cesium渲染 + + Args: + grid_metadata: 栅格元数据 + rainfall_array: 降雨量数组 + grid_lon: 经度网格 + grid_lat: 纬度网格 + + Returns: + GeoJSON格式的FeatureCollection + """ + features = [] + + # 将栅格数据转换为矩形要素 + for i in range(len(grid_lat) - 1): + for j in range(len(grid_lon) - 1): + rainfall_value = float(rainfall_array[i, j]) + + # 跳过无数据的区域 + if np.isnan(rainfall_value) or rainfall_value < 0: + continue + + # 创建矩形多边形 + lon_min = float(grid_lon[j]) + lon_max = float(grid_lon[j + 1]) + lat_min = float(grid_lat[i]) + lat_max = float(grid_lat[i + 1]) + + feature = { + "type": "Feature", + "geometry": { + "type": "Polygon", + "coordinates": [[ + [lon_min, lat_min], + [lon_max, lat_min], + [lon_max, lat_max], + [lon_min, lat_max], + [lon_min, lat_min] + ]] + }, + "properties": { + "rainfall": round(rainfall_value, 2), + "color": InterpolationService.get_rainfall_color(rainfall_value) + } + } + features.append(feature) + + return { + "type": "FeatureCollection", + "features": features, + "metadata": { + "resolution": grid_metadata['resolution'], + "grid_size": [len(grid_lon), len(grid_lat)], + "bounds": { + "min_lon": float(grid_lon.min()), + "max_lon": float(grid_lon.max()), + "min_lat": float(grid_lat.min()), + "max_lat": float(grid_lat.max()) + } + } + } + + +class RainfallService: + """降雨数据业务服务类""" + + def __init__(self): + self.repository = RainfallRepository() + self.interpolation_service = InterpolationService() + self.geojson_service = GeoJSONService() + + def get_stations_data( + self, + query_time: datetime + ) -> List[Dict[str, Any]]: + """ + 获取站点降雨数据 + + Args: + query_time: 查询时间(自动查询前12小时数据) + + Returns: + 站点数据列表 + """ + return self.repository.query_stations_rainfall(query_time) + + def generate_rainfall_grid( + self, + query_time: datetime, + resolution: float = 0.01 + ) -> Dict[str, Any]: + """ + 生成降雨栅格数据 + + Args: + query_time: 查询时间(自动查询前12小时数据) + resolution: 栅格分辨率 + + Returns: + GeoJSON格式的栅格数据 + """ + logger.info(f"查询降雨数据: {query_time}") + + # 查询站点数据(自动查询前12小时数据) + stations_data = self.get_stations_data(query_time) + + if not stations_data: + return None + + # 提取站点坐标和降雨量(过滤空值) + valid_stations = [row for row in stations_data if row['rainfall'] is not None] + + if not valid_stations: + logger.warning("所有站点的降雨量数据均为空") + return None + + points = [(row['lon'], row['lat']) for row in valid_stations] + values = [float(row['rainfall']) for row in valid_stations] + + # 确定栅格范围(西安大致范围) + lon_min, lon_max = 107.5, 109.5 + lat_min, lat_max = 33.5, 34.5 + + # 创建栅格网格 + num_lon = int((lon_max - lon_min) / resolution) + 1 + num_lat = int((lat_max - lat_min) / resolution) + 1 + + grid_lon = np.linspace(lon_min, lon_max, num_lon) + grid_lat = np.linspace(lat_min, lat_max, num_lat) + + logger.info(f"生成栅格: {num_lon}x{num_lat}, 分辨率: {resolution}") + + # 执行IDW插值(带凸包裁剪和距离阈值) + rainfall_grid = self.interpolation_service.inverse_distance_weighting( + points=points, + values=values, + grid_lon=grid_lon, + grid_lat=grid_lat, + power=2.0, + max_distance=0.3 # 最大影响距离0.3度(约30公里) + ) + + # 创建栅格元数据 + grid_metadata = { + "query_time": query_time.isoformat(), + "resolution": resolution, + "station_count": len(stations_data), + "grid_size": [num_lon, num_lat] + } + + # 转换为GeoJSON格式 + geojson_data = self.geojson_service.create_feature_collection( + grid_metadata, rainfall_grid, grid_lon, grid_lat + ) + + logger.info("降雨栅格数据生成成功") + + return geojson_data + + def get_rainfall_at_point( + self, + longitude: float, + latitude: float, + query_time: datetime + ) -> Optional[Dict[str, Any]]: + """ + 查询指定点位的降雨量(使用IDW插值) + + Args: + longitude: 经度 + latitude: 纬度 + query_time: 查询时间(自动查询前12小时数据) + + Returns: + 点位降雨量信息 + """ + # 获取站点数据(自动查询前12小时数据) + stations_data = self.get_stations_data(query_time) + + if not stations_data: + return None + + # 提取站点坐标和降雨量 + points = [(row['lon'], row['lat']) for row in stations_data] + values = [float(row['rainfall']) for row in stations_data] + + # 使用IDW插值计算该点的降雨量 + target_point = np.array([[longitude, latitude]]) + points_array = np.array(points) + + # 计算距离 + distances = np.sqrt( + (points_array[:, 0] - longitude) ** 2 + + (points_array[:, 1] - latitude) ** 2 + ) + + # 避免除零 + min_dist = 1e-10 + distances = np.maximum(distances, min_dist) + + # IDW公式 + power = 2.0 + weights = 1.0 / (distances ** power) + rainfall_value = np.sum(weights * values) / np.sum(weights) + + # 返回结果 + return { + "longitude": longitude, + "latitude": latitude, + "rainfall": round(float(rainfall_value), 2), + "level": self._get_rainfall_level(rainfall_value), + "color": InterpolationService.get_rainfall_color(rainfall_value), + "station_count": len(stations_data), + "query_time": query_time.isoformat() + } + + @staticmethod + def _get_rainfall_level(rainfall: float) -> str: + """获取降雨等级""" + if rainfall < 0.1: + return "无雨" + elif rainfall < 10: + return "小雨" + elif rainfall < 25: + return "中雨" + elif rainfall < 50: + return "大雨" + elif rainfall < 100: + return "暴雨" + else: + return "大暴雨" diff --git a/app/utils/__init__.py b/app/utils/__init__.py new file mode 100644 index 0000000..63b5914 --- /dev/null +++ b/app/utils/__init__.py @@ -0,0 +1,3 @@ +""" +Utility functions package +""" diff --git a/app/utils/logger.py b/app/utils/logger.py new file mode 100644 index 0000000..2143f34 --- /dev/null +++ b/app/utils/logger.py @@ -0,0 +1,104 @@ +""" +日志配置工具 +""" +import logging +import sys +from pathlib import Path +from logging.handlers import RotatingFileHandler, TimedRotatingFileHandler +from datetime import datetime + + +class LoggerConfig: + """日志配置类""" + + def __init__(self, log_dir: str = "logs", log_level: str = "INFO"): + self.log_dir = Path(log_dir) + self.log_level = getattr(logging, log_level.upper(), logging.INFO) + self.ensure_log_directory() + + def ensure_log_directory(self): + """确保日志目录存在""" + if not self.log_dir.exists(): + self.log_dir.mkdir(parents=True, exist_ok=True) + + def setup_logger( + self, + name: str = "app", + log_file: str = None, + max_bytes: int = 10 * 1024 * 1024, # 10MB + backup_count: int = 5, + use_timed_rotation: bool = False + ) -> logging.Logger: + """配置日志记录器 + + Args: + name: 日志记录器名称 + log_file: 日志文件名(None则使用时间戳) + max_bytes: 单个日志文件最大大小 + backup_count: 保留的备份文件数量 + use_timed_rotation: 是否使用按时间轮转 + + Returns: + 配置好的Logger实例 + """ + logger = logging.getLogger(name) + logger.setLevel(self.log_level) + + # 避免重复添加handler + if logger.handlers: + return logger + + # 创建formatter + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + # 控制台handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(self.log_level) + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + # 文件handler + if log_file is None: + timestamp = datetime.now().strftime("%Y%m%d") + log_file = f"app_{timestamp}.log" + + log_path = self.log_dir / log_file + + if use_timed_rotation: + # 按时间轮转(每天) + file_handler = TimedRotatingFileHandler( + log_path, + when='midnight', + interval=1, + backupCount=backup_count, + encoding='utf-8' + ) + else: + # 按大小轮转 + file_handler = RotatingFileHandler( + log_path, + maxBytes=max_bytes, + backupCount=backup_count, + encoding='utf-8' + ) + + file_handler.setLevel(self.log_level) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + return logger + + @staticmethod + def get_logger(name: str = "app") -> logging.Logger: + """获取logger实例""" + return logging.getLogger(name) + + +# 全局日志配置 +def setup_logging(log_dir: str = "logs", log_level: str = "INFO") -> logging.Logger: + """设置全局日志""" + logger_config = LoggerConfig(log_dir, log_level) + return logger_config.setup_logger() diff --git a/app/utils/platform_utils.py b/app/utils/platform_utils.py new file mode 100644 index 0000000..da0eb49 --- /dev/null +++ b/app/utils/platform_utils.py @@ -0,0 +1,137 @@ +""" +平台检测工具 - 兼容Windows和Linux +""" +import sys +import os +import platform +from pathlib import Path +from typing import Tuple + + +class PlatformDetector: + """平台检测器""" + + def __init__(self): + self.os_name = os.name + self.platform_system = platform.system() + self.platform_release = platform.release() + self.platform_version = platform.version() + self.python_version = sys.version_info + self.is_windows = self.platform_system == "Windows" + self.is_linux = self.platform_system == "Linux" + self.is_macos = self.platform_system == "Darwin" + + def get_platform_info(self) -> dict: + """获取平台信息""" + return { + "os_name": self.os_name, + "platform_system": self.platform_system, + "platform_release": self.platform_release, + "platform_version": self.platform_version, + "python_version": f"{self.python_version.major}.{self.python_version.minor}.{self.python_version.micro}", + "is_windows": self.is_windows, + "is_linux": self.is_linux, + "is_macos": self.is_macos, + } + + def check_python_version(self, min_version: Tuple[int, int] = (3, 13)) -> bool: + """检查Python版本是否满足要求 + + Args: + min_version: 最低版本要求 (major, minor) + + Returns: + 是否满足版本要求 + """ + current = (self.python_version.major, self.python_version.minor) + return current >= min_version + + def get_path_separator(self) -> str: + """获取路径分隔符""" + return "\\" if self.is_windows else "/" + + def normalize_path(self, path: str) -> str: + """标准化路径""" + return Path(path).as_posix() + + def get_project_root(self) -> Path: + """获取项目根目录""" + return Path(__file__).parent.parent.parent + + def ensure_directory(self, path: Path, create: bool = True) -> bool: + """确保目录存在 + + Args: + path: 目录路径 + create: 是否自动创建 + + Returns: + 目录是否存在 + """ + if path.exists(): + return True + + if create: + try: + path.mkdir(parents=True, exist_ok=True) + return True + except Exception as e: + print(f"创建目录失败 {path}: {e}") + return False + + return False + + def get_env_file_path(self) -> Path: + """获取环境配置文件路径""" + return self.get_project_root() / ".env" + + def format_command_for_platform(self, command: str) -> str: + """根据平台格式化命令 + + Args: + command: 原始命令 + + Returns: + 适合当前平台的命令 + """ + if self.is_windows: + # Windows特定命令转换 + if command.startswith("ls "): + return command.replace("ls ", "dir ") + elif command.startswith("cat "): + return command.replace("cat ", "type ") + elif command.startswith("rm "): + return command.replace("rm ", "del ") + elif command.startswith("cp "): + return command.replace("cp ", "copy ") + elif command.startswith("mv "): + return command.replace("mv ", "move ") + else: + # Linux/Mac特定命令转换 + if command.startswith("dir "): + return command.replace("dir ", "ls ") + elif command.startswith("type "): + return command.replace("type ", "cat ") + elif command.startswith("del "): + return command.replace("del ", "rm ") + elif command.startswith("copy "): + return command.replace("copy ", "cp ") + elif command.startswith("move "): + return command.replace("move ", "mv ") + + return command + + def print_platform_banner(self): + """打印平台信息横幅""" + info = self.get_platform_info() + print("=" * 60) + print(" 系统信息") + print("=" * 60) + print(f" 操作系统: {info['platform_system']} {info['platform_release']}") + print(f" Python版本: {info['python_version']}") + print(f" 项目根目录: {self.get_project_root()}") + print("=" * 60) + + +# 全局平台检测器实例 +platform_detector = PlatformDetector() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0052252 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +fastapi==0.136.1 +uvicorn==0.46.0 +sqlalchemy==2.0.49 +psycopg2-binary==2.9.12 +pydantic==2.13.3 +pydantic-settings==2.14.0 +numpy==2.4.4 +scipy==1.17.1 diff --git a/scripts/start_dev.bat b/scripts/start_dev.bat new file mode 100644 index 0000000..8f83293 --- /dev/null +++ b/scripts/start_dev.bat @@ -0,0 +1,74 @@ +@echo off +REM Windows startup script - Development environment (background mode) +echo ======================================== +echo Starting Development Environment +echo ======================================== + +REM Change to project root directory +cd /d "%~dp0.." + +REM Verify we are in the correct directory +if not exist "start.py" ( + echo Error: Cannot find start.py in current directory: %CD% + echo Please ensure you are running this script from the project directory + pause + exit /b 1 +) +echo Current directory: %CD% + +REM Check and create virtual environment if not exists +if not exist ".venv\Scripts\activate.bat" ( + echo Virtual environment not found, creating... + python -m venv .venv + if errorlevel 1 ( + echo Error: Failed to create virtual environment + echo Please ensure Python is installed and accessible + pause + exit /b 1 + ) + echo Virtual environment created successfully +) + +REM Activate virtual environment +call .venv\Scripts\activate.bat +if errorlevel 1 ( + echo Error: Failed to activate virtual environment + pause + exit /b 1 +) +echo Virtual environment activated + +REM Upgrade pip and install dependencies +echo Checking dependencies... +.venv\Scripts\python.exe -m pip install --upgrade pip -q +.venv\Scripts\python.exe -m pip install -r requirements.txt -q +if errorlevel 1 ( + echo Error: Failed to install dependencies + pause + exit /b 1 +) +echo Dependencies installed successfully + +REM Create logs directory if not exists +if not exist "logs" mkdir logs + +REM Start application in background using start command +REM Create a temporary startup script with clean environment +set TEMP_START_SCRIPT=%TEMP%\start_xian_app_%RANDOM%.bat +( + echo @echo off + echo cd /d %CD% + echo set ENVIRONMENT=development + echo .venv\Scripts\python.exe start.py +) > "%TEMP_START_SCRIPT%" + +start "Xian Algorithm Dev" cmd /k "title Xian Algorithm Dev && call "%TEMP_START_SCRIPT%"" + +echo. +echo Application started in background +echo To view logs, check: logs\app_*.log +echo To stop the application, run: scripts\stop.bat +echo ======================================== + +REM Keep the window open briefly to show any immediate errors +timeout /t 3 /nobreak >nul diff --git a/scripts/start_dev.sh b/scripts/start_dev.sh new file mode 100644 index 0000000..8f9c083 --- /dev/null +++ b/scripts/start_dev.sh @@ -0,0 +1,52 @@ +#!/bin/bash +# Linux/Mac startup script - Development environment (background mode) + +echo "========================================" +echo " Starting Development Environment" +echo "========================================" + +# Change to project root directory +cd "$(dirname "$0")/.." || exit + +# Check and create virtual environment if not exists +if [ ! -f ".venv/bin/activate" ]; then + echo "Virtual environment not found, creating..." + python3 -m venv .venv + if [ $? -ne 0 ]; then + echo "Error: Failed to create virtual environment" + echo "Please ensure Python 3 is installed and accessible" + exit 1 + fi + echo "Virtual environment created successfully" +fi + +# Activate virtual environment +source .venv/bin/activate +echo "Virtual environment activated" + +# Upgrade pip and install dependencies +echo "Checking dependencies..." +pip install --upgrade pip -q +pip install -r requirements.txt -q +if [ $? -ne 0 ]; then + echo "Error: Failed to install dependencies" + exit 1 +fi +echo "Dependencies installed successfully" + +# Set environment variable +export ENVIRONMENT=development + +# Create logs directory if not exists +mkdir -p logs + +# Start application in background +nohup python start.py > logs/app_dev.log 2>&1 & +APP_PID=$! + +echo $APP_PID > scripts/app_dev.pid +echo "" +echo "Application started in background (PID: $APP_PID)" +echo "To view logs: tail -f logs/app_dev.log" +echo "To stop the application, run: bash scripts/stop.sh" +echo "========================================" diff --git a/scripts/start_prod.bat b/scripts/start_prod.bat new file mode 100644 index 0000000..82875ca --- /dev/null +++ b/scripts/start_prod.bat @@ -0,0 +1,52 @@ +@echo off +REM Windows startup script - Production environment (background mode) +echo ======================================== +echo Starting Production Environment +echo ======================================== + +REM Change to project root directory +cd /d "%~dp0.." + +REM Check and create virtual environment if not exists +if not exist ".venv\Scripts\activate.bat" ( + echo Virtual environment not found, creating... + python -m venv .venv + if errorlevel 1 ( + echo Error: Failed to create virtual environment + echo Please ensure Python is installed and accessible + pause + exit /b 1 + ) + echo Virtual environment created successfully +) + +REM Activate virtual environment +call .venv\Scripts\activate.bat +echo Virtual environment activated + +REM Upgrade pip and install dependencies +echo Checking dependencies... +.venv\Scripts\python.exe -m pip install --upgrade pip -q +.venv\Scripts\python.exe -m pip install -r requirements.txt -q +if errorlevel 1 ( + echo Error: Failed to install dependencies + pause + exit /b 1 +) +echo Dependencies installed successfully + +REM Set environment variable +set ENVIRONMENT=production + +REM Create logs directory if not exists +if not exist "logs" mkdir logs + +REM Start application in background using start command +start "Xian Algorithm Prod" cmd /c "title Xian Algorithm Prod && .venv\Scripts\python.exe start.py" + +echo. +echo Application started in background +echo To view logs, check: logs\app_*.log +echo To stop the application, run: scripts\stop.bat +echo ======================================== + diff --git a/scripts/start_prod.sh b/scripts/start_prod.sh new file mode 100644 index 0000000..943030c --- /dev/null +++ b/scripts/start_prod.sh @@ -0,0 +1,52 @@ +#!/bin/bash +# Linux/Mac startup script - Production environment (background mode) + +echo "========================================" +echo " Starting Production Environment" +echo "========================================" + +# Change to project root directory +cd "$(dirname "$0")/.." || exit + +# Check and create virtual environment if not exists +if [ ! -f ".venv/bin/activate" ]; then + echo "Virtual environment not found, creating..." + python3 -m venv .venv + if [ $? -ne 0 ]; then + echo "Error: Failed to create virtual environment" + echo "Please ensure Python 3 is installed and accessible" + exit 1 + fi + echo "Virtual environment created successfully" +fi + +# Activate virtual environment +source .venv/bin/activate +echo "Virtual environment activated" + +# Upgrade pip and install dependencies +echo "Checking dependencies..." +pip install --upgrade pip -q +pip install -r requirements.txt -q +if [ $? -ne 0 ]; then + echo "Error: Failed to install dependencies" + exit 1 +fi +echo "Dependencies installed successfully" + +# Set environment variable +export ENVIRONMENT=production + +# Create logs directory if not exists +mkdir -p logs + +# Start application in background +nohup python start.py > logs/app_prod.log 2>&1 & +APP_PID=$! + +echo $APP_PID > scripts/app_prod.pid +echo "" +echo "Application started in background (PID: $APP_PID)" +echo "To view logs: tail -f logs/app_prod.log" +echo "To stop the application, run: bash scripts/stop.sh" +echo "========================================" diff --git a/scripts/stop.bat b/scripts/stop.bat new file mode 100644 index 0000000..565f965 --- /dev/null +++ b/scripts/stop.bat @@ -0,0 +1,29 @@ +@echo off +REM Windows stop script - Stop the application +echo ======================================== +echo Stopping Application +echo ======================================== + +REM Change to project root directory +cd /d "%~dp0.." + +REM Find and kill python processes running start.py +echo Searching for running application... +tasklist /FI "IMAGENAME eq python.exe" /FO CSV /NH 2>nul | findstr /I "python.exe" >nul +if %errorlevel% equ 0 ( + echo Found running Python processes, stopping... + taskkill /F /FI "WINDOWTITLE eq Xian Algorithm Dev" 2>nul + taskkill /F /FI "WINDOWTITLE eq Xian Algorithm Prod" 2>nul + + REM If title-based kill didn't work, try killing all python.exe running start.py + for /f "tokens=2 delims=," %%a in ('tasklist /FI "IMAGENAME eq python.exe" /FO CSV /NH 2^>nul') do ( + set "PID=%%~a" + wmic process where "ProcessId=!PID! and CommandLine like '%%start.py%%'" delete >nul 2>&1 + ) + + echo Application stopped successfully +) else ( + echo No running application found +) + +echo ======================================== diff --git a/scripts/stop.sh b/scripts/stop.sh new file mode 100644 index 0000000..5bf8e48 --- /dev/null +++ b/scripts/stop.sh @@ -0,0 +1,67 @@ +#!/bin/bash +# Linux/Mac stop script - Stop the application + +echo "========================================" +echo " Stopping Application" +echo "========================================" + +# Change to project root directory +cd "$(dirname "$0")/.." || exit + +# Function to stop a process by PID file +stop_process() { + local pid_file=$1 + local env_name=$2 + + if [ -f "$pid_file" ]; then + local pid=$(cat "$pid_file") + if kill -0 "$pid" 2>/dev/null; then + echo "Stopping $env_name (PID: $pid)..." + kill "$pid" + + # Wait for process to stop + local count=0 + while kill -0 "$pid" 2>/dev/null && [ $count -lt 10 ]; do + sleep 1 + count=$((count + 1)) + done + + # Force kill if still running + if kill -0 "$pid" 2>/dev/null; then + echo "Force stopping $env_name..." + kill -9 "$pid" + fi + + rm -f "$pid_file" + echo "$env_name stopped" + else + echo "$env_name not running (stale PID file removed)" + rm -f "$pid_file" + fi + else + echo "No $env_name PID file found" + fi +} + +# Stop development environment +stop_process "scripts/app_dev.pid" "Development environment" + +# Stop production environment +stop_process "scripts/app_prod.pid" "Production environment" + +# Also try to find any remaining python processes running start.py +REMAINING_PIDS=$(ps aux | grep "[p]ython.*start.py" | awk '{print $2}') +if [ -n "$REMAINING_PIDS" ]; then + echo "Found additional Python processes, stopping..." + echo "$REMAINING_PIDS" | xargs kill 2>/dev/null + sleep 2 + + # Force kill if still running + REMAINING_PIDS=$(ps aux | grep "[p]ython.*start.py" | awk '{print $2}') + if [ -n "$REMAINING_PIDS" ]; then + echo "$REMAINING_PIDS" | xargs kill -9 2>/dev/null + fi +fi + +echo "========================================" +echo "All applications stopped" diff --git a/start.py b/start.py new file mode 100644 index 0000000..76d7396 --- /dev/null +++ b/start.py @@ -0,0 +1,231 @@ +""" +项目启动脚本 - 支持多环境和跨平台 +""" +import sys +import subprocess +import os +from pathlib import Path + +# 添加项目根目录到Python路径 +project_root = Path(__file__).parent +sys.path.insert(0, str(project_root)) + + +def check_platform(): + """检测平台信息""" + print("=" * 60) + print(" 系统信息检测") + print("=" * 60) + + from app.utils.platform_utils import platform_detector + platform_detector.print_platform_banner() + + return platform_detector + + +def check_python_version(): + """检查Python版本是否为3.13或更高""" + print("\n" + "=" * 60) + print(" Python版本检查") + print("=" * 60) + + from app.utils.platform_utils import platform_detector + + current_version = sys.version_info + print(f"当前Python版本: {current_version.major}.{current_version.minor}.{current_version.micro}") + + if not platform_detector.check_python_version((3, 13)): + print("\n❌ 错误: Python版本过低!") + print(f" 当前版本: {current_version.major}.{current_version.minor}.{current_version.micro}") + print(" 要求版本: 3.13 或更高") + print("\n请升级到Python 3.13或更高版本:") + print(" 下载地址: https://www.python.org/downloads/") + print("=" * 60) + sys.exit(1) + + print(f"✅ Python版本检查通过: {current_version.major}.{current_version.minor}.{current_version.micro}") + print("=" * 60) + return True + + +def get_environment(): + """获取运行环境""" + print("\n" + "=" * 60) + print(" 环境配置") + print("=" * 60) + + environment = os.getenv("ENVIRONMENT", "development") + print(f"当前环境: {environment}") + + if environment not in ["development", "production"]: + print(f"⚠️ 警告: 未知环境 '{environment}',使用默认开发环境") + environment = "development" + + print("=" * 60) + return environment + + +def install_dependencies(): + """检查并安装依赖包""" + print("\n" + "=" * 60) + print(" 依赖包检查") + print("=" * 60) + + requirements_file = "requirements.txt" + + if not os.path.exists(requirements_file): + print(f"\n❌ 错误: 找不到依赖文件 {requirements_file}") + sys.exit(1) + + # 读取已安装的包 + try: + result = subprocess.run( + [sys.executable, "-m", "pip", "list", "--format=freeze"], + capture_output=True, + text=True, + check=True + ) + installed_packages = {line.split("==")[0].lower() for line in result.stdout.strip().split("\n") if line} + except subprocess.CalledProcessError as e: + print(f"\n❌ 检查已安装包失败: {e}") + sys.exit(1) + + # 读取requirements.txt中的包 + with open(requirements_file, "r", encoding="utf-8") as f: + required_packages = [] + for line in f: + line = line.strip() + if line and not line.startswith("#"): + package_name = line.split("==")[0].lower() + required_packages.append(line) + + # 检查缺失的包 + missing_packages = [] + for package_line in required_packages: + package_name = package_line.split("==")[0].lower() + if package_name not in installed_packages: + missing_packages.append(package_line) + + if not missing_packages: + print("\n✅ 所有依赖包已安装,无需重复安装") + print("=" * 60) + return True + + print(f"\n发现 {len(missing_packages)} 个未安装的依赖包:") + for package in missing_packages: + print(f" - {package}") + + print("\n正在安装依赖包...") + try: + subprocess.run( + [sys.executable, "-m", "pip", "install", "-r", requirements_file], + check=True + ) + print("\n✅ 依赖包安装成功") + print("=" * 60) + return True + except subprocess.CalledProcessError as e: + print(f"\n❌ 依赖包安装失败: {e}") + print("=" * 60) + sys.exit(1) + + +def initialize_database(): + """初始化数据库连接""" + print("\n" + "=" * 60) + print(" 数据库初始化") + print("=" * 60) + + try: + from app.core.database import db_manager + from app.config.settings import get_settings + + settings = get_settings() + + print(f"数据库地址: {settings.DB_HOST}:{settings.DB_PORT}") + print(f"数据库名称: {settings.DB_NAME}") + print(f"连接池大小: {settings.DB_POOL_SIZE}") + + # 测试数据库连接 + if db_manager.test_connection(): + print("✅ 数据库连接成功") + else: + print("⚠️ 数据库连接失败,请检查配置") + return False + + print("=" * 60) + return True + + except Exception as e: + print(f"\n⚠️ 数据库初始化警告: {e}") + print(" 应用将继续启动,但数据库功能可能不可用") + print("=" * 60) + return False + + +def start_application(environment: str): + """启动FastAPI应用""" + print("\n" + "=" * 60) + print(" 启动FastAPI应用") + print("=" * 60) + + try: + from app.config.settings import get_settings + import uvicorn + + settings = get_settings() + + print(f"应用名称: {settings.APP_NAME}") + print(f"应用版本: {settings.APP_VERSION}") + print(f"运行环境: {settings.ENVIRONMENT.value}") + print(f"监听地址: {settings.API_HOST}:{settings.API_PORT}") + print(f"调试模式: {'开启' if settings.DEBUG else '关闭'}") + print(f"自动重载: {'开启' if hasattr(settings, 'RELOAD') and settings.RELOAD else '关闭'}") + print("\n🚀 应用启动中...\n") + print("=" * 60) + + # 启动uvicorn服务器 + uvicorn.run( + "app.main:app", + host=settings.API_HOST, + port=settings.API_PORT, + reload=settings.RELOAD if hasattr(settings, 'RELOAD') else settings.DEBUG, + log_level=settings.LOG_LEVEL.lower() + ) + + except KeyboardInterrupt: + print("\n\n应用已停止") + except Exception as e: + print(f"\n❌ 应用启动失败: {e}") + import traceback + traceback.print_exc() + sys.exit(1) + + +def main(): + """主函数""" + print("\n" + "=" * 60) + print(" Xian Algorithm New - 应用启动程序") + print("=" * 60) + + # 检测平台信息 + check_platform() + + # 检查Python版本 + check_python_version() + + # 获取环境配置 + environment = get_environment() + + # 检查并安装依赖 + install_dependencies() + + # 初始化数据库 + initialize_database() + + # 启动应用 + start_application(environment) + + +if __name__ == "__main__": + main()