TaskProgressService.java 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. package com.lingyue.parse.service;
  2. import com.lingyue.parse.entity.ParseTask;
  3. import com.lingyue.parse.repository.ParseTaskRepository;
  4. import lombok.RequiredArgsConstructor;
  5. import lombok.extern.slf4j.Slf4j;
  6. import org.springframework.stereotype.Service;
  7. import org.springframework.transaction.annotation.Transactional;
  8. import java.util.Date;
  9. /**
  10. * 任务进度更新服务
  11. * 供各处理阶段调用,统一更新任务状态
  12. *
  13. * @author lingyue
  14. * @since 2026-01-22
  15. */
  16. @Slf4j
  17. @Service
  18. @RequiredArgsConstructor
  19. public class TaskProgressService {
  20. private final ParseTaskRepository parseTaskRepository;
  21. // 阶段权重(用于计算总进度)
  22. private static final int WEIGHT_PARSE = 15;
  23. private static final int WEIGHT_RAG = 10;
  24. private static final int WEIGHT_STRUCTURED = 15;
  25. private static final int WEIGHT_NER = 50;
  26. private static final int WEIGHT_GRAPH = 10;
  27. private static final int TOTAL_WEIGHT = WEIGHT_PARSE + WEIGHT_RAG + WEIGHT_STRUCTURED + WEIGHT_NER + WEIGHT_GRAPH;
  28. /**
  29. * 更新解析阶段进度
  30. */
  31. @Transactional
  32. public void updateParseProgress(String documentId, String status, Integer progress) {
  33. ParseTask task = parseTaskRepository.findByDocumentId(documentId);
  34. if (task == null) {
  35. log.warn("任务不存在: documentId={}", documentId);
  36. return;
  37. }
  38. task.setParseStatus(status);
  39. task.setParseProgress(progress);
  40. task.setCurrentStep("parse");
  41. updateOverallProgress(task);
  42. parseTaskRepository.updateById(task);
  43. log.debug("更新解析进度: documentId={}, status={}, progress={}", documentId, status, progress);
  44. }
  45. /**
  46. * 更新 RAG 阶段进度
  47. */
  48. @Transactional
  49. public void updateRagProgress(String documentId, String status, Integer progress) {
  50. ParseTask task = parseTaskRepository.findByDocumentId(documentId);
  51. if (task == null) {
  52. log.warn("任务不存在: documentId={}", documentId);
  53. return;
  54. }
  55. task.setRagStatus(status);
  56. task.setRagProgress(progress);
  57. task.setCurrentStep("rag");
  58. updateOverallProgress(task);
  59. parseTaskRepository.updateById(task);
  60. log.debug("更新RAG进度: documentId={}, status={}, progress={}", documentId, status, progress);
  61. }
  62. /**
  63. * 更新结构化解析阶段进度
  64. */
  65. @Transactional
  66. public void updateStructuredProgress(String documentId, String status, Integer progress,
  67. Integer elementCount, Integer imageCount, Integer tableCount) {
  68. ParseTask task = parseTaskRepository.findByDocumentId(documentId);
  69. if (task == null) {
  70. log.warn("任务不存在: documentId={}", documentId);
  71. return;
  72. }
  73. task.setStructuredStatus(status);
  74. task.setStructuredProgress(progress);
  75. task.setStructuredElementCount(elementCount);
  76. task.setStructuredImageCount(imageCount);
  77. task.setStructuredTableCount(tableCount);
  78. task.setCurrentStep("structured");
  79. updateOverallProgress(task);
  80. parseTaskRepository.updateById(task);
  81. log.debug("更新结构化解析进度: documentId={}, status={}, elements={}", documentId, status, elementCount);
  82. }
  83. /**
  84. * 更新 NER 阶段进度
  85. */
  86. @Transactional
  87. public void updateNerProgress(String documentId, String status, Integer progress,
  88. String nerTaskId, Integer entityCount, Integer relationCount) {
  89. ParseTask task = parseTaskRepository.findByDocumentId(documentId);
  90. if (task == null) {
  91. log.warn("任务不存在: documentId={}", documentId);
  92. return;
  93. }
  94. task.setNerStatus(status);
  95. task.setNerProgress(progress);
  96. if (nerTaskId != null) {
  97. task.setNerTaskId(nerTaskId);
  98. }
  99. if (entityCount != null) {
  100. task.setNerEntityCount(entityCount);
  101. }
  102. if (relationCount != null) {
  103. task.setNerRelationCount(relationCount);
  104. }
  105. task.setCurrentStep("ner");
  106. updateOverallProgress(task);
  107. parseTaskRepository.updateById(task);
  108. log.debug("更新NER进度: documentId={}, status={}, progress={}, entities={}",
  109. documentId, status, progress, entityCount);
  110. }
  111. /**
  112. * 更新图构建阶段进度
  113. */
  114. @Transactional
  115. public void updateGraphProgress(String documentId, String status, Integer progress) {
  116. ParseTask task = parseTaskRepository.findByDocumentId(documentId);
  117. if (task == null) {
  118. log.warn("任务不存在: documentId={}", documentId);
  119. return;
  120. }
  121. task.setGraphStatus(status);
  122. task.setGraphProgress(progress);
  123. task.setCurrentStep("graph");
  124. updateOverallProgress(task);
  125. // 如果图构建完成,标记整个任务完成
  126. if ("completed".equals(status)) {
  127. task.setStatus("completed");
  128. task.setProgress(100);
  129. task.setCompletedAt(new Date());
  130. }
  131. parseTaskRepository.updateById(task);
  132. log.debug("更新图构建进度: documentId={}, status={}, progress={}", documentId, status, progress);
  133. }
  134. /**
  135. * 标记任务完成
  136. * 当所有阶段都完成后调用(由事件监听器触发)
  137. */
  138. @Transactional
  139. public void markCompleted(String documentId) {
  140. ParseTask task = parseTaskRepository.findByDocumentId(documentId);
  141. if (task == null) {
  142. log.warn("任务不存在: documentId={}", documentId);
  143. return;
  144. }
  145. // 检查是否所有阶段都已完成
  146. if (isAllCompleted(task)) {
  147. task.setStatus("completed");
  148. task.setProgress(100);
  149. task.setCurrentStep("completed");
  150. task.setCompletedAt(new Date());
  151. parseTaskRepository.updateById(task);
  152. log.info("任务标记为完成: documentId={}", documentId);
  153. } else if (isAnyFailed(task)) {
  154. // 如果有失败的阶段,标记为部分完成
  155. task.setStatus("partial");
  156. task.setCurrentStep("completed");
  157. task.setCompletedAt(new Date());
  158. parseTaskRepository.updateById(task);
  159. log.info("任务标记为部分完成(有阶段失败): documentId={}", documentId);
  160. } else {
  161. // 还有未完成的阶段,保持 processing 状态
  162. log.debug("任务尚未完全完成: documentId={}", documentId);
  163. }
  164. }
  165. /**
  166. * 标记任务失败
  167. */
  168. @Transactional
  169. public void markFailed(String documentId, String stage, String errorMessage) {
  170. ParseTask task = parseTaskRepository.findByDocumentId(documentId);
  171. if (task == null) {
  172. log.warn("任务不存在: documentId={}", documentId);
  173. return;
  174. }
  175. task.setStatus("failed");
  176. task.setErrorMessage(errorMessage);
  177. task.setCurrentStep(stage);
  178. task.setCompletedAt(new Date());
  179. // 标记对应阶段失败
  180. switch (stage) {
  181. case "parse":
  182. task.setParseStatus("failed");
  183. break;
  184. case "rag":
  185. task.setRagStatus("failed");
  186. break;
  187. case "structured":
  188. task.setStructuredStatus("failed");
  189. break;
  190. case "ner":
  191. task.setNerStatus("failed");
  192. break;
  193. case "graph":
  194. task.setGraphStatus("failed");
  195. break;
  196. }
  197. parseTaskRepository.updateById(task);
  198. log.error("任务失败: documentId={}, stage={}, error={}", documentId, stage, errorMessage);
  199. }
  200. /**
  201. * 计算并更新总体进度
  202. */
  203. private void updateOverallProgress(ParseTask task) {
  204. int totalProgress = 0;
  205. // 计算各阶段贡献的进度
  206. totalProgress += calculateStageProgress(task.getParseStatus(), task.getParseProgress(), WEIGHT_PARSE);
  207. totalProgress += calculateStageProgress(task.getRagStatus(), task.getRagProgress(), WEIGHT_RAG);
  208. totalProgress += calculateStageProgress(task.getStructuredStatus(), task.getStructuredProgress(), WEIGHT_STRUCTURED);
  209. totalProgress += calculateStageProgress(task.getNerStatus(), task.getNerProgress(), WEIGHT_NER);
  210. totalProgress += calculateStageProgress(task.getGraphStatus(), task.getGraphProgress(), WEIGHT_GRAPH);
  211. // 归一化到 0-100
  212. int overallProgress = (totalProgress * 100) / TOTAL_WEIGHT;
  213. task.setProgress(Math.min(overallProgress, 100));
  214. // 更新整体状态
  215. if (isAnyFailed(task)) {
  216. task.setStatus("failed");
  217. } else if (isAllCompleted(task)) {
  218. task.setStatus("completed");
  219. } else if (isAnyProcessing(task)) {
  220. task.setStatus("processing");
  221. }
  222. }
  223. private int calculateStageProgress(String status, Integer progress, int weight) {
  224. if ("completed".equals(status)) {
  225. return weight;
  226. } else if ("processing".equals(status) && progress != null) {
  227. return (progress * weight) / 100;
  228. }
  229. return 0;
  230. }
  231. private boolean isAnyFailed(ParseTask task) {
  232. return "failed".equals(task.getParseStatus()) ||
  233. "failed".equals(task.getRagStatus()) ||
  234. "failed".equals(task.getStructuredStatus()) ||
  235. "failed".equals(task.getNerStatus()) ||
  236. "failed".equals(task.getGraphStatus());
  237. }
  238. private boolean isAllCompleted(ParseTask task) {
  239. return "completed".equals(task.getParseStatus()) &&
  240. "completed".equals(task.getRagStatus()) &&
  241. "completed".equals(task.getStructuredStatus()) &&
  242. "completed".equals(task.getNerStatus()) &&
  243. "completed".equals(task.getGraphStatus());
  244. }
  245. private boolean isAnyProcessing(ParseTask task) {
  246. return "processing".equals(task.getParseStatus()) ||
  247. "processing".equals(task.getRagStatus()) ||
  248. "processing".equals(task.getStructuredStatus()) ||
  249. "processing".equals(task.getNerStatus()) ||
  250. "processing".equals(task.getGraphStatus());
  251. }
  252. }