relation_service.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. """
  2. 关系抽取服务实现
  3. 支持多种模式:
  4. 1. rule - 基于规则的简单关系抽取(默认)
  5. 2. api - 调用外部 API 进行关系抽取
  6. """
  7. import re
  8. from typing import List
  9. from loguru import logger
  10. from ..config import settings
  11. from ..models import EntityInfo, RelationInfo
  12. class RelationService:
  13. """关系抽取服务"""
  14. def __init__(self):
  15. self.model_type = settings.ner_model
  16. logger.info(f"初始化关系抽取服务: model_type={self.model_type}")
  17. async def extract_relations(
  18. self,
  19. text: str,
  20. entities: List[EntityInfo]
  21. ) -> List[RelationInfo]:
  22. """
  23. 从文本和实体中抽取关系
  24. Args:
  25. text: 原始文本
  26. entities: 已提取的实体列表
  27. Returns:
  28. 关系列表
  29. """
  30. if not text or not entities or len(entities) < 2:
  31. return []
  32. if self.model_type == "api":
  33. return await self._extract_by_api(text, entities)
  34. else:
  35. return await self._extract_by_rules(text, entities)
  36. async def _extract_by_rules(
  37. self,
  38. text: str,
  39. entities: List[EntityInfo]
  40. ) -> List[RelationInfo]:
  41. """
  42. 基于规则的关系抽取
  43. 规则策略:
  44. 1. 基于位置邻近性
  45. 2. 基于语义模式匹配
  46. """
  47. relations = []
  48. # 关系模式定义
  49. relation_patterns = {
  50. "负责": [r"负责", r"承担", r"主管", r"管理"],
  51. "属于": [r"属于", r"隶属", r"归属"],
  52. "位于": [r"位于", r"在", r"坐落于", r"地处"],
  53. "包含": [r"包含", r"包括", r"涵盖", r"含有"],
  54. "关联": [r"关联", r"相关", r"涉及", r"关于"],
  55. "生产": [r"生产", r"制造", r"制作"],
  56. "使用": [r"使用", r"采用", r"利用", r"应用"],
  57. "检测": [r"检测", r"检验", r"测试", r"测量"],
  58. "所有": [r"所有", r"拥有", r"持有"],
  59. "合作": [r"合作", r"协作", r"联合"],
  60. }
  61. # 按位置排序实体
  62. sorted_entities = sorted(
  63. [e for e in entities if e.position],
  64. key=lambda e: e.position.char_start if e.position else 0
  65. )
  66. seen_relations = set()
  67. # 检查相邻实体间的关系
  68. for i in range(len(sorted_entities) - 1):
  69. entity1 = sorted_entities[i]
  70. entity2 = sorted_entities[i + 1]
  71. if not entity1.position or not entity2.position:
  72. continue
  73. # 获取两个实体之间的文本
  74. start = entity1.position.char_end
  75. end = entity2.position.char_start
  76. if end <= start or end - start > 100: # 距离太远则跳过
  77. continue
  78. between_text = text[start:end]
  79. # 匹配关系模式
  80. for relation_type, patterns in relation_patterns.items():
  81. for pattern in patterns:
  82. if re.search(pattern, between_text):
  83. relation_key = f"{entity1.name}-{relation_type}-{entity2.name}"
  84. if relation_key not in seen_relations:
  85. seen_relations.add(relation_key)
  86. relations.append(RelationInfo(
  87. from_entity=entity1.name,
  88. from_entity_id=entity1.temp_id,
  89. to_entity=entity2.name,
  90. to_entity_id=entity2.temp_id,
  91. relation_type=relation_type,
  92. confidence=0.75
  93. ))
  94. break
  95. # 基于实体类型的隐含关系
  96. org_entities = [e for e in entities if e.type == "ORG"]
  97. person_entities = [e for e in entities if e.type == "PERSON"]
  98. loc_entities = [e for e in entities if e.type == "LOC"]
  99. project_entities = [e for e in entities if e.type == "PROJECT"]
  100. device_entities = [e for e in entities if e.type == "DEVICE"]
  101. # 人员-机构 关系
  102. for person in person_entities:
  103. for org in org_entities:
  104. relation_key = f"{person.name}-属于-{org.name}"
  105. if relation_key not in seen_relations:
  106. # 检查是否在同一句中
  107. if self._in_same_sentence(text, person, org):
  108. seen_relations.add(relation_key)
  109. relations.append(RelationInfo(
  110. from_entity=person.name,
  111. from_entity_id=person.temp_id,
  112. to_entity=org.name,
  113. to_entity_id=org.temp_id,
  114. relation_type="属于",
  115. confidence=0.6
  116. ))
  117. # 机构-地点 关系
  118. for org in org_entities:
  119. for loc in loc_entities:
  120. relation_key = f"{org.name}-位于-{loc.name}"
  121. if relation_key not in seen_relations:
  122. if self._in_same_sentence(text, org, loc):
  123. seen_relations.add(relation_key)
  124. relations.append(RelationInfo(
  125. from_entity=org.name,
  126. from_entity_id=org.temp_id,
  127. to_entity=loc.name,
  128. to_entity_id=loc.temp_id,
  129. relation_type="位于",
  130. confidence=0.6
  131. ))
  132. # 机构-项目 关系
  133. for org in org_entities:
  134. for project in project_entities:
  135. relation_key = f"{org.name}-负责-{project.name}"
  136. if relation_key not in seen_relations:
  137. if self._in_same_sentence(text, org, project):
  138. seen_relations.add(relation_key)
  139. relations.append(RelationInfo(
  140. from_entity=org.name,
  141. from_entity_id=org.temp_id,
  142. to_entity=project.name,
  143. to_entity_id=project.temp_id,
  144. relation_type="负责",
  145. confidence=0.6
  146. ))
  147. # 项目-设备 关系
  148. for project in project_entities:
  149. for device in device_entities:
  150. relation_key = f"{project.name}-使用-{device.name}"
  151. if relation_key not in seen_relations:
  152. if self._in_same_sentence(text, project, device):
  153. seen_relations.add(relation_key)
  154. relations.append(RelationInfo(
  155. from_entity=project.name,
  156. from_entity_id=project.temp_id,
  157. to_entity=device.name,
  158. to_entity_id=device.temp_id,
  159. relation_type="使用",
  160. confidence=0.6
  161. ))
  162. logger.info(f"规则关系抽取完成: relation_count={len(relations)}")
  163. return relations
  164. def _in_same_sentence(self, text: str, entity1: EntityInfo, entity2: EntityInfo) -> bool:
  165. """判断两个实体是否在同一句中"""
  166. if not entity1.position or not entity2.position:
  167. return False
  168. # 获取两个实体的位置范围
  169. start = min(entity1.position.char_start, entity2.position.char_start)
  170. end = max(entity1.position.char_end, entity2.position.char_end)
  171. # 检查范围内是否有句号等标点
  172. between_text = text[start:end]
  173. sentence_enders = ["。", "!", "?", ".", "!", "?", "\n\n"]
  174. for ender in sentence_enders:
  175. if ender in between_text:
  176. return False
  177. # 距离限制
  178. if end - start > 200:
  179. return False
  180. return True
  181. async def _extract_by_api(
  182. self,
  183. text: str,
  184. entities: List[EntityInfo]
  185. ) -> List[RelationInfo]:
  186. """
  187. 调用外部 API 进行关系抽取
  188. """
  189. # TODO: 实现 API 关系抽取
  190. logger.warning("API 关系抽取尚未实现,回退到规则模式")
  191. return await self._extract_by_rules(text, entities)
  192. # 创建单例
  193. relation_service = RelationService()