package com.lingyue.parse.service; import com.lingyue.parse.entity.ParseTask; import com.lingyue.parse.repository.ParseTaskRepository; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import java.util.Date; /** * 任务进度更新服务 * 供各处理阶段调用,统一更新任务状态 * * @author lingyue * @since 2026-01-22 */ @Slf4j @Service @RequiredArgsConstructor public class TaskProgressService { private final ParseTaskRepository parseTaskRepository; // 阶段权重(用于计算总进度) private static final int WEIGHT_PARSE = 15; private static final int WEIGHT_RAG = 10; private static final int WEIGHT_STRUCTURED = 15; private static final int WEIGHT_NER = 50; private static final int WEIGHT_GRAPH = 10; private static final int TOTAL_WEIGHT = WEIGHT_PARSE + WEIGHT_RAG + WEIGHT_STRUCTURED + WEIGHT_NER + WEIGHT_GRAPH; /** * 更新解析阶段进度 */ @Transactional public void updateParseProgress(String documentId, String status, Integer progress) { ParseTask task = parseTaskRepository.findByDocumentId(documentId); if (task == null) { log.warn("任务不存在: documentId={}", documentId); return; } task.setParseStatus(status); task.setParseProgress(progress); task.setCurrentStep("parse"); updateOverallProgress(task); parseTaskRepository.updateById(task); log.debug("更新解析进度: documentId={}, status={}, progress={}", documentId, status, progress); } /** * 更新 RAG 阶段进度 */ @Transactional public void updateRagProgress(String documentId, String status, Integer progress) { ParseTask task = parseTaskRepository.findByDocumentId(documentId); if (task == null) { log.warn("任务不存在: documentId={}", documentId); return; } task.setRagStatus(status); task.setRagProgress(progress); task.setCurrentStep("rag"); updateOverallProgress(task); parseTaskRepository.updateById(task); log.debug("更新RAG进度: documentId={}, status={}, progress={}", documentId, status, progress); } /** * 更新结构化解析阶段进度 */ @Transactional public void updateStructuredProgress(String documentId, String status, Integer progress, Integer elementCount, Integer imageCount, Integer tableCount) { ParseTask task = parseTaskRepository.findByDocumentId(documentId); if (task == null) { log.warn("任务不存在: documentId={}", documentId); return; } task.setStructuredStatus(status); task.setStructuredProgress(progress); task.setStructuredElementCount(elementCount); task.setStructuredImageCount(imageCount); task.setStructuredTableCount(tableCount); task.setCurrentStep("structured"); updateOverallProgress(task); parseTaskRepository.updateById(task); log.debug("更新结构化解析进度: documentId={}, status={}, elements={}", documentId, status, elementCount); } /** * 更新 NER 阶段进度 */ @Transactional public void updateNerProgress(String documentId, String status, Integer progress, String nerTaskId, Integer entityCount, Integer relationCount) { ParseTask task = parseTaskRepository.findByDocumentId(documentId); if (task == null) { log.warn("任务不存在: documentId={}", documentId); return; } task.setNerStatus(status); task.setNerProgress(progress); if (nerTaskId != null) { task.setNerTaskId(nerTaskId); } if (entityCount != null) { task.setNerEntityCount(entityCount); } if (relationCount != null) { task.setNerRelationCount(relationCount); } task.setCurrentStep("ner"); updateOverallProgress(task); parseTaskRepository.updateById(task); log.debug("更新NER进度: documentId={}, status={}, progress={}, entities={}", documentId, status, progress, entityCount); } /** * 更新图构建阶段进度 */ @Transactional public void updateGraphProgress(String documentId, String status, Integer progress) { ParseTask task = parseTaskRepository.findByDocumentId(documentId); if (task == null) { log.warn("任务不存在: documentId={}", documentId); return; } task.setGraphStatus(status); task.setGraphProgress(progress); task.setCurrentStep("graph"); updateOverallProgress(task); // 如果图构建完成,标记整个任务完成 if ("completed".equals(status)) { task.setStatus("completed"); task.setProgress(100); task.setCompletedAt(new Date()); } parseTaskRepository.updateById(task); log.debug("更新图构建进度: documentId={}, status={}, progress={}", documentId, status, progress); } /** * 标记任务完成 * 当所有阶段都完成后调用(由事件监听器触发) */ @Transactional public void markCompleted(String documentId) { ParseTask task = parseTaskRepository.findByDocumentId(documentId); if (task == null) { log.warn("任务不存在: documentId={}", documentId); return; } // 检查是否所有阶段都已完成 if (isAllCompleted(task)) { task.setStatus("completed"); task.setProgress(100); task.setCurrentStep("completed"); task.setCompletedAt(new Date()); parseTaskRepository.updateById(task); log.info("任务标记为完成: documentId={}", documentId); } else if (isAnyFailed(task)) { // 如果有失败的阶段,标记为部分完成 task.setStatus("partial"); task.setCurrentStep("completed"); task.setCompletedAt(new Date()); parseTaskRepository.updateById(task); log.info("任务标记为部分完成(有阶段失败): documentId={}", documentId); } else { // 还有未完成的阶段,保持 processing 状态 log.debug("任务尚未完全完成: documentId={}", documentId); } } /** * 标记任务失败 */ @Transactional public void markFailed(String documentId, String stage, String errorMessage) { ParseTask task = parseTaskRepository.findByDocumentId(documentId); if (task == null) { log.warn("任务不存在: documentId={}", documentId); return; } task.setStatus("failed"); task.setErrorMessage(errorMessage); task.setCurrentStep(stage); task.setCompletedAt(new Date()); // 标记对应阶段失败 switch (stage) { case "parse": task.setParseStatus("failed"); break; case "rag": task.setRagStatus("failed"); break; case "structured": task.setStructuredStatus("failed"); break; case "ner": task.setNerStatus("failed"); break; case "graph": task.setGraphStatus("failed"); break; } parseTaskRepository.updateById(task); log.error("任务失败: documentId={}, stage={}, error={}", documentId, stage, errorMessage); } /** * 计算并更新总体进度 */ private void updateOverallProgress(ParseTask task) { int totalProgress = 0; // 计算各阶段贡献的进度 totalProgress += calculateStageProgress(task.getParseStatus(), task.getParseProgress(), WEIGHT_PARSE); totalProgress += calculateStageProgress(task.getRagStatus(), task.getRagProgress(), WEIGHT_RAG); totalProgress += calculateStageProgress(task.getStructuredStatus(), task.getStructuredProgress(), WEIGHT_STRUCTURED); totalProgress += calculateStageProgress(task.getNerStatus(), task.getNerProgress(), WEIGHT_NER); totalProgress += calculateStageProgress(task.getGraphStatus(), task.getGraphProgress(), WEIGHT_GRAPH); // 归一化到 0-100 int overallProgress = (totalProgress * 100) / TOTAL_WEIGHT; task.setProgress(Math.min(overallProgress, 100)); // 更新整体状态 if (isAnyFailed(task)) { task.setStatus("failed"); } else if (isAllCompleted(task)) { task.setStatus("completed"); } else if (isAnyProcessing(task)) { task.setStatus("processing"); } } private int calculateStageProgress(String status, Integer progress, int weight) { if ("completed".equals(status)) { return weight; } else if ("processing".equals(status) && progress != null) { return (progress * weight) / 100; } return 0; } private boolean isAnyFailed(ParseTask task) { return "failed".equals(task.getParseStatus()) || "failed".equals(task.getRagStatus()) || "failed".equals(task.getStructuredStatus()) || "failed".equals(task.getNerStatus()) || "failed".equals(task.getGraphStatus()); } private boolean isAllCompleted(ParseTask task) { return "completed".equals(task.getParseStatus()) && "completed".equals(task.getRagStatus()) && "completed".equals(task.getStructuredStatus()) && "completed".equals(task.getNerStatus()) && "completed".equals(task.getGraphStatus()); } private boolean isAnyProcessing(ParseTask task) { return "processing".equals(task.getParseStatus()) || "processing".equals(task.getRagStatus()) || "processing".equals(task.getStructuredStatus()) || "processing".equals(task.getNerStatus()) || "processing".equals(task.getGraphStatus()); } }