test_pdf_to_markdown.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 测试 /pdf_to_markdown 接口
  5. 用法:
  6. python3 test_pdf_to_markdown.py # 测试所有文件
  7. python3 test_pdf_to_markdown.py --file 某个文件.pdf # 测试单个文件
  8. python3 test_pdf_to_markdown.py --backend paddle # 指定后端
  9. python3 test_pdf_to_markdown.py --remove-watermark # 开启去水印
  10. """
  11. import os
  12. import sys
  13. import json
  14. import time
  15. import argparse
  16. import requests
  17. from pathlib import Path
  18. # API 配置
  19. API_BASE_URL = os.getenv("PDF_CONVERTER_API_URL", "http://localhost:4214")
  20. # 测试文件目录
  21. TEST_DIR = Path("/root/test/test")
  22. # 轮询配置
  23. POLL_INTERVAL = 3 # 轮询间隔(秒)
  24. POLL_TIMEOUT = 600 # 最大等待时间(秒)
  25. # 结果保存目录
  26. RESULT_DIR = Path("/root/test/results")
  27. def upload_and_convert(
  28. file_path: Path,
  29. backend: str = "mineru",
  30. remove_watermark: bool = False,
  31. crop_header_footer: bool = False,
  32. return_images: bool = False,
  33. ) -> dict:
  34. """上传文件到 /pdf_to_markdown 并返回响应"""
  35. url = f"{API_BASE_URL}/pdf_to_markdown"
  36. with open(file_path, "rb") as f:
  37. files = {"file": (file_path.name, f, "application/pdf")}
  38. data = {
  39. "backend": backend,
  40. "remove_watermark": str(remove_watermark).lower(),
  41. "crop_header_footer": str(crop_header_footer).lower(),
  42. "return_images": str(return_images).lower(),
  43. }
  44. print(f" 上传文件: {file_path.name} (backend={backend})")
  45. resp = requests.post(url, files=files, data=data, timeout=60)
  46. resp.raise_for_status()
  47. return resp.json()
  48. def poll_task(task_id: str) -> dict:
  49. """轮询任务状态直到完成或超时"""
  50. url = f"{API_BASE_URL}/task/{task_id}"
  51. start = time.time()
  52. while True:
  53. elapsed = time.time() - start
  54. if elapsed > POLL_TIMEOUT:
  55. print(f" ⏰ 超时({POLL_TIMEOUT}s),任务未完成")
  56. return {"status": "timeout"}
  57. resp = requests.get(url, timeout=10)
  58. result = resp.json()
  59. status = result.get("status", "unknown")
  60. progress = result.get("progress", 0)
  61. message = result.get("message", "")
  62. print(f" [{int(elapsed):>3d}s] 状态: {status} 进度: {progress:.0%} {message}", end="\r")
  63. if status in ("completed", "failed", "error"):
  64. print() # 换行
  65. return result
  66. time.sleep(POLL_INTERVAL)
  67. def download_markdown(task_id: str, save_path: Path) -> bool:
  68. """下载 Markdown 结果"""
  69. url = f"{API_BASE_URL}/download/{task_id}/markdown"
  70. try:
  71. resp = requests.get(url, timeout=30)
  72. if resp.status_code == 200:
  73. save_path.parent.mkdir(parents=True, exist_ok=True)
  74. save_path.write_bytes(resp.content)
  75. print(f" 📄 Markdown 已保存: {save_path} ({len(resp.content)} bytes)")
  76. return True
  77. else:
  78. print(f" ❌ 下载 Markdown 失败: HTTP {resp.status_code}")
  79. return False
  80. except Exception as e:
  81. print(f" ❌ 下载 Markdown 异常: {e}")
  82. return False
  83. def download_json(task_id: str, save_path: Path) -> bool:
  84. """下载 JSON 结果"""
  85. url = f"{API_BASE_URL}/task/{task_id}/json"
  86. try:
  87. resp = requests.get(url, timeout=30)
  88. if resp.status_code == 200:
  89. save_path.parent.mkdir(parents=True, exist_ok=True)
  90. save_path.write_text(json.dumps(resp.json(), ensure_ascii=False, indent=2), encoding="utf-8")
  91. print(f" 📋 JSON 已保存: {save_path}")
  92. return True
  93. else:
  94. print(f" ❌ 下载 JSON 失败: HTTP {resp.status_code}")
  95. return False
  96. except Exception as e:
  97. print(f" ❌ 下载 JSON 异常: {e}")
  98. return False
  99. def download_zip(task_id: str, save_path: Path) -> bool:
  100. """下载 ZIP 压缩包"""
  101. url = f"{API_BASE_URL}/download/{task_id}/zip"
  102. try:
  103. resp = requests.get(url, timeout=60)
  104. if resp.status_code == 200:
  105. save_path.parent.mkdir(parents=True, exist_ok=True)
  106. save_path.write_bytes(resp.content)
  107. print(f" 📦 ZIP 已保存: {save_path} ({len(resp.content)} bytes)")
  108. return True
  109. else:
  110. print(f" ⚠️ ZIP 不可用: HTTP {resp.status_code}")
  111. return False
  112. except Exception as e:
  113. print(f" ⚠️ 下载 ZIP 异常: {e}")
  114. return False
  115. def test_one_file(
  116. file_path: Path,
  117. backend: str = "mineru",
  118. remove_watermark: bool = False,
  119. crop_header_footer: bool = False,
  120. return_images: bool = False,
  121. ):
  122. """测试单个文件的完整流程"""
  123. print(f"\n{'='*60}")
  124. print(f" 测试: {file_path.name}")
  125. print(f"{'='*60}")
  126. if not file_path.exists():
  127. print(f" ❌ 文件不存在: {file_path}")
  128. return False
  129. # 1. 上传
  130. try:
  131. result = upload_and_convert(
  132. file_path,
  133. backend=backend,
  134. remove_watermark=remove_watermark,
  135. crop_header_footer=crop_header_footer,
  136. return_images=return_images,
  137. )
  138. except Exception as e:
  139. print(f" ❌ 上传失败: {e}")
  140. return False
  141. task_id = result.get("task_id")
  142. if not task_id:
  143. print(f" ❌ 未返回 task_id: {result}")
  144. return False
  145. print(f" 任务ID: {task_id}")
  146. # 2. 轮询
  147. task_result = poll_task(task_id)
  148. status = task_result.get("status")
  149. if status != "completed":
  150. print(f" ❌ 任务失败: {task_result.get('error') or task_result.get('message')}")
  151. return False
  152. print(f" ✅ 任务完成")
  153. # 3. 下载结果
  154. stem = file_path.stem
  155. result_subdir = RESULT_DIR / stem
  156. download_markdown(task_id, result_subdir / f"{stem}.md")
  157. download_json(task_id, result_subdir / f"{stem}.json")
  158. if return_images:
  159. download_zip(task_id, result_subdir / f"{stem}.zip")
  160. return True
  161. def check_health():
  162. """检查 API 健康状态"""
  163. try:
  164. resp = requests.get(f"{API_BASE_URL}/health", timeout=10)
  165. if resp.status_code == 200:
  166. print(f"✅ API 正常: {resp.json()}")
  167. return True
  168. print(f"❌ API 异常: HTTP {resp.status_code}")
  169. return False
  170. except Exception as e:
  171. print(f"❌ API 连接失败: {e}")
  172. return False
  173. def main():
  174. parser = argparse.ArgumentParser(description="测试 /pdf_to_markdown 接口")
  175. parser.add_argument("--file", type=str, help="指定单个测试文件名(在 TEST_DIR 下)")
  176. parser.add_argument("--backend", type=str, default="mineru", choices=["mineru", "paddle"], help="识别后端")
  177. parser.add_argument("--remove-watermark", action="store_true", help="开启去水印")
  178. parser.add_argument("--crop-header-footer", action="store_true", help="裁剪页眉页脚")
  179. parser.add_argument("--return-images", action="store_true", help="同时返回图片(ZIP)")
  180. parser.add_argument("--api-url", type=str, help="API 地址,覆盖默认值")
  181. args = parser.parse_args()
  182. global API_BASE_URL
  183. if args.api_url:
  184. API_BASE_URL = args.api_url
  185. print(f"API 地址: {API_BASE_URL}")
  186. print(f"测试目录: {TEST_DIR}")
  187. print(f"结果目录: {RESULT_DIR}")
  188. if not check_health():
  189. sys.exit(1)
  190. RESULT_DIR.mkdir(parents=True, exist_ok=True)
  191. if args.file:
  192. # 测试单个文件
  193. file_path = TEST_DIR / args.file
  194. success = test_one_file(
  195. file_path,
  196. backend=args.backend,
  197. remove_watermark=args.remove_watermark,
  198. crop_header_footer=args.crop_header_footer,
  199. return_images=args.return_images,
  200. )
  201. sys.exit(0 if success else 1)
  202. else:
  203. # 测试目录下所有 PDF
  204. pdf_files = sorted(TEST_DIR.glob("*.pdf"))
  205. if not pdf_files:
  206. print(f"❌ 测试目录下没有 PDF 文件: {TEST_DIR}")
  207. sys.exit(1)
  208. print(f"\n共找到 {len(pdf_files)} 个 PDF 文件")
  209. results = {}
  210. for pdf_file in pdf_files:
  211. ok = test_one_file(
  212. pdf_file,
  213. backend=args.backend,
  214. remove_watermark=args.remove_watermark,
  215. crop_header_footer=args.crop_header_footer,
  216. return_images=args.return_images,
  217. )
  218. results[pdf_file.name] = ok
  219. # 汇总
  220. print(f"\n{'='*60}")
  221. print(f" 测试汇总")
  222. print(f"{'='*60}")
  223. for name, ok in results.items():
  224. status = "✅" if ok else "❌"
  225. print(f" {status} {name}")
  226. total = len(results)
  227. passed = sum(1 for v in results.values() if v)
  228. print(f"\n 通过: {passed}/{total}")
  229. if __name__ == "__main__":
  230. main()