| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287 |
- package com.lingyue.parse.service;
- import com.lingyue.parse.entity.ParseTask;
- import com.lingyue.parse.repository.ParseTaskRepository;
- import lombok.RequiredArgsConstructor;
- import lombok.extern.slf4j.Slf4j;
- import org.springframework.stereotype.Service;
- import org.springframework.transaction.annotation.Transactional;
- import java.util.Date;
- /**
- * 任务进度更新服务
- * 供各处理阶段调用,统一更新任务状态
- *
- * @author lingyue
- * @since 2026-01-22
- */
- @Slf4j
- @Service
- @RequiredArgsConstructor
- public class TaskProgressService {
- private final ParseTaskRepository parseTaskRepository;
-
- // 阶段权重(用于计算总进度)
- private static final int WEIGHT_PARSE = 15;
- private static final int WEIGHT_RAG = 10;
- private static final int WEIGHT_STRUCTURED = 15;
- private static final int WEIGHT_NER = 50;
- private static final int WEIGHT_GRAPH = 10;
- private static final int TOTAL_WEIGHT = WEIGHT_PARSE + WEIGHT_RAG + WEIGHT_STRUCTURED + WEIGHT_NER + WEIGHT_GRAPH;
- /**
- * 更新解析阶段进度
- */
- @Transactional
- public void updateParseProgress(String documentId, String status, Integer progress) {
- ParseTask task = parseTaskRepository.findByDocumentId(documentId);
- if (task == null) {
- log.warn("任务不存在: documentId={}", documentId);
- return;
- }
-
- task.setParseStatus(status);
- task.setParseProgress(progress);
- task.setCurrentStep("parse");
- updateOverallProgress(task);
-
- parseTaskRepository.updateById(task);
- log.debug("更新解析进度: documentId={}, status={}, progress={}", documentId, status, progress);
- }
- /**
- * 更新 RAG 阶段进度
- */
- @Transactional
- public void updateRagProgress(String documentId, String status, Integer progress) {
- ParseTask task = parseTaskRepository.findByDocumentId(documentId);
- if (task == null) {
- log.warn("任务不存在: documentId={}", documentId);
- return;
- }
-
- task.setRagStatus(status);
- task.setRagProgress(progress);
- task.setCurrentStep("rag");
- updateOverallProgress(task);
-
- parseTaskRepository.updateById(task);
- log.debug("更新RAG进度: documentId={}, status={}, progress={}", documentId, status, progress);
- }
- /**
- * 更新结构化解析阶段进度
- */
- @Transactional
- public void updateStructuredProgress(String documentId, String status, Integer progress,
- Integer elementCount, Integer imageCount, Integer tableCount) {
- ParseTask task = parseTaskRepository.findByDocumentId(documentId);
- if (task == null) {
- log.warn("任务不存在: documentId={}", documentId);
- return;
- }
-
- task.setStructuredStatus(status);
- task.setStructuredProgress(progress);
- task.setStructuredElementCount(elementCount);
- task.setStructuredImageCount(imageCount);
- task.setStructuredTableCount(tableCount);
- task.setCurrentStep("structured");
- updateOverallProgress(task);
-
- parseTaskRepository.updateById(task);
- log.debug("更新结构化解析进度: documentId={}, status={}, elements={}", documentId, status, elementCount);
- }
- /**
- * 更新 NER 阶段进度
- */
- @Transactional
- public void updateNerProgress(String documentId, String status, Integer progress,
- String nerTaskId, Integer entityCount, Integer relationCount) {
- ParseTask task = parseTaskRepository.findByDocumentId(documentId);
- if (task == null) {
- log.warn("任务不存在: documentId={}", documentId);
- return;
- }
-
- task.setNerStatus(status);
- task.setNerProgress(progress);
- if (nerTaskId != null) {
- task.setNerTaskId(nerTaskId);
- }
- if (entityCount != null) {
- task.setNerEntityCount(entityCount);
- }
- if (relationCount != null) {
- task.setNerRelationCount(relationCount);
- }
- task.setCurrentStep("ner");
- updateOverallProgress(task);
-
- parseTaskRepository.updateById(task);
- log.debug("更新NER进度: documentId={}, status={}, progress={}, entities={}",
- documentId, status, progress, entityCount);
- }
- /**
- * 更新图构建阶段进度
- */
- @Transactional
- public void updateGraphProgress(String documentId, String status, Integer progress) {
- ParseTask task = parseTaskRepository.findByDocumentId(documentId);
- if (task == null) {
- log.warn("任务不存在: documentId={}", documentId);
- return;
- }
-
- task.setGraphStatus(status);
- task.setGraphProgress(progress);
- task.setCurrentStep("graph");
- updateOverallProgress(task);
-
- // 如果图构建完成,标记整个任务完成
- if ("completed".equals(status)) {
- task.setStatus("completed");
- task.setProgress(100);
- task.setCompletedAt(new Date());
- }
-
- parseTaskRepository.updateById(task);
- log.debug("更新图构建进度: documentId={}, status={}, progress={}", documentId, status, progress);
- }
- /**
- * 标记任务完成
- * 当所有阶段都完成后调用(由事件监听器触发)
- */
- @Transactional
- public void markCompleted(String documentId) {
- ParseTask task = parseTaskRepository.findByDocumentId(documentId);
- if (task == null) {
- log.warn("任务不存在: documentId={}", documentId);
- return;
- }
-
- // 检查是否所有阶段都已完成
- if (isAllCompleted(task)) {
- task.setStatus("completed");
- task.setProgress(100);
- task.setCurrentStep("completed");
- task.setCompletedAt(new Date());
- parseTaskRepository.updateById(task);
- log.info("任务标记为完成: documentId={}", documentId);
- } else if (isAnyFailed(task)) {
- // 如果有失败的阶段,标记为部分完成
- task.setStatus("partial");
- task.setCurrentStep("completed");
- task.setCompletedAt(new Date());
- parseTaskRepository.updateById(task);
- log.info("任务标记为部分完成(有阶段失败): documentId={}", documentId);
- } else {
- // 还有未完成的阶段,保持 processing 状态
- log.debug("任务尚未完全完成: documentId={}", documentId);
- }
- }
-
- /**
- * 标记任务失败
- */
- @Transactional
- public void markFailed(String documentId, String stage, String errorMessage) {
- ParseTask task = parseTaskRepository.findByDocumentId(documentId);
- if (task == null) {
- log.warn("任务不存在: documentId={}", documentId);
- return;
- }
-
- task.setStatus("failed");
- task.setErrorMessage(errorMessage);
- task.setCurrentStep(stage);
- task.setCompletedAt(new Date());
-
- // 标记对应阶段失败
- switch (stage) {
- case "parse":
- task.setParseStatus("failed");
- break;
- case "rag":
- task.setRagStatus("failed");
- break;
- case "structured":
- task.setStructuredStatus("failed");
- break;
- case "ner":
- task.setNerStatus("failed");
- break;
- case "graph":
- task.setGraphStatus("failed");
- break;
- }
-
- parseTaskRepository.updateById(task);
- log.error("任务失败: documentId={}, stage={}, error={}", documentId, stage, errorMessage);
- }
- /**
- * 计算并更新总体进度
- */
- private void updateOverallProgress(ParseTask task) {
- int totalProgress = 0;
-
- // 计算各阶段贡献的进度
- totalProgress += calculateStageProgress(task.getParseStatus(), task.getParseProgress(), WEIGHT_PARSE);
- totalProgress += calculateStageProgress(task.getRagStatus(), task.getRagProgress(), WEIGHT_RAG);
- totalProgress += calculateStageProgress(task.getStructuredStatus(), task.getStructuredProgress(), WEIGHT_STRUCTURED);
- totalProgress += calculateStageProgress(task.getNerStatus(), task.getNerProgress(), WEIGHT_NER);
- totalProgress += calculateStageProgress(task.getGraphStatus(), task.getGraphProgress(), WEIGHT_GRAPH);
-
- // 归一化到 0-100
- int overallProgress = (totalProgress * 100) / TOTAL_WEIGHT;
- task.setProgress(Math.min(overallProgress, 100));
-
- // 更新整体状态
- if (isAnyFailed(task)) {
- task.setStatus("failed");
- } else if (isAllCompleted(task)) {
- task.setStatus("completed");
- } else if (isAnyProcessing(task)) {
- task.setStatus("processing");
- }
- }
-
- private int calculateStageProgress(String status, Integer progress, int weight) {
- if ("completed".equals(status)) {
- return weight;
- } else if ("processing".equals(status) && progress != null) {
- return (progress * weight) / 100;
- }
- return 0;
- }
-
- private boolean isAnyFailed(ParseTask task) {
- return "failed".equals(task.getParseStatus()) ||
- "failed".equals(task.getRagStatus()) ||
- "failed".equals(task.getStructuredStatus()) ||
- "failed".equals(task.getNerStatus()) ||
- "failed".equals(task.getGraphStatus());
- }
-
- private boolean isAllCompleted(ParseTask task) {
- return "completed".equals(task.getParseStatus()) &&
- "completed".equals(task.getRagStatus()) &&
- "completed".equals(task.getStructuredStatus()) &&
- "completed".equals(task.getNerStatus()) &&
- "completed".equals(task.getGraphStatus());
- }
-
- private boolean isAnyProcessing(ParseTask task) {
- return "processing".equals(task.getParseStatus()) ||
- "processing".equals(task.getRagStatus()) ||
- "processing".equals(task.getStructuredStatus()) ||
- "processing".equals(task.getNerStatus()) ||
- "processing".equals(task.getGraphStatus());
- }
- }
|