deepseek_service.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. """
  2. DeepSeek API 服务(阿里云百炼平台)
  3. 用于调用 DeepSeek 模型进行 NER 提取
  4. """
  5. import json
  6. import re
  7. import uuid
  8. import httpx
  9. from typing import List, Optional, Dict, Any
  10. from loguru import logger
  11. from ..config import settings
  12. from ..models import EntityInfo, PositionInfo
  13. class DeepSeekService:
  14. """DeepSeek API 服务"""
  15. def __init__(self):
  16. self.api_key = settings.deepseek_api_key
  17. self.base_url = settings.deepseek_base_url
  18. self.model = settings.deepseek_model
  19. self.timeout = settings.deepseek_timeout
  20. self.temperature = settings.deepseek_temperature
  21. self.max_tokens = settings.deepseek_max_tokens
  22. self.max_retries = settings.deepseek_max_retries
  23. self.chunk_size = settings.chunk_size
  24. self.chunk_overlap = settings.chunk_overlap
  25. logger.info(f"初始化 DeepSeek 服务: model={self.model}, base_url={self.base_url}")
  26. def _split_text(self, text: str) -> List[Dict[str, Any]]:
  27. """
  28. 将长文本分割成多个块
  29. """
  30. if len(text) <= self.chunk_size:
  31. return [{"text": text, "start_pos": 0, "end_pos": len(text)}]
  32. chunks = []
  33. start = 0
  34. while start < len(text):
  35. end = min(start + self.chunk_size, len(text))
  36. # 尝试在句号、换行处分割
  37. if end < len(text):
  38. for sep in ['\n\n', '\n', '。', ';', '!', '?', '.']:
  39. sep_pos = text.rfind(sep, start + self.chunk_size // 2, end)
  40. if sep_pos > start:
  41. end = sep_pos + len(sep)
  42. break
  43. chunk_text = text[start:end]
  44. chunks.append({
  45. "text": chunk_text,
  46. "start_pos": start,
  47. "end_pos": end
  48. })
  49. start = end - self.chunk_overlap if end < len(text) else end
  50. logger.info(f"文本分割完成: 总长度={len(text)}, 分块数={len(chunks)}")
  51. return chunks
  52. def _build_ner_prompt(self, text: str, entity_types: Optional[List[str]] = None) -> str:
  53. """
  54. 构建 NER 提取的 Prompt
  55. """
  56. types = entity_types or settings.entity_types
  57. prompt = f"""你是专业的命名实体识别(NER)专家。请从以下文本中提取有意义的命名实体。
  58. ## 实体类型定义
  59. 1. **PERSON** - 人名
  60. - 包括:姓名、职务+姓名(如"张总"、"李工程师")
  61. - 不包括:职位本身(如"经理"、"主任")
  62. 2. **ORG** - 机构/组织/部门
  63. - 包括:公司、政府机关、协会、院校、内部部门、委员会等
  64. - 例如:"中国电建集团成都勘测设计研究院有限公司"、"国家能源局"、"安全质量环保部"、"人力资源部"、"安委会"
  65. - 不包括:泛指的简称(如"公司"、"集团"),除非是特定缩写(如"成都院"、"股份公司")
  66. 3. **LOC** - 地点/位置
  67. - 包括:省市区县、具体地址、工程位置
  68. - 例如:"四川省成都市"、"龙滩水电站"
  69. 4. **DATE** - 日期/时间
  70. - 包括:具体日期、时间段、年份
  71. - 例如:"2024年7月13日"、"2019年版"
  72. - 不包括:单独的年份数字
  73. 5. **NUMBER** - 有意义的数值
  74. - 包括:带单位的数量、金额、评分、比例
  75. - 例如:"93.33分"、"7000名"、"70多年"、"5个"
  76. - **不包括**:章节编号(如"1.1"、"第3章")、纯序号(如"16"、"17")、表格序号
  77. 6. **PROJECT** - 项目/工程名称
  78. - 包括:具体项目名称、工程名称
  79. - 例如:"白鹤滩水电站工程"、"安全生产标准化建设项目"、"龙滩水电站"
  80. - **不包括**:文件编号(归类为 DOC_ID)
  81. 7. **METHOD** - 方法/标准/规范
  82. - 包括:技术标准、管理办法、评价标准的**完整名称**
  83. - 例如:"电力建设企业安全生产标准化评价标准"、"安全生产风险抵押金管理办法"
  84. - **不包括**:编号(归类为 DOC_ID)
  85. 8. **DEVICE** - 设备/系统
  86. - 包括:设备名称、信息系统名称
  87. - 例如:"OA系统"、"QHSE系统"、"造槽机"
  88. 9. **DOC_ID** - 文件编号/标准编号
  89. - 包括:文件编号、标准编号、规范编号、合同编号、项目编号、证书编号
  90. - 例如:"AY-BZ-0092-2024-Z01"、"SQE.01C0213"、"中电建股〔2019〕122号"、"ZGDIDBOY-083"、"蓉设安质〔2024〕18号"
  91. - 特征:通常包含字母、数字、连字符、年份、括号等组合
  92. 10. **CERT** - 证书/资质/等级
  93. - 包括:**完整的**资质证书名称、等级认证、荣誉称号、专业资质
  94. - 例如:"电力安全生产标准化一级企业证书"、"工程设计综合甲级资质"、"注册安全工程师证"、"电力、水利水电、市政公用工程施工总承包一级"
  95. - **不包括**:
  96. - 职务职称(归类为 TITLE)
  97. - 单独的行业名称(如单独的"电力"、"水利水电"不是资质,需要完整如"电力工程施工总承包一级")
  98. - 资质的简称片段(如"甲级"本身不提取,需要完整如"工程设计综合甲级")
  99. 11. **TITLE** - 职务/职称
  100. - 包括:行政职务、专业职称、岗位名称
  101. - 例如:"董事长"、"总经理"、"安全总监"、"总工程师"、"首席专家"
  102. - 不包括:人名(归类为 PERSON)
  103. 12. **POLICY** - 政策法规
  104. - 包括:法律、法规、条例、办法的名称(不含编号)
  105. - 例如:"安全生产法"、"企业安全生产费用提取和使用管理办法"、"劳动合同法"
  106. - 不包括:企业内部制度(归类为 METHOD)、文件编号(归类为 DOC_ID)
  107. ## 重要规则
  108. 1. **只提取有实际意义的实体**,忽略:
  109. - 章节编号(如"1"、"1.1"、"第一章")
  110. - 表格序号
  111. - 单独的数字(如"16"、"17")
  112. 2. **实体必须完整**,提取完整的名称而非片段:
  113. - 资质要完整:提取"工程设计综合甲级"而非"甲级"
  114. - 证书要完整:提取"注册安全工程师证"而非"安全工程师"
  115. - 机构要完整:提取"安全质量环保部"而非"环保部"
  116. 3. **不要拆分并列实体**:
  117. - 如"电力、水利水电、市政公用工程施工总承包一级"应作为一个完整的 CERT
  118. - 不要拆成"电力"、"水利水电"等单独的实体
  119. 4. **去除重复实体**,相同的实体只返回一次
  120. 5. **charStart和charEnd必须准确**,对应实体在原文中的字符位置(从0开始)
  121. ## 输出格式
  122. 直接输出JSON,格式如下:
  123. {{"entities": [{{"name": "实体名称", "type": "实体类型", "charStart": 起始位置, "charEnd": 结束位置}}]}}
  124. ## 待处理文本
  125. {text}
  126. 请直接输出JSON:"""
  127. return prompt
  128. async def _call_api(self, prompt: str) -> Optional[str]:
  129. """
  130. 调用 DeepSeek API(OpenAI 兼容格式)
  131. """
  132. url = f"{self.base_url}/v1/chat/completions"
  133. headers = {
  134. "Authorization": f"Bearer {self.api_key}",
  135. "Content-Type": "application/json"
  136. }
  137. payload = {
  138. "model": self.model,
  139. "messages": [
  140. {
  141. "role": "user",
  142. "content": prompt
  143. }
  144. ],
  145. "temperature": self.temperature,
  146. "max_tokens": self.max_tokens,
  147. }
  148. for attempt in range(self.max_retries):
  149. try:
  150. async with httpx.AsyncClient(timeout=self.timeout) as client:
  151. response = await client.post(url, headers=headers, json=payload)
  152. response.raise_for_status()
  153. result = response.json()
  154. # OpenAI 格式响应
  155. choices = result.get("choices", [])
  156. if choices:
  157. choice = choices[0]
  158. message = choice.get("message", {})
  159. content = message.get("content", "")
  160. # 检查是否因为 max_tokens 被截断
  161. finish_reason = choice.get("finish_reason", "")
  162. if finish_reason == "length":
  163. logger.warning(f"API 响应被截断 (finish_reason=length), 考虑增加 max_tokens 或减小分块大小")
  164. return content
  165. return None
  166. except httpx.TimeoutException:
  167. logger.warning(f"DeepSeek API 请求超时 (尝试 {attempt + 1}/{self.max_retries})")
  168. if attempt == self.max_retries - 1:
  169. logger.error(f"DeepSeek API 请求超时: timeout={self.timeout}s")
  170. return None
  171. except httpx.HTTPStatusError as e:
  172. logger.error(f"DeepSeek API HTTP 错误: {e.response.status_code} - {e.response.text}")
  173. return None
  174. except Exception as e:
  175. logger.error(f"DeepSeek API 请求失败: {e}")
  176. if attempt == self.max_retries - 1:
  177. return None
  178. return None
  179. def _parse_response(self, response: str, chunk_start_pos: int = 0) -> List[EntityInfo]:
  180. """
  181. 解析 API 返回的 JSON 结果
  182. """
  183. entities = []
  184. try:
  185. # 移除 markdown code block 标记(支持多行模式)
  186. response = re.sub(r'```json\s*\n?', '', response, flags=re.IGNORECASE)
  187. response = re.sub(r'\n?```\s*$', '', response)
  188. response = re.sub(r'```\s*', '', response)
  189. response = response.strip()
  190. # 方法1:直接解析
  191. data = None
  192. try:
  193. data = json.loads(response)
  194. except json.JSONDecodeError as e:
  195. logger.debug(f"直接解析 JSON 失败: {e}")
  196. # 方法2:查找 JSON 对象(使用更宽松的正则)
  197. if not data or "entities" not in data:
  198. # 尝试匹配从 { 开始到最后一个 } 的内容
  199. json_match = re.search(r'\{[^{}]*"entities"\s*:\s*\[[\s\S]*?\]\s*\}', response)
  200. if json_match:
  201. try:
  202. data = json.loads(json_match.group())
  203. except json.JSONDecodeError as e:
  204. logger.debug(f"正则匹配 JSON 解析失败: {e}")
  205. # 方法3:尝试提取 entities 数组
  206. if not data or "entities" not in data:
  207. array_match = re.search(r'"entities"\s*:\s*(\[[\s\S]*\])', response)
  208. if array_match:
  209. try:
  210. entity_list = json.loads(array_match.group(1))
  211. data = {"entities": entity_list}
  212. except json.JSONDecodeError as e:
  213. logger.debug(f"提取 entities 数组失败: {e}")
  214. # 方法4:处理被截断的 JSON,尝试逐个解析完整的实体对象
  215. if not data or "entities" not in data:
  216. logger.debug("尝试从截断的 JSON 中提取完整实体...")
  217. entity_pattern = r'\{\s*"name"\s*:\s*"([^"]+)"\s*,\s*"type"\s*:\s*"([^"]+)"\s*,\s*"charStart"\s*:\s*(\d+)\s*,\s*"charEnd"\s*:\s*(\d+)\s*\}'
  218. matches = re.findall(entity_pattern, response)
  219. if matches:
  220. data = {"entities": []}
  221. for match in matches:
  222. data["entities"].append({
  223. "name": match[0],
  224. "type": match[1],
  225. "charStart": int(match[2]),
  226. "charEnd": int(match[3])
  227. })
  228. logger.info(f"从截断 JSON 中恢复了 {len(matches)} 个实体")
  229. if not data or "entities" not in data:
  230. logger.warning(f"未找到有效的 entities JSON, response_length={len(response)}, response_preview={response[:500]}...")
  231. return entities
  232. entity_list = data.get("entities", [])
  233. for item in entity_list:
  234. name = item.get("name", "").strip()
  235. entity_type = item.get("type", "").upper()
  236. char_start = item.get("charStart", 0)
  237. char_end = item.get("charEnd", 0)
  238. if not name or len(name) < 2:
  239. continue
  240. # 校正位置
  241. adjusted_start = char_start + chunk_start_pos
  242. adjusted_end = char_end + chunk_start_pos
  243. entity = EntityInfo(
  244. name=name,
  245. type=entity_type,
  246. value=name,
  247. position=PositionInfo(
  248. char_start=adjusted_start,
  249. char_end=adjusted_end,
  250. line=1
  251. ),
  252. confidence=0.95, # DeepSeek 置信度较高
  253. temp_id=str(uuid.uuid4())[:8]
  254. )
  255. entities.append(entity)
  256. except Exception as e:
  257. logger.error(f"解析响应失败: {e}")
  258. return entities
  259. async def extract_entities(
  260. self,
  261. text: str,
  262. entity_types: Optional[List[str]] = None
  263. ) -> List[EntityInfo]:
  264. """
  265. 使用 DeepSeek API 提取实体
  266. """
  267. if not text or not text.strip():
  268. return []
  269. # 分割长文本
  270. chunks = self._split_text(text)
  271. all_entities = []
  272. seen_entities = set()
  273. for i, chunk in enumerate(chunks):
  274. logger.info(f"处理分块 {i+1}/{len(chunks)}: 长度={len(chunk['text'])}")
  275. prompt = self._build_ner_prompt(chunk["text"], entity_types)
  276. response = await self._call_api(prompt)
  277. if not response:
  278. logger.warning(f"分块 {i+1} API 返回为空")
  279. continue
  280. logger.debug(f"分块 {i+1} API 响应: {response[:500]}...")
  281. entities = self._parse_response(response, chunk["start_pos"])
  282. # 去重
  283. for entity in entities:
  284. entity_key = f"{entity.type}:{entity.name}"
  285. if entity_key not in seen_entities:
  286. seen_entities.add(entity_key)
  287. all_entities.append(entity)
  288. logger.info(f"分块 {i+1} 提取实体: {len(entities)} 个")
  289. logger.info(f"DeepSeek NER 提取完成: 总实体数={len(all_entities)}")
  290. return all_entities
  291. async def extract_entities_with_progress(
  292. self,
  293. text: str,
  294. entity_types: Optional[List[str]] = None
  295. ):
  296. """
  297. 使用 DeepSeek API 提取实体(带进度生成器)
  298. Yields:
  299. SSE 事件字符串
  300. """
  301. import json
  302. async def sse_event(event: str, data: dict):
  303. return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
  304. if not text or not text.strip():
  305. yield await sse_event("complete", {"entities": [], "total_entities": 0})
  306. return
  307. # 分割长文本
  308. chunks = self._split_text(text)
  309. total_chunks = len(chunks)
  310. all_entities = []
  311. seen_entities = set()
  312. for i, chunk in enumerate(chunks):
  313. chunk_index = i + 1
  314. logger.info(f"处理分块 {chunk_index}/{total_chunks}: 长度={len(chunk['text'])}")
  315. # 发送进度事件
  316. yield await sse_event("progress", {
  317. "chunk_index": chunk_index,
  318. "total_chunks": total_chunks,
  319. "chunk_length": len(chunk['text']),
  320. "total_entities": len(all_entities),
  321. "progress_percent": int((chunk_index - 1) / total_chunks * 100),
  322. "message": f"正在处理第 {chunk_index}/{total_chunks} 个文本块..."
  323. })
  324. prompt = self._build_ner_prompt(chunk["text"], entity_types)
  325. response = await self._call_api(prompt)
  326. if not response:
  327. logger.warning(f"分块 {chunk_index} API 返回为空")
  328. continue
  329. logger.debug(f"分块 {chunk_index} API 响应: {response[:500]}...")
  330. entities = self._parse_response(response, chunk["start_pos"])
  331. # 去重并收集新实体
  332. new_entities = []
  333. for entity in entities:
  334. entity_key = f"{entity.type}:{entity.name}"
  335. if entity_key not in seen_entities:
  336. seen_entities.add(entity_key)
  337. all_entities.append(entity)
  338. new_entities.append(entity)
  339. logger.info(f"分块 {chunk_index} 提取实体: {len(entities)} 个, 新增: {len(new_entities)} 个")
  340. # 发送分块完成事件
  341. yield await sse_event("chunk_complete", {
  342. "chunk_index": chunk_index,
  343. "total_chunks": total_chunks,
  344. "chunk_entities": len(entities),
  345. "new_entities": len(new_entities),
  346. "total_entities": len(all_entities),
  347. "progress_percent": int(chunk_index / total_chunks * 100)
  348. })
  349. logger.info(f"DeepSeek NER 提取完成: 总实体数={len(all_entities)}")
  350. # 发送实体数据事件(供调用方获取实体列表)
  351. yield await sse_event("entities_data", {
  352. "entities": [entity.model_dump(by_alias=True) for entity in all_entities],
  353. "total_entities": len(all_entities)
  354. })
  355. async def check_health(self) -> bool:
  356. """
  357. 检查 DeepSeek API 是否可用
  358. """
  359. try:
  360. url = f"{self.base_url}/v1/models"
  361. headers = {
  362. "Authorization": f"Bearer {self.api_key}"
  363. }
  364. async with httpx.AsyncClient(timeout=10) as client:
  365. response = await client.get(url, headers=headers)
  366. return response.status_code == 200
  367. except Exception:
  368. return False
  369. # 创建单例
  370. deepseek_service = DeepSeekService()