| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344 |
- """
- DeepSeek API 服务(阿里云百炼平台)
- 用于调用 DeepSeek 模型进行 NER 提取
- """
- import json
- import re
- import uuid
- import httpx
- from typing import List, Optional, Dict, Any
- from loguru import logger
- from ..config import settings
- from ..models import EntityInfo, PositionInfo
- class DeepSeekService:
- """DeepSeek API 服务"""
-
- def __init__(self):
- self.api_key = settings.deepseek_api_key
- self.base_url = settings.deepseek_base_url
- self.model = settings.deepseek_model
- self.timeout = settings.deepseek_timeout
- self.temperature = settings.deepseek_temperature
- self.max_tokens = settings.deepseek_max_tokens
- self.max_retries = settings.deepseek_max_retries
- self.chunk_size = settings.chunk_size
- self.chunk_overlap = settings.chunk_overlap
-
- logger.info(f"初始化 DeepSeek 服务: model={self.model}, base_url={self.base_url}")
-
- def _split_text(self, text: str) -> List[Dict[str, Any]]:
- """
- 将长文本分割成多个块
- """
- if len(text) <= self.chunk_size:
- return [{"text": text, "start_pos": 0, "end_pos": len(text)}]
-
- chunks = []
- start = 0
-
- while start < len(text):
- end = min(start + self.chunk_size, len(text))
-
- # 尝试在句号、换行处分割
- if end < len(text):
- for sep in ['\n\n', '\n', '。', ';', '!', '?', '.']:
- sep_pos = text.rfind(sep, start + self.chunk_size // 2, end)
- if sep_pos > start:
- end = sep_pos + len(sep)
- break
-
- chunk_text = text[start:end]
- chunks.append({
- "text": chunk_text,
- "start_pos": start,
- "end_pos": end
- })
-
- start = end - self.chunk_overlap if end < len(text) else end
-
- logger.info(f"文本分割完成: 总长度={len(text)}, 分块数={len(chunks)}")
- return chunks
-
- def _build_ner_prompt(self, text: str, entity_types: Optional[List[str]] = None) -> str:
- """
- 构建 NER 提取的 Prompt
- """
- types = entity_types or settings.entity_types
- types_desc = ", ".join(types)
-
- example = '{"entities": [{"name": "成都市", "type": "LOC", "charStart": 10, "charEnd": 13}, {"name": "2024年5月", "type": "DATE", "charStart": 0, "charEnd": 7}]}'
-
- prompt = f"""请从以下文本中提取命名实体,直接输出JSON,不要解释。
- 实体类型: {types_desc}
- 输出格式示例:
- {example}
- 文本:
- {text}
- 请直接输出JSON:"""
- return prompt
-
- async def _call_api(self, prompt: str) -> Optional[str]:
- """
- 调用 DeepSeek API(OpenAI 兼容格式)
- """
- url = f"{self.base_url}/v1/chat/completions"
- headers = {
- "Authorization": f"Bearer {self.api_key}",
- "Content-Type": "application/json"
- }
- payload = {
- "model": self.model,
- "messages": [
- {
- "role": "user",
- "content": prompt
- }
- ],
- "temperature": self.temperature,
- "max_tokens": self.max_tokens,
- }
-
- for attempt in range(self.max_retries):
- try:
- async with httpx.AsyncClient(timeout=self.timeout) as client:
- response = await client.post(url, headers=headers, json=payload)
- response.raise_for_status()
- result = response.json()
-
- # OpenAI 格式响应
- choices = result.get("choices", [])
- if choices:
- message = choices[0].get("message", {})
- return message.get("content", "")
- return None
-
- except httpx.TimeoutException:
- logger.warning(f"DeepSeek API 请求超时 (尝试 {attempt + 1}/{self.max_retries})")
- if attempt == self.max_retries - 1:
- logger.error(f"DeepSeek API 请求超时: timeout={self.timeout}s")
- return None
- except httpx.HTTPStatusError as e:
- logger.error(f"DeepSeek API HTTP 错误: {e.response.status_code} - {e.response.text}")
- return None
- except Exception as e:
- logger.error(f"DeepSeek API 请求失败: {e}")
- if attempt == self.max_retries - 1:
- return None
-
- return None
-
- def _parse_response(self, response: str, chunk_start_pos: int = 0) -> List[EntityInfo]:
- """
- 解析 API 返回的 JSON 结果
- """
- entities = []
-
- try:
- # 移除 markdown code block 标记
- response = re.sub(r'```json\s*', '', response)
- response = re.sub(r'```\s*', '', response)
- response = response.strip()
-
- # 方法1:直接解析
- data = None
- try:
- data = json.loads(response)
- except json.JSONDecodeError:
- pass
-
- # 方法2:查找 JSON 对象
- if not data or "entities" not in data:
- json_match = re.search(r'\{\s*"entities"\s*:\s*\[[\s\S]*\]\s*\}', response)
- if json_match:
- try:
- data = json.loads(json_match.group())
- except json.JSONDecodeError:
- pass
-
- if not data or "entities" not in data:
- logger.warning(f"未找到有效的 entities JSON, response={response[:300]}...")
- return entities
-
- entity_list = data.get("entities", [])
-
- for item in entity_list:
- name = item.get("name", "").strip()
- entity_type = item.get("type", "").upper()
- char_start = item.get("charStart", 0)
- char_end = item.get("charEnd", 0)
-
- if not name or len(name) < 2:
- continue
-
- # 校正位置
- adjusted_start = char_start + chunk_start_pos
- adjusted_end = char_end + chunk_start_pos
-
- entity = EntityInfo(
- name=name,
- type=entity_type,
- value=name,
- position=PositionInfo(
- char_start=adjusted_start,
- char_end=adjusted_end,
- line=1
- ),
- confidence=0.95, # DeepSeek 置信度较高
- temp_id=str(uuid.uuid4())[:8]
- )
- entities.append(entity)
-
- except Exception as e:
- logger.error(f"解析响应失败: {e}")
-
- return entities
-
- async def extract_entities(
- self,
- text: str,
- entity_types: Optional[List[str]] = None
- ) -> List[EntityInfo]:
- """
- 使用 DeepSeek API 提取实体
- """
- if not text or not text.strip():
- return []
-
- # 分割长文本
- chunks = self._split_text(text)
-
- all_entities = []
- seen_entities = set()
-
- for i, chunk in enumerate(chunks):
- logger.info(f"处理分块 {i+1}/{len(chunks)}: 长度={len(chunk['text'])}")
-
- prompt = self._build_ner_prompt(chunk["text"], entity_types)
- response = await self._call_api(prompt)
-
- if not response:
- logger.warning(f"分块 {i+1} API 返回为空")
- continue
-
- logger.debug(f"分块 {i+1} API 响应: {response[:500]}...")
-
- entities = self._parse_response(response, chunk["start_pos"])
-
- # 去重
- for entity in entities:
- entity_key = f"{entity.type}:{entity.name}"
- if entity_key not in seen_entities:
- seen_entities.add(entity_key)
- all_entities.append(entity)
-
- logger.info(f"分块 {i+1} 提取实体: {len(entities)} 个")
-
- logger.info(f"DeepSeek NER 提取完成: 总实体数={len(all_entities)}")
- return all_entities
-
- async def extract_entities_with_progress(
- self,
- text: str,
- entity_types: Optional[List[str]] = None
- ):
- """
- 使用 DeepSeek API 提取实体(带进度生成器)
-
- Yields:
- SSE 事件字符串
- """
- import json
-
- async def sse_event(event: str, data: dict):
- return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
-
- if not text or not text.strip():
- yield await sse_event("complete", {"entities": [], "total_entities": 0})
- return
-
- # 分割长文本
- chunks = self._split_text(text)
- total_chunks = len(chunks)
-
- all_entities = []
- seen_entities = set()
-
- for i, chunk in enumerate(chunks):
- chunk_index = i + 1
- logger.info(f"处理分块 {chunk_index}/{total_chunks}: 长度={len(chunk['text'])}")
-
- # 发送进度事件
- yield await sse_event("progress", {
- "chunk_index": chunk_index,
- "total_chunks": total_chunks,
- "chunk_length": len(chunk['text']),
- "total_entities": len(all_entities),
- "progress_percent": int((chunk_index - 1) / total_chunks * 100),
- "message": f"正在处理第 {chunk_index}/{total_chunks} 个文本块..."
- })
-
- prompt = self._build_ner_prompt(chunk["text"], entity_types)
- response = await self._call_api(prompt)
-
- if not response:
- logger.warning(f"分块 {chunk_index} API 返回为空")
- continue
-
- logger.debug(f"分块 {chunk_index} API 响应: {response[:500]}...")
-
- entities = self._parse_response(response, chunk["start_pos"])
-
- # 去重并收集新实体
- new_entities = []
- for entity in entities:
- entity_key = f"{entity.type}:{entity.name}"
- if entity_key not in seen_entities:
- seen_entities.add(entity_key)
- all_entities.append(entity)
- new_entities.append(entity)
-
- logger.info(f"分块 {chunk_index} 提取实体: {len(entities)} 个, 新增: {len(new_entities)} 个")
-
- # 发送分块完成事件
- yield await sse_event("chunk_complete", {
- "chunk_index": chunk_index,
- "total_chunks": total_chunks,
- "chunk_entities": len(entities),
- "new_entities": len(new_entities),
- "total_entities": len(all_entities),
- "progress_percent": int(chunk_index / total_chunks * 100)
- })
-
- logger.info(f"DeepSeek NER 提取完成: 总实体数={len(all_entities)}")
-
- # 发送实体数据事件(供调用方获取实体列表)
- yield await sse_event("entities_data", {
- "entities": [entity.model_dump(by_alias=True) for entity in all_entities],
- "total_entities": len(all_entities)
- })
-
- async def check_health(self) -> bool:
- """
- 检查 DeepSeek API 是否可用
- """
- try:
- url = f"{self.base_url}/v1/models"
- headers = {
- "Authorization": f"Bearer {self.api_key}"
- }
- async with httpx.AsyncClient(timeout=10) as client:
- response = await client.get(url, headers=headers)
- return response.status_code == 200
- except Exception:
- return False
- # 创建单例
- deepseek_service = DeepSeekService()
|