""" NER 路由 """ import json import asyncio import uuid from typing import Dict, Any, Optional from fastapi import APIRouter, HTTPException, BackgroundTasks from fastapi.responses import StreamingResponse from loguru import logger import time from pydantic import BaseModel from ..models import NerRequest, NerResponse, EntityInfo from ..services.ner_service import ner_service router = APIRouter() # ============== 任务存储 ============== # 存储异步任务状态和结果 _tasks: Dict[str, Dict[str, Any]] = {} class TaskStatus(BaseModel): """任务状态响应""" task_id: str document_id: str status: str # pending, processing, completed, failed progress: int = 0 # 0-100 message: str = "" result: Optional[Dict] = None error: Optional[str] = None created_at: float = 0 updated_at: float = 0 async def sse_event(event: str, data: dict): """生成 SSE 事件格式""" return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" @router.post("/extract", response_model=NerResponse) async def extract_entities(request: NerRequest): """ 从文本中提取命名实体 """ start_time = time.time() try: logger.info(f"开始提取实体: document_id={request.document_id}, text_length={len(request.text)}") # 验证文本长度 if len(request.text) > 50000: raise HTTPException(status_code=400, detail="文本长度超过限制(最大50000字符)") # 调用 NER 服务 entities = await ner_service.extract_entities( text=request.text, entity_types=request.entity_types ) # 如果需要提取关系 relations = [] if request.extract_relations and len(entities) > 1: from ..services.relation_service import relation_service relations = await relation_service.extract_relations( text=request.text, entities=entities ) processing_time = int((time.time() - start_time) * 1000) logger.info(f"实体提取完成: document_id={request.document_id}, " f"entity_count={len(entities)}, relation_count={len(relations)}, " f"processing_time={processing_time}ms") # 输出完整的实体列表 logger.info(f"========== 实体列表 ({len(entities)} 个) ==========") for i, entity in enumerate(entities, 1): logger.info(f" [{i}] {entity.type}: {entity.name}") # 输出完整的关系列表 if relations: logger.info(f"========== 关系列表 ({len(relations)} 个) ==========") for i, rel in enumerate(relations, 1): logger.info(f" [{i}] {rel.source_entity} --[{rel.relation_type}]--> {rel.target_entity}") return NerResponse.success_response( document_id=request.document_id, entities=entities, relations=relations, processing_time=processing_time ) except HTTPException: raise except Exception as e: logger.error(f"实体提取失败: document_id={request.document_id}, error={str(e)}") return NerResponse.error_response( document_id=request.document_id, error_message=str(e) ) # ============== 异步任务接口(轮询模式) ============== @router.post("/extract/async") async def extract_entities_async(request: NerRequest, background_tasks: BackgroundTasks): """ 异步提取命名实体,立即返回任务 ID 使用方式: 1. 调用此接口,获取 task_id 2. 轮询 /ner/task/{task_id} 查询进度和结果 """ task_id = str(uuid.uuid4()) now = time.time() # 初始化任务状态 _tasks[task_id] = { "task_id": task_id, "document_id": request.document_id, "status": "pending", "progress": 0, "message": "任务已创建,等待处理", "result": None, "error": None, "created_at": now, "updated_at": now } # 启动后台任务 background_tasks.add_task( _process_ner_task, task_id, request.document_id, request.text, request.entity_types, request.extract_relations ) logger.info(f"创建异步 NER 任务: task_id={task_id}, document_id={request.document_id}") return { "task_id": task_id, "document_id": request.document_id, "status": "pending", "message": "任务已创建,请轮询 /ner/task/{task_id} 获取进度" } @router.get("/task/{task_id}", response_model=TaskStatus) async def get_task_status(task_id: str): """ 查询异步任务状态和结果 状态说明: - pending: 等待处理 - processing: 正在处理(progress 字段表示进度 0-100) - completed: 处理完成(result 字段包含结果) - failed: 处理失败(error 字段包含错误信息) """ if task_id not in _tasks: raise HTTPException(status_code=404, detail=f"任务不存在: {task_id}") task = _tasks[task_id] return TaskStatus(**task) @router.delete("/task/{task_id}") async def delete_task(task_id: str): """删除已完成的任务(释放内存)""" if task_id not in _tasks: raise HTTPException(status_code=404, detail=f"任务不存在: {task_id}") del _tasks[task_id] return {"message": f"任务已删除: {task_id}"} async def _process_ner_task( task_id: str, document_id: str, text: str, entity_types: list, extract_relations: bool ): """后台处理 NER 任务""" start_time = time.time() def update_progress(progress: int, message: str): """更新任务进度""" if task_id in _tasks: _tasks[task_id]["status"] = "processing" _tasks[task_id]["progress"] = progress _tasks[task_id]["message"] = message _tasks[task_id]["updated_at"] = time.time() try: update_progress(0, "开始处理") # 验证文本长度 if len(text) > 50000: raise ValueError("文本长度超过限制(最大50000字符)") # 使用带进度回调的提取(如果支持) all_entities = [] if hasattr(ner_service, 'extract_entities_with_progress'): chunk_count = 0 total_chunks = 1 async for event_str in ner_service.extract_entities_with_progress( text=text, entity_types=entity_types ): # 解析进度事件 try: lines = event_str.strip().split('\n') for line in lines: if line.startswith('data:'): data = json.loads(line[5:].strip()) # 更新进度 if 'progress_percent' in data: progress = min(data['progress_percent'], 90) # 预留 10% 给关系提取 message = data.get('message', f"处理中 {progress}%") update_progress(progress, message) if 'total_chunks' in data: total_chunks = data['total_chunks'] if 'chunk_index' in data: chunk_count = data['chunk_index'] # 获取实体数据 if 'entities' in data: all_entities = [EntityInfo(**e) for e in data['entities']] except: pass else: update_progress(10, "正在提取实体...") all_entities = await ner_service.extract_entities( text=text, entity_types=entity_types ) update_progress(90, f"实体提取完成,共 {len(all_entities)} 个") # 提取关系 all_relations = [] if extract_relations and len(all_entities) > 1: update_progress(92, "正在提取实体关系...") from ..services.relation_service import relation_service all_relations = await relation_service.extract_relations( text=text, entities=all_entities ) processing_time = int((time.time() - start_time) * 1000) # 构建结果 response = NerResponse.success_response( document_id=document_id, entities=all_entities, relations=all_relations, processing_time=processing_time ) # 更新任务状态为完成 if task_id in _tasks: _tasks[task_id]["status"] = "completed" _tasks[task_id]["progress"] = 100 _tasks[task_id]["message"] = f"处理完成: {len(all_entities)} 个实体, {len(all_relations)} 个关系" _tasks[task_id]["result"] = response.dict() _tasks[task_id]["updated_at"] = time.time() logger.info(f"异步 NER 任务完成: task_id={task_id}, document_id={document_id}, " f"entities={len(all_entities)}, relations={len(all_relations)}, time={processing_time}ms") except Exception as e: logger.error(f"异步 NER 任务失败: task_id={task_id}, document_id={document_id}, error={str(e)}") if task_id in _tasks: _tasks[task_id]["status"] = "failed" _tasks[task_id]["error"] = str(e) _tasks[task_id]["message"] = f"处理失败: {str(e)}" _tasks[task_id]["updated_at"] = time.time()