| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295 |
- """
- NER 路由
- """
- import json
- import asyncio
- import uuid
- from typing import Dict, Any, Optional
- from fastapi import APIRouter, HTTPException, BackgroundTasks
- from fastapi.responses import StreamingResponse
- from loguru import logger
- import time
- from pydantic import BaseModel
- from ..models import NerRequest, NerResponse, EntityInfo
- from ..services.ner_service import ner_service
- router = APIRouter()
- # ============== 任务存储 ==============
- # 存储异步任务状态和结果
- _tasks: Dict[str, Dict[str, Any]] = {}
- class TaskStatus(BaseModel):
- """任务状态响应"""
- task_id: str
- document_id: str
- status: str # pending, processing, completed, failed
- progress: int = 0 # 0-100
- message: str = ""
- result: Optional[Dict] = None
- error: Optional[str] = None
- created_at: float = 0
- updated_at: float = 0
- async def sse_event(event: str, data: dict):
- """生成 SSE 事件格式"""
- return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
- @router.post("/extract", response_model=NerResponse)
- async def extract_entities(request: NerRequest):
- """
- 从文本中提取命名实体
- """
- start_time = time.time()
-
- try:
- logger.info(f"开始提取实体: document_id={request.document_id}, text_length={len(request.text)}")
-
- # 验证文本长度
- if len(request.text) > 50000:
- raise HTTPException(status_code=400, detail="文本长度超过限制(最大50000字符)")
-
- # 调用 NER 服务
- entities = await ner_service.extract_entities(
- text=request.text,
- entity_types=request.entity_types
- )
-
- # 如果需要提取关系
- relations = []
- if request.extract_relations and len(entities) > 1:
- from ..services.relation_service import relation_service
- relations = await relation_service.extract_relations(
- text=request.text,
- entities=entities
- )
-
- processing_time = int((time.time() - start_time) * 1000)
-
- logger.info(f"实体提取完成: document_id={request.document_id}, "
- f"entity_count={len(entities)}, relation_count={len(relations)}, "
- f"processing_time={processing_time}ms")
-
- # 输出完整的实体列表
- logger.info(f"========== 实体列表 ({len(entities)} 个) ==========")
- for i, entity in enumerate(entities, 1):
- logger.info(f" [{i}] {entity.type}: {entity.name}")
-
- # 输出完整的关系列表
- if relations:
- logger.info(f"========== 关系列表 ({len(relations)} 个) ==========")
- for i, rel in enumerate(relations, 1):
- logger.info(f" [{i}] {rel.source_entity} --[{rel.relation_type}]--> {rel.target_entity}")
-
- return NerResponse.success_response(
- document_id=request.document_id,
- entities=entities,
- relations=relations,
- processing_time=processing_time
- )
-
- except HTTPException:
- raise
- except Exception as e:
- logger.error(f"实体提取失败: document_id={request.document_id}, error={str(e)}")
- return NerResponse.error_response(
- document_id=request.document_id,
- error_message=str(e)
- )
- # ============== 异步任务接口(轮询模式) ==============
- @router.post("/extract/async")
- async def extract_entities_async(request: NerRequest, background_tasks: BackgroundTasks):
- """
- 异步提取命名实体,立即返回任务 ID
-
- 使用方式:
- 1. 调用此接口,获取 task_id
- 2. 轮询 /ner/task/{task_id} 查询进度和结果
- """
- task_id = str(uuid.uuid4())
- now = time.time()
-
- # 初始化任务状态
- _tasks[task_id] = {
- "task_id": task_id,
- "document_id": request.document_id,
- "status": "pending",
- "progress": 0,
- "message": "任务已创建,等待处理",
- "result": None,
- "error": None,
- "created_at": now,
- "updated_at": now
- }
-
- # 启动后台任务
- background_tasks.add_task(
- _process_ner_task,
- task_id,
- request.document_id,
- request.text,
- request.entity_types,
- request.extract_relations
- )
-
- logger.info(f"创建异步 NER 任务: task_id={task_id}, document_id={request.document_id}")
-
- return {
- "task_id": task_id,
- "document_id": request.document_id,
- "status": "pending",
- "message": "任务已创建,请轮询 /ner/task/{task_id} 获取进度"
- }
- @router.get("/task/{task_id}", response_model=TaskStatus)
- async def get_task_status(task_id: str):
- """
- 查询异步任务状态和结果
-
- 状态说明:
- - pending: 等待处理
- - processing: 正在处理(progress 字段表示进度 0-100)
- - completed: 处理完成(result 字段包含结果)
- - failed: 处理失败(error 字段包含错误信息)
- """
- if task_id not in _tasks:
- raise HTTPException(status_code=404, detail=f"任务不存在: {task_id}")
-
- task = _tasks[task_id]
- return TaskStatus(**task)
- @router.delete("/task/{task_id}")
- async def delete_task(task_id: str):
- """删除已完成的任务(释放内存)"""
- if task_id not in _tasks:
- raise HTTPException(status_code=404, detail=f"任务不存在: {task_id}")
-
- del _tasks[task_id]
- return {"message": f"任务已删除: {task_id}"}
- async def _process_ner_task(
- task_id: str,
- document_id: str,
- text: str,
- entity_types: list,
- extract_relations: bool
- ):
- """后台处理 NER 任务"""
- start_time = time.time()
-
- def update_progress(progress: int, message: str):
- """更新任务进度"""
- if task_id in _tasks:
- _tasks[task_id]["status"] = "processing"
- _tasks[task_id]["progress"] = progress
- _tasks[task_id]["message"] = message
- _tasks[task_id]["updated_at"] = time.time()
-
- try:
- update_progress(0, "开始处理")
-
- # 验证文本长度
- if len(text) > 50000:
- raise ValueError("文本长度超过限制(最大50000字符)")
-
- # 使用带进度回调的提取(如果支持)
- all_entities = []
-
- if hasattr(ner_service, 'extract_entities_with_progress'):
- chunk_count = 0
- total_chunks = 1
-
- async for event_str in ner_service.extract_entities_with_progress(
- text=text,
- entity_types=entity_types
- ):
- # 解析进度事件
- try:
- lines = event_str.strip().split('\n')
- for line in lines:
- if line.startswith('data:'):
- data = json.loads(line[5:].strip())
-
- # 更新进度
- if 'progress_percent' in data:
- progress = min(data['progress_percent'], 90) # 预留 10% 给关系提取
- message = data.get('message', f"处理中 {progress}%")
- update_progress(progress, message)
-
- if 'total_chunks' in data:
- total_chunks = data['total_chunks']
- if 'chunk_index' in data:
- chunk_count = data['chunk_index']
-
- # 获取实体数据
- if 'entities' in data:
- all_entities = [EntityInfo(**e) for e in data['entities']]
- except:
- pass
- else:
- update_progress(10, "正在提取实体...")
- all_entities = await ner_service.extract_entities(
- text=text,
- entity_types=entity_types
- )
-
- update_progress(90, f"实体提取完成,共 {len(all_entities)} 个")
-
- # 提取关系
- all_relations = []
- if extract_relations and len(all_entities) > 1:
- update_progress(92, "正在提取实体关系...")
- from ..services.relation_service import relation_service
- all_relations = await relation_service.extract_relations(
- text=text,
- entities=all_entities
- )
-
- processing_time = int((time.time() - start_time) * 1000)
-
- # 构建结果
- response = NerResponse.success_response(
- document_id=document_id,
- entities=all_entities,
- relations=all_relations,
- processing_time=processing_time
- )
-
- # 更新任务状态为完成
- if task_id in _tasks:
- _tasks[task_id]["status"] = "completed"
- _tasks[task_id]["progress"] = 100
- _tasks[task_id]["message"] = f"处理完成: {len(all_entities)} 个实体, {len(all_relations)} 个关系"
- # Pydantic v2: 使用 model_dump(),by_alias 已在 model_config 中设置
- result_dict = response.model_dump(by_alias=True)
- _tasks[task_id]["result"] = result_dict
- _tasks[task_id]["updated_at"] = time.time()
-
- # 调试:输出第一个实体的序列化结果(包括 position)
- if result_dict.get("entities") and len(result_dict["entities"]) > 0:
- first_entity = result_dict["entities"][0]
- logger.info(f"实体序列化示例: name={first_entity.get('name')}, position={first_entity.get('position')}")
- # 确保 position 被正确序列化
- if first_entity.get('position'):
- logger.info(f"Position 详情: {first_entity['position']}")
-
- logger.info(f"异步 NER 任务完成: task_id={task_id}, document_id={document_id}, "
- f"entities={len(all_entities)}, relations={len(all_relations)}, time={processing_time}ms")
-
- except Exception as e:
- logger.error(f"异步 NER 任务失败: task_id={task_id}, document_id={document_id}, error={str(e)}")
- if task_id in _tasks:
- _tasks[task_id]["status"] = "failed"
- _tasks[task_id]["error"] = str(e)
- _tasks[task_id]["message"] = f"处理失败: {str(e)}"
- _tasks[task_id]["updated_at"] = time.time()
|