ollama_service.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. """
  2. Ollama LLM 服务
  3. 用于调用本地 Ollama 模型进行 NER 提取
  4. """
  5. import json
  6. import re
  7. import uuid
  8. import httpx
  9. from typing import List, Optional, Dict, Any
  10. from loguru import logger
  11. from ..config import settings
  12. from ..models import EntityInfo, PositionInfo
  13. class OllamaService:
  14. """Ollama LLM 服务"""
  15. def __init__(self):
  16. self.base_url = settings.ollama_url
  17. self.model = settings.ollama_model
  18. self.timeout = settings.ollama_timeout
  19. self.chunk_size = settings.chunk_size
  20. self.chunk_overlap = settings.chunk_overlap
  21. logger.info(f"初始化 Ollama 服务: url={self.base_url}, model={self.model}")
  22. def _split_text(self, text: str) -> List[Dict[str, Any]]:
  23. """
  24. 将长文本分割成多个块
  25. Args:
  26. text: 原始文本
  27. Returns:
  28. 分块列表,每个块包含 text, start_pos, end_pos
  29. """
  30. if len(text) <= self.chunk_size:
  31. return [{"text": text, "start_pos": 0, "end_pos": len(text)}]
  32. chunks = []
  33. start = 0
  34. while start < len(text):
  35. end = min(start + self.chunk_size, len(text))
  36. # 尝试在句号、换行处分割,避免截断句子
  37. if end < len(text):
  38. # 向前查找最近的分隔符
  39. for sep in ['\n\n', '\n', '。', ';', '!', '?', '.']:
  40. sep_pos = text.rfind(sep, start + self.chunk_size // 2, end)
  41. if sep_pos > start:
  42. end = sep_pos + len(sep)
  43. break
  44. chunk_text = text[start:end]
  45. chunks.append({
  46. "text": chunk_text,
  47. "start_pos": start,
  48. "end_pos": end
  49. })
  50. # 下一个块的起始位置(考虑重叠)
  51. start = end - self.chunk_overlap if end < len(text) else end
  52. logger.info(f"文本分割完成: 总长度={len(text)}, 分块数={len(chunks)}")
  53. return chunks
  54. def _build_ner_prompt(self, text: str, entity_types: Optional[List[str]] = None) -> str:
  55. """
  56. 构建 NER 提取的 Prompt
  57. """
  58. types = entity_types or settings.entity_types
  59. types_desc = ", ".join(types)
  60. prompt = f"""你是一个专业的命名实体识别(NER)系统。请从以下文本中提取实体。
  61. ## 任务要求
  62. 1. 识别以下类型的实体: {types_desc}
  63. 2. 每个实体需要包含: 名称(name)、类型(type)、在文本中的起始位置(charStart)和结束位置(charEnd)
  64. 3. 只提取明确的、有意义的实体,避免提取过于泛化的词汇
  65. 4. 严格按照 JSON 格式输出
  66. ## 实体类型说明
  67. - PERSON: 人名(如:张三、李经理)
  68. - ORG: 机构/组织/公司(如:成都检测公司、环保局)
  69. - LOC: 地点/地址(如:成都市、高新区)
  70. - DATE: 日期时间(如:2024年5月15日、2024-05-15)
  71. - NUMBER: 带单位的数值(如:50分贝、100万元)
  72. - DEVICE: 设备仪器(如:噪音检测仪、分析仪器)
  73. - PROJECT: 项目/工程(如:环境监测项目、XX工程)
  74. - METHOD: 方法/标准(如:GB/T 12345、检测方法)
  75. ## 输出格式
  76. 请严格按以下 JSON 格式输出,不要包含其他内容:
  77. ```json
  78. {{
  79. "entities": [
  80. {{"name": "实体名称", "type": "实体类型", "charStart": 起始位置, "charEnd": 结束位置}}
  81. ]
  82. }}
  83. ```
  84. ## 待处理文本
  85. {text}
  86. ## 提取结果
  87. """
  88. return prompt
  89. async def _call_ollama(self, prompt: str) -> Optional[str]:
  90. """
  91. 调用 Ollama API
  92. """
  93. url = f"{self.base_url}/api/generate"
  94. payload = {
  95. "model": self.model,
  96. "prompt": prompt,
  97. "stream": False,
  98. "options": {
  99. "temperature": 0.1, # 低温度,更确定性的输出
  100. "num_predict": 4096, # 最大输出 token
  101. }
  102. }
  103. try:
  104. async with httpx.AsyncClient(timeout=self.timeout) as client:
  105. response = await client.post(url, json=payload)
  106. response.raise_for_status()
  107. result = response.json()
  108. return result.get("response", "")
  109. except httpx.TimeoutException:
  110. logger.error(f"Ollama 请求超时: timeout={self.timeout}s")
  111. return None
  112. except Exception as e:
  113. logger.error(f"Ollama 请求失败: {e}")
  114. return None
  115. def _parse_llm_response(self, response: str, chunk_start_pos: int = 0) -> List[EntityInfo]:
  116. """
  117. 解析 LLM 返回的 JSON 结果
  118. Args:
  119. response: LLM 返回的文本
  120. chunk_start_pos: 当前分块在原文中的起始位置(用于位置校正)
  121. """
  122. entities = []
  123. try:
  124. # 尝试提取 JSON 部分
  125. json_match = re.search(r'\{[\s\S]*\}', response)
  126. if not json_match:
  127. logger.warning("LLM 响应中未找到 JSON")
  128. return entities
  129. json_str = json_match.group()
  130. data = json.loads(json_str)
  131. entity_list = data.get("entities", [])
  132. for item in entity_list:
  133. name = item.get("name", "").strip()
  134. entity_type = item.get("type", "").upper()
  135. char_start = item.get("charStart", 0)
  136. char_end = item.get("charEnd", 0)
  137. if not name or len(name) < 2:
  138. continue
  139. # 校正位置(加上分块的起始位置)
  140. adjusted_start = char_start + chunk_start_pos
  141. adjusted_end = char_end + chunk_start_pos
  142. entity = EntityInfo(
  143. name=name,
  144. type=entity_type,
  145. value=name,
  146. position=PositionInfo(
  147. char_start=adjusted_start,
  148. char_end=adjusted_end,
  149. line=1 # LLM 模式不计算行号
  150. ),
  151. confidence=0.9, # LLM 模式默认较高置信度
  152. temp_id=str(uuid.uuid4())[:8]
  153. )
  154. entities.append(entity)
  155. except json.JSONDecodeError as e:
  156. logger.warning(f"JSON 解析失败: {e}, response={response[:200]}...")
  157. except Exception as e:
  158. logger.error(f"解析 LLM 响应失败: {e}")
  159. return entities
  160. async def extract_entities(
  161. self,
  162. text: str,
  163. entity_types: Optional[List[str]] = None
  164. ) -> List[EntityInfo]:
  165. """
  166. 使用 Ollama LLM 提取实体
  167. 支持长文本自动分块处理
  168. """
  169. if not text or not text.strip():
  170. return []
  171. # 分割长文本
  172. chunks = self._split_text(text)
  173. all_entities = []
  174. seen_entities = set() # 用于去重
  175. for i, chunk in enumerate(chunks):
  176. logger.info(f"处理分块 {i+1}/{len(chunks)}: 长度={len(chunk['text'])}")
  177. # 构建 prompt
  178. prompt = self._build_ner_prompt(chunk["text"], entity_types)
  179. # 调用 Ollama
  180. response = await self._call_ollama(prompt)
  181. if not response:
  182. logger.warning(f"分块 {i+1} Ollama 返回为空")
  183. continue
  184. # 解析结果
  185. entities = self._parse_llm_response(response, chunk["start_pos"])
  186. # 去重
  187. for entity in entities:
  188. entity_key = f"{entity.type}:{entity.name}"
  189. if entity_key not in seen_entities:
  190. seen_entities.add(entity_key)
  191. all_entities.append(entity)
  192. logger.info(f"分块 {i+1} 提取实体: {len(entities)} 个")
  193. logger.info(f"Ollama NER 提取完成: 总实体数={len(all_entities)}")
  194. return all_entities
  195. async def check_health(self) -> bool:
  196. """
  197. 检查 Ollama 服务是否可用
  198. """
  199. try:
  200. async with httpx.AsyncClient(timeout=5) as client:
  201. response = await client.get(f"{self.base_url}/api/tags")
  202. return response.status_code == 200
  203. except Exception:
  204. return False
  205. # 创建单例
  206. ollama_service = OllamaService()