ner.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. """
  2. NER 路由
  3. """
  4. import json
  5. import asyncio
  6. from fastapi import APIRouter, HTTPException
  7. from fastapi.responses import StreamingResponse
  8. from loguru import logger
  9. import time
  10. from ..models import NerRequest, NerResponse, EntityInfo
  11. from ..services.ner_service import ner_service
  12. router = APIRouter()
  13. async def sse_event(event: str, data: dict):
  14. """生成 SSE 事件格式"""
  15. return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
  16. @router.post("/extract", response_model=NerResponse)
  17. async def extract_entities(request: NerRequest):
  18. """
  19. 从文本中提取命名实体
  20. """
  21. start_time = time.time()
  22. try:
  23. logger.info(f"开始提取实体: document_id={request.document_id}, text_length={len(request.text)}")
  24. # 验证文本长度
  25. if len(request.text) > 50000:
  26. raise HTTPException(status_code=400, detail="文本长度超过限制(最大50000字符)")
  27. # 调用 NER 服务
  28. entities = await ner_service.extract_entities(
  29. text=request.text,
  30. entity_types=request.entity_types
  31. )
  32. # 如果需要提取关系
  33. relations = []
  34. if request.extract_relations and len(entities) > 1:
  35. from ..services.relation_service import relation_service
  36. relations = await relation_service.extract_relations(
  37. text=request.text,
  38. entities=entities
  39. )
  40. processing_time = int((time.time() - start_time) * 1000)
  41. logger.info(f"实体提取完成: document_id={request.document_id}, "
  42. f"entity_count={len(entities)}, relation_count={len(relations)}, "
  43. f"processing_time={processing_time}ms")
  44. # 输出完整的实体列表
  45. logger.info(f"========== 实体列表 ({len(entities)} 个) ==========")
  46. for i, entity in enumerate(entities, 1):
  47. logger.info(f" [{i}] {entity.type}: {entity.name}")
  48. # 输出完整的关系列表
  49. if relations:
  50. logger.info(f"========== 关系列表 ({len(relations)} 个) ==========")
  51. for i, rel in enumerate(relations, 1):
  52. logger.info(f" [{i}] {rel.source_entity} --[{rel.relation_type}]--> {rel.target_entity}")
  53. return NerResponse.success_response(
  54. document_id=request.document_id,
  55. entities=entities,
  56. relations=relations,
  57. processing_time=processing_time
  58. )
  59. except HTTPException:
  60. raise
  61. except Exception as e:
  62. logger.error(f"实体提取失败: document_id={request.document_id}, error={str(e)}")
  63. return NerResponse.error_response(
  64. document_id=request.document_id,
  65. error_message=str(e)
  66. )
  67. @router.post("/extract/stream")
  68. async def extract_entities_stream(request: NerRequest):
  69. """
  70. 从文本中提取命名实体(SSE 流式响应)
  71. 实时推送进度事件:
  72. - progress: 处理进度(分块处理时)
  73. - entity: 发现新实体
  74. - complete: 处理完成
  75. - error: 处理出错
  76. """
  77. async def generate():
  78. start_time = time.time()
  79. try:
  80. # 发送开始事件
  81. yield await sse_event("start", {
  82. "document_id": request.document_id,
  83. "text_length": len(request.text),
  84. "message": "开始 NER 提取"
  85. })
  86. # 验证文本长度
  87. if len(request.text) > 50000:
  88. yield await sse_event("error", {
  89. "document_id": request.document_id,
  90. "error": "文本长度超过限制(最大50000字符)"
  91. })
  92. return
  93. # 使用带进度回调的提取方法
  94. all_entities = []
  95. all_relations = []
  96. # 调用带进度的 NER 服务
  97. from ..services.ner_service import ner_service
  98. # 检查是否支持流式提取
  99. if hasattr(ner_service, 'extract_entities_with_progress'):
  100. async for event_str in ner_service.extract_entities_with_progress(
  101. text=request.text,
  102. entity_types=request.entity_types
  103. ):
  104. # 转发进度事件
  105. yield event_str
  106. # 解析事件获取实体数据
  107. if "entities_data" in event_str:
  108. try:
  109. # 从 SSE 格式中提取 JSON 数据
  110. lines = event_str.strip().split('\n')
  111. for line in lines:
  112. if line.startswith('data:'):
  113. data = json.loads(line[5:].strip())
  114. if 'entities' in data:
  115. # 将字典转换回 EntityInfo 对象
  116. all_entities = [
  117. EntityInfo(**e) for e in data['entities']
  118. ]
  119. except Exception as parse_err:
  120. logger.warning(f"解析实体数据事件失败: {parse_err}")
  121. else:
  122. # 回退到普通提取
  123. all_entities = await ner_service.extract_entities(
  124. text=request.text,
  125. entity_types=request.entity_types
  126. )
  127. yield await sse_event("progress", {
  128. "document_id": request.document_id,
  129. "chunk_index": 1,
  130. "total_chunks": 1,
  131. "total_entities": len(all_entities),
  132. "progress_percent": 100
  133. })
  134. # 提取关系
  135. if request.extract_relations and len(all_entities) > 1:
  136. yield await sse_event("progress", {
  137. "document_id": request.document_id,
  138. "message": "正在提取实体关系...",
  139. "stage": "relations"
  140. })
  141. from ..services.relation_service import relation_service
  142. all_relations = await relation_service.extract_relations(
  143. text=request.text,
  144. entities=all_entities
  145. )
  146. processing_time = int((time.time() - start_time) * 1000)
  147. # 发送完成事件(包含完整结果)
  148. response = NerResponse.success_response(
  149. document_id=request.document_id,
  150. entities=all_entities,
  151. relations=all_relations,
  152. processing_time=processing_time
  153. )
  154. yield await sse_event("complete", response.dict())
  155. logger.info(f"SSE 提取完成: document_id={request.document_id}, "
  156. f"entity_count={len(all_entities)}, relation_count={len(all_relations)}, "
  157. f"processing_time={processing_time}ms")
  158. except Exception as e:
  159. logger.error(f"SSE 提取失败: document_id={request.document_id}, error={str(e)}")
  160. yield await sse_event("error", {
  161. "document_id": request.document_id,
  162. "error": str(e)
  163. })
  164. return StreamingResponse(
  165. generate(),
  166. media_type="text/event-stream",
  167. headers={
  168. "Cache-Control": "no-cache",
  169. "Connection": "keep-alive",
  170. "X-Accel-Buffering": "no"
  171. }
  172. )