|
|
@@ -0,0 +1,276 @@
|
|
|
+#!/usr/bin/env python3
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+"""
|
|
|
+测试 /pdf_to_markdown 接口
|
|
|
+
|
|
|
+用法:
|
|
|
+ python3 test_pdf_to_markdown.py # 测试所有文件
|
|
|
+ python3 test_pdf_to_markdown.py --file 某个文件.pdf # 测试单个文件
|
|
|
+ python3 test_pdf_to_markdown.py --backend paddle # 指定后端
|
|
|
+ python3 test_pdf_to_markdown.py --remove-watermark # 开启去水印
|
|
|
+"""
|
|
|
+
|
|
|
+import os
|
|
|
+import sys
|
|
|
+import json
|
|
|
+import time
|
|
|
+import argparse
|
|
|
+import requests
|
|
|
+from pathlib import Path
|
|
|
+
|
|
|
+# API 配置
|
|
|
+API_BASE_URL = os.getenv("PDF_CONVERTER_API_URL", "http://localhost:4214")
|
|
|
+
|
|
|
+# 测试文件目录
|
|
|
+TEST_DIR = Path("/root/test/test")
|
|
|
+
|
|
|
+# 轮询配置
|
|
|
+POLL_INTERVAL = 3 # 轮询间隔(秒)
|
|
|
+POLL_TIMEOUT = 600 # 最大等待时间(秒)
|
|
|
+
|
|
|
+# 结果保存目录
|
|
|
+RESULT_DIR = Path("/root/test/results")
|
|
|
+
|
|
|
+
|
|
|
+def upload_and_convert(
|
|
|
+ file_path: Path,
|
|
|
+ backend: str = "mineru",
|
|
|
+ remove_watermark: bool = False,
|
|
|
+ crop_header_footer: bool = False,
|
|
|
+ return_images: bool = False,
|
|
|
+) -> dict:
|
|
|
+ """上传文件到 /pdf_to_markdown 并返回响应"""
|
|
|
+ url = f"{API_BASE_URL}/pdf_to_markdown"
|
|
|
+
|
|
|
+ with open(file_path, "rb") as f:
|
|
|
+ files = {"file": (file_path.name, f, "application/pdf")}
|
|
|
+ data = {
|
|
|
+ "backend": backend,
|
|
|
+ "remove_watermark": str(remove_watermark).lower(),
|
|
|
+ "crop_header_footer": str(crop_header_footer).lower(),
|
|
|
+ "return_images": str(return_images).lower(),
|
|
|
+ }
|
|
|
+ print(f" 上传文件: {file_path.name} (backend={backend})")
|
|
|
+ resp = requests.post(url, files=files, data=data, timeout=60)
|
|
|
+
|
|
|
+ resp.raise_for_status()
|
|
|
+ return resp.json()
|
|
|
+
|
|
|
+
|
|
|
+def poll_task(task_id: str) -> dict:
|
|
|
+ """轮询任务状态直到完成或超时"""
|
|
|
+ url = f"{API_BASE_URL}/task/{task_id}"
|
|
|
+ start = time.time()
|
|
|
+
|
|
|
+ while True:
|
|
|
+ elapsed = time.time() - start
|
|
|
+ if elapsed > POLL_TIMEOUT:
|
|
|
+ print(f" ⏰ 超时({POLL_TIMEOUT}s),任务未完成")
|
|
|
+ return {"status": "timeout"}
|
|
|
+
|
|
|
+ resp = requests.get(url, timeout=10)
|
|
|
+ result = resp.json()
|
|
|
+ status = result.get("status", "unknown")
|
|
|
+ progress = result.get("progress", 0)
|
|
|
+ message = result.get("message", "")
|
|
|
+
|
|
|
+ print(f" [{int(elapsed):>3d}s] 状态: {status} 进度: {progress:.0%} {message}", end="\r")
|
|
|
+
|
|
|
+ if status in ("completed", "failed", "error"):
|
|
|
+ print() # 换行
|
|
|
+ return result
|
|
|
+
|
|
|
+ time.sleep(POLL_INTERVAL)
|
|
|
+
|
|
|
+
|
|
|
+def download_markdown(task_id: str, save_path: Path) -> bool:
|
|
|
+ """下载 Markdown 结果"""
|
|
|
+ url = f"{API_BASE_URL}/download/{task_id}/markdown"
|
|
|
+ try:
|
|
|
+ resp = requests.get(url, timeout=30)
|
|
|
+ if resp.status_code == 200:
|
|
|
+ save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
+ save_path.write_bytes(resp.content)
|
|
|
+ print(f" 📄 Markdown 已保存: {save_path} ({len(resp.content)} bytes)")
|
|
|
+ return True
|
|
|
+ else:
|
|
|
+ print(f" ❌ 下载 Markdown 失败: HTTP {resp.status_code}")
|
|
|
+ return False
|
|
|
+ except Exception as e:
|
|
|
+ print(f" ❌ 下载 Markdown 异常: {e}")
|
|
|
+ return False
|
|
|
+
|
|
|
+
|
|
|
+def download_json(task_id: str, save_path: Path) -> bool:
|
|
|
+ """下载 JSON 结果"""
|
|
|
+ url = f"{API_BASE_URL}/task/{task_id}/json"
|
|
|
+ try:
|
|
|
+ resp = requests.get(url, timeout=30)
|
|
|
+ if resp.status_code == 200:
|
|
|
+ save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
+ save_path.write_text(json.dumps(resp.json(), ensure_ascii=False, indent=2), encoding="utf-8")
|
|
|
+ print(f" 📋 JSON 已保存: {save_path}")
|
|
|
+ return True
|
|
|
+ else:
|
|
|
+ print(f" ❌ 下载 JSON 失败: HTTP {resp.status_code}")
|
|
|
+ return False
|
|
|
+ except Exception as e:
|
|
|
+ print(f" ❌ 下载 JSON 异常: {e}")
|
|
|
+ return False
|
|
|
+
|
|
|
+
|
|
|
+def download_zip(task_id: str, save_path: Path) -> bool:
|
|
|
+ """下载 ZIP 压缩包"""
|
|
|
+ url = f"{API_BASE_URL}/download/{task_id}/zip"
|
|
|
+ try:
|
|
|
+ resp = requests.get(url, timeout=60)
|
|
|
+ if resp.status_code == 200:
|
|
|
+ save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
+ save_path.write_bytes(resp.content)
|
|
|
+ print(f" 📦 ZIP 已保存: {save_path} ({len(resp.content)} bytes)")
|
|
|
+ return True
|
|
|
+ else:
|
|
|
+ print(f" ⚠️ ZIP 不可用: HTTP {resp.status_code}")
|
|
|
+ return False
|
|
|
+ except Exception as e:
|
|
|
+ print(f" ⚠️ 下载 ZIP 异常: {e}")
|
|
|
+ return False
|
|
|
+
|
|
|
+
|
|
|
+def test_one_file(
|
|
|
+ file_path: Path,
|
|
|
+ backend: str = "mineru",
|
|
|
+ remove_watermark: bool = False,
|
|
|
+ crop_header_footer: bool = False,
|
|
|
+ return_images: bool = False,
|
|
|
+):
|
|
|
+ """测试单个文件的完整流程"""
|
|
|
+ print(f"\n{'='*60}")
|
|
|
+ print(f" 测试: {file_path.name}")
|
|
|
+ print(f"{'='*60}")
|
|
|
+
|
|
|
+ if not file_path.exists():
|
|
|
+ print(f" ❌ 文件不存在: {file_path}")
|
|
|
+ return False
|
|
|
+
|
|
|
+ # 1. 上传
|
|
|
+ try:
|
|
|
+ result = upload_and_convert(
|
|
|
+ file_path,
|
|
|
+ backend=backend,
|
|
|
+ remove_watermark=remove_watermark,
|
|
|
+ crop_header_footer=crop_header_footer,
|
|
|
+ return_images=return_images,
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ print(f" ❌ 上传失败: {e}")
|
|
|
+ return False
|
|
|
+
|
|
|
+ task_id = result.get("task_id")
|
|
|
+ if not task_id:
|
|
|
+ print(f" ❌ 未返回 task_id: {result}")
|
|
|
+ return False
|
|
|
+ print(f" 任务ID: {task_id}")
|
|
|
+
|
|
|
+ # 2. 轮询
|
|
|
+ task_result = poll_task(task_id)
|
|
|
+ status = task_result.get("status")
|
|
|
+
|
|
|
+ if status != "completed":
|
|
|
+ print(f" ❌ 任务失败: {task_result.get('error') or task_result.get('message')}")
|
|
|
+ return False
|
|
|
+
|
|
|
+ print(f" ✅ 任务完成")
|
|
|
+
|
|
|
+ # 3. 下载结果
|
|
|
+ stem = file_path.stem
|
|
|
+ result_subdir = RESULT_DIR / stem
|
|
|
+ download_markdown(task_id, result_subdir / f"{stem}.md")
|
|
|
+ download_json(task_id, result_subdir / f"{stem}.json")
|
|
|
+ if return_images:
|
|
|
+ download_zip(task_id, result_subdir / f"{stem}.zip")
|
|
|
+
|
|
|
+ return True
|
|
|
+
|
|
|
+
|
|
|
+def check_health():
|
|
|
+ """检查 API 健康状态"""
|
|
|
+ try:
|
|
|
+ resp = requests.get(f"{API_BASE_URL}/health", timeout=10)
|
|
|
+ if resp.status_code == 200:
|
|
|
+ print(f"✅ API 正常: {resp.json()}")
|
|
|
+ return True
|
|
|
+ print(f"❌ API 异常: HTTP {resp.status_code}")
|
|
|
+ return False
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ API 连接失败: {e}")
|
|
|
+ return False
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ parser = argparse.ArgumentParser(description="测试 /pdf_to_markdown 接口")
|
|
|
+ parser.add_argument("--file", type=str, help="指定单个测试文件名(在 TEST_DIR 下)")
|
|
|
+ parser.add_argument("--backend", type=str, default="mineru", choices=["mineru", "paddle"], help="识别后端")
|
|
|
+ parser.add_argument("--remove-watermark", action="store_true", help="开启去水印")
|
|
|
+ parser.add_argument("--crop-header-footer", action="store_true", help="裁剪页眉页脚")
|
|
|
+ parser.add_argument("--return-images", action="store_true", help="同时返回图片(ZIP)")
|
|
|
+ parser.add_argument("--api-url", type=str, help="API 地址,覆盖默认值")
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ global API_BASE_URL
|
|
|
+ if args.api_url:
|
|
|
+ API_BASE_URL = args.api_url
|
|
|
+
|
|
|
+ print(f"API 地址: {API_BASE_URL}")
|
|
|
+ print(f"测试目录: {TEST_DIR}")
|
|
|
+ print(f"结果目录: {RESULT_DIR}")
|
|
|
+
|
|
|
+ if not check_health():
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+ RESULT_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ if args.file:
|
|
|
+ # 测试单个文件
|
|
|
+ file_path = TEST_DIR / args.file
|
|
|
+ success = test_one_file(
|
|
|
+ file_path,
|
|
|
+ backend=args.backend,
|
|
|
+ remove_watermark=args.remove_watermark,
|
|
|
+ crop_header_footer=args.crop_header_footer,
|
|
|
+ return_images=args.return_images,
|
|
|
+ )
|
|
|
+ sys.exit(0 if success else 1)
|
|
|
+ else:
|
|
|
+ # 测试目录下所有 PDF
|
|
|
+ pdf_files = sorted(TEST_DIR.glob("*.pdf"))
|
|
|
+ if not pdf_files:
|
|
|
+ print(f"❌ 测试目录下没有 PDF 文件: {TEST_DIR}")
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+ print(f"\n共找到 {len(pdf_files)} 个 PDF 文件")
|
|
|
+ results = {}
|
|
|
+ for pdf_file in pdf_files:
|
|
|
+ ok = test_one_file(
|
|
|
+ pdf_file,
|
|
|
+ backend=args.backend,
|
|
|
+ remove_watermark=args.remove_watermark,
|
|
|
+ crop_header_footer=args.crop_header_footer,
|
|
|
+ return_images=args.return_images,
|
|
|
+ )
|
|
|
+ results[pdf_file.name] = ok
|
|
|
+
|
|
|
+ # 汇总
|
|
|
+ print(f"\n{'='*60}")
|
|
|
+ print(f" 测试汇总")
|
|
|
+ print(f"{'='*60}")
|
|
|
+ for name, ok in results.items():
|
|
|
+ status = "✅" if ok else "❌"
|
|
|
+ print(f" {status} {name}")
|
|
|
+ total = len(results)
|
|
|
+ passed = sum(1 for v in results.values() if v)
|
|
|
+ print(f"\n 通过: {passed}/{total}")
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|