| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167 |
- # Copyright (c) Opendatalab. All rights reserved.
- """
- 配置文件加载器
- 支持从 YAML 或 JSON 配置文件读取配置
- """
- import os
- import json
- from pathlib import Path
- from typing import Any, Dict, Optional
- class ConfigLoader:
- """配置文件加载器"""
-
- def __init__(self, config_file: Optional[str] = None):
- """
- 初始化配置加载器
-
- Args:
- config_file: 配置文件路径,支持 .yaml / .yml / .json 格式
- 如果为 None,将按以下顺序查找:
- 1. 当前目录下的 config.yaml
- 2. 当前目录下的 config.yml
- 3. 当前目录下的 config.json
- """
- self.config_file = config_file
- self.config: Dict[str, Any] = {}
- self._load_config()
-
- def _find_config_file(self) -> Optional[str]:
- """查找配置文件"""
- # 获取当前模块所在目录
- current_dir = Path(__file__).parent
-
- # 按优先级查找配置文件
- candidates = [
- current_dir / "config.yaml",
- current_dir / "config.yml",
- current_dir / "config.json",
- ]
-
- for candidate in candidates:
- if candidate.exists():
- return str(candidate)
-
- return None
-
- def _load_yaml(self, file_path: str) -> Dict[str, Any]:
- """加载 YAML 配置文件"""
- try:
- import yaml
- with open(file_path, 'r', encoding='utf-8') as f:
- return yaml.safe_load(f) or {}
- except ImportError:
- raise ImportError(
- "需要安装 PyYAML 才能读取 YAML 配置文件。"
- "请运行: pip install pyyaml"
- )
- except Exception as e:
- raise RuntimeError(f"加载 YAML 配置文件失败: {e}")
-
- def _load_json(self, file_path: str) -> Dict[str, Any]:
- """加载 JSON 配置文件"""
- try:
- with open(file_path, 'r', encoding='utf-8') as f:
- return json.load(f)
- except Exception as e:
- raise RuntimeError(f"加载 JSON 配置文件失败: {e}")
-
- def _load_config(self):
- """加载配置文件"""
- # 确定配置文件路径
- config_path = self.config_file
- if not config_path:
- config_path = self._find_config_file()
-
- if not config_path:
- # 没有找到配置文件,使用空配置(将使用默认值)
- return
-
- if not os.path.exists(config_path):
- raise FileNotFoundError(f"配置文件不存在: {config_path}")
-
- # 根据文件扩展名加载配置
- ext = Path(config_path).suffix.lower()
- if ext in ['.yaml', '.yml']:
- self.config = self._load_yaml(config_path)
- elif ext == '.json':
- self.config = self._load_json(config_path)
- else:
- raise ValueError(f"不支持的配置文件格式: {ext},仅支持 .yaml, .yml, .json")
-
- def get(self, key: str, default: Any = None) -> Any:
- """
- 获取配置项
-
- Args:
- key: 配置项的键名
- default: 默认值(如果配置项不存在)
-
- Returns:
- 配置项的值,如果不存在则返回默认值
- """
- return self.config.get(key, default)
-
- def get_int(self, key: str, default: int = 0) -> int:
- """获取整数类型的配置项"""
- value = self.get(key, default)
- try:
- return int(value)
- except (ValueError, TypeError):
- return default
-
- def get_float(self, key: str, default: float = 0.0) -> float:
- """获取浮点数类型的配置项"""
- value = self.get(key, default)
- try:
- return float(value)
- except (ValueError, TypeError):
- return default
-
- def get_bool(self, key: str, default: bool = False) -> bool:
- """获取布尔类型的配置项"""
- value = self.get(key, default)
- if isinstance(value, bool):
- return value
- if isinstance(value, str):
- return value.lower() in ['true', 'yes', '1', 'on']
- return bool(value)
-
- def get_str(self, key: str, default: str = "") -> str:
- """获取字符串类型的配置项"""
- value = self.get(key, default)
- return str(value) if value is not None else default
- # 创建全局配置加载器实例
- _config_loader: Optional[ConfigLoader] = None
- def get_config_loader(config_file: Optional[str] = None) -> ConfigLoader:
- """
- 获取全局配置加载器实例
-
- Args:
- config_file: 配置文件路径(可选)
-
- Returns:
- ConfigLoader 实例
- """
- global _config_loader
- if _config_loader is None:
- _config_loader = ConfigLoader(config_file)
- return _config_loader
- def reload_config(config_file: Optional[str] = None):
- """
- 重新加载配置文件
-
- Args:
- config_file: 配置文件路径(可选)
- """
- global _config_loader
- _config_loader = ConfigLoader(config_file)
|