deepseek_service.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. """
  2. DeepSeek API 服务(阿里云百炼平台)
  3. 用于调用 DeepSeek 模型进行 NER 提取
  4. """
  5. import json
  6. import re
  7. import uuid
  8. import httpx
  9. from typing import List, Optional, Dict, Any
  10. from loguru import logger
  11. from ..config import settings
  12. from ..models import EntityInfo, PositionInfo
  13. class DeepSeekService:
  14. """DeepSeek API 服务"""
  15. def __init__(self):
  16. self.api_key = settings.deepseek_api_key
  17. self.base_url = settings.deepseek_base_url
  18. self.model = settings.deepseek_model
  19. self.timeout = settings.deepseek_timeout
  20. self.temperature = settings.deepseek_temperature
  21. self.max_tokens = settings.deepseek_max_tokens
  22. self.max_retries = settings.deepseek_max_retries
  23. self.chunk_size = settings.chunk_size
  24. self.chunk_overlap = settings.chunk_overlap
  25. logger.info(f"初始化 DeepSeek 服务: model={self.model}, base_url={self.base_url}")
  26. def _split_text(self, text: str) -> List[Dict[str, Any]]:
  27. """
  28. 将长文本分割成多个块
  29. """
  30. if len(text) <= self.chunk_size:
  31. return [{"text": text, "start_pos": 0, "end_pos": len(text)}]
  32. chunks = []
  33. start = 0
  34. while start < len(text):
  35. end = min(start + self.chunk_size, len(text))
  36. # 尝试在句号、换行处分割
  37. if end < len(text):
  38. for sep in ['\n\n', '\n', '。', ';', '!', '?', '.']:
  39. sep_pos = text.rfind(sep, start + self.chunk_size // 2, end)
  40. if sep_pos > start:
  41. end = sep_pos + len(sep)
  42. break
  43. chunk_text = text[start:end]
  44. chunks.append({
  45. "text": chunk_text,
  46. "start_pos": start,
  47. "end_pos": end
  48. })
  49. start = end - self.chunk_overlap if end < len(text) else end
  50. logger.info(f"文本分割完成: 总长度={len(text)}, 分块数={len(chunks)}")
  51. return chunks
  52. def _build_ner_prompt(self, text: str, entity_types: Optional[List[str]] = None) -> str:
  53. """
  54. 构建 NER 提取的 Prompt
  55. """
  56. types = entity_types or settings.entity_types
  57. types_desc = ", ".join(types)
  58. example = '{"entities": [{"name": "成都市", "type": "LOC", "charStart": 10, "charEnd": 13}, {"name": "2024年5月", "type": "DATE", "charStart": 0, "charEnd": 7}]}'
  59. prompt = f"""请从以下文本中提取命名实体,直接输出JSON,不要解释。
  60. 实体类型: {types_desc}
  61. 输出格式示例:
  62. {example}
  63. 文本:
  64. {text}
  65. 请直接输出JSON:"""
  66. return prompt
  67. async def _call_api(self, prompt: str) -> Optional[str]:
  68. """
  69. 调用 DeepSeek API(OpenAI 兼容格式)
  70. """
  71. url = f"{self.base_url}/v1/chat/completions"
  72. headers = {
  73. "Authorization": f"Bearer {self.api_key}",
  74. "Content-Type": "application/json"
  75. }
  76. payload = {
  77. "model": self.model,
  78. "messages": [
  79. {
  80. "role": "user",
  81. "content": prompt
  82. }
  83. ],
  84. "temperature": self.temperature,
  85. "max_tokens": self.max_tokens,
  86. }
  87. for attempt in range(self.max_retries):
  88. try:
  89. async with httpx.AsyncClient(timeout=self.timeout) as client:
  90. response = await client.post(url, headers=headers, json=payload)
  91. response.raise_for_status()
  92. result = response.json()
  93. # OpenAI 格式响应
  94. choices = result.get("choices", [])
  95. if choices:
  96. message = choices[0].get("message", {})
  97. return message.get("content", "")
  98. return None
  99. except httpx.TimeoutException:
  100. logger.warning(f"DeepSeek API 请求超时 (尝试 {attempt + 1}/{self.max_retries})")
  101. if attempt == self.max_retries - 1:
  102. logger.error(f"DeepSeek API 请求超时: timeout={self.timeout}s")
  103. return None
  104. except httpx.HTTPStatusError as e:
  105. logger.error(f"DeepSeek API HTTP 错误: {e.response.status_code} - {e.response.text}")
  106. return None
  107. except Exception as e:
  108. logger.error(f"DeepSeek API 请求失败: {e}")
  109. if attempt == self.max_retries - 1:
  110. return None
  111. return None
  112. def _parse_response(self, response: str, chunk_start_pos: int = 0) -> List[EntityInfo]:
  113. """
  114. 解析 API 返回的 JSON 结果
  115. """
  116. entities = []
  117. try:
  118. # 移除 markdown code block 标记
  119. response = re.sub(r'```json\s*', '', response)
  120. response = re.sub(r'```\s*', '', response)
  121. response = response.strip()
  122. # 方法1:直接解析
  123. data = None
  124. try:
  125. data = json.loads(response)
  126. except json.JSONDecodeError:
  127. pass
  128. # 方法2:查找 JSON 对象
  129. if not data or "entities" not in data:
  130. json_match = re.search(r'\{\s*"entities"\s*:\s*\[[\s\S]*\]\s*\}', response)
  131. if json_match:
  132. try:
  133. data = json.loads(json_match.group())
  134. except json.JSONDecodeError:
  135. pass
  136. if not data or "entities" not in data:
  137. logger.warning(f"未找到有效的 entities JSON, response={response[:300]}...")
  138. return entities
  139. entity_list = data.get("entities", [])
  140. for item in entity_list:
  141. name = item.get("name", "").strip()
  142. entity_type = item.get("type", "").upper()
  143. char_start = item.get("charStart", 0)
  144. char_end = item.get("charEnd", 0)
  145. if not name or len(name) < 2:
  146. continue
  147. # 校正位置
  148. adjusted_start = char_start + chunk_start_pos
  149. adjusted_end = char_end + chunk_start_pos
  150. entity = EntityInfo(
  151. name=name,
  152. type=entity_type,
  153. value=name,
  154. position=PositionInfo(
  155. char_start=adjusted_start,
  156. char_end=adjusted_end,
  157. line=1
  158. ),
  159. confidence=0.95, # DeepSeek 置信度较高
  160. temp_id=str(uuid.uuid4())[:8]
  161. )
  162. entities.append(entity)
  163. except Exception as e:
  164. logger.error(f"解析响应失败: {e}")
  165. return entities
  166. async def extract_entities(
  167. self,
  168. text: str,
  169. entity_types: Optional[List[str]] = None
  170. ) -> List[EntityInfo]:
  171. """
  172. 使用 DeepSeek API 提取实体
  173. """
  174. if not text or not text.strip():
  175. return []
  176. # 分割长文本
  177. chunks = self._split_text(text)
  178. all_entities = []
  179. seen_entities = set()
  180. for i, chunk in enumerate(chunks):
  181. logger.info(f"处理分块 {i+1}/{len(chunks)}: 长度={len(chunk['text'])}")
  182. prompt = self._build_ner_prompt(chunk["text"], entity_types)
  183. response = await self._call_api(prompt)
  184. if not response:
  185. logger.warning(f"分块 {i+1} API 返回为空")
  186. continue
  187. logger.debug(f"分块 {i+1} API 响应: {response[:500]}...")
  188. entities = self._parse_response(response, chunk["start_pos"])
  189. # 去重
  190. for entity in entities:
  191. entity_key = f"{entity.type}:{entity.name}"
  192. if entity_key not in seen_entities:
  193. seen_entities.add(entity_key)
  194. all_entities.append(entity)
  195. logger.info(f"分块 {i+1} 提取实体: {len(entities)} 个")
  196. logger.info(f"DeepSeek NER 提取完成: 总实体数={len(all_entities)}")
  197. return all_entities
  198. async def extract_entities_with_progress(
  199. self,
  200. text: str,
  201. entity_types: Optional[List[str]] = None
  202. ):
  203. """
  204. 使用 DeepSeek API 提取实体(带进度生成器)
  205. Yields:
  206. SSE 事件字符串
  207. """
  208. import json
  209. async def sse_event(event: str, data: dict):
  210. return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
  211. if not text or not text.strip():
  212. yield await sse_event("complete", {"entities": [], "total_entities": 0})
  213. return
  214. # 分割长文本
  215. chunks = self._split_text(text)
  216. total_chunks = len(chunks)
  217. all_entities = []
  218. seen_entities = set()
  219. for i, chunk in enumerate(chunks):
  220. chunk_index = i + 1
  221. logger.info(f"处理分块 {chunk_index}/{total_chunks}: 长度={len(chunk['text'])}")
  222. # 发送进度事件
  223. yield await sse_event("progress", {
  224. "chunk_index": chunk_index,
  225. "total_chunks": total_chunks,
  226. "chunk_length": len(chunk['text']),
  227. "total_entities": len(all_entities),
  228. "progress_percent": int((chunk_index - 1) / total_chunks * 100),
  229. "message": f"正在处理第 {chunk_index}/{total_chunks} 个文本块..."
  230. })
  231. prompt = self._build_ner_prompt(chunk["text"], entity_types)
  232. response = await self._call_api(prompt)
  233. if not response:
  234. logger.warning(f"分块 {chunk_index} API 返回为空")
  235. continue
  236. logger.debug(f"分块 {chunk_index} API 响应: {response[:500]}...")
  237. entities = self._parse_response(response, chunk["start_pos"])
  238. # 去重并收集新实体
  239. new_entities = []
  240. for entity in entities:
  241. entity_key = f"{entity.type}:{entity.name}"
  242. if entity_key not in seen_entities:
  243. seen_entities.add(entity_key)
  244. all_entities.append(entity)
  245. new_entities.append(entity)
  246. logger.info(f"分块 {chunk_index} 提取实体: {len(entities)} 个, 新增: {len(new_entities)} 个")
  247. # 发送分块完成事件
  248. yield await sse_event("chunk_complete", {
  249. "chunk_index": chunk_index,
  250. "total_chunks": total_chunks,
  251. "chunk_entities": len(entities),
  252. "new_entities": len(new_entities),
  253. "total_entities": len(all_entities),
  254. "progress_percent": int(chunk_index / total_chunks * 100)
  255. })
  256. logger.info(f"DeepSeek NER 提取完成: 总实体数={len(all_entities)}")
  257. # 发送实体数据事件(供调用方获取实体列表)
  258. yield await sse_event("entities_data", {
  259. "entities": [entity.model_dump(by_alias=True) for entity in all_entities],
  260. "total_entities": len(all_entities)
  261. })
  262. async def check_health(self) -> bool:
  263. """
  264. 检查 DeepSeek API 是否可用
  265. """
  266. try:
  267. url = f"{self.base_url}/v1/models"
  268. headers = {
  269. "Authorization": f"Bearer {self.api_key}"
  270. }
  271. async with httpx.AsyncClient(timeout=10) as client:
  272. response = await client.get(url, headers=headers)
  273. return response.status_code == 200
  274. except Exception:
  275. return False
  276. # 创建单例
  277. deepseek_service = DeepSeekService()