| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- 测试 /pdf_to_markdown 接口
- 用法:
- python3 test_pdf_to_markdown.py # 测试所有文件
- python3 test_pdf_to_markdown.py --file 某个文件.pdf # 测试单个文件
- python3 test_pdf_to_markdown.py --backend paddle # 指定后端
- python3 test_pdf_to_markdown.py --remove-watermark # 开启去水印
- """
- import os
- import sys
- import json
- import time
- import argparse
- import requests
- from pathlib import Path
- # API 配置
- API_BASE_URL = os.getenv("PDF_CONVERTER_API_URL", "http://localhost:4214")
- # 测试文件目录
- TEST_DIR = Path("/root/test/test")
- # 轮询配置
- POLL_INTERVAL = 3 # 轮询间隔(秒)
- POLL_TIMEOUT = 600 # 最大等待时间(秒)
- # 结果保存目录
- RESULT_DIR = Path("/root/test/results")
- def upload_and_convert(
- file_path: Path,
- backend: str = "mineru",
- remove_watermark: bool = False,
- crop_header_footer: bool = False,
- return_images: bool = False,
- ) -> dict:
- """上传文件到 /pdf_to_markdown 并返回响应"""
- url = f"{API_BASE_URL}/pdf_to_markdown"
- with open(file_path, "rb") as f:
- files = {"file": (file_path.name, f, "application/pdf")}
- data = {
- "backend": backend,
- "remove_watermark": str(remove_watermark).lower(),
- "crop_header_footer": str(crop_header_footer).lower(),
- "return_images": str(return_images).lower(),
- }
- print(f" 上传文件: {file_path.name} (backend={backend})")
- resp = requests.post(url, files=files, data=data, timeout=60)
- resp.raise_for_status()
- return resp.json()
- def poll_task(task_id: str) -> dict:
- """轮询任务状态直到完成或超时"""
- url = f"{API_BASE_URL}/task/{task_id}"
- start = time.time()
- while True:
- elapsed = time.time() - start
- if elapsed > POLL_TIMEOUT:
- print(f" ⏰ 超时({POLL_TIMEOUT}s),任务未完成")
- return {"status": "timeout"}
- resp = requests.get(url, timeout=10)
- result = resp.json()
- status = result.get("status", "unknown")
- progress = result.get("progress", 0)
- message = result.get("message", "")
- print(f" [{int(elapsed):>3d}s] 状态: {status} 进度: {progress:.0%} {message}", end="\r")
- if status in ("completed", "failed", "error"):
- print() # 换行
- return result
- time.sleep(POLL_INTERVAL)
- def download_markdown(task_id: str, save_path: Path) -> bool:
- """下载 Markdown 结果"""
- url = f"{API_BASE_URL}/download/{task_id}/markdown"
- try:
- resp = requests.get(url, timeout=30)
- if resp.status_code == 200:
- save_path.parent.mkdir(parents=True, exist_ok=True)
- save_path.write_bytes(resp.content)
- print(f" 📄 Markdown 已保存: {save_path} ({len(resp.content)} bytes)")
- return True
- else:
- print(f" ❌ 下载 Markdown 失败: HTTP {resp.status_code}")
- return False
- except Exception as e:
- print(f" ❌ 下载 Markdown 异常: {e}")
- return False
- def download_json(task_id: str, save_path: Path) -> bool:
- """下载 JSON 结果"""
- url = f"{API_BASE_URL}/task/{task_id}/json"
- try:
- resp = requests.get(url, timeout=30)
- if resp.status_code == 200:
- save_path.parent.mkdir(parents=True, exist_ok=True)
- save_path.write_text(json.dumps(resp.json(), ensure_ascii=False, indent=2), encoding="utf-8")
- print(f" 📋 JSON 已保存: {save_path}")
- return True
- else:
- print(f" ❌ 下载 JSON 失败: HTTP {resp.status_code}")
- return False
- except Exception as e:
- print(f" ❌ 下载 JSON 异常: {e}")
- return False
- def download_zip(task_id: str, save_path: Path) -> bool:
- """下载 ZIP 压缩包"""
- url = f"{API_BASE_URL}/download/{task_id}/zip"
- try:
- resp = requests.get(url, timeout=60)
- if resp.status_code == 200:
- save_path.parent.mkdir(parents=True, exist_ok=True)
- save_path.write_bytes(resp.content)
- print(f" 📦 ZIP 已保存: {save_path} ({len(resp.content)} bytes)")
- return True
- else:
- print(f" ⚠️ ZIP 不可用: HTTP {resp.status_code}")
- return False
- except Exception as e:
- print(f" ⚠️ 下载 ZIP 异常: {e}")
- return False
- def test_one_file(
- file_path: Path,
- backend: str = "mineru",
- remove_watermark: bool = False,
- crop_header_footer: bool = False,
- return_images: bool = False,
- ):
- """测试单个文件的完整流程"""
- print(f"\n{'='*60}")
- print(f" 测试: {file_path.name}")
- print(f"{'='*60}")
- if not file_path.exists():
- print(f" ❌ 文件不存在: {file_path}")
- return False
- # 1. 上传
- try:
- result = upload_and_convert(
- file_path,
- backend=backend,
- remove_watermark=remove_watermark,
- crop_header_footer=crop_header_footer,
- return_images=return_images,
- )
- except Exception as e:
- print(f" ❌ 上传失败: {e}")
- return False
- task_id = result.get("task_id")
- if not task_id:
- print(f" ❌ 未返回 task_id: {result}")
- return False
- print(f" 任务ID: {task_id}")
- # 2. 轮询
- task_result = poll_task(task_id)
- status = task_result.get("status")
- if status != "completed":
- print(f" ❌ 任务失败: {task_result.get('error') or task_result.get('message')}")
- return False
- print(f" ✅ 任务完成")
- # 3. 下载结果
- stem = file_path.stem
- result_subdir = RESULT_DIR / stem
- download_markdown(task_id, result_subdir / f"{stem}.md")
- download_json(task_id, result_subdir / f"{stem}.json")
- if return_images:
- download_zip(task_id, result_subdir / f"{stem}.zip")
- return True
- def check_health():
- """检查 API 健康状态"""
- try:
- resp = requests.get(f"{API_BASE_URL}/health", timeout=10)
- if resp.status_code == 200:
- print(f"✅ API 正常: {resp.json()}")
- return True
- print(f"❌ API 异常: HTTP {resp.status_code}")
- return False
- except Exception as e:
- print(f"❌ API 连接失败: {e}")
- return False
- def main():
- parser = argparse.ArgumentParser(description="测试 /pdf_to_markdown 接口")
- parser.add_argument("--file", type=str, help="指定单个测试文件名(在 TEST_DIR 下)")
- parser.add_argument("--backend", type=str, default="mineru", choices=["mineru", "paddle"], help="识别后端")
- parser.add_argument("--remove-watermark", action="store_true", help="开启去水印")
- parser.add_argument("--crop-header-footer", action="store_true", help="裁剪页眉页脚")
- parser.add_argument("--return-images", action="store_true", help="同时返回图片(ZIP)")
- parser.add_argument("--api-url", type=str, help="API 地址,覆盖默认值")
- args = parser.parse_args()
- global API_BASE_URL
- if args.api_url:
- API_BASE_URL = args.api_url
- print(f"API 地址: {API_BASE_URL}")
- print(f"测试目录: {TEST_DIR}")
- print(f"结果目录: {RESULT_DIR}")
- if not check_health():
- sys.exit(1)
- RESULT_DIR.mkdir(parents=True, exist_ok=True)
- if args.file:
- # 测试单个文件
- file_path = TEST_DIR / args.file
- success = test_one_file(
- file_path,
- backend=args.backend,
- remove_watermark=args.remove_watermark,
- crop_header_footer=args.crop_header_footer,
- return_images=args.return_images,
- )
- sys.exit(0 if success else 1)
- else:
- # 测试目录下所有 PDF
- pdf_files = sorted(TEST_DIR.glob("*.pdf"))
- if not pdf_files:
- print(f"❌ 测试目录下没有 PDF 文件: {TEST_DIR}")
- sys.exit(1)
- print(f"\n共找到 {len(pdf_files)} 个 PDF 文件")
- results = {}
- for pdf_file in pdf_files:
- ok = test_one_file(
- pdf_file,
- backend=args.backend,
- remove_watermark=args.remove_watermark,
- crop_header_footer=args.crop_header_footer,
- return_images=args.return_images,
- )
- results[pdf_file.name] = ok
- # 汇总
- print(f"\n{'='*60}")
- print(f" 测试汇总")
- print(f"{'='*60}")
- for name, ok in results.items():
- status = "✅" if ok else "❌"
- print(f" {status} {name}")
- total = len(results)
- passed = sum(1 for v in results.values() if v)
- print(f"\n 通过: {passed}/{total}")
- if __name__ == "__main__":
- main()
|