element_extractor.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. """
  2. 要素提取器:使用NER和LLM从文档中提取实体
  3. 支持分章节提取和实体去重。
  4. """
  5. import json
  6. import asyncio
  7. from typing import Dict, List, Any, Optional
  8. from loguru import logger
  9. from .ner_service import ner_service
  10. class ElementExtractor:
  11. """
  12. 要素提取器
  13. 使用NER服务识别文档中的实体,可选使用LLM进行智能提取。
  14. 不预定义要素结构,返回动态识别的实体。
  15. """
  16. def __init__(self):
  17. self._deepseek_service = None
  18. @property
  19. def deepseek_service(self):
  20. """延迟加载deepseek服务"""
  21. if self._deepseek_service is None:
  22. try:
  23. from .deepseek_service import deepseek_service
  24. self._deepseek_service = deepseek_service
  25. except ImportError:
  26. logger.warning("DeepSeek服务未配置,LLM提取将跳过")
  27. self._deepseek_service = None
  28. return self._deepseek_service
  29. async def extract_from_text(
  30. self,
  31. text: str,
  32. attachment_id: int = 0,
  33. use_llm: bool = True
  34. ) -> Dict[str, Any]:
  35. """
  36. 从纯文本中提取实体(主接口)
  37. Args:
  38. text: 文档纯文本
  39. attachment_id: 附件ID
  40. use_llm: 是否使用LLM提取
  41. Returns:
  42. {
  43. "entities": [...], # NER识别的实体列表
  44. "llm_extractions": [...], # LLM提取的内容(可选)
  45. "statistics": {...}
  46. }
  47. """
  48. logger.info(f"开始提取实体: attachment_id={attachment_id}, "
  49. f"text_length={len(text)}, use_llm={use_llm}")
  50. # 1. 使用NER服务提取实体
  51. ner_entities = await self._extract_by_ner(text)
  52. logger.info(f"NER提取完成: {len(ner_entities)} 个实体")
  53. # 2. LLM智能提取(可选)
  54. llm_extractions = []
  55. if use_llm and self.deepseek_service:
  56. llm_extractions = await self._extract_by_llm(text)
  57. logger.info(f"LLM提取完成: {len(llm_extractions)} 个内容")
  58. return {
  59. "entities": ner_entities,
  60. "llm_extractions": llm_extractions,
  61. "statistics": {
  62. "ner_entity_count": len(ner_entities),
  63. "llm_extraction_count": len(llm_extractions),
  64. "text_length": len(text)
  65. }
  66. }
  67. async def _extract_by_ner(self, text: str) -> List[Dict]:
  68. """
  69. 使用NER服务提取实体
  70. 返回实体列表,每个实体包含:
  71. - text: 实体文本
  72. - type: 实体类型(DATE, ORG, PERSON, NUMBER, CODE等)
  73. - label: 实体标签
  74. - confidence: 置信度
  75. - position: 位置信息
  76. """
  77. try:
  78. # 调用现有的NER服务,返回EntityInfo对象列表
  79. entities = await ner_service.extract_entities(text)
  80. # 格式化输出(EntityInfo是Pydantic模型,使用属性访问)
  81. result = []
  82. for entity in entities:
  83. result.append({
  84. "text": entity.name,
  85. "type": entity.type,
  86. "label": entity.type,
  87. "confidence": entity.confidence,
  88. "position": {
  89. "start": entity.position.char_start if entity.position else 0,
  90. "end": entity.position.char_end if entity.position else 0
  91. }
  92. })
  93. return result
  94. except Exception as e:
  95. logger.error(f"NER提取失败: {e}")
  96. return []
  97. async def _extract_by_llm(self, text: str) -> List[Dict]:
  98. """
  99. 使用LLM智能提取关键信息
  100. 让LLM自动识别文档中的重要信息,不预设要提取什么。
  101. """
  102. if not self.deepseek_service:
  103. return []
  104. try:
  105. # 截取文档前部分进行分析
  106. sample_text = text[:8000] if len(text) > 8000 else text
  107. prompt = f"""请分析以下文档,提取其中的关键信息。
  108. 要求:
  109. 1. 识别文档类型(如:报告、合同、通知等)
  110. 2. 提取关键实体(如:组织名称、日期、金额、编号等)
  111. 3. 提取关键数据(如:得分、级别、数量等)
  112. 4. 以JSON格式返回
  113. 返回格式:
  114. {{
  115. "document_type": "文档类型",
  116. "key_entities": [
  117. {{"name": "实体名称", "type": "实体类型", "value": "实体值"}}
  118. ],
  119. "key_data": [
  120. {{"name": "数据名称", "value": "数据值", "unit": "单位"}}
  121. ],
  122. "summary": "文档摘要(50字以内)"
  123. }}
  124. 文档内容:
  125. {sample_text}
  126. 只返回JSON,不要其他内容。"""
  127. response = await self.deepseek_service.chat(prompt)
  128. if response:
  129. # 尝试解析JSON
  130. try:
  131. # 清理响应,提取JSON部分
  132. json_str = response.strip()
  133. if json_str.startswith("```"):
  134. json_str = json_str.split("```")[1]
  135. if json_str.startswith("json"):
  136. json_str = json_str[4:]
  137. data = json.loads(json_str)
  138. extractions = []
  139. # 文档类型
  140. if data.get("document_type"):
  141. extractions.append({
  142. "name": "文档类型",
  143. "value": data["document_type"],
  144. "source": "llm"
  145. })
  146. # 关键实体
  147. for entity in data.get("key_entities", []):
  148. extractions.append({
  149. "name": entity.get("name", ""),
  150. "type": entity.get("type", ""),
  151. "value": entity.get("value", ""),
  152. "source": "llm"
  153. })
  154. # 关键数据
  155. for item in data.get("key_data", []):
  156. value = item.get("value", "")
  157. if item.get("unit"):
  158. value = f"{value}{item['unit']}"
  159. extractions.append({
  160. "name": item.get("name", ""),
  161. "value": value,
  162. "source": "llm"
  163. })
  164. # 摘要
  165. if data.get("summary"):
  166. extractions.append({
  167. "name": "文档摘要",
  168. "value": data["summary"],
  169. "source": "llm"
  170. })
  171. return extractions
  172. except json.JSONDecodeError:
  173. logger.warning(f"LLM返回的不是有效JSON: {response[:200]}")
  174. return []
  175. return []
  176. except Exception as e:
  177. logger.error(f"LLM提取失败: {e}")
  178. return []
  179. async def extract_from_chapters(
  180. self,
  181. chapters: List[Dict],
  182. attachment_id: int = 0,
  183. use_llm: bool = True,
  184. parallel: bool = True
  185. ) -> Dict[str, Any]:
  186. """
  187. 分章节提取实体,最后去重合并
  188. Args:
  189. chapters: 章节列表,每个章节包含 {chapter_id, title, text}
  190. attachment_id: 附件ID
  191. use_llm: 是否使用LLM提取
  192. parallel: 是否并行处理章节
  193. Returns:
  194. {
  195. "entities": [...], # 去重后的实体列表
  196. "chapter_entities": {...}, # 按章节分组的实体
  197. "llm_extractions": [...],
  198. "statistics": {...}
  199. }
  200. """
  201. logger.info(f"开始分章节提取: {len(chapters)} 个章节, parallel={parallel}")
  202. chapter_results = {}
  203. all_entities = []
  204. all_llm_extractions = []
  205. if parallel and len(chapters) > 1:
  206. # 并行处理章节
  207. tasks = []
  208. for chapter in chapters:
  209. task = self._extract_chapter(chapter, attachment_id, use_llm)
  210. tasks.append(task)
  211. results = await asyncio.gather(*tasks, return_exceptions=True)
  212. for chapter, result in zip(chapters, results):
  213. if isinstance(result, Exception):
  214. logger.error(f"章节 {chapter['chapter_id']} 提取失败: {result}")
  215. continue
  216. chapter_results[chapter['chapter_id']] = result
  217. else:
  218. # 串行处理章节
  219. for chapter in chapters:
  220. try:
  221. result = await self._extract_chapter(chapter, attachment_id, use_llm)
  222. chapter_results[chapter['chapter_id']] = result
  223. except Exception as e:
  224. logger.error(f"章节 {chapter['chapter_id']} 提取失败: {e}")
  225. # 合并所有章节的实体
  226. for chapter_id, result in chapter_results.items():
  227. for entity in result.get('entities', []):
  228. entity['chapter_id'] = chapter_id
  229. all_entities.append(entity)
  230. all_llm_extractions.extend(result.get('llm_extractions', []))
  231. # 去重
  232. unique_entities = self._deduplicate_entities(all_entities)
  233. unique_llm = self._deduplicate_llm_extractions(all_llm_extractions)
  234. logger.info(f"分章节提取完成: 原始 {len(all_entities)} 个实体, 去重后 {len(unique_entities)} 个")
  235. return {
  236. "entities": unique_entities,
  237. "chapter_entities": chapter_results,
  238. "llm_extractions": unique_llm,
  239. "statistics": {
  240. "chapter_count": len(chapters),
  241. "total_entities_before_dedup": len(all_entities),
  242. "unique_entity_count": len(unique_entities),
  243. "llm_extraction_count": len(unique_llm)
  244. }
  245. }
  246. async def _extract_chapter(
  247. self,
  248. chapter: Dict,
  249. attachment_id: int,
  250. use_llm: bool
  251. ) -> Dict[str, Any]:
  252. """提取单个章节的实体"""
  253. chapter_id = chapter.get('chapter_id', 'unknown')
  254. title = chapter.get('title', '')
  255. text = chapter.get('text', '')
  256. if not text or len(text.strip()) < 10:
  257. return {"entities": [], "llm_extractions": []}
  258. logger.debug(f"提取章节 {chapter_id}: {title[:30]}... (长度: {len(text)})")
  259. # NER提取
  260. entities = await self._extract_by_ner(text)
  261. # 为每个实体添加章节信息
  262. for entity in entities:
  263. entity['chapter_id'] = chapter_id
  264. entity['chapter_title'] = title
  265. # LLM提取(可选)
  266. llm_extractions = []
  267. if use_llm and self.deepseek_service:
  268. llm_extractions = await self._extract_by_llm(text)
  269. for item in llm_extractions:
  270. item['chapter_id'] = chapter_id
  271. item['chapter_title'] = title
  272. return {
  273. "entities": entities,
  274. "llm_extractions": llm_extractions
  275. }
  276. def _deduplicate_entities(self, entities: List[Dict]) -> List[Dict]:
  277. """
  278. 实体去重
  279. 去重规则:
  280. 1. 相同类型+相同文本 -> 保留第一个出现的
  281. 2. 包含关系 -> 保留更长的实体
  282. """
  283. if not entities:
  284. return []
  285. # 按 (type, text) 去重
  286. seen = {}
  287. for entity in entities:
  288. key = (entity.get('type', ''), entity.get('text', ''))
  289. if key not in seen:
  290. seen[key] = entity
  291. else:
  292. # 保留置信度更高的
  293. if entity.get('confidence', 0) > seen[key].get('confidence', 0):
  294. seen[key] = entity
  295. unique = list(seen.values())
  296. # 处理包含关系(可选,较复杂)
  297. # 例如:"中国电建集团" 和 "中国电建集团成都勘测设计研究院有限公司"
  298. # 保留更长的
  299. final = []
  300. texts = set()
  301. # 按文本长度降序排序
  302. unique.sort(key=lambda x: len(x.get('text', '')), reverse=True)
  303. for entity in unique:
  304. text = entity.get('text', '')
  305. # 检查是否被更长的实体包含
  306. is_substring = False
  307. for existing_text in texts:
  308. if text in existing_text and text != existing_text:
  309. is_substring = True
  310. break
  311. if not is_substring:
  312. final.append(entity)
  313. texts.add(text)
  314. # 恢复原始顺序(按位置)
  315. final.sort(key=lambda x: x.get('position', {}).get('start', 0))
  316. return final
  317. def _deduplicate_llm_extractions(self, extractions: List[Dict]) -> List[Dict]:
  318. """LLM提取结果去重"""
  319. if not extractions:
  320. return []
  321. seen = {}
  322. for item in extractions:
  323. key = (item.get('name', ''), item.get('value', ''))
  324. if key not in seen:
  325. seen[key] = item
  326. return list(seen.values())
  327. # 创建单例
  328. element_extractor = ElementExtractor()