Эх сурвалжийг харах

feat: 添加 NER SSE 流式进度反馈机制

实现 Java 后端与 Python NER 服务的实时进度通信:
- Python 端新增 /ner/extract/stream SSE 端点
- DeepSeek 服务支持分块处理进度生成器
- Java 端使用 WebClient 处理 SSE 流并输出进度日志
- 新增配置项 ner.python-service.use-stream 控制是否启用
- SSE 失败时自动回退到普通 REST API
何文松 1 сар өмнө
parent
commit
f8f7230c61

+ 85 - 4
backend/graph-service/src/main/java/com/lingyue/graph/listener/DocumentParsedEventListener.java

@@ -1,5 +1,6 @@
 package com.lingyue.graph.listener;
 
+import com.fasterxml.jackson.databind.ObjectMapper;
 import com.lingyue.common.event.DocumentParsedEvent;
 import com.lingyue.graph.service.GraphNerService;
 import lombok.RequiredArgsConstructor;
@@ -10,8 +11,11 @@ import org.springframework.http.*;
 import org.springframework.scheduling.annotation.Async;
 import org.springframework.stereotype.Component;
 import org.springframework.web.client.RestTemplate;
+import org.springframework.web.reactive.function.client.WebClient;
 
+import java.time.Duration;
 import java.util.*;
+import java.util.concurrent.atomic.AtomicReference;
 
 /**
  * 文档解析完成事件监听器
@@ -27,12 +31,16 @@ public class DocumentParsedEventListener {
 
     private final GraphNerService graphNerService;
     private final RestTemplate restTemplate;
+    private final ObjectMapper objectMapper;
 
     @Value("${ner.auto-extract.enabled:true}")
     private boolean nerAutoExtractEnabled;
 
     @Value("${ner.python-service.url:http://localhost:8001}")
     private String nerServiceUrl;
+    
+    @Value("${ner.python-service.use-stream:true}")
+    private boolean useStreamApi;
 
     /**
      * 处理文档解析完成事件
@@ -66,8 +74,13 @@ public class DocumentParsedEventListener {
                 return;
             }
 
-            // 2. 调用 Python NER 服务
-            Map<String, Object> nerResponse = callPythonNerService(documentId, text, userId);
+            // 2. 调用 Python NER 服务(根据配置选择流式或普通 API)
+            Map<String, Object> nerResponse;
+            if (useStreamApi) {
+                nerResponse = callPythonNerServiceWithStream(documentId, text, userId);
+            } else {
+                nerResponse = callPythonNerService(documentId, text, userId);
+            }
             
             if (nerResponse == null || !Boolean.TRUE.equals(nerResponse.get("success"))) {
                 log.warn("NER 服务调用失败: documentId={}, error={}", 
@@ -100,7 +113,7 @@ public class DocumentParsedEventListener {
     }
 
     /**
-     * 调用 Python NER 服务
+     * 调用 Python NER 服务(普通 REST API)
      */
     private Map<String, Object> callPythonNerService(String documentId, String text, String userId) {
         try {
@@ -117,7 +130,6 @@ public class DocumentParsedEventListener {
             
             HttpEntity<Map<String, Object>> entity = new HttpEntity<>(request, headers);
             
-            @SuppressWarnings("unchecked")
             ResponseEntity<Map> response = restTemplate.exchange(url, HttpMethod.POST, entity, Map.class);
             
             if (response.getStatusCode().is2xxSuccessful() && response.getBody() != null) {
@@ -133,4 +145,73 @@ public class DocumentParsedEventListener {
             return null;
         }
     }
+    
+    /**
+     * 调用 Python NER 服务(SSE 流式 API,带进度反馈)
+     */
+    private Map<String, Object> callPythonNerServiceWithStream(String documentId, String text, String userId) {
+        try {
+            Map<String, Object> request = new HashMap<>();
+            request.put("documentId", documentId);
+            request.put("text", text);
+            request.put("userId", userId);
+            request.put("extractRelations", true);
+            
+            // 使用 WebClient 处理 SSE
+            WebClient webClient = WebClient.builder()
+                    .baseUrl(nerServiceUrl)
+                    .build();
+            
+            AtomicReference<Map<String, Object>> resultRef = new AtomicReference<>();
+            
+            // 订阅 SSE 事件流
+            webClient.post()
+                    .uri("/ner/extract/stream")
+                    .contentType(MediaType.APPLICATION_JSON)
+                    .bodyValue(request)
+                    .retrieve()
+                    .bodyToFlux(String.class)
+                    .timeout(Duration.ofMinutes(10))  // 10 分钟总超时
+                    .doOnNext(data -> {
+                        try {
+                            // 解析 SSE 数据
+                            if (data.startsWith("event:")) {
+                                // 跳过事件类型行
+                                return;
+                            }
+                            if (data.startsWith("data:")) {
+                                String jsonData = data.substring(5).trim();
+                                @SuppressWarnings("unchecked")
+                                Map<String, Object> eventData = objectMapper.readValue(jsonData, Map.class);
+                                
+                                // 检查是否包含进度信息
+                                if (eventData.containsKey("progress_percent")) {
+                                    int progress = (Integer) eventData.get("progress_percent");
+                                    String message = (String) eventData.getOrDefault("message", "处理中...");
+                                    log.info("NER 进度: documentId={}, progress={}%, message={}", 
+                                            documentId, progress, message);
+                                }
+                                
+                                // 检查是否是完成事件(包含 entities)
+                                if (eventData.containsKey("entities")) {
+                                    resultRef.set(eventData);
+                                    log.info("NER 流式处理完成: documentId={}", documentId);
+                                }
+                            }
+                        } catch (Exception e) {
+                            log.debug("解析 SSE 数据时出错: {}", e.getMessage());
+                        }
+                    })
+                    .doOnError(e -> log.error("SSE 流处理错误: documentId={}, error={}", documentId, e.getMessage()))
+                    .blockLast();  // 阻塞等待完成
+            
+            return resultRef.get();
+            
+        } catch (Exception e) {
+            log.error("调用 Python NER SSE 服务失败: {}", e.getMessage());
+            // 回退到普通 API
+            log.info("回退到普通 NER API: documentId={}", documentId);
+            return callPythonNerService(documentId, text, userId);
+        }
+    }
 }

+ 2 - 0
backend/lingyue-starter/src/main/resources/application.properties

@@ -129,6 +129,8 @@ ner.python-service.timeout=300000
 ner.python-service.connect-timeout=5000
 ner.python-service.max-retries=3
 ner.python-service.retry-interval=1000
+# 是否使用 SSE 流式 API(实时进度反馈)
+ner.python-service.use-stream=true
 
 # NER 实体类型配置
 ner.entity-types=PERSON,ORG,LOC,DATE,NUMBER,DEVICE,PROJECT,TERM

+ 131 - 0
python-services/ner-service/app/routers/ner.py

@@ -1,7 +1,10 @@
 """
 NER 路由
 """
+import json
+import asyncio
 from fastapi import APIRouter, HTTPException
+from fastapi.responses import StreamingResponse
 from loguru import logger
 import time
 
@@ -11,6 +14,11 @@ from ..services.ner_service import ner_service
 router = APIRouter()
 
 
+async def sse_event(event: str, data: dict):
+    """生成 SSE 事件格式"""
+    return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
+
+
 @router.post("/extract", response_model=NerResponse)
 async def extract_entities(request: NerRequest):
     """
@@ -72,3 +80,126 @@ async def extract_entities(request: NerRequest):
             document_id=request.document_id,
             error_message=str(e)
         )
+
+
+@router.post("/extract/stream")
+async def extract_entities_stream(request: NerRequest):
+    """
+    从文本中提取命名实体(SSE 流式响应)
+    
+    实时推送进度事件:
+    - progress: 处理进度(分块处理时)
+    - entity: 发现新实体
+    - complete: 处理完成
+    - error: 处理出错
+    """
+    async def generate():
+        start_time = time.time()
+        
+        try:
+            # 发送开始事件
+            yield await sse_event("start", {
+                "document_id": request.document_id,
+                "text_length": len(request.text),
+                "message": "开始 NER 提取"
+            })
+            
+            # 验证文本长度
+            if len(request.text) > 50000:
+                yield await sse_event("error", {
+                    "document_id": request.document_id,
+                    "error": "文本长度超过限制(最大50000字符)"
+                })
+                return
+            
+            # 使用带进度回调的提取方法
+            all_entities = []
+            all_relations = []
+            
+            # 创建进度回调
+            async def progress_callback(chunk_index: int, total_chunks: int, chunk_entities: list):
+                nonlocal all_entities
+                all_entities.extend(chunk_entities)
+                
+                progress_data = {
+                    "document_id": request.document_id,
+                    "chunk_index": chunk_index,
+                    "total_chunks": total_chunks,
+                    "chunk_entities": len(chunk_entities),
+                    "total_entities": len(all_entities),
+                    "progress_percent": int((chunk_index / total_chunks) * 100)
+                }
+                return await sse_event("progress", progress_data)
+            
+            # 调用带进度的 NER 服务
+            from ..services.ner_service import ner_service
+            
+            # 检查是否支持流式提取
+            if hasattr(ner_service, 'extract_entities_with_progress'):
+                async for event in ner_service.extract_entities_with_progress(
+                    text=request.text,
+                    entity_types=request.entity_types,
+                    progress_callback=progress_callback
+                ):
+                    yield event
+                    all_entities = event.get('entities', all_entities) if isinstance(event, dict) else all_entities
+            else:
+                # 回退到普通提取
+                all_entities = await ner_service.extract_entities(
+                    text=request.text,
+                    entity_types=request.entity_types
+                )
+                yield await sse_event("progress", {
+                    "document_id": request.document_id,
+                    "chunk_index": 1,
+                    "total_chunks": 1,
+                    "total_entities": len(all_entities),
+                    "progress_percent": 100
+                })
+            
+            # 提取关系
+            if request.extract_relations and len(all_entities) > 1:
+                yield await sse_event("progress", {
+                    "document_id": request.document_id,
+                    "message": "正在提取实体关系...",
+                    "stage": "relations"
+                })
+                
+                from ..services.relation_service import relation_service
+                all_relations = await relation_service.extract_relations(
+                    text=request.text,
+                    entities=all_entities
+                )
+            
+            processing_time = int((time.time() - start_time) * 1000)
+            
+            # 发送完成事件(包含完整结果)
+            response = NerResponse.success_response(
+                document_id=request.document_id,
+                entities=all_entities,
+                relations=all_relations,
+                processing_time=processing_time
+            )
+            
+            yield await sse_event("complete", response.dict())
+            
+            logger.info(f"SSE 提取完成: document_id={request.document_id}, "
+                       f"entity_count={len(all_entities)}, relation_count={len(all_relations)}, "
+                       f"processing_time={processing_time}ms")
+            
+        except Exception as e:
+            logger.error(f"SSE 提取失败: document_id={request.document_id}, error={str(e)}")
+            yield await sse_event("error", {
+                "document_id": request.document_id,
+                "error": str(e)
+            })
+    
+    return StreamingResponse(
+        generate(),
+        media_type="text/event-stream",
+        headers={
+            "Cache-Control": "no-cache",
+            "Connection": "keep-alive",
+            "X-Accel-Buffering": "no"
+        }
+    )

+ 79 - 0
python-services/ner-service/app/services/deepseek_service.py

@@ -243,6 +243,85 @@ class DeepSeekService:
         logger.info(f"DeepSeek NER 提取完成: 总实体数={len(all_entities)}")
         return all_entities
     
+    async def extract_entities_with_progress(
+        self, 
+        text: str, 
+        entity_types: Optional[List[str]] = None
+    ):
+        """
+        使用 DeepSeek API 提取实体(带进度生成器)
+        
+        Yields:
+            SSE 事件字符串
+        """
+        import json
+        
+        async def sse_event(event: str, data: dict):
+            return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
+        
+        if not text or not text.strip():
+            yield await sse_event("complete", {"entities": [], "total_entities": 0})
+            return
+        
+        # 分割长文本
+        chunks = self._split_text(text)
+        total_chunks = len(chunks)
+        
+        all_entities = []
+        seen_entities = set()
+        
+        for i, chunk in enumerate(chunks):
+            chunk_index = i + 1
+            logger.info(f"处理分块 {chunk_index}/{total_chunks}: 长度={len(chunk['text'])}")
+            
+            # 发送进度事件
+            yield await sse_event("progress", {
+                "chunk_index": chunk_index,
+                "total_chunks": total_chunks,
+                "chunk_length": len(chunk['text']),
+                "total_entities": len(all_entities),
+                "progress_percent": int((chunk_index - 1) / total_chunks * 100),
+                "message": f"正在处理第 {chunk_index}/{total_chunks} 个文本块..."
+            })
+            
+            prompt = self._build_ner_prompt(chunk["text"], entity_types)
+            response = await self._call_api(prompt)
+            
+            if not response:
+                logger.warning(f"分块 {chunk_index} API 返回为空")
+                continue
+            
+            logger.debug(f"分块 {chunk_index} API 响应: {response[:500]}...")
+            
+            entities = self._parse_response(response, chunk["start_pos"])
+            
+            # 去重并收集新实体
+            new_entities = []
+            for entity in entities:
+                entity_key = f"{entity.type}:{entity.name}"
+                if entity_key not in seen_entities:
+                    seen_entities.add(entity_key)
+                    all_entities.append(entity)
+                    new_entities.append(entity)
+            
+            logger.info(f"分块 {chunk_index} 提取实体: {len(entities)} 个, 新增: {len(new_entities)} 个")
+            
+            # 发送分块完成事件
+            yield await sse_event("chunk_complete", {
+                "chunk_index": chunk_index,
+                "total_chunks": total_chunks,
+                "chunk_entities": len(entities),
+                "new_entities": len(new_entities),
+                "total_entities": len(all_entities),
+                "progress_percent": int(chunk_index / total_chunks * 100)
+            })
+        
+        logger.info(f"DeepSeek NER 提取完成: 总实体数={len(all_entities)}")
+        
+        # 最终不在这里发送 complete 事件,由调用方处理
+        # 返回最终实体列表供调用方使用
+        return all_entities
+    
     async def check_health(self) -> bool:
         """
         检查 DeepSeek API 是否可用

+ 25 - 0
python-services/ner-service/app/services/ner_service.py

@@ -57,6 +57,31 @@ class NerService:
             logger.warning(f"未知的模型类型: {self.model_type},使用规则模式")
             return await self._extract_by_rules(text, entity_types)
     
+    async def extract_entities_with_progress(
+        self, 
+        text: str, 
+        entity_types: Optional[List[str]] = None,
+        progress_callback=None
+    ):
+        """
+        从文本中提取实体(带进度生成器,用于 SSE 流式响应)
+        
+        Yields:
+            SSE 事件字符串
+        """
+        if not text or not text.strip():
+            return
+        
+        if self.model_type == "deepseek":
+            from .deepseek_service import deepseek_service
+            async for event in deepseek_service.extract_entities_with_progress(text, entity_types):
+                yield event
+        else:
+            # 其他模型回退到普通提取,一次性返回
+            entities = await self.extract_entities(text, entity_types)
+            import json
+            yield f"event: chunk_complete\ndata: {json.dumps({'total_entities': len(entities), 'progress_percent': 100}, ensure_ascii=False)}\n\n"
+    
     async def _extract_by_rules(
         self, 
         text: str,