| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228 |
- """
- 关系抽取服务实现
- 支持多种模式:
- 1. rule - 基于规则的简单关系抽取(默认)
- 2. api - 调用外部 API 进行关系抽取
- """
- import re
- from typing import List
- from loguru import logger
- from ..config import settings
- from ..models import EntityInfo, RelationInfo
- class RelationService:
- """关系抽取服务"""
-
- def __init__(self):
- self.model_type = settings.ner_model
- logger.info(f"初始化关系抽取服务: model_type={self.model_type}")
-
- async def extract_relations(
- self,
- text: str,
- entities: List[EntityInfo]
- ) -> List[RelationInfo]:
- """
- 从文本和实体中抽取关系
-
- Args:
- text: 原始文本
- entities: 已提取的实体列表
-
- Returns:
- 关系列表
- """
- if not text or not entities or len(entities) < 2:
- return []
-
- if self.model_type == "api":
- return await self._extract_by_api(text, entities)
- else:
- return await self._extract_by_rules(text, entities)
-
- async def _extract_by_rules(
- self,
- text: str,
- entities: List[EntityInfo]
- ) -> List[RelationInfo]:
- """
- 基于规则的关系抽取
-
- 规则策略:
- 1. 基于位置邻近性
- 2. 基于语义模式匹配
- """
- relations = []
-
- # 关系模式定义
- relation_patterns = {
- "负责": [r"负责", r"承担", r"主管", r"管理"],
- "属于": [r"属于", r"隶属", r"归属"],
- "位于": [r"位于", r"在", r"坐落于", r"地处"],
- "包含": [r"包含", r"包括", r"涵盖", r"含有"],
- "关联": [r"关联", r"相关", r"涉及", r"关于"],
- "生产": [r"生产", r"制造", r"制作"],
- "使用": [r"使用", r"采用", r"利用", r"应用"],
- "检测": [r"检测", r"检验", r"测试", r"测量"],
- "所有": [r"所有", r"拥有", r"持有"],
- "合作": [r"合作", r"协作", r"联合"],
- }
-
- # 按位置排序实体
- sorted_entities = sorted(
- [e for e in entities if e.position],
- key=lambda e: e.position.char_start if e.position else 0
- )
-
- seen_relations = set()
-
- # 检查相邻实体间的关系
- for i in range(len(sorted_entities) - 1):
- entity1 = sorted_entities[i]
- entity2 = sorted_entities[i + 1]
-
- if not entity1.position or not entity2.position:
- continue
-
- # 获取两个实体之间的文本
- start = entity1.position.char_end
- end = entity2.position.char_start
-
- if end <= start or end - start > 100: # 距离太远则跳过
- continue
-
- between_text = text[start:end]
-
- # 匹配关系模式
- for relation_type, patterns in relation_patterns.items():
- for pattern in patterns:
- if re.search(pattern, between_text):
- relation_key = f"{entity1.name}-{relation_type}-{entity2.name}"
- if relation_key not in seen_relations:
- seen_relations.add(relation_key)
- relations.append(RelationInfo(
- from_entity=entity1.name,
- from_entity_id=entity1.temp_id,
- to_entity=entity2.name,
- to_entity_id=entity2.temp_id,
- relation_type=relation_type,
- confidence=0.75
- ))
- break
-
- # 基于实体类型的隐含关系
- org_entities = [e for e in entities if e.type == "ORG"]
- person_entities = [e for e in entities if e.type == "PERSON"]
- loc_entities = [e for e in entities if e.type == "LOC"]
- project_entities = [e for e in entities if e.type == "PROJECT"]
- device_entities = [e for e in entities if e.type == "DEVICE"]
-
- # 人员-机构 关系
- for person in person_entities:
- for org in org_entities:
- relation_key = f"{person.name}-属于-{org.name}"
- if relation_key not in seen_relations:
- # 检查是否在同一句中
- if self._in_same_sentence(text, person, org):
- seen_relations.add(relation_key)
- relations.append(RelationInfo(
- from_entity=person.name,
- from_entity_id=person.temp_id,
- to_entity=org.name,
- to_entity_id=org.temp_id,
- relation_type="属于",
- confidence=0.6
- ))
-
- # 机构-地点 关系
- for org in org_entities:
- for loc in loc_entities:
- relation_key = f"{org.name}-位于-{loc.name}"
- if relation_key not in seen_relations:
- if self._in_same_sentence(text, org, loc):
- seen_relations.add(relation_key)
- relations.append(RelationInfo(
- from_entity=org.name,
- from_entity_id=org.temp_id,
- to_entity=loc.name,
- to_entity_id=loc.temp_id,
- relation_type="位于",
- confidence=0.6
- ))
-
- # 机构-项目 关系
- for org in org_entities:
- for project in project_entities:
- relation_key = f"{org.name}-负责-{project.name}"
- if relation_key not in seen_relations:
- if self._in_same_sentence(text, org, project):
- seen_relations.add(relation_key)
- relations.append(RelationInfo(
- from_entity=org.name,
- from_entity_id=org.temp_id,
- to_entity=project.name,
- to_entity_id=project.temp_id,
- relation_type="负责",
- confidence=0.6
- ))
-
- # 项目-设备 关系
- for project in project_entities:
- for device in device_entities:
- relation_key = f"{project.name}-使用-{device.name}"
- if relation_key not in seen_relations:
- if self._in_same_sentence(text, project, device):
- seen_relations.add(relation_key)
- relations.append(RelationInfo(
- from_entity=project.name,
- from_entity_id=project.temp_id,
- to_entity=device.name,
- to_entity_id=device.temp_id,
- relation_type="使用",
- confidence=0.6
- ))
-
- logger.info(f"规则关系抽取完成: relation_count={len(relations)}")
- return relations
-
- def _in_same_sentence(self, text: str, entity1: EntityInfo, entity2: EntityInfo) -> bool:
- """判断两个实体是否在同一句中"""
- if not entity1.position or not entity2.position:
- return False
-
- # 获取两个实体的位置范围
- start = min(entity1.position.char_start, entity2.position.char_start)
- end = max(entity1.position.char_end, entity2.position.char_end)
-
- # 检查范围内是否有句号等标点
- between_text = text[start:end]
- sentence_enders = ["。", "!", "?", ".", "!", "?", "\n\n"]
-
- for ender in sentence_enders:
- if ender in between_text:
- return False
-
- # 距离限制
- if end - start > 200:
- return False
-
- return True
-
- async def _extract_by_api(
- self,
- text: str,
- entities: List[EntityInfo]
- ) -> List[RelationInfo]:
- """
- 调用外部 API 进行关系抽取
- """
- # TODO: 实现 API 关系抽取
- logger.warning("API 关系抽取尚未实现,回退到规则模式")
- return await self._extract_by_rules(text, entities)
- # 创建单例
- relation_service = RelationService()
|