Quellcode durchsuchen

fix: 修复 Java 端 SSE 事件解析逻辑

- 修复 bodyToFlux 无法正确按 SSE 事件分割的问题
- 添加缓冲区累积数据并按双换行符分割事件
- 提取 parseSseEvent 方法单独处理每个事件
- 根据事件类型(progress/complete/error)分别处理
- 增加内存缓冲区大小至 10MB
何文松 vor 1 Monat
Ursprung
Commit
eaf2241957

+ 71 - 29
backend/graph-service/src/main/java/com/lingyue/graph/listener/DocumentParsedEventListener.java

@@ -160,51 +160,47 @@ public class DocumentParsedEventListener {
             // 使用 WebClient 处理 SSE
             WebClient webClient = WebClient.builder()
                     .baseUrl(nerServiceUrl)
+                    .codecs(configurer -> configurer.defaultCodecs().maxInMemorySize(10 * 1024 * 1024))  // 10MB
                     .build();
             
             AtomicReference<Map<String, Object>> resultRef = new AtomicReference<>();
             
-            // 订阅 SSE 事件流
+            // 订阅 SSE 事件流,使用 DataBuffer 逐块读取
+            StringBuilder buffer = new StringBuilder();
+            
             webClient.post()
                     .uri("/ner/extract/stream")
                     .contentType(MediaType.APPLICATION_JSON)
+                    .accept(MediaType.TEXT_EVENT_STREAM)
                     .bodyValue(request)
                     .retrieve()
                     .bodyToFlux(String.class)
                     .timeout(Duration.ofMinutes(10))  // 10 分钟总超时
-                    .doOnNext(data -> {
-                        try {
-                            // 解析 SSE 数据
-                            if (data.startsWith("event:")) {
-                                // 跳过事件类型行
-                                return;
-                            }
-                            if (data.startsWith("data:")) {
-                                String jsonData = data.substring(5).trim();
-                                @SuppressWarnings("unchecked")
-                                Map<String, Object> eventData = objectMapper.readValue(jsonData, Map.class);
-                                
-                                // 检查是否包含进度信息
-                                if (eventData.containsKey("progress_percent")) {
-                                    int progress = (Integer) eventData.get("progress_percent");
-                                    String message = (String) eventData.getOrDefault("message", "处理中...");
-                                    log.info("NER 进度: documentId={}, progress={}%, message={}", 
-                                            documentId, progress, message);
-                                }
-                                
-                                // 检查是否是完成事件(包含 entities)
-                                if (eventData.containsKey("entities")) {
-                                    resultRef.set(eventData);
-                                    log.info("NER 流式处理完成: documentId={}", documentId);
-                                }
-                            }
-                        } catch (Exception e) {
-                            log.debug("解析 SSE 数据时出错: {}", e.getMessage());
+                    .doOnNext(chunk -> {
+                        // 累积数据到缓冲区
+                        buffer.append(chunk);
+                        
+                        // 处理完整的 SSE 事件(以双换行符分隔)
+                        String bufferStr = buffer.toString();
+                        int eventEnd;
+                        while ((eventEnd = bufferStr.indexOf("\n\n")) != -1) {
+                            String eventBlock = bufferStr.substring(0, eventEnd);
+                            bufferStr = bufferStr.substring(eventEnd + 2);
+                            buffer.setLength(0);
+                            buffer.append(bufferStr);
+                            
+                            // 解析单个 SSE 事件
+                            parseSseEvent(eventBlock, documentId, resultRef);
                         }
                     })
                     .doOnError(e -> log.error("SSE 流处理错误: documentId={}, error={}", documentId, e.getMessage()))
                     .blockLast();  // 阻塞等待完成
             
+            // 处理缓冲区中剩余的数据
+            if (buffer.length() > 0) {
+                parseSseEvent(buffer.toString(), documentId, resultRef);
+            }
+            
             return resultRef.get();
             
         } catch (Exception e) {
@@ -214,4 +210,50 @@ public class DocumentParsedEventListener {
             return callPythonNerService(documentId, text, userId);
         }
     }
+    
+    /**
+     * 解析单个 SSE 事件
+     */
+    private void parseSseEvent(String eventBlock, String documentId, AtomicReference<Map<String, Object>> resultRef) {
+        try {
+            String eventType = null;
+            String eventData = null;
+            
+            // 解析事件块
+            for (String line : eventBlock.split("\n")) {
+                if (line.startsWith("event:")) {
+                    eventType = line.substring(6).trim();
+                } else if (line.startsWith("data:")) {
+                    eventData = line.substring(5).trim();
+                }
+            }
+            
+            if (eventData == null || eventData.isEmpty()) {
+                return;
+            }
+            
+            @SuppressWarnings("unchecked")
+            Map<String, Object> data = objectMapper.readValue(eventData, Map.class);
+            
+            // 根据事件类型处理
+            if ("progress".equals(eventType) || "chunk_complete".equals(eventType)) {
+                Object progressObj = data.get("progress_percent");
+                if (progressObj != null) {
+                    int progress = progressObj instanceof Integer ? (Integer) progressObj : ((Number) progressObj).intValue();
+                    String message = (String) data.getOrDefault("message", "处理中...");
+                    log.info("NER 进度: documentId={}, progress={}%, message={}", documentId, progress, message);
+                }
+            } else if ("complete".equals(eventType)) {
+                // 完成事件,包含最终结果
+                resultRef.set(data);
+                log.info("NER 流式处理完成: documentId={}, entities={}", documentId, 
+                        data.get("entities") != null ? ((List<?>) data.get("entities")).size() : 0);
+            } else if ("error".equals(eventType)) {
+                log.error("NER 服务返回错误: documentId={}, error={}", documentId, data.get("error"));
+            }
+            
+        } catch (Exception e) {
+            log.debug("解析 SSE 事件时出错: {}", e.getMessage());
+        }
+    }
 }