|
|
@@ -1,6 +1,5 @@
|
|
|
package com.lingyue.graph.listener;
|
|
|
|
|
|
-import com.fasterxml.jackson.databind.ObjectMapper;
|
|
|
import com.lingyue.common.event.DocumentParsedEvent;
|
|
|
import com.lingyue.graph.service.GraphNerService;
|
|
|
import lombok.RequiredArgsConstructor;
|
|
|
@@ -11,11 +10,8 @@ import org.springframework.http.*;
|
|
|
import org.springframework.scheduling.annotation.Async;
|
|
|
import org.springframework.stereotype.Component;
|
|
|
import org.springframework.web.client.RestTemplate;
|
|
|
-import org.springframework.web.reactive.function.client.WebClient;
|
|
|
|
|
|
-import java.time.Duration;
|
|
|
import java.util.*;
|
|
|
-import java.util.concurrent.atomic.AtomicReference;
|
|
|
|
|
|
/**
|
|
|
* 文档解析完成事件监听器
|
|
|
@@ -31,7 +27,6 @@ public class DocumentParsedEventListener {
|
|
|
|
|
|
private final GraphNerService graphNerService;
|
|
|
private final RestTemplate restTemplate;
|
|
|
- private final ObjectMapper objectMapper;
|
|
|
|
|
|
@Value("${ner.auto-extract.enabled:true}")
|
|
|
private boolean nerAutoExtractEnabled;
|
|
|
@@ -39,8 +34,14 @@ public class DocumentParsedEventListener {
|
|
|
@Value("${ner.python-service.url:http://localhost:8001}")
|
|
|
private String nerServiceUrl;
|
|
|
|
|
|
- @Value("${ner.python-service.use-stream:true}")
|
|
|
- private boolean useStreamApi;
|
|
|
+ @Value("${ner.python-service.use-async:true}")
|
|
|
+ private boolean useAsyncApi;
|
|
|
+
|
|
|
+ @Value("${ner.python-service.poll-interval:3000}")
|
|
|
+ private long pollInterval; // 轮询间隔(毫秒)
|
|
|
+
|
|
|
+ @Value("${ner.python-service.max-wait-time:600000}")
|
|
|
+ private long maxWaitTime; // 最大等待时间(毫秒)
|
|
|
|
|
|
/**
|
|
|
* 处理文档解析完成事件
|
|
|
@@ -74,10 +75,10 @@ public class DocumentParsedEventListener {
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
- // 2. 调用 Python NER 服务(根据配置选择流式或普通 API)
|
|
|
+ // 2. 调用 Python NER 服务(根据配置选择异步轮询或同步 API)
|
|
|
Map<String, Object> nerResponse;
|
|
|
- if (useStreamApi) {
|
|
|
- nerResponse = callPythonNerServiceWithStream(documentId, text, userId);
|
|
|
+ if (useAsyncApi) {
|
|
|
+ nerResponse = callPythonNerServiceAsync(documentId, text, userId);
|
|
|
} else {
|
|
|
nerResponse = callPythonNerService(documentId, text, userId);
|
|
|
}
|
|
|
@@ -113,7 +114,7 @@ public class DocumentParsedEventListener {
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
- * 调用 Python NER 服务(普通 REST API)
|
|
|
+ * 调用 Python NER 服务(同步 REST API)
|
|
|
*/
|
|
|
private Map<String, Object> callPythonNerService(String documentId, String text, String userId) {
|
|
|
try {
|
|
|
@@ -147,146 +148,114 @@ public class DocumentParsedEventListener {
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
- * 调用 Python NER 服务(SSE 流式 API,带进度反馈)
|
|
|
+ * 调用 Python NER 服务(异步 + 轮询模式)
|
|
|
+ *
|
|
|
+ * 流程:
|
|
|
+ * 1. 提交异步任务,立即获得 task_id
|
|
|
+ * 2. 定期轮询任务状态
|
|
|
+ * 3. 任务完成后获取结果
|
|
|
*/
|
|
|
- private Map<String, Object> callPythonNerServiceWithStream(String documentId, String text, String userId) {
|
|
|
+ private Map<String, Object> callPythonNerServiceAsync(String documentId, String text, String userId) {
|
|
|
try {
|
|
|
+ // 1. 提交异步任务
|
|
|
+ String submitUrl = nerServiceUrl + "/ner/extract/async";
|
|
|
+
|
|
|
Map<String, Object> request = new HashMap<>();
|
|
|
request.put("documentId", documentId);
|
|
|
request.put("text", text);
|
|
|
request.put("userId", userId);
|
|
|
request.put("extractRelations", true);
|
|
|
|
|
|
- // 使用 WebClient 处理 SSE
|
|
|
- WebClient webClient = WebClient.builder()
|
|
|
- .baseUrl(nerServiceUrl)
|
|
|
- .codecs(configurer -> configurer.defaultCodecs().maxInMemorySize(10 * 1024 * 1024)) // 10MB
|
|
|
- .build();
|
|
|
-
|
|
|
- AtomicReference<Map<String, Object>> resultRef = new AtomicReference<>();
|
|
|
- StringBuilder buffer = new StringBuilder();
|
|
|
-
|
|
|
- log.info("开始 SSE 流式 NER 请求: documentId={}", documentId);
|
|
|
+ HttpHeaders headers = new HttpHeaders();
|
|
|
+ headers.setContentType(MediaType.APPLICATION_JSON);
|
|
|
|
|
|
- webClient.post()
|
|
|
- .uri("/ner/extract/stream")
|
|
|
- .contentType(MediaType.APPLICATION_JSON)
|
|
|
- .accept(MediaType.TEXT_EVENT_STREAM)
|
|
|
- .bodyValue(request)
|
|
|
- .retrieve()
|
|
|
- .bodyToFlux(org.springframework.core.io.buffer.DataBuffer.class)
|
|
|
- .timeout(Duration.ofMinutes(10)) // 10 分钟总超时
|
|
|
- .doOnNext(dataBuffer -> {
|
|
|
- // 从 DataBuffer 读取字符串
|
|
|
- byte[] bytes = new byte[dataBuffer.readableByteCount()];
|
|
|
- dataBuffer.read(bytes);
|
|
|
- org.springframework.core.io.buffer.DataBufferUtils.release(dataBuffer);
|
|
|
- String chunk = new String(bytes, java.nio.charset.StandardCharsets.UTF_8);
|
|
|
-
|
|
|
- log.debug("收到 SSE 数据块: length={}", chunk.length());
|
|
|
-
|
|
|
- // 累积数据到缓冲区
|
|
|
- buffer.append(chunk);
|
|
|
-
|
|
|
- // 处理完整的 SSE 事件(以双换行符分隔)
|
|
|
- String bufferStr = buffer.toString();
|
|
|
- int eventEnd;
|
|
|
- while ((eventEnd = bufferStr.indexOf("\n\n")) != -1) {
|
|
|
- String eventBlock = bufferStr.substring(0, eventEnd);
|
|
|
- bufferStr = bufferStr.substring(eventEnd + 2);
|
|
|
- buffer.setLength(0);
|
|
|
- buffer.append(bufferStr);
|
|
|
-
|
|
|
- // 解析单个 SSE 事件
|
|
|
- parseSseEvent(eventBlock, documentId, resultRef);
|
|
|
- }
|
|
|
- })
|
|
|
- .doOnComplete(() -> log.info("SSE 流完成: documentId={}", documentId))
|
|
|
- .doOnError(e -> log.error("SSE 流处理错误: documentId={}, error={}", documentId, e.getMessage(), e))
|
|
|
- .blockLast(); // 阻塞等待完成
|
|
|
+ HttpEntity<Map<String, Object>> entity = new HttpEntity<>(request, headers);
|
|
|
|
|
|
- // 处理缓冲区中剩余的数据
|
|
|
- if (buffer.length() > 0) {
|
|
|
- log.debug("处理剩余缓冲区数据: length={}", buffer.length());
|
|
|
- parseSseEvent(buffer.toString(), documentId, resultRef);
|
|
|
- }
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ ResponseEntity<Map<String, Object>> submitResponse = restTemplate.exchange(
|
|
|
+ submitUrl, HttpMethod.POST, entity,
|
|
|
+ (Class<Map<String, Object>>) (Class<?>) Map.class);
|
|
|
|
|
|
- Map<String, Object> result = resultRef.get();
|
|
|
- if (result == null) {
|
|
|
- log.warn("SSE 处理完成但未获取到结果,回退到普通 API: documentId={}", documentId);
|
|
|
- return callPythonNerService(documentId, text, userId);
|
|
|
+ if (!submitResponse.getStatusCode().is2xxSuccessful() || submitResponse.getBody() == null) {
|
|
|
+ log.error("提交异步 NER 任务失败: documentId={}", documentId);
|
|
|
+ return null;
|
|
|
}
|
|
|
|
|
|
- return result;
|
|
|
-
|
|
|
- } catch (Exception e) {
|
|
|
- log.error("调用 Python NER SSE 服务失败: documentId={}, error={}", documentId, e.getMessage(), e);
|
|
|
- // 回退到普通 API
|
|
|
- log.info("回退到普通 NER API: documentId={}", documentId);
|
|
|
- return callPythonNerService(documentId, text, userId);
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * 解析单个 SSE 事件
|
|
|
- */
|
|
|
- private void parseSseEvent(String eventBlock, String documentId, AtomicReference<Map<String, Object>> resultRef) {
|
|
|
- try {
|
|
|
- if (eventBlock == null || eventBlock.trim().isEmpty()) {
|
|
|
- return;
|
|
|
- }
|
|
|
+ String taskId = (String) submitResponse.getBody().get("task_id");
|
|
|
+ log.info("异步 NER 任务已提交: documentId={}, taskId={}", documentId, taskId);
|
|
|
|
|
|
- String eventType = null;
|
|
|
- String eventData = null;
|
|
|
+ // 2. 轮询任务状态
|
|
|
+ String statusUrl = nerServiceUrl + "/ner/task/" + taskId;
|
|
|
+ long startTime = System.currentTimeMillis();
|
|
|
+ int lastProgress = -1;
|
|
|
|
|
|
- // 解析事件块
|
|
|
- for (String line : eventBlock.split("\n")) {
|
|
|
- line = line.trim();
|
|
|
- if (line.startsWith("event:")) {
|
|
|
- eventType = line.substring(6).trim();
|
|
|
- } else if (line.startsWith("data:")) {
|
|
|
- eventData = line.substring(5).trim();
|
|
|
+ while (System.currentTimeMillis() - startTime < maxWaitTime) {
|
|
|
+ try {
|
|
|
+ Thread.sleep(pollInterval);
|
|
|
+ } catch (InterruptedException e) {
|
|
|
+ Thread.currentThread().interrupt();
|
|
|
+ log.warn("轮询被中断: taskId={}", taskId);
|
|
|
+ break;
|
|
|
}
|
|
|
- }
|
|
|
-
|
|
|
- log.debug("解析 SSE 事件: type={}, dataLength={}", eventType, eventData != null ? eventData.length() : 0);
|
|
|
-
|
|
|
- if (eventData == null || eventData.isEmpty()) {
|
|
|
- log.debug("SSE 事件数据为空,跳过: type={}", eventType);
|
|
|
- return;
|
|
|
- }
|
|
|
-
|
|
|
- @SuppressWarnings("unchecked")
|
|
|
- Map<String, Object> data = objectMapper.readValue(eventData, Map.class);
|
|
|
-
|
|
|
- // 根据事件类型处理
|
|
|
- if ("start".equals(eventType)) {
|
|
|
- log.info("NER 开始处理: documentId={}", documentId);
|
|
|
- } else if ("progress".equals(eventType) || "chunk_complete".equals(eventType)) {
|
|
|
- Object progressObj = data.get("progress_percent");
|
|
|
- if (progressObj != null) {
|
|
|
- int progress = progressObj instanceof Integer ? (Integer) progressObj : ((Number) progressObj).intValue();
|
|
|
- String message = (String) data.getOrDefault("message", "处理中...");
|
|
|
- log.info("NER 进度: documentId={}, progress={}%, message={}", documentId, progress, message);
|
|
|
+
|
|
|
+ try {
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ ResponseEntity<Map<String, Object>> statusResponse = restTemplate.exchange(
|
|
|
+ statusUrl, HttpMethod.GET, null,
|
|
|
+ (Class<Map<String, Object>>) (Class<?>) Map.class);
|
|
|
+
|
|
|
+ if (!statusResponse.getStatusCode().is2xxSuccessful() || statusResponse.getBody() == null) {
|
|
|
+ log.warn("查询任务状态失败: taskId={}", taskId);
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ Map<String, Object> taskStatus = statusResponse.getBody();
|
|
|
+ String status = (String) taskStatus.get("status");
|
|
|
+ int progress = taskStatus.get("progress") != null ?
|
|
|
+ ((Number) taskStatus.get("progress")).intValue() : 0;
|
|
|
+ String message = (String) taskStatus.get("message");
|
|
|
+
|
|
|
+ // 只在进度变化时打印日志
|
|
|
+ if (progress != lastProgress) {
|
|
|
+ log.info("NER 进度: documentId={}, taskId={}, status={}, progress={}%, message={}",
|
|
|
+ documentId, taskId, status, progress, message);
|
|
|
+ lastProgress = progress;
|
|
|
+ }
|
|
|
+
|
|
|
+ // 检查任务是否完成
|
|
|
+ if ("completed".equals(status)) {
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ Map<String, Object> result = (Map<String, Object>) taskStatus.get("result");
|
|
|
+ log.info("异步 NER 任务完成: documentId={}, taskId={}", documentId, taskId);
|
|
|
+
|
|
|
+ // 删除任务(释放服务端内存)
|
|
|
+ try {
|
|
|
+ restTemplate.delete(statusUrl);
|
|
|
+ } catch (Exception e) {
|
|
|
+ log.debug("删除任务失败(可忽略): taskId={}", taskId);
|
|
|
+ }
|
|
|
+
|
|
|
+ return result;
|
|
|
+ } else if ("failed".equals(status)) {
|
|
|
+ String error = (String) taskStatus.get("error");
|
|
|
+ log.error("异步 NER 任务失败: documentId={}, taskId={}, error={}", documentId, taskId, error);
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+
|
|
|
+ } catch (Exception e) {
|
|
|
+ log.warn("轮询任务状态异常: taskId={}, error={}", taskId, e.getMessage());
|
|
|
}
|
|
|
- } else if ("entities_data".equals(eventType)) {
|
|
|
- log.debug("收到 entities_data 事件,实体数: {}", data.get("total_entities"));
|
|
|
- } else if ("complete".equals(eventType)) {
|
|
|
- // 完成事件,包含最终结果
|
|
|
- resultRef.set(data);
|
|
|
- int entityCount = data.get("entities") != null ? ((List<?>) data.get("entities")).size() : 0;
|
|
|
- log.info("NER 流式处理完成: documentId={}, entities={}, success={}",
|
|
|
- documentId, entityCount, data.get("success"));
|
|
|
- } else if ("error".equals(eventType)) {
|
|
|
- log.error("NER 服务返回错误: documentId={}, error={}", documentId, data.get("error"));
|
|
|
- } else {
|
|
|
- log.debug("未知 SSE 事件类型: {}", eventType);
|
|
|
}
|
|
|
|
|
|
+ log.error("异步 NER 任务超时: documentId={}, taskId={}, maxWaitTime={}ms",
|
|
|
+ documentId, taskId, maxWaitTime);
|
|
|
+ return null;
|
|
|
+
|
|
|
} catch (Exception e) {
|
|
|
- log.warn("解析 SSE 事件时出错: eventBlock={}, error={}",
|
|
|
- eventBlock.length() > 200 ? eventBlock.substring(0, 200) + "..." : eventBlock,
|
|
|
- e.getMessage());
|
|
|
+ log.error("调用异步 NER 服务失败: documentId={}, error={}", documentId, e.getMessage(), e);
|
|
|
+ // 回退到同步 API
|
|
|
+ log.info("回退到同步 NER API: documentId={}", documentId);
|
|
|
+ return callPythonNerService(documentId, text, userId);
|
|
|
}
|
|
|
}
|
|
|
}
|