|
@@ -116,33 +116,33 @@ async def extract_entities_stream(request: NerRequest):
|
|
|
all_entities = []
|
|
all_entities = []
|
|
|
all_relations = []
|
|
all_relations = []
|
|
|
|
|
|
|
|
- # 创建进度回调
|
|
|
|
|
- async def progress_callback(chunk_index: int, total_chunks: int, chunk_entities: list):
|
|
|
|
|
- nonlocal all_entities
|
|
|
|
|
- all_entities.extend(chunk_entities)
|
|
|
|
|
-
|
|
|
|
|
- progress_data = {
|
|
|
|
|
- "document_id": request.document_id,
|
|
|
|
|
- "chunk_index": chunk_index,
|
|
|
|
|
- "total_chunks": total_chunks,
|
|
|
|
|
- "chunk_entities": len(chunk_entities),
|
|
|
|
|
- "total_entities": len(all_entities),
|
|
|
|
|
- "progress_percent": int((chunk_index / total_chunks) * 100)
|
|
|
|
|
- }
|
|
|
|
|
- return await sse_event("progress", progress_data)
|
|
|
|
|
-
|
|
|
|
|
# 调用带进度的 NER 服务
|
|
# 调用带进度的 NER 服务
|
|
|
from ..services.ner_service import ner_service
|
|
from ..services.ner_service import ner_service
|
|
|
|
|
|
|
|
# 检查是否支持流式提取
|
|
# 检查是否支持流式提取
|
|
|
if hasattr(ner_service, 'extract_entities_with_progress'):
|
|
if hasattr(ner_service, 'extract_entities_with_progress'):
|
|
|
- async for event in ner_service.extract_entities_with_progress(
|
|
|
|
|
|
|
+ async for event_str in ner_service.extract_entities_with_progress(
|
|
|
text=request.text,
|
|
text=request.text,
|
|
|
- entity_types=request.entity_types,
|
|
|
|
|
- progress_callback=progress_callback
|
|
|
|
|
|
|
+ entity_types=request.entity_types
|
|
|
):
|
|
):
|
|
|
- yield event
|
|
|
|
|
- all_entities = event.get('entities', all_entities) if isinstance(event, dict) else all_entities
|
|
|
|
|
|
|
+ # 转发进度事件
|
|
|
|
|
+ yield event_str
|
|
|
|
|
+
|
|
|
|
|
+ # 解析事件获取实体数据
|
|
|
|
|
+ if "entities_data" in event_str:
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 从 SSE 格式中提取 JSON 数据
|
|
|
|
|
+ lines = event_str.strip().split('\n')
|
|
|
|
|
+ for line in lines:
|
|
|
|
|
+ if line.startswith('data:'):
|
|
|
|
|
+ data = json.loads(line[5:].strip())
|
|
|
|
|
+ if 'entities' in data:
|
|
|
|
|
+ # 将字典转换回 EntityInfo 对象
|
|
|
|
|
+ all_entities = [
|
|
|
|
|
+ EntityInfo(**e) for e in data['entities']
|
|
|
|
|
+ ]
|
|
|
|
|
+ except Exception as parse_err:
|
|
|
|
|
+ logger.warning(f"解析实体数据事件失败: {parse_err}")
|
|
|
else:
|
|
else:
|
|
|
# 回退到普通提取
|
|
# 回退到普通提取
|
|
|
all_entities = await ner_service.extract_entities(
|
|
all_entities = await ner_service.extract_entities(
|