| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395 |
- """
- DeepSeek API 服务(阿里云百炼平台)
- 用于调用 DeepSeek 模型进行 NER 提取
- """
- import json
- import re
- import uuid
- import httpx
- from typing import List, Optional, Dict, Any
- from loguru import logger
- from ..config import settings
- from ..models import EntityInfo, PositionInfo
- class DeepSeekService:
- """DeepSeek API 服务"""
-
- def __init__(self):
- self.api_key = settings.deepseek_api_key
- self.base_url = settings.deepseek_base_url
- self.model = settings.deepseek_model
- self.timeout = settings.deepseek_timeout
- self.temperature = settings.deepseek_temperature
- self.max_tokens = settings.deepseek_max_tokens
- self.max_retries = settings.deepseek_max_retries
- self.chunk_size = settings.chunk_size
- self.chunk_overlap = settings.chunk_overlap
-
- logger.info(f"初始化 DeepSeek 服务: model={self.model}, base_url={self.base_url}")
-
- def _split_text(self, text: str) -> List[Dict[str, Any]]:
- """
- 将长文本分割成多个块
- """
- if len(text) <= self.chunk_size:
- return [{"text": text, "start_pos": 0, "end_pos": len(text)}]
-
- chunks = []
- start = 0
-
- while start < len(text):
- end = min(start + self.chunk_size, len(text))
-
- # 尝试在句号、换行处分割
- if end < len(text):
- for sep in ['\n\n', '\n', '。', ';', '!', '?', '.']:
- sep_pos = text.rfind(sep, start + self.chunk_size // 2, end)
- if sep_pos > start:
- end = sep_pos + len(sep)
- break
-
- chunk_text = text[start:end]
- chunks.append({
- "text": chunk_text,
- "start_pos": start,
- "end_pos": end
- })
-
- start = end - self.chunk_overlap if end < len(text) else end
-
- logger.info(f"文本分割完成: 总长度={len(text)}, 分块数={len(chunks)}")
- return chunks
-
- def _build_ner_prompt(self, text: str, entity_types: Optional[List[str]] = None) -> str:
- """
- 构建 NER 提取的 Prompt
- """
- types = entity_types or settings.entity_types
-
- prompt = f"""你是专业的命名实体识别(NER)专家。请从以下文本中提取有意义的命名实体。
- ## 实体类型定义
- 1. **PERSON** - 人名
- - 包括:姓名、职务+姓名(如"张总"、"李工程师")
- - 不包括:职位本身(如"经理"、"主任")
- 2. **ORG** - 机构/组织
- - 包括:公司、政府机关、协会、院校等完整名称
- - 例如:"中国电建集团成都勘测设计研究院有限公司"、"国家能源局"
- - 不包括:简称(如"公司"、"集团"),除非是特定缩写(如"成都院")
- 3. **LOC** - 地点/位置
- - 包括:省市区县、具体地址、工程位置
- - 例如:"四川省成都市"、"龙滩水电站"
- 4. **DATE** - 日期/时间
- - 包括:具体日期、时间段、年份
- - 例如:"2024年7月13日"、"2019年版"
- - 不包括:单独的年份数字
- 5. **NUMBER** - 有意义的数值
- - 包括:带单位的数量、金额、评分、比例
- - 例如:"93.33分"、"7000名"、"70多年"、"5个"
- - **不包括**:章节编号(如"1.1"、"第3章")、纯序号(如"16"、"17")、表格序号
- 6. **PROJECT** - 项目/工程名称
- - 包括:具体项目名称、工程名称
- - 例如:"白鹤滩水电站工程"、"安全生产标准化建设项目"
- - **不包括**:文件编号(如"AY-BZ-0092-2024")、标准编号
- 7. **METHOD** - 方法/标准/规范
- - 包括:技术标准、管理办法、评价标准的完整名称
- - 例如:"电力建设企业安全生产标准化评价标准"、"安全生产风险抵押金管理办法"
- - **不包括**:编号(如"SQE.01C0213"、"中电建协〔2014〕24号文")
- 8. **DEVICE** - 设备/系统
- - 包括:设备名称、信息系统名称
- - 例如:"OA系统"、"QHSE系统"、"造槽机"
- ## 重要规则
- 1. **只提取有实际意义的实体**,忽略:
- - 章节编号(如"1"、"1.1"、"第一章")
- - 表格序号
- - 单独的数字(如"16"、"17")
- - 文件编号和标准编号(归类为无意义数据,不要提取)
- 2. **实体边界要准确**,提取完整的名称而非片段
- 3. **去除重复实体**,相同的实体只返回一次
- 4. **charStart和charEnd必须准确**,对应实体在原文中的字符位置(从0开始)
- ## 输出格式
- 直接输出JSON,格式如下:
- {{"entities": [{{"name": "实体名称", "type": "实体类型", "charStart": 起始位置, "charEnd": 结束位置}}]}}
- ## 待处理文本
- {text}
- 请直接输出JSON:"""
- return prompt
-
- async def _call_api(self, prompt: str) -> Optional[str]:
- """
- 调用 DeepSeek API(OpenAI 兼容格式)
- """
- url = f"{self.base_url}/v1/chat/completions"
- headers = {
- "Authorization": f"Bearer {self.api_key}",
- "Content-Type": "application/json"
- }
- payload = {
- "model": self.model,
- "messages": [
- {
- "role": "user",
- "content": prompt
- }
- ],
- "temperature": self.temperature,
- "max_tokens": self.max_tokens,
- }
-
- for attempt in range(self.max_retries):
- try:
- async with httpx.AsyncClient(timeout=self.timeout) as client:
- response = await client.post(url, headers=headers, json=payload)
- response.raise_for_status()
- result = response.json()
-
- # OpenAI 格式响应
- choices = result.get("choices", [])
- if choices:
- message = choices[0].get("message", {})
- return message.get("content", "")
- return None
-
- except httpx.TimeoutException:
- logger.warning(f"DeepSeek API 请求超时 (尝试 {attempt + 1}/{self.max_retries})")
- if attempt == self.max_retries - 1:
- logger.error(f"DeepSeek API 请求超时: timeout={self.timeout}s")
- return None
- except httpx.HTTPStatusError as e:
- logger.error(f"DeepSeek API HTTP 错误: {e.response.status_code} - {e.response.text}")
- return None
- except Exception as e:
- logger.error(f"DeepSeek API 请求失败: {e}")
- if attempt == self.max_retries - 1:
- return None
-
- return None
-
- def _parse_response(self, response: str, chunk_start_pos: int = 0) -> List[EntityInfo]:
- """
- 解析 API 返回的 JSON 结果
- """
- entities = []
-
- try:
- # 移除 markdown code block 标记
- response = re.sub(r'```json\s*', '', response)
- response = re.sub(r'```\s*', '', response)
- response = response.strip()
-
- # 方法1:直接解析
- data = None
- try:
- data = json.loads(response)
- except json.JSONDecodeError:
- pass
-
- # 方法2:查找 JSON 对象
- if not data or "entities" not in data:
- json_match = re.search(r'\{\s*"entities"\s*:\s*\[[\s\S]*\]\s*\}', response)
- if json_match:
- try:
- data = json.loads(json_match.group())
- except json.JSONDecodeError:
- pass
-
- if not data or "entities" not in data:
- logger.warning(f"未找到有效的 entities JSON, response={response[:300]}...")
- return entities
-
- entity_list = data.get("entities", [])
-
- for item in entity_list:
- name = item.get("name", "").strip()
- entity_type = item.get("type", "").upper()
- char_start = item.get("charStart", 0)
- char_end = item.get("charEnd", 0)
-
- if not name or len(name) < 2:
- continue
-
- # 校正位置
- adjusted_start = char_start + chunk_start_pos
- adjusted_end = char_end + chunk_start_pos
-
- entity = EntityInfo(
- name=name,
- type=entity_type,
- value=name,
- position=PositionInfo(
- char_start=adjusted_start,
- char_end=adjusted_end,
- line=1
- ),
- confidence=0.95, # DeepSeek 置信度较高
- temp_id=str(uuid.uuid4())[:8]
- )
- entities.append(entity)
-
- except Exception as e:
- logger.error(f"解析响应失败: {e}")
-
- return entities
-
- async def extract_entities(
- self,
- text: str,
- entity_types: Optional[List[str]] = None
- ) -> List[EntityInfo]:
- """
- 使用 DeepSeek API 提取实体
- """
- if not text or not text.strip():
- return []
-
- # 分割长文本
- chunks = self._split_text(text)
-
- all_entities = []
- seen_entities = set()
-
- for i, chunk in enumerate(chunks):
- logger.info(f"处理分块 {i+1}/{len(chunks)}: 长度={len(chunk['text'])}")
-
- prompt = self._build_ner_prompt(chunk["text"], entity_types)
- response = await self._call_api(prompt)
-
- if not response:
- logger.warning(f"分块 {i+1} API 返回为空")
- continue
-
- logger.debug(f"分块 {i+1} API 响应: {response[:500]}...")
-
- entities = self._parse_response(response, chunk["start_pos"])
-
- # 去重
- for entity in entities:
- entity_key = f"{entity.type}:{entity.name}"
- if entity_key not in seen_entities:
- seen_entities.add(entity_key)
- all_entities.append(entity)
-
- logger.info(f"分块 {i+1} 提取实体: {len(entities)} 个")
-
- logger.info(f"DeepSeek NER 提取完成: 总实体数={len(all_entities)}")
- return all_entities
-
- async def extract_entities_with_progress(
- self,
- text: str,
- entity_types: Optional[List[str]] = None
- ):
- """
- 使用 DeepSeek API 提取实体(带进度生成器)
-
- Yields:
- SSE 事件字符串
- """
- import json
-
- async def sse_event(event: str, data: dict):
- return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
-
- if not text or not text.strip():
- yield await sse_event("complete", {"entities": [], "total_entities": 0})
- return
-
- # 分割长文本
- chunks = self._split_text(text)
- total_chunks = len(chunks)
-
- all_entities = []
- seen_entities = set()
-
- for i, chunk in enumerate(chunks):
- chunk_index = i + 1
- logger.info(f"处理分块 {chunk_index}/{total_chunks}: 长度={len(chunk['text'])}")
-
- # 发送进度事件
- yield await sse_event("progress", {
- "chunk_index": chunk_index,
- "total_chunks": total_chunks,
- "chunk_length": len(chunk['text']),
- "total_entities": len(all_entities),
- "progress_percent": int((chunk_index - 1) / total_chunks * 100),
- "message": f"正在处理第 {chunk_index}/{total_chunks} 个文本块..."
- })
-
- prompt = self._build_ner_prompt(chunk["text"], entity_types)
- response = await self._call_api(prompt)
-
- if not response:
- logger.warning(f"分块 {chunk_index} API 返回为空")
- continue
-
- logger.debug(f"分块 {chunk_index} API 响应: {response[:500]}...")
-
- entities = self._parse_response(response, chunk["start_pos"])
-
- # 去重并收集新实体
- new_entities = []
- for entity in entities:
- entity_key = f"{entity.type}:{entity.name}"
- if entity_key not in seen_entities:
- seen_entities.add(entity_key)
- all_entities.append(entity)
- new_entities.append(entity)
-
- logger.info(f"分块 {chunk_index} 提取实体: {len(entities)} 个, 新增: {len(new_entities)} 个")
-
- # 发送分块完成事件
- yield await sse_event("chunk_complete", {
- "chunk_index": chunk_index,
- "total_chunks": total_chunks,
- "chunk_entities": len(entities),
- "new_entities": len(new_entities),
- "total_entities": len(all_entities),
- "progress_percent": int(chunk_index / total_chunks * 100)
- })
-
- logger.info(f"DeepSeek NER 提取完成: 总实体数={len(all_entities)}")
-
- # 发送实体数据事件(供调用方获取实体列表)
- yield await sse_event("entities_data", {
- "entities": [entity.model_dump(by_alias=True) for entity in all_entities],
- "total_entities": len(all_entities)
- })
-
- async def check_health(self) -> bool:
- """
- 检查 DeepSeek API 是否可用
- """
- try:
- url = f"{self.base_url}/v1/models"
- headers = {
- "Authorization": f"Bearer {self.api_key}"
- }
- async with httpx.AsyncClient(timeout=10) as client:
- response = await client.get(url, headers=headers)
- return response.status_code == 200
- except Exception:
- return False
- # 创建单例
- deepseek_service = DeepSeekService()
|