ner_service.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. """
  2. NER 服务实现 - 使用 HanLP
  3. HanLP 是一个中文NLP工具包,支持高质量的命名实体识别。
  4. """
  5. import uuid
  6. from typing import List, Optional
  7. from loguru import logger
  8. from ..models import EntityInfo, PositionInfo
  9. # 每个分段的最大字符数(HanLP对长文本有限制)
  10. MAX_SEGMENT_LENGTH = 500
  11. # 需要过滤的实体类型(这些类型通常是噪音)
  12. FILTER_TYPES = {
  13. 'INTEGER', 'DECIMAL', 'FRACTION', 'ORDINAL', 'CARDINAL',
  14. 'RATE', 'DURATION', 'NUMBER', 'POSTALCODE'
  15. }
  16. # 需要保留的核心实体类型
  17. KEEP_TYPES = {'PERSON', 'ORG', 'LOC', 'DATE', 'TIME', 'MONEY', 'PERCENT'}
  18. # 太泛化的实体(黑名单)
  19. BLACKLIST_ENTITIES = {
  20. '公司', '评审组', '部门', '单位', '组织', '机构', '中心', '委员会',
  21. '第一', '第二', '第三', '第四', '第五', '一级', '二级', '三级',
  22. '百千万', '十四五', '十三五'
  23. }
  24. class NerService:
  25. """NER 服务 - 基于 HanLP"""
  26. def __init__(self):
  27. self._hanlp_ner = None
  28. self._hanlp_tokenizer = None
  29. logger.info("初始化 NER 服务: model=HanLP")
  30. def _load_model(self):
  31. """延迟加载HanLP模型"""
  32. if self._hanlp_ner is not None:
  33. return
  34. try:
  35. import hanlp
  36. logger.info("正在加载HanLP NER模型...")
  37. # 使用MTL多任务模型,更稳定
  38. self._hanlp_ner = hanlp.load(hanlp.pretrained.mtl.CLOSE_TOK_POS_NER_SRL_DEP_SDP_CON_ELECTRA_SMALL_ZH)
  39. logger.info("HanLP NER模型加载完成")
  40. except ImportError:
  41. logger.error("HanLP未安装,请运行: pip install hanlp")
  42. raise
  43. except Exception as e:
  44. logger.error(f"HanLP模型加载失败: {e}")
  45. raise
  46. def _split_text(self, text: str) -> List[tuple]:
  47. """
  48. 将长文本分段,返回 [(segment, offset), ...]
  49. """
  50. segments = []
  51. lines = text.split('\n')
  52. current_segment = ""
  53. current_offset = 0
  54. segment_start = 0
  55. for line in lines:
  56. if len(current_segment) + len(line) + 1 > MAX_SEGMENT_LENGTH:
  57. if current_segment:
  58. segments.append((current_segment, segment_start))
  59. current_segment = line
  60. segment_start = current_offset
  61. else:
  62. if current_segment:
  63. current_segment += '\n' + line
  64. else:
  65. current_segment = line
  66. segment_start = current_offset
  67. current_offset += len(line) + 1
  68. if current_segment:
  69. segments.append((current_segment, segment_start))
  70. return segments
  71. async def extract_entities(
  72. self,
  73. text: str,
  74. entity_types: Optional[List[str]] = None
  75. ) -> List[EntityInfo]:
  76. """
  77. 从文本中提取实体
  78. """
  79. if not text or not text.strip():
  80. return []
  81. # 加载模型
  82. self._load_model()
  83. # HanLP实体类型映射
  84. type_mapping = {
  85. 'PERSON': 'PERSON', 'PER': 'PERSON', 'NR': 'PERSON',
  86. 'ORGANIZATION': 'ORG', 'ORG': 'ORG', 'NT': 'ORG',
  87. 'LOCATION': 'LOC', 'LOC': 'LOC', 'GPE': 'LOC', 'NS': 'LOC',
  88. 'DATE': 'DATE', 'TIME': 'DATE',
  89. 'MONEY': 'NUMBER', 'PERCENT': 'NUMBER', 'QUANTITY': 'NUMBER', 'CARDINAL': 'NUMBER',
  90. }
  91. entities = []
  92. seen_entities = set()
  93. # 分段处理
  94. segments = self._split_text(text)
  95. total_segments = len(segments)
  96. logger.info(f"开始NER提取: 文本长度={len(text)}, 分段数={total_segments}")
  97. for seg_idx, (segment, offset) in enumerate(segments):
  98. if seg_idx % 10 == 0:
  99. logger.info(f"NER进度: {seg_idx}/{total_segments} 段")
  100. try:
  101. # 调用HanLP MTL模型
  102. result = self._hanlp_ner(segment, tasks='ner')
  103. # MTL模型返回格式: {'ner/msra': [['实体', '类型', start, end], ...]}
  104. ner_results = []
  105. if isinstance(result, dict):
  106. for key in result:
  107. if 'ner' in key.lower():
  108. ner_results = result[key]
  109. break
  110. elif isinstance(result, list):
  111. ner_results = result
  112. # 处理结果
  113. for item in ner_results:
  114. entity_text = None
  115. entity_type = None
  116. char_start = 0
  117. char_end = 0
  118. if isinstance(item, (list, tuple)) and len(item) >= 2:
  119. entity_text = item[0]
  120. entity_type = item[1]
  121. if len(item) >= 4:
  122. char_start = item[2] + offset
  123. char_end = item[3] + offset
  124. else:
  125. pos = segment.find(str(entity_text))
  126. char_start = pos + offset if pos >= 0 else offset
  127. char_end = char_start + len(str(entity_text))
  128. elif isinstance(item, dict):
  129. entity_text = item.get('text', item.get('word', ''))
  130. entity_type = item.get('type', item.get('label', 'UNKNOWN'))
  131. char_start = item.get('start', 0) + offset
  132. char_end = item.get('end', char_start + len(entity_text))
  133. else:
  134. continue
  135. if not entity_text or not entity_type:
  136. continue
  137. entity_text = str(entity_text)
  138. entity_type = str(entity_type)
  139. # 映射实体类型
  140. mapped_type = type_mapping.get(entity_type.upper(), entity_type.upper())
  141. # 过滤噪音类型(数字、序号等)
  142. if mapped_type in FILTER_TYPES or entity_type.upper() in FILTER_TYPES:
  143. continue
  144. # 只保留核心类型
  145. if mapped_type not in KEEP_TYPES and entity_type.upper() not in KEEP_TYPES:
  146. continue
  147. # 过滤实体类型(用户指定)
  148. if entity_types and mapped_type not in entity_types:
  149. continue
  150. # 黑名单过滤
  151. if entity_text in BLACKLIST_ENTITIES:
  152. continue
  153. # 去重(忽略类型,只看文本)
  154. if entity_text in seen_entities:
  155. continue
  156. seen_entities.add(entity_text)
  157. # 跳过太短的实体
  158. if len(entity_text) < 2:
  159. continue
  160. # 跳过纯数字
  161. if entity_text.replace('.', '').replace('-', '').isdigit():
  162. continue
  163. # 计算行号
  164. line_num = text[:char_start].count('\n') + 1 if char_start > 0 else 1
  165. # 获取上下文
  166. context_start = max(0, char_start - 20)
  167. context_end = min(len(text), char_end + 20)
  168. context = text[context_start:context_end]
  169. entity = EntityInfo(
  170. name=entity_text,
  171. type=mapped_type,
  172. value=entity_text,
  173. position=PositionInfo(
  174. char_start=char_start,
  175. char_end=char_end,
  176. line=line_num
  177. ),
  178. context=context,
  179. confidence=0.9,
  180. temp_id=str(uuid.uuid4())[:8]
  181. )
  182. entities.append(entity)
  183. except Exception as e:
  184. logger.warning(f"分段 {seg_idx} NER失败: {e}")
  185. continue
  186. logger.info(f"HanLP NER 提取完成: entity_count={len(entities)}")
  187. return entities
  188. # 创建单例
  189. ner_service = NerService()