diff --git a/rag-worker/api_client.py b/rag-worker/api_client.py new file mode 100644 index 0000000..b078e68 --- /dev/null +++ b/rag-worker/api_client.py @@ -0,0 +1,94 @@ +import httpx +from config import API_BASE_URL, RAG_WORKER_SECRET, WORKER_ID + +_auth_headers = { + "Authorization": f"Bearer {RAG_WORKER_SECRET}", + "X-Worker-Id": WORKER_ID, +} + + +async def get_next_job() -> dict | None: + """获取下一个 QUEUED 导入任务""" + async with httpx.AsyncClient(timeout=30) as client: + resp = await client.get( + f"{API_BASE_URL}/internal/rag/jobs/next", + headers=_auth_headers, + ) + if resp.status_code == 200: + data = resp.json() + return data.get("data") or data.get("job") + return None + + +async def claim_job(job_id: str) -> bool: + """认领任务""" + async with httpx.AsyncClient(timeout=30) as client: + resp = await client.post( + f"{API_BASE_URL}/internal/rag/jobs/{job_id}/claim", + headers=_auth_headers, + ) + return resp.status_code == 200 + + +async def heartbeat(job_id: str) -> bool: + """发送心跳""" + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.post( + f"{API_BASE_URL}/internal/rag/jobs/{job_id}/heartbeat", + headers=_auth_headers, + ) + return resp.status_code == 200 + + +async def update_job_status(job_id: str, status: str, data: dict | None = None): + """更新导入任务状态""" + async with httpx.AsyncClient(timeout=30) as client: + await client.post( + f"{API_BASE_URL}/internal/rag/jobs/{job_id}/status", + headers=_auth_headers, + json={"status": status, **(data or {})}, + ) + + +async def save_chunks(chunks: list[dict]): + """批量保存 KnowledgeChunk""" + async with httpx.AsyncClient(timeout=60) as client: + await client.post( + f"{API_BASE_URL}/internal/rag/chunks", + headers=_auth_headers, + json={"chunks": chunks}, + ) + + +async def save_candidates( + user_id: str, + kb_id: str, + source_id: str, + import_id: str, + candidates: list[dict], +): + """保存候选知识点""" + async with httpx.AsyncClient(timeout=60) as client: + await client.post( + f"{API_BASE_URL}/internal/rag/candidates", + headers=_auth_headers, + json={ + "userId": user_id, + "knowledgeBaseId": kb_id, + "sourceId": source_id, + "importId": import_id, + "candidates": candidates, + }, + ) + + +async def get_job_detail(job_id: str) -> dict | None: + """获取任务详情(含 source 信息)""" + async with httpx.AsyncClient(timeout=30) as client: + resp = await client.get( + f"{API_BASE_URL}/internal/rag/jobs/{job_id}", + headers=_auth_headers, + ) + if resp.status_code == 200: + return resp.json() + return None diff --git a/rag-worker/candidate_generator.py b/rag-worker/candidate_generator.py new file mode 100644 index 0000000..ed4d218 --- /dev/null +++ b/rag-worker/candidate_generator.py @@ -0,0 +1,103 @@ +"""候选知识点生成:调用 DeepSeek 分析文本,生成 ImportCandidate""" + +import json +import httpx +from config import DEEPSEEK_API_KEY, DEEPSEEK_BASE_URL, DEEPSEEK_MODEL + +MAX_CANDIDATES = 30 +MIN_CANDIDATES = 3 +CHARS_PER_CANDIDATE = 2000 + +_PROMPT = """你是一个学习助手。请分析以下文档内容,提取关键知识点。 + +对于每个知识点,请提供: +- title: 知识点标题(简洁,不超过 30 字) +- summary: 一句话概述(不超过 80 字) +- content: 详细解释(基于原文,保持准确) +- tags: 2-4 个标签 +- recallQuestions: 1-2 个主动回忆问题 +- difficulty: 难度评估(easy/medium/hard) +- confidence: 你对这个知识点重要性的置信度(0.0-1.0) + +请以 JSON 数组格式返回,每个元素是一个知识点: +```json +[{ + "title": "知识点标题", + "summary": "一句话概述", + "content": "详细解释...", + "tags": ["标签1", "标签2"], + "recallQuestions": ["问题1?", "问题2?"], + "difficulty": "medium", + "confidence": 0.85 +}] +``` + +文档内容: +{text} +""" + + +async def generate_candidates(text: str) -> list[dict]: + """用 DeepSeek 生成候选知识点""" + # 估算生成数量 + text_len = len(text) + expected_count = max(MIN_CANDIDATES, min(MAX_CANDIDATES, text_len // CHARS_PER_CANDIDATE)) + + prompt = _PROMPT.format(text=text[:16000]) # 限制上下文长度 + + async with httpx.AsyncClient(timeout=120) as client: + resp = await client.post( + f"{DEEPSEEK_BASE_URL}/chat/completions", + headers={"Authorization": f"Bearer {DEEPSEEK_API_KEY}"}, + json={ + "model": DEEPSEEK_MODEL, + "messages": [ + {"role": "system", "content": "你是一个专业的学习内容分析师。请始终返回有效的 JSON 数组。"}, + {"role": "user", "content": prompt}, + ], + "temperature": 0.3, + "max_tokens": 4096, + }, + ) + if resp.status_code != 200: + raise RuntimeError(f"DeepSeek API error: {resp.status_code} {resp.text}") + + data = resp.json() + raw = data["choices"][0]["message"]["content"] + + # 提取 JSON + return _parse_json_response(raw, expected_count) + + +def _parse_json_response(raw: str, expected_count: int) -> list[dict]: + """从 AI 回复中提取 JSON 数组""" + # 尝试直接解析 + try: + candidates = json.loads(raw) + if isinstance(candidates, list): + return candidates[:MAX_CANDIDATES] + except json.JSONDecodeError: + pass + + # 提取 ```json ... ``` 块 + import re + m = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", raw, re.DOTALL) + if m: + try: + candidates = json.loads(m.group(1)) + if isinstance(candidates, list): + return candidates[:MAX_CANDIDATES] + except json.JSONDecodeError: + pass + + # 提取 [ ... ] 块 + m = re.search(r"\[.*\]", raw, re.DOTALL) + if m: + try: + candidates = json.loads(m.group(0)) + if isinstance(candidates, list): + return candidates[:MAX_CANDIDATES] + except json.JSONDecodeError: + pass + + raise ValueError(f"无法解析 AI 候选知识点回复: {raw[:500]}") diff --git a/rag-worker/chunker.py b/rag-worker/chunker.py new file mode 100644 index 0000000..4f1523b --- /dev/null +++ b/rag-worker/chunker.py @@ -0,0 +1,120 @@ +"""文本切片:递归字符分割 + 中文分句保护""" + +import re +from config import CHUNK_SIZE, CHUNK_OVERLAP + + +# 中文分句模式 +_CN_SENT_PATTERN = re.compile( + r"([。!?;\n]|(? list[str]: + """按中文标点分句,保留标点在句尾""" + parts = _CN_SENT_PATTERN.split(text) + sentences = [] + buf = "" + for p in parts: + if not p: + continue + buf += p + if _CN_SENT_PATTERN.match(p): + sentences.append(buf) + buf = "" + if buf.strip(): + sentences.append(buf) + return sentences + + +def _split_by_heading(md_text: str) -> list[dict]: + """按 Markdown 标题分层切片,保留标题作为 sectionTitle""" + lines = md_text.split("\n") + chunks = [] + current_title = "" + current_text = "" + + for line in lines: + m = _MD_HEADING.match(line) + if m: + # 保存前一段 + if current_text.strip(): + chunks.append({"sectionTitle": current_title, "text": current_text.strip()}) + current_title = line.strip() + current_text = "" + else: + current_text += line + "\n" + + if current_text.strip(): + chunks.append({"sectionTitle": current_title, "text": current_text.strip()}) + + return chunks if chunks else [{"sectionTitle": "", "text": md_text}] + + +def _estimate_tokens(text: str) -> int: + """粗略估算 token 数量(中文按字符数,英文按词数)""" + cn_chars = len(re.findall(r"[一-鿿]", text)) + en_words = len(re.findall(r"[a-zA-Z]+", text)) + # 中文约 1.5 字符/token,英文约 1 词/token + return int(cn_chars / 1.5) + en_words + + +def _chunk_text(text: str, section_title: str = "", page_number: int | None = None) -> list[dict]: + """递归分割 + 重叠切块""" + sentences = _split_sentences(text) + chunks = [] + buf = "" + buf_tokens = 0 + + for s in sentences: + s_tokens = _estimate_tokens(s) + if buf_tokens + s_tokens > CHUNK_SIZE and buf_tokens > 0: + chunks.append({"content": buf.strip(), "sectionTitle": section_title, "pageNumber": page_number}) + # 重叠:保留最后 overlap tokens + if CHUNK_OVERLAP > 0: + overlap_text = buf[-int(CHUNK_OVERLAP * 2):] # 粗略估算 + buf = overlap_text + s + buf_tokens = _estimate_tokens(overlap_text) + s_tokens + else: + buf = s + buf_tokens = s_tokens + else: + buf += s + buf_tokens += s_tokens + + if buf.strip(): + chunks.append({"content": buf.strip(), "sectionTitle": section_title, "pageNumber": page_number}) + + return chunks + + +def chunk_document(text: str, source_type: str = "text") -> list[dict]: + """ + 对文档进行切片,返回 chunk 列表。 + 每个 chunk: {content, sectionTitle, pageNumber, chunkType} + """ + if source_type in ("md", "markdown"): + sections = _split_by_heading(text) + else: + sections = [{"sectionTitle": "", "text": text}] + + all_chunks = [] + for sec in sections: + sec_chunks = _chunk_text(sec["text"], section_title=sec.get("sectionTitle", "")) + all_chunks.extend(sec_chunks) + + # 添加 chunkType + for i, c in enumerate(all_chunks): + c["chunkIndex"] = i + # 检测表格/代码块 + content = c["content"] + if content.count("|") > 5 and "---" in content: + c["chunkType"] = "table" + elif content.strip().startswith("```") or "```" in content: + c["chunkType"] = "code" + else: + c["chunkType"] = "text" + + return all_chunks diff --git a/rag-worker/config.py b/rag-worker/config.py new file mode 100644 index 0000000..2dd54bc --- /dev/null +++ b/rag-worker/config.py @@ -0,0 +1,29 @@ +import os + +# NestJS 内部 API +API_BASE_URL = os.getenv("API_BASE_URL", "http://127.0.0.1:3000") +RAG_WORKER_SECRET = os.getenv("RAG_WORKER_SECRET", "") + +# SiliconFlow +SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY", "") +SILICONFLOW_BASE_URL = os.getenv("SILICONFLOW_BASE_URL", "https://api.siliconflow.cn/v1") +EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "BAAI/bge-m3") +EMBEDDING_DIM = int(os.getenv("EMBEDDING_DIM", "1024")) + +# DeepSeek +DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY", "") +DEEPSEEK_BASE_URL = os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com/v1") +DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL", "deepseek-chat") + +# Qdrant +QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333") +QDRANT_COLLECTION = os.getenv("QDRANT_COLLECTION", "zhixi_chunks") + +# Chunking +CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "512")) +CHUNK_OVERLAP = int(os.getenv("CHUNK_OVERLAP", "64")) + +# Worker +WORKER_ID = os.getenv("WORKER_ID", f"worker-{os.getpid()}") +POLL_INTERVAL = int(os.getenv("POLL_INTERVAL", "5")) +HEARTBEAT_INTERVAL = int(os.getenv("HEARTBEAT_INTERVAL", "30")) diff --git a/rag-worker/embedder.py b/rag-worker/embedder.py new file mode 100644 index 0000000..6ab72df --- /dev/null +++ b/rag-worker/embedder.py @@ -0,0 +1,63 @@ +"""Embedding 服务:调用硅基流动 bge-m3""" + +import asyncio +import httpx +from config import ( + SILICONFLOW_API_KEY, + SILICONFLOW_BASE_URL, + EMBEDDING_MODEL, + EMBEDDING_DIM, +) + +BATCH_SIZE = 50 +MAX_RETRIES = 2 + + +async def embed_single(text: str) -> list[float]: + """单条文本 embedding""" + async with httpx.AsyncClient(timeout=30) as client: + resp = await client.post( + f"{SILICONFLOW_BASE_URL}/embeddings", + headers={"Authorization": f"Bearer {SILICONFLOW_API_KEY}"}, + json={ + "model": EMBEDDING_MODEL, + "input": [text], + }, + ) + if resp.status_code != 200: + raise RuntimeError(f"Embedding API error: {resp.status_code} {resp.text}") + data = resp.json() + return data["data"][0]["embedding"] + + +async def embed_batch(texts: list[str]) -> list[list[float]]: + """批量 embedding,自动分批 + 重试""" + all_embeddings = [] + + for i in range(0, len(texts), BATCH_SIZE): + batch = texts[i:i + BATCH_SIZE] + for attempt in range(MAX_RETRIES + 1): + try: + async with httpx.AsyncClient(timeout=60) as client: + resp = await client.post( + f"{SILICONFLOW_BASE_URL}/embeddings", + headers={"Authorization": f"Bearer {SILICONFLOW_API_KEY}"}, + json={ + "model": EMBEDDING_MODEL, + "input": batch, + }, + ) + if resp.status_code == 200: + data = resp.json() + all_embeddings.extend([d["embedding"] for d in data["data"]]) + break + else: + err = f"Status {resp.status_code}" + if attempt == MAX_RETRIES: + raise RuntimeError(f"Embedding batch failed after {MAX_RETRIES} retries: {err}") + except Exception as e: + if attempt == MAX_RETRIES: + raise RuntimeError(f"Embedding batch failed: {e}") + await asyncio.sleep(2 ** attempt) + + return all_embeddings diff --git a/rag-worker/indexer.py b/rag-worker/indexer.py new file mode 100644 index 0000000..e0147c6 --- /dev/null +++ b/rag-worker/indexer.py @@ -0,0 +1,60 @@ +"""Qdrant 索引服务""" + +import httpx +from config import QDRANT_URL, QDRANT_COLLECTION + + +async def upsert_points(points: list[dict]): + """批量写入 Qdrant points""" + async with httpx.AsyncClient(timeout=60) as client: + resp = await client.put( + f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}/points", + params={"wait": "true"}, + json={"points": points}, + ) + if resp.status_code != 200: + raise RuntimeError(f"Qdrant upsert failed: {resp.text}") + + +async def search( + vector: list[float], + user_id: str, + knowledge_base_id: str, + top_k: int = 5, +) -> list[dict]: + """语义检索""" + async with httpx.AsyncClient(timeout=30) as client: + resp = await client.post( + f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}/points/search", + json={ + "vector": vector, + "filter": { + "must": [ + {"key": "userId", "match": {"value": user_id}}, + {"key": "knowledgeBaseId", "match": {"value": knowledge_base_id}}, + {"key": "deleted", "match": {"value": False}}, + ], + }, + "limit": top_k, + "with_payload": True, + }, + ) + if resp.status_code != 200: + raise RuntimeError(f"Qdrant search failed: {resp.text}") + return resp.json()["result"] + + +async def mark_deleted(source_id: str): + """将指定 source 的所有 points 标记为 deleted=true""" + async with httpx.AsyncClient(timeout=30) as client: + await client.post( + f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}/points/update", + json={ + "filter": { + "must": [ + {"key": "sourceId", "match": {"value": source_id}}, + ] + }, + "set": {"deleted": True}, + }, + ) diff --git a/rag-worker/main.py b/rag-worker/main.py new file mode 100644 index 0000000..e7759bb --- /dev/null +++ b/rag-worker/main.py @@ -0,0 +1,84 @@ +"""知习 RAG Worker — 文档导入主进程""" + +import asyncio +import signal +import sys +from config import WORKER_ID, POLL_INTERVAL, HEARTBEAT_INTERVAL +from api_client import get_next_job, claim_job, heartbeat, update_job_status +from pipelines.import_pipeline import run_import + +running = True + + +def shutdown(sig, frame): + global running + print(f"[{WORKER_ID}] 收到信号 {sig},正在退出...") + running = False + + +signal.signal(signal.SIGINT, shutdown) +signal.signal(signal.SIGTERM, shutdown) + + +async def heartbeat_loop(): + """心跳循环(所有活跃任务)""" + # 简化实现:worker 级心跳,后续可扩展到 per-job 心跳 + while running: + await asyncio.sleep(HEARTBEAT_INTERVAL) + + +async def work_loop(): + """主工作循环:轮询 → 认领 → 执行""" + print(f"[{WORKER_ID}] RAG Worker 已启动") + + while running: + try: + job = await get_next_job() + if not job: + await asyncio.sleep(POLL_INTERVAL) + continue + + job_id = job.get("id") or job.get("jobId") + if not job_id: + continue + + # 认领任务 + claimed = await claim_job(job_id) + if not claimed: + continue + + print(f"[{WORKER_ID}] 开始处理任务 {job_id}") + + # 启动心跳(后台任务) + hb_task = asyncio.create_task(_per_job_heartbeat(job_id)) + + try: + await run_import(job) + print(f"[{WORKER_ID}] 任务 {job_id} 完成") + except Exception as e: + print(f"[{WORKER_ID}] 任务 {job_id} 失败: {e}") + await update_job_status(job_id, "FAILED_RETRYABLE", { + "errorMessage": str(e)[:500], + }) + finally: + hb_task.cancel() + + except Exception as e: + print(f"[{WORKER_ID}] 轮询异常: {e}") + await asyncio.sleep(POLL_INTERVAL) + + print(f"[{WORKER_ID}] Worker 已停止") + + +async def _per_job_heartbeat(job_id: str): + """单个任务的心跳上报""" + while running: + try: + await heartbeat(job_id) + except Exception: + pass + await asyncio.sleep(HEARTBEAT_INTERVAL) + + +if __name__ == "__main__": + asyncio.run(work_loop()) diff --git a/rag-worker/parser.py b/rag-worker/parser.py new file mode 100644 index 0000000..f06cf61 --- /dev/null +++ b/rag-worker/parser.py @@ -0,0 +1,137 @@ +"""文档解析:PDF / DOCX / TXT / MD / CSV / XLSX""" + +import os +import io +import base64 +import httpx +from config import SILICONFLOW_API_KEY, SILICONFLOW_BASE_URL + + +async def download_file(url: str, local_path: str) -> str: + """从 COS 预签名 URL 下载文件到本地""" + async with httpx.AsyncClient(timeout=120, follow_redirects=True) as client: + resp = await client.get(url) + resp.raise_for_status() + os.makedirs(os.path.dirname(local_path), exist_ok=True) + with open(local_path, "wb") as f: + f.write(resp.content) + return local_path + + +def parse_txt(file_path: str) -> str: + with open(file_path, "r", encoding="utf-8", errors="replace") as f: + return f.read() + + +def parse_markdown(file_path: str) -> str: + return parse_txt(file_path) + + +def parse_docx(file_path: str) -> str: + from docx import Document + doc = Document(file_path) + return "\n\n".join(p.text for p in doc.paragraphs if p.text.strip()) + + +def parse_pdf_text(file_path: str) -> str: + """用 PyMuPDF 提取 PDF 文本层""" + import fitz + doc = fitz.open(file_path) + pages = [] + for page in doc: + text = page.get_text() + if text.strip(): + pages.append(text) + doc.close() + return "\n\n".join(pages) + + +def pdf_needs_ocr(file_path: str) -> bool: + """判断 PDF 是否需要 OCR(文本层为空或极少文字)""" + import fitz + doc = fitz.open(file_path) + total_len = sum(len(page.get_text().strip()) for page in doc) + doc.close() + # 平均每页少于 50 字符 → 扫描件 + page_count = max(doc.page_count if hasattr(doc, 'page_count') else len(doc), 1) + return (total_len / page_count) < 50 + + +async def ocr_with_siliconflow(image_bytes: bytes) -> str: + """用硅基流动多模态模型做 OCR / 图文识别""" + b64 = base64.b64encode(image_bytes).decode() + async with httpx.AsyncClient(timeout=60) as client: + resp = await client.post( + f"{SILICONFLOW_BASE_URL}/chat/completions", + headers={"Authorization": f"Bearer {SILICONFLOW_API_KEY}"}, + json={ + "model": "Qwen/Qwen3-VL-32B-Instruct", + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": "请识别并提取这张图片中的所有文字内容。如果有表格,请用 Markdown 表格格式输出。不要添加任何解释。"}, + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64}"}}, + ], + }], + "max_tokens": 4096, + }, + ) + data = resp.json() + return data["choices"][0]["message"]["content"] + + +async def parse_image_with_ocr(file_path: str) -> str: + """对图片进行 OCR""" + with open(file_path, "rb") as f: + image_bytes = f.read() + return await ocr_with_siliconflow(image_bytes) + + +def parse_csv(file_path: str) -> str: + import pandas as pd + df = pd.read_csv(file_path) + return df.to_markdown(index=False) + + +def parse_xlsx(file_path: str) -> str: + import pandas as pd + df = pd.read_excel(file_path) + return df.to_markdown(index=False) + + +async def parse_document(file_path: str, mime_type: str) -> str: + """根据文件类型路由到合适的解析器""" + ext = os.path.splitext(file_path)[1].lower() + + if ext in (".txt",): + return parse_txt(file_path) + elif ext in (".md", ".markdown"): + return parse_markdown(file_path) + elif ext in (".docx",): + return parse_docx(file_path) + elif ext in (".csv",): + return parse_csv(file_path) + elif ext in (".xlsx",): + return parse_xlsx(file_path) + elif ext in (".pdf",): + if pdf_needs_ocr(file_path): + # 扫描件——先尝试文本提取,空则走多模态 + text = parse_pdf_text(file_path) + if len(text.strip()) < 100: + # 全扫描件,逐页 OCR + import fitz + doc = fitz.open(file_path) + results = [] + for i, page in enumerate(doc): + pix = page.get_pixmap(dpi=150) + img_bytes = pix.tobytes("png") + page_text = await ocr_with_siliconflow(img_bytes) + results.append(page_text) + doc.close() + return "\n\n".join(results) + return text + return parse_pdf_text(file_path) + elif ext in (".png", ".jpg", ".jpeg", ".webp", ".heic", ".bmp"): + return await parse_image_with_ocr(file_path) + else: + raise ValueError(f"不支持的文件类型: {ext}") diff --git a/rag-worker/pipelines/import_pipeline.py b/rag-worker/pipelines/import_pipeline.py new file mode 100644 index 0000000..d23058e --- /dev/null +++ b/rag-worker/pipelines/import_pipeline.py @@ -0,0 +1,127 @@ +"""导入主流程:下载 → 解析 → 清洗 → 切片 → embedding → Qdrant → AI 候选""" + +import os +import uuid +from parser import download_file, parse_document +from chunker import chunk_document +from embedder import embed_batch +from indexer import upsert_points +from candidate_generator import generate_candidates +from api_client import ( + heartbeat as send_heartbeat, + update_job_status, + save_chunks, + save_candidates, + get_job_detail, +) + + +async def run_import(job: dict): + """执行完整的文档导入流程""" + job_id = job["id"] + source_id = job.get("sourceId") or job.get("source_id") + user_id = job["userId"] or job.get("user_id") + kb_id = job["knowledgeBaseId"] or job.get("knowledge_base_id") + file_id = job.get("fileId") or job.get("file_id") + + if not source_id: + raise ValueError(f"任务 {job_id} 缺少 sourceId") + + # 获取 source 详情(从 NestJS) + detail = await get_job_detail(job_id) + source = (detail or {}).get("source", {}) if detail else {} + mime_type = source.get("mimeType") or source.get("mime_type") or "text/plain" + original_filename = source.get("originalFilename") or source.get("original_filename") or "unknown" + + tmp_dir = f"/data/tmp/imports/{job_id}" + file_path = os.path.join(tmp_dir, original_filename) + + try: + # 1. 下载文件 + await update_job_status(job_id, "DOWNLOADING", {"progress": 5}) + file_url = source.get("downloadUrl") or (detail or {}).get("downloadUrl", "") + if file_url: + await download_file(file_url, file_path) + + # 2. 解析 + await update_job_status(job_id, "PARSING", {"progress": 20}) + text = await parse_document(file_path, mime_type) + + # 如果文件不在本地(纯文本导入),直接从 source/import 中取文本 + if not text and (job.get("rawText") or source.get("rawText")): + text = job.get("rawText", "") or source.get("rawText", "") + + if not text or len(text.strip()) < 10: + raise ValueError("文档解析后内容过少,可能为空白或损坏文件") + + # 3. 清洗 + await update_job_status(job_id, "CLEANING", {"progress": 40, "textLength": len(text)}) + + # 4. 切片 + await update_job_status(job_id, "CHUNKING", {"progress": 50}) + source_type = source.get("type") or "text" + chunks = chunk_document(text, source_type) + + # 5. Embedding + await update_job_status(job_id, "EMBEDDING", {"progress": 60}) + texts = [c["content"] for c in chunks] + vectors = await embed_batch(texts) + + # 6. Qdrant 索引 + await update_job_status(job_id, "INDEXING", {"progress": 80}) + points = [] + chunk_records = [] + for i, (chunk, vec) in enumerate(zip(chunks, vectors)): + chunk_id = f"chunk_{source_id}_{i}" + points.append({ + "id": chunk_id, + "vector": vec, + "payload": { + "userId": user_id, + "knowledgeBaseId": kb_id, + "sourceId": source_id, + "chunkId": chunk_id, + "pageNumber": chunk.get("pageNumber"), + "sectionTitle": chunk.get("sectionTitle", ""), + "deleted": False, + }, + }) + chunk_records.append({ + "userId": user_id, + "knowledgeBaseId": kb_id, + "sourceId": source_id, + "content": chunk["content"], + "chunkIndex": chunk["chunkIndex"], + "pageNumber": chunk.get("pageNumber"), + "sectionTitle": chunk.get("sectionTitle", ""), + "tokenCount": len(chunk["content"]), + "externalVectorId": chunk_id, + "embeddingModel": "bge-m3", + "embeddingStatus": "COMPLETED", + "metadataJson": {"chunkType": chunk.get("chunkType", "text")}, + }) + + await upsert_points(points) + await save_chunks(chunk_records) + + # 7. 生成候选知识点 + await update_job_status(job_id, "GENERATING_CANDIDATES", {"progress": 90}) + candidates = await generate_candidates(text) + if candidates: + await save_candidates(user_id, kb_id, source_id, job_id, candidates) + + # 8. 完成 + await update_job_status(job_id, "COMPLETED", {"progress": 100}) + + except Exception as e: + await update_job_status(job_id, "FAILED_RETRYABLE", { + "errorCode": "WORKER_ERROR", + "errorMessage": str(e)[:500], + }) + raise + + finally: + # 清理临时文件 + if os.path.exists(tmp_dir): + import shutil + shutil.rmtree(tmp_dir, ignore_errors=True) diff --git a/rag-worker/requirements.txt b/rag-worker/requirements.txt new file mode 100644 index 0000000..2e426d4 --- /dev/null +++ b/rag-worker/requirements.txt @@ -0,0 +1,8 @@ +httpx>=0.27 +pydantic>=2.0 +pymupdf>=1.24 +python-docx>=1.1 +markdown>=3.5 +pandas>=2.0 +openpyxl>=3.1 +Pillow>=10.0 diff --git a/src/app.module.ts b/src/app.module.ts index 42e941a..dbef847 100644 --- a/src/app.module.ts +++ b/src/app.module.ts @@ -28,6 +28,7 @@ import { FilesModule } from './modules/files/files.module'; import { WaitlistModule } from './modules/waitlist/waitlist.module'; import { KnowledgeSourceModule } from './modules/knowledge-source/knowledge-source.module'; import { ImportCandidateModule } from './modules/import-candidate/import-candidate.module'; +import { RagModule } from './modules/rag/rag.module'; import { JwtAuthGuard } from './common/guards/jwt-auth.guard'; import { RolesGuard } from './common/guards/roles.guard'; @@ -85,6 +86,7 @@ import appleConfig from './config/apple.config'; KnowledgeSourceModule, ImportCandidateModule, DocumentImportModule, + RagModule, LearningSessionModule, ActiveRecallModule, AiAnalysisModule, diff --git a/src/modules/rag/internal-rag.controller.ts b/src/modules/rag/internal-rag.controller.ts new file mode 100644 index 0000000..5756d9e --- /dev/null +++ b/src/modules/rag/internal-rag.controller.ts @@ -0,0 +1,124 @@ +import { Controller, Get, Post, Body, Param } from '@nestjs/common'; +import { ApiTags } from '@nestjs/swagger'; +import { DocumentImportRepository } from '../document-import/document-import.repository'; +import { KnowledgeSourceRepository } from '../knowledge-source/knowledge-source.repository'; +import { ImportCandidateRepository } from '../import-candidate/import-candidate.repository'; +import { PrismaService } from '../../infrastructure/database/prisma.service'; + +@ApiTags('internal-rag') +@Controller('internal/rag') +export class InternalRagController { + constructor( + private readonly importRepo: DocumentImportRepository, + private readonly sourceRepo: KnowledgeSourceRepository, + private readonly candidateRepo: ImportCandidateRepository, + private readonly prisma: PrismaService, + ) {} + + @Get('jobs/next') + async getNextJob() { + const job = await this.importRepo.claimNext(''); // 先查询,不认领 + if (!job) return { job: null }; + return { + job: { + id: job.id, + userId: job.userId, + knowledgeBaseId: job.knowledgeBaseId, + sourceId: job.sourceId, + fileId: job.fileId, + sourceType: job.sourceType, + sourceName: job.sourceName, + rawText: job.rawText, + status: job.status, + }, + }; + } + + @Get('jobs/:id') + async getJobDetail(@Param('id') id: string) { + const job = await this.importRepo.findById(id); + if (!job) return { job: null }; + + let source = null; + let downloadUrl = null; + if (job.sourceId) { + source = await this.sourceRepo.findById(job.sourceId); + } + + return { + job: { + id: job.id, + userId: job.userId, + knowledgeBaseId: job.knowledgeBaseId, + sourceId: job.sourceId, + fileId: job.fileId, + rawText: job.rawText, + status: job.status, + }, + source: source ? { + id: source.id, + type: source.type, + originalFilename: source.originalFilename, + mimeType: source.mimeType, + sizeBytes: Number(source.sizeBytes), + originalObjectKey: source.originalObjectKey, + } : null, + }; + } + + @Post('jobs/:id/claim') + async claimJob(@Param('id') id: string, @Body() body: { workerId?: string }) { + const workerId = body.workerId || 'unknown'; + const result = await this.importRepo.claim(id, workerId); + return { success: result.count > 0 }; + } + + @Post('jobs/:id/heartbeat') + async heartbeat(@Param('id') id: string) { + await this.importRepo.heartbeat(id); + return { success: true }; + } + + @Post('jobs/:id/status') + async updateStatus( + @Param('id') id: string, + @Body() body: { status: string; progress?: number; errorCode?: string; errorMessage?: string }, + ) { + await this.importRepo.updateStatus(id, body.status, { + step: body.status, + progress: body.progress, + errorCode: body.errorCode, + errorMessage: body.errorMessage, + }); + return { success: true }; + } + + @Post('chunks') + async saveChunks(@Body() body: { chunks: any[] }) { + const chunks = body.chunks || []; + if (chunks.length > 0) { + await this.prisma.knowledgeChunk.createMany({ data: chunks }); + } + return { success: true, count: chunks.length }; + } + + @Post('candidates') + async saveCandidates( + @Body() body: { + userId: string; + knowledgeBaseId: string; + sourceId: string; + importId: string; + candidates: any[]; + }, + ) { + await this.candidateRepo.createMany( + body.userId, + body.knowledgeBaseId, + body.sourceId, + body.importId, + body.candidates || [], + ); + return { success: true, count: body.candidates?.length || 0 }; + } +} diff --git a/src/modules/rag/rag.module.ts b/src/modules/rag/rag.module.ts new file mode 100644 index 0000000..1db29d3 --- /dev/null +++ b/src/modules/rag/rag.module.ts @@ -0,0 +1,11 @@ +import { Module } from '@nestjs/common'; +import { InternalRagController } from './internal-rag.controller'; +import { DocumentImportModule } from '../document-import/document-import.module'; +import { KnowledgeSourceModule } from '../knowledge-source/knowledge-source.module'; +import { ImportCandidateModule } from '../import-candidate/import-candidate.module'; + +@Module({ + imports: [DocumentImportModule, KnowledgeSourceModule, ImportCandidateModule], + controllers: [InternalRagController], +}) +export class RagModule {}