Преглед на файлове

feat: 添加 /pdf_to_markdown 测试脚本;config.yaml 配置 vlm-http-client backend

何文松 преди 1 седмица
родител
ревизия
dcc86d8002
променени са 1 файла, в които са добавени 276 реда и са изтрити 0 реда
  1. 276 0
      pdf_converter_v2/test_pdf_to_markdown.py

+ 276 - 0
pdf_converter_v2/test_pdf_to_markdown.py

@@ -0,0 +1,276 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+测试 /pdf_to_markdown 接口
+
+用法:
+  python3 test_pdf_to_markdown.py                          # 测试所有文件
+  python3 test_pdf_to_markdown.py --file 某个文件.pdf       # 测试单个文件
+  python3 test_pdf_to_markdown.py --backend paddle          # 指定后端
+  python3 test_pdf_to_markdown.py --remove-watermark        # 开启去水印
+"""
+
+import os
+import sys
+import json
+import time
+import argparse
+import requests
+from pathlib import Path
+
+# API 配置
+API_BASE_URL = os.getenv("PDF_CONVERTER_API_URL", "http://localhost:4214")
+
+# 测试文件目录
+TEST_DIR = Path("/root/test/test")
+
+# 轮询配置
+POLL_INTERVAL = 3       # 轮询间隔(秒)
+POLL_TIMEOUT = 600      # 最大等待时间(秒)
+
+# 结果保存目录
+RESULT_DIR = Path("/root/test/results")
+
+
+def upload_and_convert(
+    file_path: Path,
+    backend: str = "mineru",
+    remove_watermark: bool = False,
+    crop_header_footer: bool = False,
+    return_images: bool = False,
+) -> dict:
+    """上传文件到 /pdf_to_markdown 并返回响应"""
+    url = f"{API_BASE_URL}/pdf_to_markdown"
+
+    with open(file_path, "rb") as f:
+        files = {"file": (file_path.name, f, "application/pdf")}
+        data = {
+            "backend": backend,
+            "remove_watermark": str(remove_watermark).lower(),
+            "crop_header_footer": str(crop_header_footer).lower(),
+            "return_images": str(return_images).lower(),
+        }
+        print(f"  上传文件: {file_path.name} (backend={backend})")
+        resp = requests.post(url, files=files, data=data, timeout=60)
+
+    resp.raise_for_status()
+    return resp.json()
+
+
+def poll_task(task_id: str) -> dict:
+    """轮询任务状态直到完成或超时"""
+    url = f"{API_BASE_URL}/task/{task_id}"
+    start = time.time()
+
+    while True:
+        elapsed = time.time() - start
+        if elapsed > POLL_TIMEOUT:
+            print(f"  ⏰ 超时({POLL_TIMEOUT}s),任务未完成")
+            return {"status": "timeout"}
+
+        resp = requests.get(url, timeout=10)
+        result = resp.json()
+        status = result.get("status", "unknown")
+        progress = result.get("progress", 0)
+        message = result.get("message", "")
+
+        print(f"  [{int(elapsed):>3d}s] 状态: {status}  进度: {progress:.0%}  {message}", end="\r")
+
+        if status in ("completed", "failed", "error"):
+            print()  # 换行
+            return result
+
+        time.sleep(POLL_INTERVAL)
+
+
+def download_markdown(task_id: str, save_path: Path) -> bool:
+    """下载 Markdown 结果"""
+    url = f"{API_BASE_URL}/download/{task_id}/markdown"
+    try:
+        resp = requests.get(url, timeout=30)
+        if resp.status_code == 200:
+            save_path.parent.mkdir(parents=True, exist_ok=True)
+            save_path.write_bytes(resp.content)
+            print(f"  📄 Markdown 已保存: {save_path} ({len(resp.content)} bytes)")
+            return True
+        else:
+            print(f"  ❌ 下载 Markdown 失败: HTTP {resp.status_code}")
+            return False
+    except Exception as e:
+        print(f"  ❌ 下载 Markdown 异常: {e}")
+        return False
+
+
+def download_json(task_id: str, save_path: Path) -> bool:
+    """下载 JSON 结果"""
+    url = f"{API_BASE_URL}/task/{task_id}/json"
+    try:
+        resp = requests.get(url, timeout=30)
+        if resp.status_code == 200:
+            save_path.parent.mkdir(parents=True, exist_ok=True)
+            save_path.write_text(json.dumps(resp.json(), ensure_ascii=False, indent=2), encoding="utf-8")
+            print(f"  📋 JSON 已保存: {save_path}")
+            return True
+        else:
+            print(f"  ❌ 下载 JSON 失败: HTTP {resp.status_code}")
+            return False
+    except Exception as e:
+        print(f"  ❌ 下载 JSON 异常: {e}")
+        return False
+
+
+def download_zip(task_id: str, save_path: Path) -> bool:
+    """下载 ZIP 压缩包"""
+    url = f"{API_BASE_URL}/download/{task_id}/zip"
+    try:
+        resp = requests.get(url, timeout=60)
+        if resp.status_code == 200:
+            save_path.parent.mkdir(parents=True, exist_ok=True)
+            save_path.write_bytes(resp.content)
+            print(f"  📦 ZIP 已保存: {save_path} ({len(resp.content)} bytes)")
+            return True
+        else:
+            print(f"  ⚠️  ZIP 不可用: HTTP {resp.status_code}")
+            return False
+    except Exception as e:
+        print(f"  ⚠️  下载 ZIP 异常: {e}")
+        return False
+
+
+def test_one_file(
+    file_path: Path,
+    backend: str = "mineru",
+    remove_watermark: bool = False,
+    crop_header_footer: bool = False,
+    return_images: bool = False,
+):
+    """测试单个文件的完整流程"""
+    print(f"\n{'='*60}")
+    print(f" 测试: {file_path.name}")
+    print(f"{'='*60}")
+
+    if not file_path.exists():
+        print(f"  ❌ 文件不存在: {file_path}")
+        return False
+
+    # 1. 上传
+    try:
+        result = upload_and_convert(
+            file_path,
+            backend=backend,
+            remove_watermark=remove_watermark,
+            crop_header_footer=crop_header_footer,
+            return_images=return_images,
+        )
+    except Exception as e:
+        print(f"  ❌ 上传失败: {e}")
+        return False
+
+    task_id = result.get("task_id")
+    if not task_id:
+        print(f"  ❌ 未返回 task_id: {result}")
+        return False
+    print(f"  任务ID: {task_id}")
+
+    # 2. 轮询
+    task_result = poll_task(task_id)
+    status = task_result.get("status")
+
+    if status != "completed":
+        print(f"  ❌ 任务失败: {task_result.get('error') or task_result.get('message')}")
+        return False
+
+    print(f"  ✅ 任务完成")
+
+    # 3. 下载结果
+    stem = file_path.stem
+    result_subdir = RESULT_DIR / stem
+    download_markdown(task_id, result_subdir / f"{stem}.md")
+    download_json(task_id, result_subdir / f"{stem}.json")
+    if return_images:
+        download_zip(task_id, result_subdir / f"{stem}.zip")
+
+    return True
+
+
+def check_health():
+    """检查 API 健康状态"""
+    try:
+        resp = requests.get(f"{API_BASE_URL}/health", timeout=10)
+        if resp.status_code == 200:
+            print(f"✅ API 正常: {resp.json()}")
+            return True
+        print(f"❌ API 异常: HTTP {resp.status_code}")
+        return False
+    except Exception as e:
+        print(f"❌ API 连接失败: {e}")
+        return False
+
+
+def main():
+    parser = argparse.ArgumentParser(description="测试 /pdf_to_markdown 接口")
+    parser.add_argument("--file", type=str, help="指定单个测试文件名(在 TEST_DIR 下)")
+    parser.add_argument("--backend", type=str, default="mineru", choices=["mineru", "paddle"], help="识别后端")
+    parser.add_argument("--remove-watermark", action="store_true", help="开启去水印")
+    parser.add_argument("--crop-header-footer", action="store_true", help="裁剪页眉页脚")
+    parser.add_argument("--return-images", action="store_true", help="同时返回图片(ZIP)")
+    parser.add_argument("--api-url", type=str, help="API 地址,覆盖默认值")
+    args = parser.parse_args()
+
+    global API_BASE_URL
+    if args.api_url:
+        API_BASE_URL = args.api_url
+
+    print(f"API 地址: {API_BASE_URL}")
+    print(f"测试目录: {TEST_DIR}")
+    print(f"结果目录: {RESULT_DIR}")
+
+    if not check_health():
+        sys.exit(1)
+
+    RESULT_DIR.mkdir(parents=True, exist_ok=True)
+
+    if args.file:
+        # 测试单个文件
+        file_path = TEST_DIR / args.file
+        success = test_one_file(
+            file_path,
+            backend=args.backend,
+            remove_watermark=args.remove_watermark,
+            crop_header_footer=args.crop_header_footer,
+            return_images=args.return_images,
+        )
+        sys.exit(0 if success else 1)
+    else:
+        # 测试目录下所有 PDF
+        pdf_files = sorted(TEST_DIR.glob("*.pdf"))
+        if not pdf_files:
+            print(f"❌ 测试目录下没有 PDF 文件: {TEST_DIR}")
+            sys.exit(1)
+
+        print(f"\n共找到 {len(pdf_files)} 个 PDF 文件")
+        results = {}
+        for pdf_file in pdf_files:
+            ok = test_one_file(
+                pdf_file,
+                backend=args.backend,
+                remove_watermark=args.remove_watermark,
+                crop_header_footer=args.crop_header_footer,
+                return_images=args.return_images,
+            )
+            results[pdf_file.name] = ok
+
+        # 汇总
+        print(f"\n{'='*60}")
+        print(f" 测试汇总")
+        print(f"{'='*60}")
+        for name, ok in results.items():
+            status = "✅" if ok else "❌"
+            print(f"  {status} {name}")
+        total = len(results)
+        passed = sum(1 for v in results.values() if v)
+        print(f"\n  通过: {passed}/{total}")
+
+
+if __name__ == "__main__":
+    main()