|
|
@@ -0,0 +1,263 @@
|
|
|
+"""
|
|
|
+DeepSeek API 服务(阿里云百炼平台)
|
|
|
+用于调用 DeepSeek 模型进行 NER 提取
|
|
|
+"""
|
|
|
+import json
|
|
|
+import re
|
|
|
+import uuid
|
|
|
+import httpx
|
|
|
+from typing import List, Optional, Dict, Any
|
|
|
+from loguru import logger
|
|
|
+
|
|
|
+from ..config import settings
|
|
|
+from ..models import EntityInfo, PositionInfo
|
|
|
+
|
|
|
+
|
|
|
+class DeepSeekService:
|
|
|
+ """DeepSeek API 服务"""
|
|
|
+
|
|
|
+ def __init__(self):
|
|
|
+ self.api_key = settings.deepseek_api_key
|
|
|
+ self.base_url = settings.deepseek_base_url
|
|
|
+ self.model = settings.deepseek_model
|
|
|
+ self.timeout = settings.deepseek_timeout
|
|
|
+ self.temperature = settings.deepseek_temperature
|
|
|
+ self.max_tokens = settings.deepseek_max_tokens
|
|
|
+ self.max_retries = settings.deepseek_max_retries
|
|
|
+ self.chunk_size = settings.chunk_size
|
|
|
+ self.chunk_overlap = settings.chunk_overlap
|
|
|
+
|
|
|
+ logger.info(f"初始化 DeepSeek 服务: model={self.model}, base_url={self.base_url}")
|
|
|
+
|
|
|
+ def _split_text(self, text: str) -> List[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 将长文本分割成多个块
|
|
|
+ """
|
|
|
+ if len(text) <= self.chunk_size:
|
|
|
+ return [{"text": text, "start_pos": 0, "end_pos": len(text)}]
|
|
|
+
|
|
|
+ chunks = []
|
|
|
+ start = 0
|
|
|
+
|
|
|
+ while start < len(text):
|
|
|
+ end = min(start + self.chunk_size, len(text))
|
|
|
+
|
|
|
+ # 尝试在句号、换行处分割
|
|
|
+ if end < len(text):
|
|
|
+ for sep in ['\n\n', '\n', '。', ';', '!', '?', '.']:
|
|
|
+ sep_pos = text.rfind(sep, start + self.chunk_size // 2, end)
|
|
|
+ if sep_pos > start:
|
|
|
+ end = sep_pos + len(sep)
|
|
|
+ break
|
|
|
+
|
|
|
+ chunk_text = text[start:end]
|
|
|
+ chunks.append({
|
|
|
+ "text": chunk_text,
|
|
|
+ "start_pos": start,
|
|
|
+ "end_pos": end
|
|
|
+ })
|
|
|
+
|
|
|
+ start = end - self.chunk_overlap if end < len(text) else end
|
|
|
+
|
|
|
+ logger.info(f"文本分割完成: 总长度={len(text)}, 分块数={len(chunks)}")
|
|
|
+ return chunks
|
|
|
+
|
|
|
+ def _build_ner_prompt(self, text: str, entity_types: Optional[List[str]] = None) -> str:
|
|
|
+ """
|
|
|
+ 构建 NER 提取的 Prompt
|
|
|
+ """
|
|
|
+ types = entity_types or settings.entity_types
|
|
|
+ types_desc = ", ".join(types)
|
|
|
+
|
|
|
+ example = '{"entities": [{"name": "成都市", "type": "LOC", "charStart": 10, "charEnd": 13}, {"name": "2024年5月", "type": "DATE", "charStart": 0, "charEnd": 7}]}'
|
|
|
+
|
|
|
+ prompt = f"""请从以下文本中提取命名实体,直接输出JSON,不要解释。
|
|
|
+
|
|
|
+实体类型: {types_desc}
|
|
|
+
|
|
|
+输出格式示例:
|
|
|
+{example}
|
|
|
+
|
|
|
+文本:
|
|
|
+{text}
|
|
|
+
|
|
|
+请直接输出JSON:"""
|
|
|
+ return prompt
|
|
|
+
|
|
|
+ async def _call_api(self, prompt: str) -> Optional[str]:
|
|
|
+ """
|
|
|
+ 调用 DeepSeek API(OpenAI 兼容格式)
|
|
|
+ """
|
|
|
+ url = f"{self.base_url}/v1/chat/completions"
|
|
|
+ headers = {
|
|
|
+ "Authorization": f"Bearer {self.api_key}",
|
|
|
+ "Content-Type": "application/json"
|
|
|
+ }
|
|
|
+ payload = {
|
|
|
+ "model": self.model,
|
|
|
+ "messages": [
|
|
|
+ {
|
|
|
+ "role": "user",
|
|
|
+ "content": prompt
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "temperature": self.temperature,
|
|
|
+ "max_tokens": self.max_tokens,
|
|
|
+ }
|
|
|
+
|
|
|
+ for attempt in range(self.max_retries):
|
|
|
+ try:
|
|
|
+ async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
|
+ response = await client.post(url, headers=headers, json=payload)
|
|
|
+ response.raise_for_status()
|
|
|
+ result = response.json()
|
|
|
+
|
|
|
+ # OpenAI 格式响应
|
|
|
+ choices = result.get("choices", [])
|
|
|
+ if choices:
|
|
|
+ message = choices[0].get("message", {})
|
|
|
+ return message.get("content", "")
|
|
|
+ return None
|
|
|
+
|
|
|
+ except httpx.TimeoutException:
|
|
|
+ logger.warning(f"DeepSeek API 请求超时 (尝试 {attempt + 1}/{self.max_retries})")
|
|
|
+ if attempt == self.max_retries - 1:
|
|
|
+ logger.error(f"DeepSeek API 请求超时: timeout={self.timeout}s")
|
|
|
+ return None
|
|
|
+ except httpx.HTTPStatusError as e:
|
|
|
+ logger.error(f"DeepSeek API HTTP 错误: {e.response.status_code} - {e.response.text}")
|
|
|
+ return None
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"DeepSeek API 请求失败: {e}")
|
|
|
+ if attempt == self.max_retries - 1:
|
|
|
+ return None
|
|
|
+
|
|
|
+ return None
|
|
|
+
|
|
|
+ def _parse_response(self, response: str, chunk_start_pos: int = 0) -> List[EntityInfo]:
|
|
|
+ """
|
|
|
+ 解析 API 返回的 JSON 结果
|
|
|
+ """
|
|
|
+ entities = []
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 移除 markdown code block 标记
|
|
|
+ response = re.sub(r'```json\s*', '', response)
|
|
|
+ response = re.sub(r'```\s*', '', response)
|
|
|
+ response = response.strip()
|
|
|
+
|
|
|
+ # 方法1:直接解析
|
|
|
+ data = None
|
|
|
+ try:
|
|
|
+ data = json.loads(response)
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ pass
|
|
|
+
|
|
|
+ # 方法2:查找 JSON 对象
|
|
|
+ if not data or "entities" not in data:
|
|
|
+ json_match = re.search(r'\{\s*"entities"\s*:\s*\[[\s\S]*\]\s*\}', response)
|
|
|
+ if json_match:
|
|
|
+ try:
|
|
|
+ data = json.loads(json_match.group())
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ pass
|
|
|
+
|
|
|
+ if not data or "entities" not in data:
|
|
|
+ logger.warning(f"未找到有效的 entities JSON, response={response[:300]}...")
|
|
|
+ return entities
|
|
|
+
|
|
|
+ entity_list = data.get("entities", [])
|
|
|
+
|
|
|
+ for item in entity_list:
|
|
|
+ name = item.get("name", "").strip()
|
|
|
+ entity_type = item.get("type", "").upper()
|
|
|
+ char_start = item.get("charStart", 0)
|
|
|
+ char_end = item.get("charEnd", 0)
|
|
|
+
|
|
|
+ if not name or len(name) < 2:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 校正位置
|
|
|
+ adjusted_start = char_start + chunk_start_pos
|
|
|
+ adjusted_end = char_end + chunk_start_pos
|
|
|
+
|
|
|
+ entity = EntityInfo(
|
|
|
+ name=name,
|
|
|
+ type=entity_type,
|
|
|
+ value=name,
|
|
|
+ position=PositionInfo(
|
|
|
+ char_start=adjusted_start,
|
|
|
+ char_end=adjusted_end,
|
|
|
+ line=1
|
|
|
+ ),
|
|
|
+ confidence=0.95, # DeepSeek 置信度较高
|
|
|
+ temp_id=str(uuid.uuid4())[:8]
|
|
|
+ )
|
|
|
+ entities.append(entity)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"解析响应失败: {e}")
|
|
|
+
|
|
|
+ return entities
|
|
|
+
|
|
|
+ async def extract_entities(
|
|
|
+ self,
|
|
|
+ text: str,
|
|
|
+ entity_types: Optional[List[str]] = None
|
|
|
+ ) -> List[EntityInfo]:
|
|
|
+ """
|
|
|
+ 使用 DeepSeek API 提取实体
|
|
|
+ """
|
|
|
+ if not text or not text.strip():
|
|
|
+ return []
|
|
|
+
|
|
|
+ # 分割长文本
|
|
|
+ chunks = self._split_text(text)
|
|
|
+
|
|
|
+ all_entities = []
|
|
|
+ seen_entities = set()
|
|
|
+
|
|
|
+ for i, chunk in enumerate(chunks):
|
|
|
+ logger.info(f"处理分块 {i+1}/{len(chunks)}: 长度={len(chunk['text'])}")
|
|
|
+
|
|
|
+ prompt = self._build_ner_prompt(chunk["text"], entity_types)
|
|
|
+ response = await self._call_api(prompt)
|
|
|
+
|
|
|
+ if not response:
|
|
|
+ logger.warning(f"分块 {i+1} API 返回为空")
|
|
|
+ continue
|
|
|
+
|
|
|
+ logger.debug(f"分块 {i+1} API 响应: {response[:500]}...")
|
|
|
+
|
|
|
+ entities = self._parse_response(response, chunk["start_pos"])
|
|
|
+
|
|
|
+ # 去重
|
|
|
+ for entity in entities:
|
|
|
+ entity_key = f"{entity.type}:{entity.name}"
|
|
|
+ if entity_key not in seen_entities:
|
|
|
+ seen_entities.add(entity_key)
|
|
|
+ all_entities.append(entity)
|
|
|
+
|
|
|
+ logger.info(f"分块 {i+1} 提取实体: {len(entities)} 个")
|
|
|
+
|
|
|
+ logger.info(f"DeepSeek NER 提取完成: 总实体数={len(all_entities)}")
|
|
|
+ return all_entities
|
|
|
+
|
|
|
+ async def check_health(self) -> bool:
|
|
|
+ """
|
|
|
+ 检查 DeepSeek API 是否可用
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ url = f"{self.base_url}/v1/models"
|
|
|
+ headers = {
|
|
|
+ "Authorization": f"Bearer {self.api_key}"
|
|
|
+ }
|
|
|
+ async with httpx.AsyncClient(timeout=10) as client:
|
|
|
+ response = await client.get(url, headers=headers)
|
|
|
+ return response.status_code == 200
|
|
|
+ except Exception:
|
|
|
+ return False
|
|
|
+
|
|
|
+
|
|
|
+# 创建单例
|
|
|
+deepseek_service = DeepSeekService()
|