""" NER 路由 """ import json import asyncio from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse from loguru import logger import time from ..models import NerRequest, NerResponse, EntityInfo from ..services.ner_service import ner_service router = APIRouter() async def sse_event(event: str, data: dict): """生成 SSE 事件格式""" return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" @router.post("/extract", response_model=NerResponse) async def extract_entities(request: NerRequest): """ 从文本中提取命名实体 """ start_time = time.time() try: logger.info(f"开始提取实体: document_id={request.document_id}, text_length={len(request.text)}") # 验证文本长度 if len(request.text) > 50000: raise HTTPException(status_code=400, detail="文本长度超过限制(最大50000字符)") # 调用 NER 服务 entities = await ner_service.extract_entities( text=request.text, entity_types=request.entity_types ) # 如果需要提取关系 relations = [] if request.extract_relations and len(entities) > 1: from ..services.relation_service import relation_service relations = await relation_service.extract_relations( text=request.text, entities=entities ) processing_time = int((time.time() - start_time) * 1000) logger.info(f"实体提取完成: document_id={request.document_id}, " f"entity_count={len(entities)}, relation_count={len(relations)}, " f"processing_time={processing_time}ms") # 输出完整的实体列表 logger.info(f"========== 实体列表 ({len(entities)} 个) ==========") for i, entity in enumerate(entities, 1): logger.info(f" [{i}] {entity.type}: {entity.name}") # 输出完整的关系列表 if relations: logger.info(f"========== 关系列表 ({len(relations)} 个) ==========") for i, rel in enumerate(relations, 1): logger.info(f" [{i}] {rel.source_entity} --[{rel.relation_type}]--> {rel.target_entity}") return NerResponse.success_response( document_id=request.document_id, entities=entities, relations=relations, processing_time=processing_time ) except HTTPException: raise except Exception as e: logger.error(f"实体提取失败: document_id={request.document_id}, error={str(e)}") return NerResponse.error_response( document_id=request.document_id, error_message=str(e) ) @router.post("/extract/stream") async def extract_entities_stream(request: NerRequest): """ 从文本中提取命名实体(SSE 流式响应) 实时推送进度事件: - progress: 处理进度(分块处理时) - entity: 发现新实体 - complete: 处理完成 - error: 处理出错 """ async def generate(): start_time = time.time() try: # 发送开始事件 yield await sse_event("start", { "document_id": request.document_id, "text_length": len(request.text), "message": "开始 NER 提取" }) # 验证文本长度 if len(request.text) > 50000: yield await sse_event("error", { "document_id": request.document_id, "error": "文本长度超过限制(最大50000字符)" }) return # 使用带进度回调的提取方法 all_entities = [] all_relations = [] # 创建进度回调 async def progress_callback(chunk_index: int, total_chunks: int, chunk_entities: list): nonlocal all_entities all_entities.extend(chunk_entities) progress_data = { "document_id": request.document_id, "chunk_index": chunk_index, "total_chunks": total_chunks, "chunk_entities": len(chunk_entities), "total_entities": len(all_entities), "progress_percent": int((chunk_index / total_chunks) * 100) } return await sse_event("progress", progress_data) # 调用带进度的 NER 服务 from ..services.ner_service import ner_service # 检查是否支持流式提取 if hasattr(ner_service, 'extract_entities_with_progress'): async for event in ner_service.extract_entities_with_progress( text=request.text, entity_types=request.entity_types, progress_callback=progress_callback ): yield event all_entities = event.get('entities', all_entities) if isinstance(event, dict) else all_entities else: # 回退到普通提取 all_entities = await ner_service.extract_entities( text=request.text, entity_types=request.entity_types ) yield await sse_event("progress", { "document_id": request.document_id, "chunk_index": 1, "total_chunks": 1, "total_entities": len(all_entities), "progress_percent": 100 }) # 提取关系 if request.extract_relations and len(all_entities) > 1: yield await sse_event("progress", { "document_id": request.document_id, "message": "正在提取实体关系...", "stage": "relations" }) from ..services.relation_service import relation_service all_relations = await relation_service.extract_relations( text=request.text, entities=all_entities ) processing_time = int((time.time() - start_time) * 1000) # 发送完成事件(包含完整结果) response = NerResponse.success_response( document_id=request.document_id, entities=all_entities, relations=all_relations, processing_time=processing_time ) yield await sse_event("complete", response.dict()) logger.info(f"SSE 提取完成: document_id={request.document_id}, " f"entity_count={len(all_entities)}, relation_count={len(all_relations)}, " f"processing_time={processing_time}ms") except Exception as e: logger.error(f"SSE 提取失败: document_id={request.document_id}, error={str(e)}") yield await sse_event("error", { "document_id": request.document_id, "error": str(e) }) return StreamingResponse( generate(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no" } )