ollama_service.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. """
  2. Ollama LLM 服务
  3. 用于调用本地 Ollama 模型进行 NER 提取
  4. """
  5. import json
  6. import re
  7. import uuid
  8. import httpx
  9. from typing import List, Optional, Dict, Any
  10. from loguru import logger
  11. from ..config import settings
  12. from ..models import EntityInfo, PositionInfo
  13. class OllamaService:
  14. """Ollama LLM 服务"""
  15. def __init__(self):
  16. self.base_url = settings.ollama_url
  17. self.model = settings.ollama_model
  18. self.timeout = settings.ollama_timeout
  19. self.chunk_size = settings.chunk_size
  20. self.chunk_overlap = settings.chunk_overlap
  21. # 检测是否使用 UniversalNER
  22. self.is_universal_ner = "universal-ner" in self.model.lower()
  23. logger.info(f"初始化 Ollama 服务: url={self.base_url}, model={self.model}, universal_ner={self.is_universal_ner}")
  24. def _split_text(self, text: str) -> List[Dict[str, Any]]:
  25. """
  26. 将长文本分割成多个块
  27. Args:
  28. text: 原始文本
  29. Returns:
  30. 分块列表,每个块包含 text, start_pos, end_pos
  31. """
  32. if len(text) <= self.chunk_size:
  33. return [{"text": text, "start_pos": 0, "end_pos": len(text)}]
  34. chunks = []
  35. start = 0
  36. while start < len(text):
  37. end = min(start + self.chunk_size, len(text))
  38. # 尝试在句号、换行处分割,避免截断句子
  39. if end < len(text):
  40. # 向前查找最近的分隔符
  41. for sep in ['\n\n', '\n', '。', ';', '!', '?', '.']:
  42. sep_pos = text.rfind(sep, start + self.chunk_size // 2, end)
  43. if sep_pos > start:
  44. end = sep_pos + len(sep)
  45. break
  46. chunk_text = text[start:end]
  47. chunks.append({
  48. "text": chunk_text,
  49. "start_pos": start,
  50. "end_pos": end
  51. })
  52. # 下一个块的起始位置(考虑重叠)
  53. start = end - self.chunk_overlap if end < len(text) else end
  54. logger.info(f"文本分割完成: 总长度={len(text)}, 分块数={len(chunks)}")
  55. return chunks
  56. def _build_ner_prompt(self, text: str, entity_types: Optional[List[str]] = None) -> str:
  57. """
  58. 构建 NER 提取的 Prompt
  59. """
  60. types = entity_types or settings.entity_types
  61. types_desc = ", ".join(types)
  62. # 示例帮助模型理解格式
  63. example = '{"entities": [{"name": "成都市", "type": "LOC", "charStart": 10, "charEnd": 13}, {"name": "2024年5月", "type": "DATE", "charStart": 0, "charEnd": 7}]}'
  64. # /no_think 指令用于禁用 Qwen3 的思考模式
  65. prompt = f"""/no_think
  66. 你是一个命名实体识别(NER)专家。请从以下文本中提取命名实体。
  67. 【任务要求】
  68. 1. 只输出JSON格式,不要输出任何解释或思考过程
  69. 2. 实体类型: {types_desc}
  70. 3. charStart和charEnd是实体在文本中的字符位置索引(从0开始)
  71. 【输出格式】
  72. {example}
  73. 【待处理文本】
  74. {text}
  75. 【JSON输出】"""
  76. return prompt
  77. async def _call_ollama(self, prompt: str, disable_thinking: bool = True) -> Optional[str]:
  78. """
  79. 调用 Ollama API
  80. Args:
  81. prompt: 输入提示词
  82. disable_thinking: 是否禁用思考模式(适用于 Qwen3 等支持思考的模型)
  83. """
  84. url = f"{self.base_url}/api/generate"
  85. payload = {
  86. "model": self.model,
  87. "prompt": prompt,
  88. "stream": False,
  89. "options": {
  90. "temperature": 0.1, # 低温度,更确定性的输出
  91. "num_predict": 20480, # 最大输出 token
  92. }
  93. }
  94. # Qwen3 思考模式:禁用思考,直接输出 JSON 结果
  95. # 思考模式会导致 token 用于推理过程,无法输出最终结果
  96. payload["think"] = False
  97. try:
  98. async with httpx.AsyncClient(timeout=self.timeout) as client:
  99. response = await client.post(url, json=payload)
  100. response.raise_for_status()
  101. result = response.json()
  102. return result.get("response", "")
  103. except httpx.TimeoutException:
  104. logger.error(f"Ollama 请求超时: timeout={self.timeout}s")
  105. return None
  106. except Exception as e:
  107. logger.error(f"Ollama 请求失败: {e}")
  108. return None
  109. def _parse_llm_response(self, response: str, chunk_start_pos: int = 0) -> List[EntityInfo]:
  110. """
  111. 解析 LLM 返回的 JSON 结果
  112. Args:
  113. response: LLM 返回的文本
  114. chunk_start_pos: 当前分块在原文中的起始位置(用于位置校正)
  115. """
  116. entities = []
  117. try:
  118. # Qwen3 思考模式处理:提取 </think> 之后的内容
  119. think_end = response.find('</think>')
  120. if think_end != -1:
  121. # 只保留思考结束后的内容
  122. response = response[think_end + len('</think>'):]
  123. logger.debug(f"提取思考后内容: {response[:200]}...")
  124. else:
  125. # 检查是否存在 <think> 但没有 </think>(思考未完成或被截断)
  126. think_start = response.find('<think>')
  127. if think_start != -1:
  128. # 尝试从 <think> 之前的内容或整个响应中查找 JSON
  129. # 有些情况下 JSON 可能在思考标签之前
  130. pre_think = response[:think_start].strip()
  131. if pre_think:
  132. response = pre_think
  133. logger.debug(f"使用思考前内容: {response[:200]}...")
  134. else:
  135. # 思考内容中可能包含 JSON,尝试直接从响应中提取
  136. logger.debug("检测到不完整的思考模式,尝试直接提取JSON")
  137. # 移除 markdown code block 标记
  138. response = re.sub(r'```json\s*', '', response)
  139. response = re.sub(r'```\s*', '', response)
  140. response = response.strip()
  141. # 方法1:直接尝试解析整个响应(如果是纯 JSON)
  142. data = None
  143. try:
  144. data = json.loads(response)
  145. except json.JSONDecodeError:
  146. pass
  147. # 方法2:查找包含 entities 的 JSON 对象(使用更宽松的匹配)
  148. if not data or "entities" not in data:
  149. # 匹配 {"entities": [...]} 格式,使用贪婪匹配以捕获完整的嵌套结构
  150. # 先尝试找到所有可能的 JSON 对象
  151. json_matches = re.findall(r'\{[^{}]*"entities"\s*:\s*\[[^\]]*\][^{}]*\}', response)
  152. for json_str in json_matches:
  153. try:
  154. data = json.loads(json_str)
  155. if "entities" in data:
  156. break
  157. except json.JSONDecodeError:
  158. continue
  159. # 方法3:尝试更宽松的正则匹配(处理多行和嵌套)
  160. if not data or "entities" not in data:
  161. # 匹配从 {"entities" 开始到最后一个 ]} 的内容
  162. json_match = re.search(r'\{\s*"entities"\s*:\s*\[[\s\S]*\]\s*\}', response)
  163. if json_match:
  164. try:
  165. data = json.loads(json_match.group())
  166. except json.JSONDecodeError:
  167. pass
  168. if not data or "entities" not in data:
  169. logger.warning(f"未找到有效的 entities JSON, response={response[:300]}...")
  170. return entities
  171. entity_list = data.get("entities", [])
  172. for item in entity_list:
  173. name = item.get("name", "").strip()
  174. entity_type = item.get("type", "").upper()
  175. char_start = item.get("charStart", 0)
  176. char_end = item.get("charEnd", 0)
  177. if not name or len(name) < 2:
  178. continue
  179. # 校正位置(加上分块的起始位置)
  180. adjusted_start = char_start + chunk_start_pos
  181. adjusted_end = char_end + chunk_start_pos
  182. entity = EntityInfo(
  183. name=name,
  184. type=entity_type,
  185. value=name,
  186. position=PositionInfo(
  187. char_start=adjusted_start,
  188. char_end=adjusted_end,
  189. line=1 # LLM 模式不计算行号
  190. ),
  191. confidence=0.9, # LLM 模式默认较高置信度
  192. temp_id=str(uuid.uuid4())[:8]
  193. )
  194. entities.append(entity)
  195. except json.JSONDecodeError as e:
  196. logger.warning(f"JSON 解析失败: {e}, response={response[:200]}...")
  197. except Exception as e:
  198. logger.error(f"解析 LLM 响应失败: {e}")
  199. return entities
  200. async def extract_entities(
  201. self,
  202. text: str,
  203. entity_types: Optional[List[str]] = None
  204. ) -> List[EntityInfo]:
  205. """
  206. 使用 Ollama LLM 提取实体
  207. 支持长文本自动分块处理
  208. 自动检测是否使用 UniversalNER 并切换提取策略
  209. """
  210. if not text or not text.strip():
  211. return []
  212. # 根据模型类型选择提取策略
  213. if self.is_universal_ner:
  214. return await self._extract_with_universal_ner(text, entity_types)
  215. else:
  216. return await self._extract_with_general_llm(text, entity_types)
  217. async def _extract_with_general_llm(
  218. self,
  219. text: str,
  220. entity_types: Optional[List[str]] = None
  221. ) -> List[EntityInfo]:
  222. """
  223. 使用通用 LLM(如 Qwen)提取实体
  224. """
  225. # 分割长文本
  226. chunks = self._split_text(text)
  227. all_entities = []
  228. seen_entities = set() # 用于去重
  229. for i, chunk in enumerate(chunks):
  230. logger.info(f"处理分块 {i+1}/{len(chunks)}: 长度={len(chunk['text'])}")
  231. # 构建 prompt
  232. prompt = self._build_ner_prompt(chunk["text"], entity_types)
  233. # 调用 Ollama
  234. response = await self._call_ollama(prompt)
  235. if not response:
  236. logger.warning(f"分块 {i+1} Ollama 返回为空")
  237. continue
  238. # 打印完整响应用于调试
  239. logger.debug(f"分块 {i+1} LLM 完整响应:\n{response}\n{'='*50}")
  240. # 解析结果
  241. entities = self._parse_llm_response(response, chunk["start_pos"])
  242. # 去重
  243. for entity in entities:
  244. entity_key = f"{entity.type}:{entity.name}"
  245. if entity_key not in seen_entities:
  246. seen_entities.add(entity_key)
  247. all_entities.append(entity)
  248. logger.info(f"分块 {i+1} 提取实体: {len(entities)} 个")
  249. logger.info(f"通用 LLM NER 提取完成: 总实体数={len(all_entities)}")
  250. return all_entities
  251. async def _extract_with_universal_ner(
  252. self,
  253. text: str,
  254. entity_types: Optional[List[str]] = None
  255. ) -> List[EntityInfo]:
  256. """
  257. 使用 UniversalNER 模型提取实体
  258. UniversalNER 的 Prompt 格式: "文本内容. 实体类型英文名"
  259. 返回格式: ["实体1", "实体2", ...]
  260. """
  261. # 实体类型映射(中文类型 -> UniversalNER 英文类型)
  262. type_mapping = {
  263. "PERSON": ["person", "people", "human"],
  264. "ORG": ["organization", "company", "institution"],
  265. "LOC": ["location", "place", "address"],
  266. "DATE": ["date", "time"],
  267. "NUMBER": ["number", "quantity", "measurement"],
  268. "DEVICE": ["device", "equipment", "instrument"],
  269. "PROJECT": ["project", "program"],
  270. "METHOD": ["method", "standard", "specification"],
  271. }
  272. types_to_extract = entity_types or list(type_mapping.keys())
  273. # 分割长文本
  274. chunks = self._split_text(text)
  275. all_entities = []
  276. seen_entities = set() # 用于去重
  277. for i, chunk in enumerate(chunks):
  278. chunk_text = chunk["text"]
  279. chunk_start = chunk["start_pos"]
  280. logger.info(f"UniversalNER 处理分块 {i+1}/{len(chunks)}: 长度={len(chunk_text)}")
  281. # 对每种实体类型分别提取
  282. for entity_type in types_to_extract:
  283. if entity_type not in type_mapping:
  284. continue
  285. # 使用第一个英文类型名
  286. english_type = type_mapping[entity_type][0]
  287. # UniversalNER 的 Prompt 格式
  288. prompt = f"{chunk_text} {english_type}"
  289. # 调用 Ollama
  290. response = await self._call_ollama(prompt)
  291. if not response:
  292. continue
  293. # 解析 UniversalNER 响应(返回格式如: ["实体1", "实体2"])
  294. entities = self._parse_universal_ner_response(
  295. response, entity_type, chunk_text, chunk_start
  296. )
  297. # 去重
  298. for entity in entities:
  299. entity_key = f"{entity.type}:{entity.name}"
  300. if entity_key not in seen_entities:
  301. seen_entities.add(entity_key)
  302. all_entities.append(entity)
  303. logger.info(f"分块 {i+1} UniversalNER 提取实体: {len([e for e in all_entities if e not in seen_entities])} 个")
  304. logger.info(f"UniversalNER 提取完成: 总实体数={len(all_entities)}")
  305. return all_entities
  306. def _parse_universal_ner_response(
  307. self,
  308. response: str,
  309. entity_type: str,
  310. original_text: str,
  311. chunk_start_pos: int = 0
  312. ) -> List[EntityInfo]:
  313. """
  314. 解析 UniversalNER 的响应
  315. UniversalNER 返回格式: ["实体1", "实体2", ...]
  316. """
  317. entities = []
  318. try:
  319. # 清理响应,提取 JSON 数组
  320. response = response.strip()
  321. # 尝试找到 JSON 数组
  322. json_match = re.search(r'\[[\s\S]*?\]', response)
  323. if not json_match:
  324. logger.debug(f"UniversalNER 响应中未找到数组: {response[:100]}")
  325. return entities
  326. json_str = json_match.group()
  327. entity_names = json.loads(json_str)
  328. if not isinstance(entity_names, list):
  329. return entities
  330. for name in entity_names:
  331. if not isinstance(name, str) or len(name) < 2:
  332. continue
  333. name = name.strip()
  334. # 在原文中查找位置
  335. pos = original_text.find(name)
  336. char_start = pos + chunk_start_pos if pos >= 0 else 0
  337. char_end = char_start + len(name) if pos >= 0 else 0
  338. entity = EntityInfo(
  339. name=name,
  340. type=entity_type,
  341. value=name,
  342. position=PositionInfo(
  343. char_start=char_start,
  344. char_end=char_end,
  345. line=1
  346. ),
  347. confidence=0.85, # UniversalNER 置信度
  348. temp_id=str(uuid.uuid4())[:8]
  349. )
  350. entities.append(entity)
  351. except json.JSONDecodeError as e:
  352. logger.debug(f"UniversalNER JSON 解析失败: {e}, response={response[:100]}")
  353. except Exception as e:
  354. logger.error(f"解析 UniversalNER 响应失败: {e}")
  355. return entities
  356. async def check_health(self) -> bool:
  357. """
  358. 检查 Ollama 服务是否可用
  359. """
  360. try:
  361. async with httpx.AsyncClient(timeout=5) as client:
  362. response = await client.get(f"{self.base_url}/api/tags")
  363. return response.status_code == 200
  364. except Exception:
  365. return False
  366. # 创建单例
  367. ollama_service = OllamaService()