"""
Ollama LLM 服务
用于调用本地 Ollama 模型进行 NER 提取
"""
import json
import re
import uuid
import httpx
from typing import List, Optional, Dict, Any
from loguru import logger
from ..config import settings
from ..models import EntityInfo, PositionInfo
class OllamaService:
"""Ollama LLM 服务"""
def __init__(self):
self.base_url = settings.ollama_url
self.model = settings.ollama_model
self.timeout = settings.ollama_timeout
self.chunk_size = settings.chunk_size
self.chunk_overlap = settings.chunk_overlap
# 检测是否使用 UniversalNER
self.is_universal_ner = "universal-ner" in self.model.lower()
logger.info(f"初始化 Ollama 服务: url={self.base_url}, model={self.model}, universal_ner={self.is_universal_ner}")
def _split_text(self, text: str) -> List[Dict[str, Any]]:
"""
将长文本分割成多个块
Args:
text: 原始文本
Returns:
分块列表,每个块包含 text, start_pos, end_pos
"""
if len(text) <= self.chunk_size:
return [{"text": text, "start_pos": 0, "end_pos": len(text)}]
chunks = []
start = 0
while start < len(text):
end = min(start + self.chunk_size, len(text))
# 尝试在句号、换行处分割,避免截断句子
if end < len(text):
# 向前查找最近的分隔符
for sep in ['\n\n', '\n', '。', ';', '!', '?', '.']:
sep_pos = text.rfind(sep, start + self.chunk_size // 2, end)
if sep_pos > start:
end = sep_pos + len(sep)
break
chunk_text = text[start:end]
chunks.append({
"text": chunk_text,
"start_pos": start,
"end_pos": end
})
# 下一个块的起始位置(考虑重叠)
start = end - self.chunk_overlap if end < len(text) else end
logger.info(f"文本分割完成: 总长度={len(text)}, 分块数={len(chunks)}")
return chunks
def _build_ner_prompt(self, text: str, entity_types: Optional[List[str]] = None) -> str:
"""
构建 NER 提取的 Prompt
"""
types = entity_types or settings.entity_types
types_desc = ", ".join(types)
# 示例帮助模型理解格式
example = '{"entities": [{"name": "成都市", "type": "LOC", "charStart": 10, "charEnd": 13}, {"name": "2024年5月", "type": "DATE", "charStart": 0, "charEnd": 7}]}'
prompt = f"""从文本中提取命名实体,只输出JSON。
实体类型: {types_desc}
输出格式示例:
{example}
文本内容:
{text}
JSON结果:
```json"""
return prompt
async def _call_ollama(self, prompt: str) -> Optional[str]:
"""
调用 Ollama API
"""
url = f"{self.base_url}/api/generate"
payload = {
"model": self.model,
"prompt": prompt,
"stream": False,
"options": {
"temperature": 0.1, # 低温度,更确定性的输出
"num_predict": 4096, # 最大输出 token
}
}
try:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(url, json=payload)
response.raise_for_status()
result = response.json()
return result.get("response", "")
except httpx.TimeoutException:
logger.error(f"Ollama 请求超时: timeout={self.timeout}s")
return None
except Exception as e:
logger.error(f"Ollama 请求失败: {e}")
return None
def _parse_llm_response(self, response: str, chunk_start_pos: int = 0) -> List[EntityInfo]:
"""
解析 LLM 返回的 JSON 结果
Args:
response: LLM 返回的文本
chunk_start_pos: 当前分块在原文中的起始位置(用于位置校正)
"""
entities = []
try:
# Qwen3 可能有 thinking 模式,需要移除 ... 部分
response = re.sub(r'[\s\S]*?', '', response)
# 移除 markdown code block 标记
response = re.sub(r'```json\s*', '', response)
response = re.sub(r'```\s*$', '', response)
# 尝试提取 JSON 部分
json_match = re.search(r'\{[\s\S]*\}', response)
if not json_match:
logger.warning(f"LLM 响应中未找到 JSON, response={response[:300]}...")
return entities
json_str = json_match.group()
data = json.loads(json_str)
entity_list = data.get("entities", [])
for item in entity_list:
name = item.get("name", "").strip()
entity_type = item.get("type", "").upper()
char_start = item.get("charStart", 0)
char_end = item.get("charEnd", 0)
if not name or len(name) < 2:
continue
# 校正位置(加上分块的起始位置)
adjusted_start = char_start + chunk_start_pos
adjusted_end = char_end + chunk_start_pos
entity = EntityInfo(
name=name,
type=entity_type,
value=name,
position=PositionInfo(
char_start=adjusted_start,
char_end=adjusted_end,
line=1 # LLM 模式不计算行号
),
confidence=0.9, # LLM 模式默认较高置信度
temp_id=str(uuid.uuid4())[:8]
)
entities.append(entity)
except json.JSONDecodeError as e:
logger.warning(f"JSON 解析失败: {e}, response={response[:200]}...")
except Exception as e:
logger.error(f"解析 LLM 响应失败: {e}")
return entities
async def extract_entities(
self,
text: str,
entity_types: Optional[List[str]] = None
) -> List[EntityInfo]:
"""
使用 Ollama LLM 提取实体
支持长文本自动分块处理
自动检测是否使用 UniversalNER 并切换提取策略
"""
if not text or not text.strip():
return []
# 根据模型类型选择提取策略
if self.is_universal_ner:
return await self._extract_with_universal_ner(text, entity_types)
else:
return await self._extract_with_general_llm(text, entity_types)
async def _extract_with_general_llm(
self,
text: str,
entity_types: Optional[List[str]] = None
) -> List[EntityInfo]:
"""
使用通用 LLM(如 Qwen)提取实体
"""
# 分割长文本
chunks = self._split_text(text)
all_entities = []
seen_entities = set() # 用于去重
for i, chunk in enumerate(chunks):
logger.info(f"处理分块 {i+1}/{len(chunks)}: 长度={len(chunk['text'])}")
# 构建 prompt
prompt = self._build_ner_prompt(chunk["text"], entity_types)
# 调用 Ollama
response = await self._call_ollama(prompt)
if not response:
logger.warning(f"分块 {i+1} Ollama 返回为空")
continue
# 打印前 500 字符用于调试
logger.debug(f"分块 {i+1} LLM 响应: {response[:500]}...")
# 解析结果
entities = self._parse_llm_response(response, chunk["start_pos"])
# 去重
for entity in entities:
entity_key = f"{entity.type}:{entity.name}"
if entity_key not in seen_entities:
seen_entities.add(entity_key)
all_entities.append(entity)
logger.info(f"分块 {i+1} 提取实体: {len(entities)} 个")
logger.info(f"通用 LLM NER 提取完成: 总实体数={len(all_entities)}")
return all_entities
async def _extract_with_universal_ner(
self,
text: str,
entity_types: Optional[List[str]] = None
) -> List[EntityInfo]:
"""
使用 UniversalNER 模型提取实体
UniversalNER 的 Prompt 格式: "文本内容. 实体类型英文名"
返回格式: ["实体1", "实体2", ...]
"""
# 实体类型映射(中文类型 -> UniversalNER 英文类型)
type_mapping = {
"PERSON": ["person", "people", "human"],
"ORG": ["organization", "company", "institution"],
"LOC": ["location", "place", "address"],
"DATE": ["date", "time"],
"NUMBER": ["number", "quantity", "measurement"],
"DEVICE": ["device", "equipment", "instrument"],
"PROJECT": ["project", "program"],
"METHOD": ["method", "standard", "specification"],
}
types_to_extract = entity_types or list(type_mapping.keys())
# 分割长文本
chunks = self._split_text(text)
all_entities = []
seen_entities = set() # 用于去重
for i, chunk in enumerate(chunks):
chunk_text = chunk["text"]
chunk_start = chunk["start_pos"]
logger.info(f"UniversalNER 处理分块 {i+1}/{len(chunks)}: 长度={len(chunk_text)}")
# 对每种实体类型分别提取
for entity_type in types_to_extract:
if entity_type not in type_mapping:
continue
# 使用第一个英文类型名
english_type = type_mapping[entity_type][0]
# UniversalNER 的 Prompt 格式
prompt = f"{chunk_text} {english_type}"
# 调用 Ollama
response = await self._call_ollama(prompt)
if not response:
continue
# 解析 UniversalNER 响应(返回格式如: ["实体1", "实体2"])
entities = self._parse_universal_ner_response(
response, entity_type, chunk_text, chunk_start
)
# 去重
for entity in entities:
entity_key = f"{entity.type}:{entity.name}"
if entity_key not in seen_entities:
seen_entities.add(entity_key)
all_entities.append(entity)
logger.info(f"分块 {i+1} UniversalNER 提取实体: {len([e for e in all_entities if e not in seen_entities])} 个")
logger.info(f"UniversalNER 提取完成: 总实体数={len(all_entities)}")
return all_entities
def _parse_universal_ner_response(
self,
response: str,
entity_type: str,
original_text: str,
chunk_start_pos: int = 0
) -> List[EntityInfo]:
"""
解析 UniversalNER 的响应
UniversalNER 返回格式: ["实体1", "实体2", ...]
"""
entities = []
try:
# 清理响应,提取 JSON 数组
response = response.strip()
# 尝试找到 JSON 数组
json_match = re.search(r'\[[\s\S]*?\]', response)
if not json_match:
logger.debug(f"UniversalNER 响应中未找到数组: {response[:100]}")
return entities
json_str = json_match.group()
entity_names = json.loads(json_str)
if not isinstance(entity_names, list):
return entities
for name in entity_names:
if not isinstance(name, str) or len(name) < 2:
continue
name = name.strip()
# 在原文中查找位置
pos = original_text.find(name)
char_start = pos + chunk_start_pos if pos >= 0 else 0
char_end = char_start + len(name) if pos >= 0 else 0
entity = EntityInfo(
name=name,
type=entity_type,
value=name,
position=PositionInfo(
char_start=char_start,
char_end=char_end,
line=1
),
confidence=0.85, # UniversalNER 置信度
temp_id=str(uuid.uuid4())[:8]
)
entities.append(entity)
except json.JSONDecodeError as e:
logger.debug(f"UniversalNER JSON 解析失败: {e}, response={response[:100]}")
except Exception as e:
logger.error(f"解析 UniversalNER 响应失败: {e}")
return entities
async def check_health(self) -> bool:
"""
检查 Ollama 服务是否可用
"""
try:
async with httpx.AsyncClient(timeout=5) as client:
response = await client.get(f"{self.base_url}/api/tags")
return response.status_code == 200
except Exception:
return False
# 创建单例
ollama_service = OllamaService()