ner.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. """
  2. NER 路由
  3. """
  4. import json
  5. import asyncio
  6. import uuid
  7. from typing import Dict, Any, Optional
  8. from fastapi import APIRouter, HTTPException, BackgroundTasks
  9. from fastapi.responses import StreamingResponse
  10. from loguru import logger
  11. import time
  12. from pydantic import BaseModel
  13. from ..models import NerRequest, NerResponse, EntityInfo
  14. from ..services.ner_service import ner_service
  15. router = APIRouter()
  16. # ============== 任务存储 ==============
  17. # 存储异步任务状态和结果
  18. _tasks: Dict[str, Dict[str, Any]] = {}
  19. class TaskStatus(BaseModel):
  20. """任务状态响应"""
  21. task_id: str
  22. document_id: str
  23. status: str # pending, processing, completed, failed
  24. progress: int = 0 # 0-100
  25. message: str = ""
  26. result: Optional[Dict] = None
  27. error: Optional[str] = None
  28. created_at: float = 0
  29. updated_at: float = 0
  30. async def sse_event(event: str, data: dict):
  31. """生成 SSE 事件格式"""
  32. return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
  33. @router.post("/extract", response_model=NerResponse)
  34. async def extract_entities(request: NerRequest):
  35. """
  36. 从文本中提取命名实体
  37. """
  38. start_time = time.time()
  39. try:
  40. logger.info(f"开始提取实体: document_id={request.document_id}, text_length={len(request.text)}")
  41. # 验证文本长度
  42. if len(request.text) > 50000:
  43. raise HTTPException(status_code=400, detail="文本长度超过限制(最大50000字符)")
  44. # 调用 NER 服务
  45. entities = await ner_service.extract_entities(
  46. text=request.text,
  47. entity_types=request.entity_types
  48. )
  49. # 如果需要提取关系
  50. relations = []
  51. if request.extract_relations and len(entities) > 1:
  52. from ..services.relation_service import relation_service
  53. relations = await relation_service.extract_relations(
  54. text=request.text,
  55. entities=entities
  56. )
  57. processing_time = int((time.time() - start_time) * 1000)
  58. logger.info(f"实体提取完成: document_id={request.document_id}, "
  59. f"entity_count={len(entities)}, relation_count={len(relations)}, "
  60. f"processing_time={processing_time}ms")
  61. # 输出完整的实体列表
  62. logger.info(f"========== 实体列表 ({len(entities)} 个) ==========")
  63. for i, entity in enumerate(entities, 1):
  64. logger.info(f" [{i}] {entity.type}: {entity.name}")
  65. # 输出完整的关系列表
  66. if relations:
  67. logger.info(f"========== 关系列表 ({len(relations)} 个) ==========")
  68. for i, rel in enumerate(relations, 1):
  69. logger.info(f" [{i}] {rel.source_entity} --[{rel.relation_type}]--> {rel.target_entity}")
  70. return NerResponse.success_response(
  71. document_id=request.document_id,
  72. entities=entities,
  73. relations=relations,
  74. processing_time=processing_time
  75. )
  76. except HTTPException:
  77. raise
  78. except Exception as e:
  79. logger.error(f"实体提取失败: document_id={request.document_id}, error={str(e)}")
  80. return NerResponse.error_response(
  81. document_id=request.document_id,
  82. error_message=str(e)
  83. )
  84. # ============== 异步任务接口(轮询模式) ==============
  85. @router.post("/extract/async")
  86. async def extract_entities_async(request: NerRequest, background_tasks: BackgroundTasks):
  87. """
  88. 异步提取命名实体,立即返回任务 ID
  89. 使用方式:
  90. 1. 调用此接口,获取 task_id
  91. 2. 轮询 /ner/task/{task_id} 查询进度和结果
  92. """
  93. task_id = str(uuid.uuid4())
  94. now = time.time()
  95. # 初始化任务状态
  96. _tasks[task_id] = {
  97. "task_id": task_id,
  98. "document_id": request.document_id,
  99. "status": "pending",
  100. "progress": 0,
  101. "message": "任务已创建,等待处理",
  102. "result": None,
  103. "error": None,
  104. "created_at": now,
  105. "updated_at": now
  106. }
  107. # 启动后台任务
  108. background_tasks.add_task(
  109. _process_ner_task,
  110. task_id,
  111. request.document_id,
  112. request.text,
  113. request.entity_types,
  114. request.extract_relations
  115. )
  116. logger.info(f"创建异步 NER 任务: task_id={task_id}, document_id={request.document_id}")
  117. return {
  118. "task_id": task_id,
  119. "document_id": request.document_id,
  120. "status": "pending",
  121. "message": "任务已创建,请轮询 /ner/task/{task_id} 获取进度"
  122. }
  123. @router.get("/task/{task_id}", response_model=TaskStatus)
  124. async def get_task_status(task_id: str):
  125. """
  126. 查询异步任务状态和结果
  127. 状态说明:
  128. - pending: 等待处理
  129. - processing: 正在处理(progress 字段表示进度 0-100)
  130. - completed: 处理完成(result 字段包含结果)
  131. - failed: 处理失败(error 字段包含错误信息)
  132. """
  133. if task_id not in _tasks:
  134. raise HTTPException(status_code=404, detail=f"任务不存在: {task_id}")
  135. task = _tasks[task_id]
  136. return TaskStatus(**task)
  137. @router.delete("/task/{task_id}")
  138. async def delete_task(task_id: str):
  139. """删除已完成的任务(释放内存)"""
  140. if task_id not in _tasks:
  141. raise HTTPException(status_code=404, detail=f"任务不存在: {task_id}")
  142. del _tasks[task_id]
  143. return {"message": f"任务已删除: {task_id}"}
  144. async def _process_ner_task(
  145. task_id: str,
  146. document_id: str,
  147. text: str,
  148. entity_types: list,
  149. extract_relations: bool
  150. ):
  151. """后台处理 NER 任务"""
  152. start_time = time.time()
  153. def update_progress(progress: int, message: str):
  154. """更新任务进度"""
  155. if task_id in _tasks:
  156. _tasks[task_id]["status"] = "processing"
  157. _tasks[task_id]["progress"] = progress
  158. _tasks[task_id]["message"] = message
  159. _tasks[task_id]["updated_at"] = time.time()
  160. try:
  161. update_progress(0, "开始处理")
  162. # 验证文本长度
  163. if len(text) > 50000:
  164. raise ValueError("文本长度超过限制(最大50000字符)")
  165. # 使用带进度回调的提取(如果支持)
  166. all_entities = []
  167. if hasattr(ner_service, 'extract_entities_with_progress'):
  168. chunk_count = 0
  169. total_chunks = 1
  170. async for event_str in ner_service.extract_entities_with_progress(
  171. text=text,
  172. entity_types=entity_types
  173. ):
  174. # 解析进度事件
  175. try:
  176. lines = event_str.strip().split('\n')
  177. for line in lines:
  178. if line.startswith('data:'):
  179. data = json.loads(line[5:].strip())
  180. # 更新进度
  181. if 'progress_percent' in data:
  182. progress = min(data['progress_percent'], 90) # 预留 10% 给关系提取
  183. message = data.get('message', f"处理中 {progress}%")
  184. update_progress(progress, message)
  185. if 'total_chunks' in data:
  186. total_chunks = data['total_chunks']
  187. if 'chunk_index' in data:
  188. chunk_count = data['chunk_index']
  189. # 获取实体数据
  190. if 'entities' in data:
  191. all_entities = [EntityInfo(**e) for e in data['entities']]
  192. except:
  193. pass
  194. else:
  195. update_progress(10, "正在提取实体...")
  196. all_entities = await ner_service.extract_entities(
  197. text=text,
  198. entity_types=entity_types
  199. )
  200. update_progress(90, f"实体提取完成,共 {len(all_entities)} 个")
  201. # 提取关系
  202. all_relations = []
  203. if extract_relations and len(all_entities) > 1:
  204. update_progress(92, "正在提取实体关系...")
  205. from ..services.relation_service import relation_service
  206. all_relations = await relation_service.extract_relations(
  207. text=text,
  208. entities=all_entities
  209. )
  210. processing_time = int((time.time() - start_time) * 1000)
  211. # 构建结果
  212. response = NerResponse.success_response(
  213. document_id=document_id,
  214. entities=all_entities,
  215. relations=all_relations,
  216. processing_time=processing_time
  217. )
  218. # 更新任务状态为完成
  219. if task_id in _tasks:
  220. _tasks[task_id]["status"] = "completed"
  221. _tasks[task_id]["progress"] = 100
  222. _tasks[task_id]["message"] = f"处理完成: {len(all_entities)} 个实体, {len(all_relations)} 个关系"
  223. _tasks[task_id]["result"] = response.dict()
  224. _tasks[task_id]["updated_at"] = time.time()
  225. logger.info(f"异步 NER 任务完成: task_id={task_id}, document_id={document_id}, "
  226. f"entities={len(all_entities)}, relations={len(all_relations)}, time={processing_time}ms")
  227. except Exception as e:
  228. logger.error(f"异步 NER 任务失败: task_id={task_id}, document_id={document_id}, error={str(e)}")
  229. if task_id in _tasks:
  230. _tasks[task_id]["status"] = "failed"
  231. _tasks[task_id]["error"] = str(e)
  232. _tasks[task_id]["message"] = f"处理失败: {str(e)}"
  233. _tasks[task_id]["updated_at"] = time.time()