# Copyright (c) Opendatalab. All rights reserved. """ 配置文件加载器 支持从 YAML 或 JSON 配置文件读取配置 """ import os import json from pathlib import Path from typing import Any, Dict, Optional class ConfigLoader: """配置文件加载器""" def __init__(self, config_file: Optional[str] = None): """ 初始化配置加载器 Args: config_file: 配置文件路径,支持 .yaml / .yml / .json 格式 如果为 None,将按以下顺序查找: 1. 当前目录下的 config.yaml 2. 当前目录下的 config.yml 3. 当前目录下的 config.json """ self.config_file = config_file self.config: Dict[str, Any] = {} self._load_config() def _find_config_file(self) -> Optional[str]: """查找配置文件""" # 获取当前模块所在目录 current_dir = Path(__file__).parent # 按优先级查找配置文件 candidates = [ current_dir / "config.yaml", current_dir / "config.yml", current_dir / "config.json", ] for candidate in candidates: if candidate.exists(): return str(candidate) return None def _load_yaml(self, file_path: str) -> Dict[str, Any]: """加载 YAML 配置文件""" try: import yaml with open(file_path, 'r', encoding='utf-8') as f: return yaml.safe_load(f) or {} except ImportError: raise ImportError( "需要安装 PyYAML 才能读取 YAML 配置文件。" "请运行: pip install pyyaml" ) except Exception as e: raise RuntimeError(f"加载 YAML 配置文件失败: {e}") def _load_json(self, file_path: str) -> Dict[str, Any]: """加载 JSON 配置文件""" try: with open(file_path, 'r', encoding='utf-8') as f: return json.load(f) except Exception as e: raise RuntimeError(f"加载 JSON 配置文件失败: {e}") def _load_config(self): """加载配置文件""" # 确定配置文件路径 config_path = self.config_file if not config_path: config_path = self._find_config_file() if not config_path: # 没有找到配置文件,使用空配置(将使用默认值) return if not os.path.exists(config_path): raise FileNotFoundError(f"配置文件不存在: {config_path}") # 根据文件扩展名加载配置 ext = Path(config_path).suffix.lower() if ext in ['.yaml', '.yml']: self.config = self._load_yaml(config_path) elif ext == '.json': self.config = self._load_json(config_path) else: raise ValueError(f"不支持的配置文件格式: {ext},仅支持 .yaml, .yml, .json") def get(self, key: str, default: Any = None) -> Any: """ 获取配置项 Args: key: 配置项的键名 default: 默认值(如果配置项不存在) Returns: 配置项的值,如果不存在则返回默认值 """ return self.config.get(key, default) def get_int(self, key: str, default: int = 0) -> int: """获取整数类型的配置项""" value = self.get(key, default) try: return int(value) except (ValueError, TypeError): return default def get_float(self, key: str, default: float = 0.0) -> float: """获取浮点数类型的配置项""" value = self.get(key, default) try: return float(value) except (ValueError, TypeError): return default def get_bool(self, key: str, default: bool = False) -> bool: """获取布尔类型的配置项""" value = self.get(key, default) if isinstance(value, bool): return value if isinstance(value, str): return value.lower() in ['true', 'yes', '1', 'on'] return bool(value) def get_str(self, key: str, default: str = "") -> str: """获取字符串类型的配置项""" value = self.get(key, default) return str(value) if value is not None else default # 创建全局配置加载器实例 _config_loader: Optional[ConfigLoader] = None def get_config_loader(config_file: Optional[str] = None) -> ConfigLoader: """ 获取全局配置加载器实例 Args: config_file: 配置文件路径(可选) Returns: ConfigLoader 实例 """ global _config_loader if _config_loader is None: _config_loader = ConfigLoader(config_file) return _config_loader def reload_config(config_file: Optional[str] = None): """ 重新加载配置文件 Args: config_file: 配置文件路径(可选) """ global _config_loader _config_loader = ConfigLoader(config_file)