|
@@ -164,6 +164,17 @@ def _get_paddle_ocr_device() -> str:
|
|
|
"""获取 PaddleOCR Python API 使用的设备字符串(如 'gpu:0' 或 'cpu')"""
|
|
"""获取 PaddleOCR Python API 使用的设备字符串(如 'gpu:0' 或 'cpu')"""
|
|
|
devices = _get_paddle_ocr_devices()
|
|
devices = _get_paddle_ocr_devices()
|
|
|
if not devices:
|
|
if not devices:
|
|
|
|
|
+ # 如果没有配置设备,根据环境自动选择
|
|
|
|
|
+ from .device_env import is_npu
|
|
|
|
|
+ if is_npu():
|
|
|
|
|
+ return "npu:0"
|
|
|
|
|
+ # NVIDIA GPU 环境,默认使用 gpu:0
|
|
|
|
|
+ try:
|
|
|
|
|
+ import torch
|
|
|
|
|
+ if torch.cuda.is_available():
|
|
|
|
|
+ return "gpu:0"
|
|
|
|
|
+ except ImportError:
|
|
|
|
|
+ pass
|
|
|
return "cpu"
|
|
return "cpu"
|
|
|
global _PADDLE_OCR_DEVICE_INDEX
|
|
global _PADDLE_OCR_DEVICE_INDEX
|
|
|
with _PADDLE_OCR_DEVICE_LOCK:
|
|
with _PADDLE_OCR_DEVICE_LOCK:
|