api-server/rag-worker/pipelines/import_pipeline.py
WangDL c9882c8d04
All checks were successful
Deploy API Server / build-and-deploy (push) Successful in 15s
add rerank module + bug fixes from e2e test
- New reranker.py: SiliconFlow bge-reranker-v2-m3 integration
- config.py: add RERANK_MODEL
- api_client.py: fix get_next_job/claim_job/get_job_detail unwrapping
- candidate_generator.py: fix .format() conflict with JSON braces
- import_pipeline.py: fix file existence check + UUID point IDs
- Add .gitignore for __pycache__

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-20 16:05:09 +08:00

134 lines
5.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""导入主流程:下载 → 解析 → 清洗 → 切片 → embedding → Qdrant → AI 候选"""
import os
import uuid
from parser import download_file, parse_document
from chunker import chunk_document
from embedder import embed_batch
from indexer import upsert_points
from candidate_generator import generate_candidates
from api_client import (
heartbeat as send_heartbeat,
update_job_status,
save_chunks,
save_candidates,
get_job_detail,
)
async def run_import(job: dict):
"""执行完整的文档导入流程"""
job_id = job["id"]
source_id = job.get("sourceId") or job.get("source_id")
user_id = job["userId"] or job.get("user_id")
kb_id = job["knowledgeBaseId"] or job.get("knowledge_base_id")
file_id = job.get("fileId") or job.get("file_id")
if not source_id:
raise ValueError(f"任务 {job_id} 缺少 sourceId")
# 获取 source 详情(从 NestJS
detail = await get_job_detail(job_id)
source = (detail or {}).get("source", {}) if detail else {}
mime_type = source.get("mimeType") or source.get("mime_type") or "text/plain"
original_filename = source.get("originalFilename") or source.get("original_filename") or "unknown"
tmp_dir = f"/data/tmp/imports/{job_id}"
file_path = os.path.join(tmp_dir, original_filename)
try:
# 1. 下载文件
await update_job_status(job_id, "DOWNLOADING", {"progress": 5})
file_url = source.get("downloadUrl") or (detail or {}).get("downloadUrl", "")
if file_url:
await download_file(file_url, file_path)
# 2. 解析
await update_job_status(job_id, "PARSING", {"progress": 20})
text = ""
if os.path.exists(file_path):
text = await parse_document(file_path, mime_type)
# 如果文件不在本地(纯文本导入),直接从 source/import 中取文本
detail_job = (detail or {}).get("job", {})
if not text:
text = job.get("rawText") or source.get("rawText") or detail_job.get("rawText") or ""
if not text or len(text.strip()) < 10:
raise ValueError("文档解析后内容过少,可能为空白或损坏文件")
# 3. 清洗
await update_job_status(job_id, "CLEANING", {"progress": 40, "textLength": len(text)})
# 4. 切片
await update_job_status(job_id, "CHUNKING", {"progress": 50})
source_type = source.get("type") or "text"
chunks = chunk_document(text, source_type)
# 5. Embedding
await update_job_status(job_id, "EMBEDDING", {"progress": 60})
texts = [c["content"] for c in chunks]
vectors = await embed_batch(texts)
# 6. Qdrant 索引
await update_job_status(job_id, "INDEXING", {"progress": 80})
points = []
chunk_records = []
for i, (chunk, vec) in enumerate(zip(chunks, vectors)):
chunk_id = str(uuid.uuid4())
points.append({
"id": chunk_id,
"vector": vec,
"payload": {
"userId": user_id,
"knowledgeBaseId": kb_id,
"sourceId": source_id,
"chunkId": chunk_id,
"pageNumber": chunk.get("pageNumber"),
"sectionTitle": chunk.get("sectionTitle", ""),
"deleted": False,
},
})
chunk_records.append({
"userId": user_id,
"knowledgeBaseId": kb_id,
"sourceId": source_id,
"content": chunk["content"],
"chunkIndex": chunk["chunkIndex"],
"pageNumber": chunk.get("pageNumber"),
"sectionTitle": chunk.get("sectionTitle", ""),
"tokenCount": len(chunk["content"]),
"externalVectorId": chunk_id,
"embeddingModel": "bge-m3",
"embeddingStatus": "COMPLETED",
"metadataJson": {"chunkType": chunk.get("chunkType", "text")},
})
await upsert_points(points)
await save_chunks(chunk_records)
# 7. 生成候选知识点(非致命:失败不影响导入完成)
await update_job_status(job_id, "GENERATING_CANDIDATES", {"progress": 90})
try:
candidates = await generate_candidates(text)
if candidates:
await save_candidates(user_id, kb_id, source_id, job_id, candidates)
except Exception as e:
print(f"[worker] 候选知识点生成失败(非致命): {e}")
# 8. 完成
await update_job_status(job_id, "COMPLETED", {"progress": 100})
except Exception as e:
await update_job_status(job_id, "FAILED_RETRYABLE", {
"errorCode": "WORKER_ERROR",
"errorMessage": str(e)[:500],
})
raise
finally:
# 清理临时文件
if os.path.exists(tmp_dir):
import shutil
shutil.rmtree(tmp_dir, ignore_errors=True)