Explorar el Código

feat: 添加 GPU 显存检测和 MinerU 服务自动控制

- 添加 _get_gpu_memory_total() 检测 GPU 总显存
- 添加 _stop_mineru_service() 和 _start_mineru_service()
- 16GB 显存时自动停止 MinerU 释放显存
- PaddleOCR 处理完成后自动重启 MinerU
- 使用 finally 确保服务一定会重启
何文松 hace 1 día
padre
commit
386d5b0359
Se han modificado 1 ficheros con 79 adiciones y 0 borrados
  1. 79 0
      pdf_converter_v2/utils/paddleocr_fallback.py

+ 79 - 0
pdf_converter_v2/utils/paddleocr_fallback.py

@@ -209,6 +209,65 @@ def _get_paddleocr_pipeline(use_chart_recognition: bool = False, use_layout_dete
     return pipeline
 
 
+def _get_gpu_memory_total() -> Optional[int]:
+    """获取 GPU 总显存(GB)"""
+    try:
+        result = subprocess.run(
+            ["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"],
+            capture_output=True,
+            text=True,
+            timeout=5,
+        )
+        if result.returncode == 0:
+            memory_mb = int(result.stdout.strip().split('\n')[0])
+            return memory_mb // 1024  # 转换为 GB
+    except Exception:
+        pass
+    return None
+
+
+def _stop_mineru_service() -> bool:
+    """停止 MinerU API 服务"""
+    try:
+        logger.info("[PaddleOCR Wrapper] 停止 MinerU API 服务...")
+        result = subprocess.run(
+            ["sudo", "systemctl", "stop", "mineru-api"],
+            capture_output=True,
+            text=True,
+            timeout=30,
+        )
+        if result.returncode == 0:
+            logger.info("[PaddleOCR Wrapper] MinerU API 服务已停止")
+            return True
+        else:
+            logger.error(f"[PaddleOCR Wrapper] 停止 MinerU API 失败: {result.stderr}")
+            return False
+    except Exception as e:
+        logger.error(f"[PaddleOCR Wrapper] 停止 MinerU API 异常: {e}")
+        return False
+
+
+def _start_mineru_service() -> bool:
+    """启动 MinerU API 服务"""
+    try:
+        logger.info("[PaddleOCR Wrapper] 启动 MinerU API 服务...")
+        result = subprocess.run(
+            ["sudo", "systemctl", "start", "mineru-api"],
+            capture_output=True,
+            text=True,
+            timeout=30,
+        )
+        if result.returncode == 0:
+            logger.info("[PaddleOCR Wrapper] MinerU API 服务已启动")
+            return True
+        else:
+            logger.error(f"[PaddleOCR Wrapper] 启动 MinerU API 失败: {result.stderr}")
+            return False
+    except Exception as e:
+        logger.error(f"[PaddleOCR Wrapper] 启动 MinerU API 异常: {e}")
+        return False
+
+
 def _call_paddleocr_api(
     image_path: str,
     save_path: str,
@@ -217,6 +276,9 @@ def _call_paddleocr_api(
 ) -> tuple[bool, Optional[str]]:
     """通过独立脚本调用 PaddleOCR VL(避免显存共享问题)
     
+    如果 GPU 显存为 16GB,会自动停止 MinerU 服务以释放显存,
+    PaddleOCR 处理完成后再重启 MinerU 服务。
+    
     Args:
         image_path: 输入图片路径
         save_path: 输出保存路径(目录)
@@ -226,6 +288,8 @@ def _call_paddleocr_api(
     Returns:
         (是否成功, markdown 文件路径)
     """
+    mineru_stopped = False
+    
     try:
         if not os.path.exists(image_path):
             logger.error(f"[PaddleOCR Wrapper] 图片文件不存在: {image_path}")
@@ -233,6 +297,16 @@ def _call_paddleocr_api(
         
         os.makedirs(save_path, exist_ok=True)
         
+        # 检测 GPU 显存,如果是 16GB 则停止 MinerU 服务
+        gpu_memory = _get_gpu_memory_total()
+        if gpu_memory and gpu_memory <= 16:
+            logger.info(f"[PaddleOCR Wrapper] 检测到 GPU 显存为 {gpu_memory}GB,停止 MinerU 服务以释放显存")
+            if _stop_mineru_service():
+                mineru_stopped = True
+                # 等待服务完全停止
+                import time
+                time.sleep(3)
+        
         # 获取 wrapper 脚本路径
         current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
         wrapper_script = os.path.join(current_dir, "paddleocr_wrapper.py")
@@ -295,6 +369,11 @@ def _call_paddleocr_api(
         import traceback
         logger.error(traceback.format_exc())
         return False, None
+    finally:
+        # 如果停止了 MinerU 服务,需要重新启动
+        if mineru_stopped:
+            logger.info("[PaddleOCR Wrapper] PaddleOCR 处理完成,重启 MinerU 服务")
+            _start_mineru_service()
 
 
 def has_recognition_garbage(text: str, min_repeat: int = 10) -> bool: