ner.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. """
  2. NER 路由
  3. """
  4. import json
  5. import asyncio
  6. from fastapi import APIRouter, HTTPException
  7. from fastapi.responses import StreamingResponse
  8. from loguru import logger
  9. import time
  10. from ..models import NerRequest, NerResponse, EntityInfo
  11. from ..services.ner_service import ner_service
  12. router = APIRouter()
  13. async def sse_event(event: str, data: dict):
  14. """生成 SSE 事件格式"""
  15. return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
  16. @router.post("/extract", response_model=NerResponse)
  17. async def extract_entities(request: NerRequest):
  18. """
  19. 从文本中提取命名实体
  20. """
  21. start_time = time.time()
  22. try:
  23. logger.info(f"开始提取实体: document_id={request.document_id}, text_length={len(request.text)}")
  24. # 验证文本长度
  25. if len(request.text) > 50000:
  26. raise HTTPException(status_code=400, detail="文本长度超过限制(最大50000字符)")
  27. # 调用 NER 服务
  28. entities = await ner_service.extract_entities(
  29. text=request.text,
  30. entity_types=request.entity_types
  31. )
  32. # 如果需要提取关系
  33. relations = []
  34. if request.extract_relations and len(entities) > 1:
  35. from ..services.relation_service import relation_service
  36. relations = await relation_service.extract_relations(
  37. text=request.text,
  38. entities=entities
  39. )
  40. processing_time = int((time.time() - start_time) * 1000)
  41. logger.info(f"实体提取完成: document_id={request.document_id}, "
  42. f"entity_count={len(entities)}, relation_count={len(relations)}, "
  43. f"processing_time={processing_time}ms")
  44. # 输出完整的实体列表
  45. logger.info(f"========== 实体列表 ({len(entities)} 个) ==========")
  46. for i, entity in enumerate(entities, 1):
  47. logger.info(f" [{i}] {entity.type}: {entity.name}")
  48. # 输出完整的关系列表
  49. if relations:
  50. logger.info(f"========== 关系列表 ({len(relations)} 个) ==========")
  51. for i, rel in enumerate(relations, 1):
  52. logger.info(f" [{i}] {rel.source_entity} --[{rel.relation_type}]--> {rel.target_entity}")
  53. return NerResponse.success_response(
  54. document_id=request.document_id,
  55. entities=entities,
  56. relations=relations,
  57. processing_time=processing_time
  58. )
  59. except HTTPException:
  60. raise
  61. except Exception as e:
  62. logger.error(f"实体提取失败: document_id={request.document_id}, error={str(e)}")
  63. return NerResponse.error_response(
  64. document_id=request.document_id,
  65. error_message=str(e)
  66. )
  67. @router.post("/extract/stream")
  68. async def extract_entities_stream(request: NerRequest):
  69. """
  70. 从文本中提取命名实体(SSE 流式响应)
  71. 实时推送进度事件:
  72. - progress: 处理进度(分块处理时)
  73. - entity: 发现新实体
  74. - complete: 处理完成
  75. - error: 处理出错
  76. """
  77. async def generate():
  78. start_time = time.time()
  79. try:
  80. # 发送开始事件
  81. yield await sse_event("start", {
  82. "document_id": request.document_id,
  83. "text_length": len(request.text),
  84. "message": "开始 NER 提取"
  85. })
  86. # 验证文本长度
  87. if len(request.text) > 50000:
  88. yield await sse_event("error", {
  89. "document_id": request.document_id,
  90. "error": "文本长度超过限制(最大50000字符)"
  91. })
  92. return
  93. # 使用带进度回调的提取方法
  94. all_entities = []
  95. all_relations = []
  96. # 创建进度回调
  97. async def progress_callback(chunk_index: int, total_chunks: int, chunk_entities: list):
  98. nonlocal all_entities
  99. all_entities.extend(chunk_entities)
  100. progress_data = {
  101. "document_id": request.document_id,
  102. "chunk_index": chunk_index,
  103. "total_chunks": total_chunks,
  104. "chunk_entities": len(chunk_entities),
  105. "total_entities": len(all_entities),
  106. "progress_percent": int((chunk_index / total_chunks) * 100)
  107. }
  108. return await sse_event("progress", progress_data)
  109. # 调用带进度的 NER 服务
  110. from ..services.ner_service import ner_service
  111. # 检查是否支持流式提取
  112. if hasattr(ner_service, 'extract_entities_with_progress'):
  113. async for event in ner_service.extract_entities_with_progress(
  114. text=request.text,
  115. entity_types=request.entity_types,
  116. progress_callback=progress_callback
  117. ):
  118. yield event
  119. all_entities = event.get('entities', all_entities) if isinstance(event, dict) else all_entities
  120. else:
  121. # 回退到普通提取
  122. all_entities = await ner_service.extract_entities(
  123. text=request.text,
  124. entity_types=request.entity_types
  125. )
  126. yield await sse_event("progress", {
  127. "document_id": request.document_id,
  128. "chunk_index": 1,
  129. "total_chunks": 1,
  130. "total_entities": len(all_entities),
  131. "progress_percent": 100
  132. })
  133. # 提取关系
  134. if request.extract_relations and len(all_entities) > 1:
  135. yield await sse_event("progress", {
  136. "document_id": request.document_id,
  137. "message": "正在提取实体关系...",
  138. "stage": "relations"
  139. })
  140. from ..services.relation_service import relation_service
  141. all_relations = await relation_service.extract_relations(
  142. text=request.text,
  143. entities=all_entities
  144. )
  145. processing_time = int((time.time() - start_time) * 1000)
  146. # 发送完成事件(包含完整结果)
  147. response = NerResponse.success_response(
  148. document_id=request.document_id,
  149. entities=all_entities,
  150. relations=all_relations,
  151. processing_time=processing_time
  152. )
  153. yield await sse_event("complete", response.dict())
  154. logger.info(f"SSE 提取完成: document_id={request.document_id}, "
  155. f"entity_count={len(all_entities)}, relation_count={len(all_relations)}, "
  156. f"processing_time={processing_time}ms")
  157. except Exception as e:
  158. logger.error(f"SSE 提取失败: document_id={request.document_id}, error={str(e)}")
  159. yield await sse_event("error", {
  160. "document_id": request.document_id,
  161. "error": str(e)
  162. })
  163. return StreamingResponse(
  164. generate(),
  165. media_type="text/event-stream",
  166. headers={
  167. "Cache-Control": "no-cache",
  168. "Connection": "keep-alive",
  169. "X-Accel-Buffering": "no"
  170. }
  171. )