| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205 |
- """
- NER 路由
- """
- import json
- import asyncio
- from fastapi import APIRouter, HTTPException
- from fastapi.responses import StreamingResponse
- from loguru import logger
- import time
- from ..models import NerRequest, NerResponse, EntityInfo
- from ..services.ner_service import ner_service
- router = APIRouter()
- async def sse_event(event: str, data: dict):
- """生成 SSE 事件格式"""
- return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
- @router.post("/extract", response_model=NerResponse)
- async def extract_entities(request: NerRequest):
- """
- 从文本中提取命名实体
- """
- start_time = time.time()
-
- try:
- logger.info(f"开始提取实体: document_id={request.document_id}, text_length={len(request.text)}")
-
- # 验证文本长度
- if len(request.text) > 50000:
- raise HTTPException(status_code=400, detail="文本长度超过限制(最大50000字符)")
-
- # 调用 NER 服务
- entities = await ner_service.extract_entities(
- text=request.text,
- entity_types=request.entity_types
- )
-
- # 如果需要提取关系
- relations = []
- if request.extract_relations and len(entities) > 1:
- from ..services.relation_service import relation_service
- relations = await relation_service.extract_relations(
- text=request.text,
- entities=entities
- )
-
- processing_time = int((time.time() - start_time) * 1000)
-
- logger.info(f"实体提取完成: document_id={request.document_id}, "
- f"entity_count={len(entities)}, relation_count={len(relations)}, "
- f"processing_time={processing_time}ms")
-
- # 输出完整的实体列表
- logger.info(f"========== 实体列表 ({len(entities)} 个) ==========")
- for i, entity in enumerate(entities, 1):
- logger.info(f" [{i}] {entity.type}: {entity.name}")
-
- # 输出完整的关系列表
- if relations:
- logger.info(f"========== 关系列表 ({len(relations)} 个) ==========")
- for i, rel in enumerate(relations, 1):
- logger.info(f" [{i}] {rel.source_entity} --[{rel.relation_type}]--> {rel.target_entity}")
-
- return NerResponse.success_response(
- document_id=request.document_id,
- entities=entities,
- relations=relations,
- processing_time=processing_time
- )
-
- except HTTPException:
- raise
- except Exception as e:
- logger.error(f"实体提取失败: document_id={request.document_id}, error={str(e)}")
- return NerResponse.error_response(
- document_id=request.document_id,
- error_message=str(e)
- )
- @router.post("/extract/stream")
- async def extract_entities_stream(request: NerRequest):
- """
- 从文本中提取命名实体(SSE 流式响应)
-
- 实时推送进度事件:
- - progress: 处理进度(分块处理时)
- - entity: 发现新实体
- - complete: 处理完成
- - error: 处理出错
- """
- async def generate():
- start_time = time.time()
-
- try:
- # 发送开始事件
- yield await sse_event("start", {
- "document_id": request.document_id,
- "text_length": len(request.text),
- "message": "开始 NER 提取"
- })
-
- # 验证文本长度
- if len(request.text) > 50000:
- yield await sse_event("error", {
- "document_id": request.document_id,
- "error": "文本长度超过限制(最大50000字符)"
- })
- return
-
- # 使用带进度回调的提取方法
- all_entities = []
- all_relations = []
-
- # 创建进度回调
- async def progress_callback(chunk_index: int, total_chunks: int, chunk_entities: list):
- nonlocal all_entities
- all_entities.extend(chunk_entities)
-
- progress_data = {
- "document_id": request.document_id,
- "chunk_index": chunk_index,
- "total_chunks": total_chunks,
- "chunk_entities": len(chunk_entities),
- "total_entities": len(all_entities),
- "progress_percent": int((chunk_index / total_chunks) * 100)
- }
- return await sse_event("progress", progress_data)
-
- # 调用带进度的 NER 服务
- from ..services.ner_service import ner_service
-
- # 检查是否支持流式提取
- if hasattr(ner_service, 'extract_entities_with_progress'):
- async for event in ner_service.extract_entities_with_progress(
- text=request.text,
- entity_types=request.entity_types,
- progress_callback=progress_callback
- ):
- yield event
- all_entities = event.get('entities', all_entities) if isinstance(event, dict) else all_entities
- else:
- # 回退到普通提取
- all_entities = await ner_service.extract_entities(
- text=request.text,
- entity_types=request.entity_types
- )
- yield await sse_event("progress", {
- "document_id": request.document_id,
- "chunk_index": 1,
- "total_chunks": 1,
- "total_entities": len(all_entities),
- "progress_percent": 100
- })
-
- # 提取关系
- if request.extract_relations and len(all_entities) > 1:
- yield await sse_event("progress", {
- "document_id": request.document_id,
- "message": "正在提取实体关系...",
- "stage": "relations"
- })
-
- from ..services.relation_service import relation_service
- all_relations = await relation_service.extract_relations(
- text=request.text,
- entities=all_entities
- )
-
- processing_time = int((time.time() - start_time) * 1000)
-
- # 发送完成事件(包含完整结果)
- response = NerResponse.success_response(
- document_id=request.document_id,
- entities=all_entities,
- relations=all_relations,
- processing_time=processing_time
- )
-
- yield await sse_event("complete", response.dict())
-
- logger.info(f"SSE 提取完成: document_id={request.document_id}, "
- f"entity_count={len(all_entities)}, relation_count={len(all_relations)}, "
- f"processing_time={processing_time}ms")
-
- except Exception as e:
- logger.error(f"SSE 提取失败: document_id={request.document_id}, error={str(e)}")
- yield await sse_event("error", {
- "document_id": request.document_id,
- "error": str(e)
- })
-
- return StreamingResponse(
- generate(),
- media_type="text/event-stream",
- headers={
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- "X-Accel-Buffering": "no"
- }
- )
|