diff --git a/app/utils/redis_helper.py b/app/utils/redis_helper.py index 83d17ca..7ef7581 100644 --- a/app/utils/redis_helper.py +++ b/app/utils/redis_helper.py @@ -2,6 +2,7 @@ Redis 数据库工具类 提供常用的 Redis 操作方法 """ +import json import redis from typing import Any, Optional, List, Dict, Union from config import settings @@ -45,29 +46,44 @@ class RedisHelper: def set(self, key: str, value: Any, ex: Optional[int] = None) -> bool: """ - 设置键值对 + 设置键值对(自动处理JSON序列化) Args: key: 键名 - value: 值 + value: 值(可以是字符串、数字、字典、列表等) ex: 过期时间(秒),None 表示不过期 Returns: 是否设置成功 """ + # 如果不是字符串,自动序列化为JSON + if not isinstance(value, str): + value = json.dumps(value, ensure_ascii=False) return self.client.set(key, value, ex=ex) - def get(self, key: str) -> Optional[str]: + def get(self, key: str, parse_json: bool = True) -> Optional[Any]: """ - 获取键对应的值 + 获取键对应的值(自动处理JSON反序列化) Args: key: 键名 + parse_json: 是否尝试解析为JSON对象 Returns: 键对应的值,如果键不存在则返回 None """ - return self.client.get(key) + value = self.client.get(key) + if value is None: + return None + + # 尝试解析JSON + if parse_json: + try: + return json.loads(value) + except (json.JSONDecodeError, TypeError): + pass + + return value def delete(self, *keys: str) -> int: """ @@ -148,42 +164,70 @@ class RedisHelper: def hset(self, name: str, key: str, value: Any) -> int: """ - 设置哈希表字段的值 + 设置哈希表字段的值(自动处理JSON序列化) Args: name: 哈希表名 key: 字段名 - value: 字段值 + value: 字段值(可以是字符串、数字、字典、列表等) Returns: 1 表示新增字段,0 表示更新字段 """ + # 如果不是字符串,自动序列化为JSON + if not isinstance(value, str): + value = json.dumps(value, ensure_ascii=False) return self.client.hset(name, key, value) - def hget(self, name: str, key: str) -> Optional[str]: + def hget(self, name: str, key: str, parse_json: bool = True) -> Optional[Any]: """ - 获取哈希表字段的值 + 获取哈希表字段的值(自动处理JSON反序列化) Args: name: 哈希表名 key: 字段名 + parse_json: 是否尝试解析为JSON对象 Returns: 字段值,如果字段不存在则返回 None """ - return self.client.hget(name, key) + value = self.client.hget(name, key) + if value is None: + return None + + # 尝试解析JSON + if parse_json: + try: + return json.loads(value) + except (json.JSONDecodeError, TypeError): + pass + + return value - def hgetall(self, name: str) -> Dict[str, str]: + def hgetall(self, name: str, parse_json: bool = True) -> Dict[str, Any]: """ - 获取哈希表所有字段和值 + 获取哈希表所有字段和值(自动处理JSON反序列化) Args: name: 哈希表名 + parse_json: 是否尝试解析为JSON对象 Returns: 包含所有字段和值的字典 """ - return self.client.hgetall(name) + data = self.client.hgetall(name) + + if parse_json: + # 尝试解析每个值为JSON + result = {} + for k, v in data.items(): + try: + result[k] = json.loads(v) + except (json.JSONDecodeError, TypeError): + result[k] = v + return result + + return data def hdel(self, name: str, *keys: str) -> int: """ diff --git a/app/utils/thread_pool_manager.py b/app/utils/thread_pool_manager.py new file mode 100644 index 0000000..a1fdc30 --- /dev/null +++ b/app/utils/thread_pool_manager.py @@ -0,0 +1,220 @@ +""" +线程池管理工具类 +提供线程池的创建、管理和任务提交功能 +""" +import threading +import time +from concurrent.futures import ThreadPoolExecutor, Future +from typing import Callable, Any, Optional +from app.utils.logger import get_logger + + +class ThreadPoolManager: + """线程池管理器""" + + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + """单例模式确保只有一个线程池管理器实例""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self, max_workers: int = 10, thread_name_prefix: str = "Worker"): + """ + 初始化线程池管理器 + + Args: + max_workers: 线程池最大工作线程数 + thread_name_prefix: 线程名称前缀 + """ + # 防止重复初始化 + if hasattr(self, '_initialized') and self._initialized: + return + + self.max_workers = max_workers + self.thread_name_prefix = thread_name_prefix + self.executor = ThreadPoolExecutor( + max_workers=max_workers, + thread_name_prefix=thread_name_prefix + ) + self.futures = {} # 存储提交的任务future对象 + self.tasks_count = 0 # 任务计数器 + self.completed_tasks = 0 # 已完成任务计数 + self._shutdown = False + self.logger = get_logger() + self._initialized = True + + self.logger.info(f"线程池管理器已初始化,最大工作线程数: {max_workers}") + + @classmethod + def get_instance(cls, max_workers: int = 10, thread_name_prefix: str = "Worker") -> 'ThreadPoolManager': + """ + 获取线程池管理器单例实例 + + Args: + max_workers: 线程池最大工作线程数 + thread_name_prefix: 线程名称前缀 + + Returns: + ThreadPoolManager: 线程池管理器实例 + """ + if cls._instance is None: + cls._instance = cls(max_workers, thread_name_prefix) + return cls._instance + + def submit_task(self, func: Callable, *args, task_name: Optional[str] = None, **kwargs) -> Future: + """ + 提交任务到线程池 + + Args: + func: 要执行的函数 + *args: 函数的位置参数 + task_name: 任务名称(可选) + **kwargs: 函数的关键字参数 + + Returns: + Future: 代表异步执行结果的Future对象 + """ + if self._shutdown: + raise RuntimeError("线程池已关闭,无法提交新任务") + + self.tasks_count += 1 + if not task_name: + task_name = f"Task_{self.tasks_count}" + + future = self.executor.submit(func, *args, **kwargs) + self.futures[task_name] = future + + # 添加回调函数,在任务完成时更新计数 + future.add_done_callback(lambda f: self._on_task_complete(task_name)) + + self.logger.debug(f"任务 '{task_name}' 已提交到线程池") + return future + + def _on_task_complete(self, task_name: str): + """ + 任务完成时的回调处理 + + Args: + task_name: 完成的任务名称 + """ + self.completed_tasks += 1 + if task_name in self.futures: + del self.futures[task_name] + self.logger.debug(f"任务 '{task_name}' 已完成,当前完成任务数: {self.completed_tasks}") + + def submit_and_wait(self, func: Callable, *args, timeout: Optional[float] = None, **kwargs) -> Any: + """ + 提交任务并等待结果 + + Args: + func: 要执行的函数 + *args: 函数的位置参数 + timeout: 超时时间(秒),None表示无限等待 + **kwargs: 函数的关键字参数 + + Returns: + Any: 任务执行结果 + """ + future = self.submit_task(func, *args, **kwargs) + try: + result = future.result(timeout=timeout) + return result + except Exception as e: + self.logger.error(f"任务执行出错: {e}") + raise + + def shutdown(self, wait: bool = True): + """ + 关闭线程池 + + Args: + wait: 是否等待所有任务完成后再关闭 + """ + if not self._shutdown: + self._shutdown = True + self.logger.info("正在关闭线程池...") + + # 等待所有任务完成 + if wait: + self.executor.shutdown(wait=True) + self.logger.info("线程池已关闭,所有任务已完成") + else: + self.executor.shutdown(wait=False) + self.logger.info("线程池已关闭,未完成任务将被取消") + + # 清理future引用 + self.futures.clear() + + def get_active_threads_count(self) -> int: + """ + 获取活跃线程数量 + + Returns: + int: 活跃线程数 + """ + # 通过统计未完成的future数量来估算活跃线程数 + active_futures = [f for f in self.futures.values() if not f.done()] + return len(active_futures) + + def get_pool_status(self) -> dict: + """ + 获取线程池状态信息 + + Returns: + dict: 包含线程池状态信息的字典 + """ + return { + "max_workers": self.max_workers, + "active_threads": self.get_active_threads_count(), + "total_tasks_submitted": self.tasks_count, + "completed_tasks": self.completed_tasks, + "pending_tasks": len([f for f in self.futures.values() if not f.done()]), + "is_shutdown": self._shutdown + } + + def wait_for_completion(self, timeout: Optional[float] = None): + """ + 等待所有任务完成 + + Args: + timeout: 超时时间(秒),None表示无限等待 + """ + start_time = time.time() + while True: + pending_futures = [f for f in self.futures.values() if not f.done()] + if not pending_futures: + break + + if timeout and (time.time() - start_time) > timeout: + self.logger.warning(f"等待任务完成超时 ({timeout}s)") + break + + time.sleep(0.1) # 短暂休眠以减少CPU占用 + + +def block_main_thread(): + """ + 阻塞主线程,防止程序立即退出 + 通常用于保持服务持续运行 + """ + logger = get_logger() + logger.info("主线程进入阻塞状态,按 Ctrl+C 退出...") + + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + logger.info("收到中断信号,准备退出...") + # 获取线程池管理器实例并关闭 + pool_manager = ThreadPoolManager.get_instance() + pool_manager.shutdown(wait=True) + logger.info("程序正常退出") + + +# 全局线程池管理器实例 +thread_pool_manager = ThreadPoolManager.get_instance() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index b3ccf6e..ee4674e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,7 @@ dynaconf == 3.2.13 psycopg2-binary == 2.9.12 -redis == 7.4.0 \ No newline at end of file +redis == 7.4.0 +numpy == 2.4.4 +scipy == 1.17.1 +matplotlib == 3.10.0 +Pillow == 11.0.0 \ No newline at end of file diff --git a/settings.toml b/settings.toml index 3fedcbd..f1d8fef 100644 --- a/settings.toml +++ b/settings.toml @@ -2,6 +2,10 @@ [default] APP_NAME = "西安项目算法服务" LOG_DIR = "logs" +# 雨量站栅格存储位置,:id会被替换成数据id +RAIN_STATION_GRID_DIR = "/xian/rainfall/grid/images/:id" +# 雨量站栅格存储redis的key +REDIS_RAIN_STATION_GRID_KEY = "xian:rainfall:rain_station_grid" # 开发环境 [development] @@ -22,6 +26,9 @@ REDIS_HOST = "47.92.216.173" REDIS_PORT = 7655 REDIS_PASSWORD = "zhangsan" REDIS_DB = 0 +# 文件存储 +FILE_STORE_DIR = "G:/files" + # 生产环境 [production] @@ -41,4 +48,6 @@ LOG_LEVEL = "WARNING" REDIS_HOST = "localhost" REDIS_PORT = 6379 REDIS_PASSWORD = "XAYJ@gis2603" -REDIS_DB = 0 \ No newline at end of file +REDIS_DB = 0 +# 文件存储 +FILE_STORE_DIR = "D:/files" \ No newline at end of file