feat: Python RAG Worker + NestJS 内部 API(文档解析/切片/embedding/Qdrant/候选生成)
Some checks failed
Deploy API Server / build-and-deploy (push) Failing after 22s

This commit is contained in:
WangDL 2026-05-19 22:35:12 +08:00
parent c149b96b04
commit fbdae9078f
13 changed files with 962 additions and 0 deletions

94
rag-worker/api_client.py Normal file
View File

@ -0,0 +1,94 @@
import httpx
from config import API_BASE_URL, RAG_WORKER_SECRET, WORKER_ID
_auth_headers = {
"Authorization": f"Bearer {RAG_WORKER_SECRET}",
"X-Worker-Id": WORKER_ID,
}
async def get_next_job() -> dict | None:
"""获取下一个 QUEUED 导入任务"""
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.get(
f"{API_BASE_URL}/internal/rag/jobs/next",
headers=_auth_headers,
)
if resp.status_code == 200:
data = resp.json()
return data.get("data") or data.get("job")
return None
async def claim_job(job_id: str) -> bool:
"""认领任务"""
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(
f"{API_BASE_URL}/internal/rag/jobs/{job_id}/claim",
headers=_auth_headers,
)
return resp.status_code == 200
async def heartbeat(job_id: str) -> bool:
"""发送心跳"""
async with httpx.AsyncClient(timeout=10) as client:
resp = await client.post(
f"{API_BASE_URL}/internal/rag/jobs/{job_id}/heartbeat",
headers=_auth_headers,
)
return resp.status_code == 200
async def update_job_status(job_id: str, status: str, data: dict | None = None):
"""更新导入任务状态"""
async with httpx.AsyncClient(timeout=30) as client:
await client.post(
f"{API_BASE_URL}/internal/rag/jobs/{job_id}/status",
headers=_auth_headers,
json={"status": status, **(data or {})},
)
async def save_chunks(chunks: list[dict]):
"""批量保存 KnowledgeChunk"""
async with httpx.AsyncClient(timeout=60) as client:
await client.post(
f"{API_BASE_URL}/internal/rag/chunks",
headers=_auth_headers,
json={"chunks": chunks},
)
async def save_candidates(
user_id: str,
kb_id: str,
source_id: str,
import_id: str,
candidates: list[dict],
):
"""保存候选知识点"""
async with httpx.AsyncClient(timeout=60) as client:
await client.post(
f"{API_BASE_URL}/internal/rag/candidates",
headers=_auth_headers,
json={
"userId": user_id,
"knowledgeBaseId": kb_id,
"sourceId": source_id,
"importId": import_id,
"candidates": candidates,
},
)
async def get_job_detail(job_id: str) -> dict | None:
"""获取任务详情(含 source 信息)"""
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.get(
f"{API_BASE_URL}/internal/rag/jobs/{job_id}",
headers=_auth_headers,
)
if resp.status_code == 200:
return resp.json()
return None

View File

@ -0,0 +1,103 @@
"""候选知识点生成:调用 DeepSeek 分析文本,生成 ImportCandidate"""
import json
import httpx
from config import DEEPSEEK_API_KEY, DEEPSEEK_BASE_URL, DEEPSEEK_MODEL
MAX_CANDIDATES = 30
MIN_CANDIDATES = 3
CHARS_PER_CANDIDATE = 2000
_PROMPT = """你是一个学习助手。请分析以下文档内容,提取关键知识点。
对于每个知识点请提供
- title: 知识点标题简洁不超过 30
- summary: 一句话概述不超过 80
- content: 详细解释基于原文保持准确
- tags: 2-4 个标签
- recallQuestions: 1-2 个主动回忆问题
- difficulty: 难度评估easy/medium/hard
- confidence: 你对这个知识点重要性的置信度0.0-1.0
请以 JSON 数组格式返回每个元素是一个知识点
```json
[{
"title": "知识点标题",
"summary": "一句话概述",
"content": "详细解释...",
"tags": ["标签1", "标签2"],
"recallQuestions": ["问题1", "问题2"],
"difficulty": "medium",
"confidence": 0.85
}]
```
文档内容
{text}
"""
async def generate_candidates(text: str) -> list[dict]:
"""用 DeepSeek 生成候选知识点"""
# 估算生成数量
text_len = len(text)
expected_count = max(MIN_CANDIDATES, min(MAX_CANDIDATES, text_len // CHARS_PER_CANDIDATE))
prompt = _PROMPT.format(text=text[:16000]) # 限制上下文长度
async with httpx.AsyncClient(timeout=120) as client:
resp = await client.post(
f"{DEEPSEEK_BASE_URL}/chat/completions",
headers={"Authorization": f"Bearer {DEEPSEEK_API_KEY}"},
json={
"model": DEEPSEEK_MODEL,
"messages": [
{"role": "system", "content": "你是一个专业的学习内容分析师。请始终返回有效的 JSON 数组。"},
{"role": "user", "content": prompt},
],
"temperature": 0.3,
"max_tokens": 4096,
},
)
if resp.status_code != 200:
raise RuntimeError(f"DeepSeek API error: {resp.status_code} {resp.text}")
data = resp.json()
raw = data["choices"][0]["message"]["content"]
# 提取 JSON
return _parse_json_response(raw, expected_count)
def _parse_json_response(raw: str, expected_count: int) -> list[dict]:
"""从 AI 回复中提取 JSON 数组"""
# 尝试直接解析
try:
candidates = json.loads(raw)
if isinstance(candidates, list):
return candidates[:MAX_CANDIDATES]
except json.JSONDecodeError:
pass
# 提取 ```json ... ``` 块
import re
m = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", raw, re.DOTALL)
if m:
try:
candidates = json.loads(m.group(1))
if isinstance(candidates, list):
return candidates[:MAX_CANDIDATES]
except json.JSONDecodeError:
pass
# 提取 [ ... ] 块
m = re.search(r"\[.*\]", raw, re.DOTALL)
if m:
try:
candidates = json.loads(m.group(0))
if isinstance(candidates, list):
return candidates[:MAX_CANDIDATES]
except json.JSONDecodeError:
pass
raise ValueError(f"无法解析 AI 候选知识点回复: {raw[:500]}")

120
rag-worker/chunker.py Normal file
View File

@ -0,0 +1,120 @@
"""文本切片:递归字符分割 + 中文分句保护"""
import re
from config import CHUNK_SIZE, CHUNK_OVERLAP
# 中文分句模式
_CN_SENT_PATTERN = re.compile(
r"([。!?;\n]|(?<!\d)\.(?!\d)|!\?|\?!)"
)
# Markdown 标题
_MD_HEADING = re.compile(r"^#{1,6}\s+", re.MULTILINE)
def _split_sentences(text: str) -> list[str]:
"""按中文标点分句,保留标点在句尾"""
parts = _CN_SENT_PATTERN.split(text)
sentences = []
buf = ""
for p in parts:
if not p:
continue
buf += p
if _CN_SENT_PATTERN.match(p):
sentences.append(buf)
buf = ""
if buf.strip():
sentences.append(buf)
return sentences
def _split_by_heading(md_text: str) -> list[dict]:
"""按 Markdown 标题分层切片,保留标题作为 sectionTitle"""
lines = md_text.split("\n")
chunks = []
current_title = ""
current_text = ""
for line in lines:
m = _MD_HEADING.match(line)
if m:
# 保存前一段
if current_text.strip():
chunks.append({"sectionTitle": current_title, "text": current_text.strip()})
current_title = line.strip()
current_text = ""
else:
current_text += line + "\n"
if current_text.strip():
chunks.append({"sectionTitle": current_title, "text": current_text.strip()})
return chunks if chunks else [{"sectionTitle": "", "text": md_text}]
def _estimate_tokens(text: str) -> int:
"""粗略估算 token 数量(中文按字符数,英文按词数)"""
cn_chars = len(re.findall(r"[一-鿿]", text))
en_words = len(re.findall(r"[a-zA-Z]+", text))
# 中文约 1.5 字符/token英文约 1 词/token
return int(cn_chars / 1.5) + en_words
def _chunk_text(text: str, section_title: str = "", page_number: int | None = None) -> list[dict]:
"""递归分割 + 重叠切块"""
sentences = _split_sentences(text)
chunks = []
buf = ""
buf_tokens = 0
for s in sentences:
s_tokens = _estimate_tokens(s)
if buf_tokens + s_tokens > CHUNK_SIZE and buf_tokens > 0:
chunks.append({"content": buf.strip(), "sectionTitle": section_title, "pageNumber": page_number})
# 重叠:保留最后 overlap tokens
if CHUNK_OVERLAP > 0:
overlap_text = buf[-int(CHUNK_OVERLAP * 2):] # 粗略估算
buf = overlap_text + s
buf_tokens = _estimate_tokens(overlap_text) + s_tokens
else:
buf = s
buf_tokens = s_tokens
else:
buf += s
buf_tokens += s_tokens
if buf.strip():
chunks.append({"content": buf.strip(), "sectionTitle": section_title, "pageNumber": page_number})
return chunks
def chunk_document(text: str, source_type: str = "text") -> list[dict]:
"""
对文档进行切片返回 chunk 列表
每个 chunk: {content, sectionTitle, pageNumber, chunkType}
"""
if source_type in ("md", "markdown"):
sections = _split_by_heading(text)
else:
sections = [{"sectionTitle": "", "text": text}]
all_chunks = []
for sec in sections:
sec_chunks = _chunk_text(sec["text"], section_title=sec.get("sectionTitle", ""))
all_chunks.extend(sec_chunks)
# 添加 chunkType
for i, c in enumerate(all_chunks):
c["chunkIndex"] = i
# 检测表格/代码块
content = c["content"]
if content.count("|") > 5 and "---" in content:
c["chunkType"] = "table"
elif content.strip().startswith("```") or "```" in content:
c["chunkType"] = "code"
else:
c["chunkType"] = "text"
return all_chunks

29
rag-worker/config.py Normal file
View File

@ -0,0 +1,29 @@
import os
# NestJS 内部 API
API_BASE_URL = os.getenv("API_BASE_URL", "http://127.0.0.1:3000")
RAG_WORKER_SECRET = os.getenv("RAG_WORKER_SECRET", "")
# SiliconFlow
SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY", "")
SILICONFLOW_BASE_URL = os.getenv("SILICONFLOW_BASE_URL", "https://api.siliconflow.cn/v1")
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "BAAI/bge-m3")
EMBEDDING_DIM = int(os.getenv("EMBEDDING_DIM", "1024"))
# DeepSeek
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY", "")
DEEPSEEK_BASE_URL = os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com/v1")
DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL", "deepseek-chat")
# Qdrant
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_COLLECTION = os.getenv("QDRANT_COLLECTION", "zhixi_chunks")
# Chunking
CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "512"))
CHUNK_OVERLAP = int(os.getenv("CHUNK_OVERLAP", "64"))
# Worker
WORKER_ID = os.getenv("WORKER_ID", f"worker-{os.getpid()}")
POLL_INTERVAL = int(os.getenv("POLL_INTERVAL", "5"))
HEARTBEAT_INTERVAL = int(os.getenv("HEARTBEAT_INTERVAL", "30"))

63
rag-worker/embedder.py Normal file
View File

@ -0,0 +1,63 @@
"""Embedding 服务:调用硅基流动 bge-m3"""
import asyncio
import httpx
from config import (
SILICONFLOW_API_KEY,
SILICONFLOW_BASE_URL,
EMBEDDING_MODEL,
EMBEDDING_DIM,
)
BATCH_SIZE = 50
MAX_RETRIES = 2
async def embed_single(text: str) -> list[float]:
"""单条文本 embedding"""
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(
f"{SILICONFLOW_BASE_URL}/embeddings",
headers={"Authorization": f"Bearer {SILICONFLOW_API_KEY}"},
json={
"model": EMBEDDING_MODEL,
"input": [text],
},
)
if resp.status_code != 200:
raise RuntimeError(f"Embedding API error: {resp.status_code} {resp.text}")
data = resp.json()
return data["data"][0]["embedding"]
async def embed_batch(texts: list[str]) -> list[list[float]]:
"""批量 embedding自动分批 + 重试"""
all_embeddings = []
for i in range(0, len(texts), BATCH_SIZE):
batch = texts[i:i + BATCH_SIZE]
for attempt in range(MAX_RETRIES + 1):
try:
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(
f"{SILICONFLOW_BASE_URL}/embeddings",
headers={"Authorization": f"Bearer {SILICONFLOW_API_KEY}"},
json={
"model": EMBEDDING_MODEL,
"input": batch,
},
)
if resp.status_code == 200:
data = resp.json()
all_embeddings.extend([d["embedding"] for d in data["data"]])
break
else:
err = f"Status {resp.status_code}"
if attempt == MAX_RETRIES:
raise RuntimeError(f"Embedding batch failed after {MAX_RETRIES} retries: {err}")
except Exception as e:
if attempt == MAX_RETRIES:
raise RuntimeError(f"Embedding batch failed: {e}")
await asyncio.sleep(2 ** attempt)
return all_embeddings

60
rag-worker/indexer.py Normal file
View File

@ -0,0 +1,60 @@
"""Qdrant 索引服务"""
import httpx
from config import QDRANT_URL, QDRANT_COLLECTION
async def upsert_points(points: list[dict]):
"""批量写入 Qdrant points"""
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.put(
f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}/points",
params={"wait": "true"},
json={"points": points},
)
if resp.status_code != 200:
raise RuntimeError(f"Qdrant upsert failed: {resp.text}")
async def search(
vector: list[float],
user_id: str,
knowledge_base_id: str,
top_k: int = 5,
) -> list[dict]:
"""语义检索"""
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(
f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}/points/search",
json={
"vector": vector,
"filter": {
"must": [
{"key": "userId", "match": {"value": user_id}},
{"key": "knowledgeBaseId", "match": {"value": knowledge_base_id}},
{"key": "deleted", "match": {"value": False}},
],
},
"limit": top_k,
"with_payload": True,
},
)
if resp.status_code != 200:
raise RuntimeError(f"Qdrant search failed: {resp.text}")
return resp.json()["result"]
async def mark_deleted(source_id: str):
"""将指定 source 的所有 points 标记为 deleted=true"""
async with httpx.AsyncClient(timeout=30) as client:
await client.post(
f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}/points/update",
json={
"filter": {
"must": [
{"key": "sourceId", "match": {"value": source_id}},
]
},
"set": {"deleted": True},
},
)

84
rag-worker/main.py Normal file
View File

@ -0,0 +1,84 @@
"""知习 RAG Worker — 文档导入主进程"""
import asyncio
import signal
import sys
from config import WORKER_ID, POLL_INTERVAL, HEARTBEAT_INTERVAL
from api_client import get_next_job, claim_job, heartbeat, update_job_status
from pipelines.import_pipeline import run_import
running = True
def shutdown(sig, frame):
global running
print(f"[{WORKER_ID}] 收到信号 {sig},正在退出...")
running = False
signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGTERM, shutdown)
async def heartbeat_loop():
"""心跳循环(所有活跃任务)"""
# 简化实现worker 级心跳,后续可扩展到 per-job 心跳
while running:
await asyncio.sleep(HEARTBEAT_INTERVAL)
async def work_loop():
"""主工作循环:轮询 → 认领 → 执行"""
print(f"[{WORKER_ID}] RAG Worker 已启动")
while running:
try:
job = await get_next_job()
if not job:
await asyncio.sleep(POLL_INTERVAL)
continue
job_id = job.get("id") or job.get("jobId")
if not job_id:
continue
# 认领任务
claimed = await claim_job(job_id)
if not claimed:
continue
print(f"[{WORKER_ID}] 开始处理任务 {job_id}")
# 启动心跳(后台任务)
hb_task = asyncio.create_task(_per_job_heartbeat(job_id))
try:
await run_import(job)
print(f"[{WORKER_ID}] 任务 {job_id} 完成")
except Exception as e:
print(f"[{WORKER_ID}] 任务 {job_id} 失败: {e}")
await update_job_status(job_id, "FAILED_RETRYABLE", {
"errorMessage": str(e)[:500],
})
finally:
hb_task.cancel()
except Exception as e:
print(f"[{WORKER_ID}] 轮询异常: {e}")
await asyncio.sleep(POLL_INTERVAL)
print(f"[{WORKER_ID}] Worker 已停止")
async def _per_job_heartbeat(job_id: str):
"""单个任务的心跳上报"""
while running:
try:
await heartbeat(job_id)
except Exception:
pass
await asyncio.sleep(HEARTBEAT_INTERVAL)
if __name__ == "__main__":
asyncio.run(work_loop())

137
rag-worker/parser.py Normal file
View File

@ -0,0 +1,137 @@
"""文档解析PDF / DOCX / TXT / MD / CSV / XLSX"""
import os
import io
import base64
import httpx
from config import SILICONFLOW_API_KEY, SILICONFLOW_BASE_URL
async def download_file(url: str, local_path: str) -> str:
"""从 COS 预签名 URL 下载文件到本地"""
async with httpx.AsyncClient(timeout=120, follow_redirects=True) as client:
resp = await client.get(url)
resp.raise_for_status()
os.makedirs(os.path.dirname(local_path), exist_ok=True)
with open(local_path, "wb") as f:
f.write(resp.content)
return local_path
def parse_txt(file_path: str) -> str:
with open(file_path, "r", encoding="utf-8", errors="replace") as f:
return f.read()
def parse_markdown(file_path: str) -> str:
return parse_txt(file_path)
def parse_docx(file_path: str) -> str:
from docx import Document
doc = Document(file_path)
return "\n\n".join(p.text for p in doc.paragraphs if p.text.strip())
def parse_pdf_text(file_path: str) -> str:
"""用 PyMuPDF 提取 PDF 文本层"""
import fitz
doc = fitz.open(file_path)
pages = []
for page in doc:
text = page.get_text()
if text.strip():
pages.append(text)
doc.close()
return "\n\n".join(pages)
def pdf_needs_ocr(file_path: str) -> bool:
"""判断 PDF 是否需要 OCR文本层为空或极少文字"""
import fitz
doc = fitz.open(file_path)
total_len = sum(len(page.get_text().strip()) for page in doc)
doc.close()
# 平均每页少于 50 字符 → 扫描件
page_count = max(doc.page_count if hasattr(doc, 'page_count') else len(doc), 1)
return (total_len / page_count) < 50
async def ocr_with_siliconflow(image_bytes: bytes) -> str:
"""用硅基流动多模态模型做 OCR / 图文识别"""
b64 = base64.b64encode(image_bytes).decode()
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(
f"{SILICONFLOW_BASE_URL}/chat/completions",
headers={"Authorization": f"Bearer {SILICONFLOW_API_KEY}"},
json={
"model": "Qwen/Qwen3-VL-32B-Instruct",
"messages": [{
"role": "user",
"content": [
{"type": "text", "text": "请识别并提取这张图片中的所有文字内容。如果有表格,请用 Markdown 表格格式输出。不要添加任何解释。"},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64}"}},
],
}],
"max_tokens": 4096,
},
)
data = resp.json()
return data["choices"][0]["message"]["content"]
async def parse_image_with_ocr(file_path: str) -> str:
"""对图片进行 OCR"""
with open(file_path, "rb") as f:
image_bytes = f.read()
return await ocr_with_siliconflow(image_bytes)
def parse_csv(file_path: str) -> str:
import pandas as pd
df = pd.read_csv(file_path)
return df.to_markdown(index=False)
def parse_xlsx(file_path: str) -> str:
import pandas as pd
df = pd.read_excel(file_path)
return df.to_markdown(index=False)
async def parse_document(file_path: str, mime_type: str) -> str:
"""根据文件类型路由到合适的解析器"""
ext = os.path.splitext(file_path)[1].lower()
if ext in (".txt",):
return parse_txt(file_path)
elif ext in (".md", ".markdown"):
return parse_markdown(file_path)
elif ext in (".docx",):
return parse_docx(file_path)
elif ext in (".csv",):
return parse_csv(file_path)
elif ext in (".xlsx",):
return parse_xlsx(file_path)
elif ext in (".pdf",):
if pdf_needs_ocr(file_path):
# 扫描件——先尝试文本提取,空则走多模态
text = parse_pdf_text(file_path)
if len(text.strip()) < 100:
# 全扫描件,逐页 OCR
import fitz
doc = fitz.open(file_path)
results = []
for i, page in enumerate(doc):
pix = page.get_pixmap(dpi=150)
img_bytes = pix.tobytes("png")
page_text = await ocr_with_siliconflow(img_bytes)
results.append(page_text)
doc.close()
return "\n\n".join(results)
return text
return parse_pdf_text(file_path)
elif ext in (".png", ".jpg", ".jpeg", ".webp", ".heic", ".bmp"):
return await parse_image_with_ocr(file_path)
else:
raise ValueError(f"不支持的文件类型: {ext}")

View File

@ -0,0 +1,127 @@
"""导入主流程:下载 → 解析 → 清洗 → 切片 → embedding → Qdrant → AI 候选"""
import os
import uuid
from parser import download_file, parse_document
from chunker import chunk_document
from embedder import embed_batch
from indexer import upsert_points
from candidate_generator import generate_candidates
from api_client import (
heartbeat as send_heartbeat,
update_job_status,
save_chunks,
save_candidates,
get_job_detail,
)
async def run_import(job: dict):
"""执行完整的文档导入流程"""
job_id = job["id"]
source_id = job.get("sourceId") or job.get("source_id")
user_id = job["userId"] or job.get("user_id")
kb_id = job["knowledgeBaseId"] or job.get("knowledge_base_id")
file_id = job.get("fileId") or job.get("file_id")
if not source_id:
raise ValueError(f"任务 {job_id} 缺少 sourceId")
# 获取 source 详情(从 NestJS
detail = await get_job_detail(job_id)
source = (detail or {}).get("source", {}) if detail else {}
mime_type = source.get("mimeType") or source.get("mime_type") or "text/plain"
original_filename = source.get("originalFilename") or source.get("original_filename") or "unknown"
tmp_dir = f"/data/tmp/imports/{job_id}"
file_path = os.path.join(tmp_dir, original_filename)
try:
# 1. 下载文件
await update_job_status(job_id, "DOWNLOADING", {"progress": 5})
file_url = source.get("downloadUrl") or (detail or {}).get("downloadUrl", "")
if file_url:
await download_file(file_url, file_path)
# 2. 解析
await update_job_status(job_id, "PARSING", {"progress": 20})
text = await parse_document(file_path, mime_type)
# 如果文件不在本地(纯文本导入),直接从 source/import 中取文本
if not text and (job.get("rawText") or source.get("rawText")):
text = job.get("rawText", "") or source.get("rawText", "")
if not text or len(text.strip()) < 10:
raise ValueError("文档解析后内容过少,可能为空白或损坏文件")
# 3. 清洗
await update_job_status(job_id, "CLEANING", {"progress": 40, "textLength": len(text)})
# 4. 切片
await update_job_status(job_id, "CHUNKING", {"progress": 50})
source_type = source.get("type") or "text"
chunks = chunk_document(text, source_type)
# 5. Embedding
await update_job_status(job_id, "EMBEDDING", {"progress": 60})
texts = [c["content"] for c in chunks]
vectors = await embed_batch(texts)
# 6. Qdrant 索引
await update_job_status(job_id, "INDEXING", {"progress": 80})
points = []
chunk_records = []
for i, (chunk, vec) in enumerate(zip(chunks, vectors)):
chunk_id = f"chunk_{source_id}_{i}"
points.append({
"id": chunk_id,
"vector": vec,
"payload": {
"userId": user_id,
"knowledgeBaseId": kb_id,
"sourceId": source_id,
"chunkId": chunk_id,
"pageNumber": chunk.get("pageNumber"),
"sectionTitle": chunk.get("sectionTitle", ""),
"deleted": False,
},
})
chunk_records.append({
"userId": user_id,
"knowledgeBaseId": kb_id,
"sourceId": source_id,
"content": chunk["content"],
"chunkIndex": chunk["chunkIndex"],
"pageNumber": chunk.get("pageNumber"),
"sectionTitle": chunk.get("sectionTitle", ""),
"tokenCount": len(chunk["content"]),
"externalVectorId": chunk_id,
"embeddingModel": "bge-m3",
"embeddingStatus": "COMPLETED",
"metadataJson": {"chunkType": chunk.get("chunkType", "text")},
})
await upsert_points(points)
await save_chunks(chunk_records)
# 7. 生成候选知识点
await update_job_status(job_id, "GENERATING_CANDIDATES", {"progress": 90})
candidates = await generate_candidates(text)
if candidates:
await save_candidates(user_id, kb_id, source_id, job_id, candidates)
# 8. 完成
await update_job_status(job_id, "COMPLETED", {"progress": 100})
except Exception as e:
await update_job_status(job_id, "FAILED_RETRYABLE", {
"errorCode": "WORKER_ERROR",
"errorMessage": str(e)[:500],
})
raise
finally:
# 清理临时文件
if os.path.exists(tmp_dir):
import shutil
shutil.rmtree(tmp_dir, ignore_errors=True)

View File

@ -0,0 +1,8 @@
httpx>=0.27
pydantic>=2.0
pymupdf>=1.24
python-docx>=1.1
markdown>=3.5
pandas>=2.0
openpyxl>=3.1
Pillow>=10.0

View File

@ -28,6 +28,7 @@ import { FilesModule } from './modules/files/files.module';
import { WaitlistModule } from './modules/waitlist/waitlist.module'; import { WaitlistModule } from './modules/waitlist/waitlist.module';
import { KnowledgeSourceModule } from './modules/knowledge-source/knowledge-source.module'; import { KnowledgeSourceModule } from './modules/knowledge-source/knowledge-source.module';
import { ImportCandidateModule } from './modules/import-candidate/import-candidate.module'; import { ImportCandidateModule } from './modules/import-candidate/import-candidate.module';
import { RagModule } from './modules/rag/rag.module';
import { JwtAuthGuard } from './common/guards/jwt-auth.guard'; import { JwtAuthGuard } from './common/guards/jwt-auth.guard';
import { RolesGuard } from './common/guards/roles.guard'; import { RolesGuard } from './common/guards/roles.guard';
@ -85,6 +86,7 @@ import appleConfig from './config/apple.config';
KnowledgeSourceModule, KnowledgeSourceModule,
ImportCandidateModule, ImportCandidateModule,
DocumentImportModule, DocumentImportModule,
RagModule,
LearningSessionModule, LearningSessionModule,
ActiveRecallModule, ActiveRecallModule,
AiAnalysisModule, AiAnalysisModule,

View File

@ -0,0 +1,124 @@
import { Controller, Get, Post, Body, Param } from '@nestjs/common';
import { ApiTags } from '@nestjs/swagger';
import { DocumentImportRepository } from '../document-import/document-import.repository';
import { KnowledgeSourceRepository } from '../knowledge-source/knowledge-source.repository';
import { ImportCandidateRepository } from '../import-candidate/import-candidate.repository';
import { PrismaService } from '../../infrastructure/database/prisma.service';
@ApiTags('internal-rag')
@Controller('internal/rag')
export class InternalRagController {
constructor(
private readonly importRepo: DocumentImportRepository,
private readonly sourceRepo: KnowledgeSourceRepository,
private readonly candidateRepo: ImportCandidateRepository,
private readonly prisma: PrismaService,
) {}
@Get('jobs/next')
async getNextJob() {
const job = await this.importRepo.claimNext(''); // 先查询,不认领
if (!job) return { job: null };
return {
job: {
id: job.id,
userId: job.userId,
knowledgeBaseId: job.knowledgeBaseId,
sourceId: job.sourceId,
fileId: job.fileId,
sourceType: job.sourceType,
sourceName: job.sourceName,
rawText: job.rawText,
status: job.status,
},
};
}
@Get('jobs/:id')
async getJobDetail(@Param('id') id: string) {
const job = await this.importRepo.findById(id);
if (!job) return { job: null };
let source = null;
let downloadUrl = null;
if (job.sourceId) {
source = await this.sourceRepo.findById(job.sourceId);
}
return {
job: {
id: job.id,
userId: job.userId,
knowledgeBaseId: job.knowledgeBaseId,
sourceId: job.sourceId,
fileId: job.fileId,
rawText: job.rawText,
status: job.status,
},
source: source ? {
id: source.id,
type: source.type,
originalFilename: source.originalFilename,
mimeType: source.mimeType,
sizeBytes: Number(source.sizeBytes),
originalObjectKey: source.originalObjectKey,
} : null,
};
}
@Post('jobs/:id/claim')
async claimJob(@Param('id') id: string, @Body() body: { workerId?: string }) {
const workerId = body.workerId || 'unknown';
const result = await this.importRepo.claim(id, workerId);
return { success: result.count > 0 };
}
@Post('jobs/:id/heartbeat')
async heartbeat(@Param('id') id: string) {
await this.importRepo.heartbeat(id);
return { success: true };
}
@Post('jobs/:id/status')
async updateStatus(
@Param('id') id: string,
@Body() body: { status: string; progress?: number; errorCode?: string; errorMessage?: string },
) {
await this.importRepo.updateStatus(id, body.status, {
step: body.status,
progress: body.progress,
errorCode: body.errorCode,
errorMessage: body.errorMessage,
});
return { success: true };
}
@Post('chunks')
async saveChunks(@Body() body: { chunks: any[] }) {
const chunks = body.chunks || [];
if (chunks.length > 0) {
await this.prisma.knowledgeChunk.createMany({ data: chunks });
}
return { success: true, count: chunks.length };
}
@Post('candidates')
async saveCandidates(
@Body() body: {
userId: string;
knowledgeBaseId: string;
sourceId: string;
importId: string;
candidates: any[];
},
) {
await this.candidateRepo.createMany(
body.userId,
body.knowledgeBaseId,
body.sourceId,
body.importId,
body.candidates || [],
);
return { success: true, count: body.candidates?.length || 0 };
}
}

View File

@ -0,0 +1,11 @@
import { Module } from '@nestjs/common';
import { InternalRagController } from './internal-rag.controller';
import { DocumentImportModule } from '../document-import/document-import.module';
import { KnowledgeSourceModule } from '../knowledge-source/knowledge-source.module';
import { ImportCandidateModule } from '../import-candidate/import-candidate.module';
@Module({
imports: [DocumentImportModule, KnowledgeSourceModule, ImportCandidateModule],
controllers: [InternalRagController],
})
export class RagModule {}