config_loader.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. """
  3. 配置文件加载器
  4. 支持从 YAML 或 JSON 配置文件读取配置
  5. """
  6. import os
  7. import json
  8. from pathlib import Path
  9. from typing import Any, Dict, Optional
  10. class ConfigLoader:
  11. """配置文件加载器"""
  12. def __init__(self, config_file: Optional[str] = None):
  13. """
  14. 初始化配置加载器
  15. Args:
  16. config_file: 配置文件路径,支持 .yaml / .yml / .json 格式
  17. 如果为 None,将按以下顺序查找:
  18. 1. 当前目录下的 config.yaml
  19. 2. 当前目录下的 config.yml
  20. 3. 当前目录下的 config.json
  21. """
  22. self.config_file = config_file
  23. self.config: Dict[str, Any] = {}
  24. self._load_config()
  25. def _find_config_file(self) -> Optional[str]:
  26. """查找配置文件"""
  27. # 获取当前模块所在目录
  28. current_dir = Path(__file__).parent
  29. # 按优先级查找配置文件
  30. candidates = [
  31. current_dir / "config.yaml",
  32. current_dir / "config.yml",
  33. current_dir / "config.json",
  34. ]
  35. for candidate in candidates:
  36. if candidate.exists():
  37. return str(candidate)
  38. return None
  39. def _load_yaml(self, file_path: str) -> Dict[str, Any]:
  40. """加载 YAML 配置文件"""
  41. try:
  42. import yaml
  43. with open(file_path, 'r', encoding='utf-8') as f:
  44. return yaml.safe_load(f) or {}
  45. except ImportError:
  46. raise ImportError(
  47. "需要安装 PyYAML 才能读取 YAML 配置文件。"
  48. "请运行: pip install pyyaml"
  49. )
  50. except Exception as e:
  51. raise RuntimeError(f"加载 YAML 配置文件失败: {e}")
  52. def _load_json(self, file_path: str) -> Dict[str, Any]:
  53. """加载 JSON 配置文件"""
  54. try:
  55. with open(file_path, 'r', encoding='utf-8') as f:
  56. return json.load(f)
  57. except Exception as e:
  58. raise RuntimeError(f"加载 JSON 配置文件失败: {e}")
  59. def _load_config(self):
  60. """加载配置文件"""
  61. # 确定配置文件路径
  62. config_path = self.config_file
  63. if not config_path:
  64. config_path = self._find_config_file()
  65. if not config_path:
  66. # 没有找到配置文件,使用空配置(将使用默认值)
  67. return
  68. if not os.path.exists(config_path):
  69. raise FileNotFoundError(f"配置文件不存在: {config_path}")
  70. # 根据文件扩展名加载配置
  71. ext = Path(config_path).suffix.lower()
  72. if ext in ['.yaml', '.yml']:
  73. self.config = self._load_yaml(config_path)
  74. elif ext == '.json':
  75. self.config = self._load_json(config_path)
  76. else:
  77. raise ValueError(f"不支持的配置文件格式: {ext},仅支持 .yaml, .yml, .json")
  78. def get(self, key: str, default: Any = None) -> Any:
  79. """
  80. 获取配置项
  81. Args:
  82. key: 配置项的键名
  83. default: 默认值(如果配置项不存在)
  84. Returns:
  85. 配置项的值,如果不存在则返回默认值
  86. """
  87. return self.config.get(key, default)
  88. def get_int(self, key: str, default: int = 0) -> int:
  89. """获取整数类型的配置项"""
  90. value = self.get(key, default)
  91. try:
  92. return int(value)
  93. except (ValueError, TypeError):
  94. return default
  95. def get_float(self, key: str, default: float = 0.0) -> float:
  96. """获取浮点数类型的配置项"""
  97. value = self.get(key, default)
  98. try:
  99. return float(value)
  100. except (ValueError, TypeError):
  101. return default
  102. def get_bool(self, key: str, default: bool = False) -> bool:
  103. """获取布尔类型的配置项"""
  104. value = self.get(key, default)
  105. if isinstance(value, bool):
  106. return value
  107. if isinstance(value, str):
  108. return value.lower() in ['true', 'yes', '1', 'on']
  109. return bool(value)
  110. def get_str(self, key: str, default: str = "") -> str:
  111. """获取字符串类型的配置项"""
  112. value = self.get(key, default)
  113. return str(value) if value is not None else default
  114. # 创建全局配置加载器实例
  115. _config_loader: Optional[ConfigLoader] = None
  116. def get_config_loader(config_file: Optional[str] = None) -> ConfigLoader:
  117. """
  118. 获取全局配置加载器实例
  119. Args:
  120. config_file: 配置文件路径(可选)
  121. Returns:
  122. ConfigLoader 实例
  123. """
  124. global _config_loader
  125. if _config_loader is None:
  126. _config_loader = ConfigLoader(config_file)
  127. return _config_loader
  128. def reload_config(config_file: Optional[str] = None):
  129. """
  130. 重新加载配置文件
  131. Args:
  132. config_file: 配置文件路径(可选)
  133. """
  134. global _config_loader
  135. _config_loader = ConfigLoader(config_file)