feat: Python RAG Worker + NestJS 内部 API(文档解析/切片/embedding/Qdrant/候选生成)
Some checks failed
Deploy API Server / build-and-deploy (push) Failing after 22s
Some checks failed
Deploy API Server / build-and-deploy (push) Failing after 22s
This commit is contained in:
parent
c149b96b04
commit
fbdae9078f
94
rag-worker/api_client.py
Normal file
94
rag-worker/api_client.py
Normal file
@ -0,0 +1,94 @@
|
||||
import httpx
|
||||
from config import API_BASE_URL, RAG_WORKER_SECRET, WORKER_ID
|
||||
|
||||
_auth_headers = {
|
||||
"Authorization": f"Bearer {RAG_WORKER_SECRET}",
|
||||
"X-Worker-Id": WORKER_ID,
|
||||
}
|
||||
|
||||
|
||||
async def get_next_job() -> dict | None:
|
||||
"""获取下一个 QUEUED 导入任务"""
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
resp = await client.get(
|
||||
f"{API_BASE_URL}/internal/rag/jobs/next",
|
||||
headers=_auth_headers,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
return data.get("data") or data.get("job")
|
||||
return None
|
||||
|
||||
|
||||
async def claim_job(job_id: str) -> bool:
|
||||
"""认领任务"""
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
resp = await client.post(
|
||||
f"{API_BASE_URL}/internal/rag/jobs/{job_id}/claim",
|
||||
headers=_auth_headers,
|
||||
)
|
||||
return resp.status_code == 200
|
||||
|
||||
|
||||
async def heartbeat(job_id: str) -> bool:
|
||||
"""发送心跳"""
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
resp = await client.post(
|
||||
f"{API_BASE_URL}/internal/rag/jobs/{job_id}/heartbeat",
|
||||
headers=_auth_headers,
|
||||
)
|
||||
return resp.status_code == 200
|
||||
|
||||
|
||||
async def update_job_status(job_id: str, status: str, data: dict | None = None):
|
||||
"""更新导入任务状态"""
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
await client.post(
|
||||
f"{API_BASE_URL}/internal/rag/jobs/{job_id}/status",
|
||||
headers=_auth_headers,
|
||||
json={"status": status, **(data or {})},
|
||||
)
|
||||
|
||||
|
||||
async def save_chunks(chunks: list[dict]):
|
||||
"""批量保存 KnowledgeChunk"""
|
||||
async with httpx.AsyncClient(timeout=60) as client:
|
||||
await client.post(
|
||||
f"{API_BASE_URL}/internal/rag/chunks",
|
||||
headers=_auth_headers,
|
||||
json={"chunks": chunks},
|
||||
)
|
||||
|
||||
|
||||
async def save_candidates(
|
||||
user_id: str,
|
||||
kb_id: str,
|
||||
source_id: str,
|
||||
import_id: str,
|
||||
candidates: list[dict],
|
||||
):
|
||||
"""保存候选知识点"""
|
||||
async with httpx.AsyncClient(timeout=60) as client:
|
||||
await client.post(
|
||||
f"{API_BASE_URL}/internal/rag/candidates",
|
||||
headers=_auth_headers,
|
||||
json={
|
||||
"userId": user_id,
|
||||
"knowledgeBaseId": kb_id,
|
||||
"sourceId": source_id,
|
||||
"importId": import_id,
|
||||
"candidates": candidates,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def get_job_detail(job_id: str) -> dict | None:
|
||||
"""获取任务详情(含 source 信息)"""
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
resp = await client.get(
|
||||
f"{API_BASE_URL}/internal/rag/jobs/{job_id}",
|
||||
headers=_auth_headers,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
return resp.json()
|
||||
return None
|
||||
103
rag-worker/candidate_generator.py
Normal file
103
rag-worker/candidate_generator.py
Normal file
@ -0,0 +1,103 @@
|
||||
"""候选知识点生成:调用 DeepSeek 分析文本,生成 ImportCandidate"""
|
||||
|
||||
import json
|
||||
import httpx
|
||||
from config import DEEPSEEK_API_KEY, DEEPSEEK_BASE_URL, DEEPSEEK_MODEL
|
||||
|
||||
MAX_CANDIDATES = 30
|
||||
MIN_CANDIDATES = 3
|
||||
CHARS_PER_CANDIDATE = 2000
|
||||
|
||||
_PROMPT = """你是一个学习助手。请分析以下文档内容,提取关键知识点。
|
||||
|
||||
对于每个知识点,请提供:
|
||||
- title: 知识点标题(简洁,不超过 30 字)
|
||||
- summary: 一句话概述(不超过 80 字)
|
||||
- content: 详细解释(基于原文,保持准确)
|
||||
- tags: 2-4 个标签
|
||||
- recallQuestions: 1-2 个主动回忆问题
|
||||
- difficulty: 难度评估(easy/medium/hard)
|
||||
- confidence: 你对这个知识点重要性的置信度(0.0-1.0)
|
||||
|
||||
请以 JSON 数组格式返回,每个元素是一个知识点:
|
||||
```json
|
||||
[{
|
||||
"title": "知识点标题",
|
||||
"summary": "一句话概述",
|
||||
"content": "详细解释...",
|
||||
"tags": ["标签1", "标签2"],
|
||||
"recallQuestions": ["问题1?", "问题2?"],
|
||||
"difficulty": "medium",
|
||||
"confidence": 0.85
|
||||
}]
|
||||
```
|
||||
|
||||
文档内容:
|
||||
{text}
|
||||
"""
|
||||
|
||||
|
||||
async def generate_candidates(text: str) -> list[dict]:
|
||||
"""用 DeepSeek 生成候选知识点"""
|
||||
# 估算生成数量
|
||||
text_len = len(text)
|
||||
expected_count = max(MIN_CANDIDATES, min(MAX_CANDIDATES, text_len // CHARS_PER_CANDIDATE))
|
||||
|
||||
prompt = _PROMPT.format(text=text[:16000]) # 限制上下文长度
|
||||
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
resp = await client.post(
|
||||
f"{DEEPSEEK_BASE_URL}/chat/completions",
|
||||
headers={"Authorization": f"Bearer {DEEPSEEK_API_KEY}"},
|
||||
json={
|
||||
"model": DEEPSEEK_MODEL,
|
||||
"messages": [
|
||||
{"role": "system", "content": "你是一个专业的学习内容分析师。请始终返回有效的 JSON 数组。"},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 4096,
|
||||
},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
raise RuntimeError(f"DeepSeek API error: {resp.status_code} {resp.text}")
|
||||
|
||||
data = resp.json()
|
||||
raw = data["choices"][0]["message"]["content"]
|
||||
|
||||
# 提取 JSON
|
||||
return _parse_json_response(raw, expected_count)
|
||||
|
||||
|
||||
def _parse_json_response(raw: str, expected_count: int) -> list[dict]:
|
||||
"""从 AI 回复中提取 JSON 数组"""
|
||||
# 尝试直接解析
|
||||
try:
|
||||
candidates = json.loads(raw)
|
||||
if isinstance(candidates, list):
|
||||
return candidates[:MAX_CANDIDATES]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 提取 ```json ... ``` 块
|
||||
import re
|
||||
m = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", raw, re.DOTALL)
|
||||
if m:
|
||||
try:
|
||||
candidates = json.loads(m.group(1))
|
||||
if isinstance(candidates, list):
|
||||
return candidates[:MAX_CANDIDATES]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 提取 [ ... ] 块
|
||||
m = re.search(r"\[.*\]", raw, re.DOTALL)
|
||||
if m:
|
||||
try:
|
||||
candidates = json.loads(m.group(0))
|
||||
if isinstance(candidates, list):
|
||||
return candidates[:MAX_CANDIDATES]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
raise ValueError(f"无法解析 AI 候选知识点回复: {raw[:500]}")
|
||||
120
rag-worker/chunker.py
Normal file
120
rag-worker/chunker.py
Normal file
@ -0,0 +1,120 @@
|
||||
"""文本切片:递归字符分割 + 中文分句保护"""
|
||||
|
||||
import re
|
||||
from config import CHUNK_SIZE, CHUNK_OVERLAP
|
||||
|
||||
|
||||
# 中文分句模式
|
||||
_CN_SENT_PATTERN = re.compile(
|
||||
r"([。!?;\n]|(?<!\d)\.(?!\d)|!\?|\?!)"
|
||||
)
|
||||
# Markdown 标题
|
||||
_MD_HEADING = re.compile(r"^#{1,6}\s+", re.MULTILINE)
|
||||
|
||||
|
||||
def _split_sentences(text: str) -> list[str]:
|
||||
"""按中文标点分句,保留标点在句尾"""
|
||||
parts = _CN_SENT_PATTERN.split(text)
|
||||
sentences = []
|
||||
buf = ""
|
||||
for p in parts:
|
||||
if not p:
|
||||
continue
|
||||
buf += p
|
||||
if _CN_SENT_PATTERN.match(p):
|
||||
sentences.append(buf)
|
||||
buf = ""
|
||||
if buf.strip():
|
||||
sentences.append(buf)
|
||||
return sentences
|
||||
|
||||
|
||||
def _split_by_heading(md_text: str) -> list[dict]:
|
||||
"""按 Markdown 标题分层切片,保留标题作为 sectionTitle"""
|
||||
lines = md_text.split("\n")
|
||||
chunks = []
|
||||
current_title = ""
|
||||
current_text = ""
|
||||
|
||||
for line in lines:
|
||||
m = _MD_HEADING.match(line)
|
||||
if m:
|
||||
# 保存前一段
|
||||
if current_text.strip():
|
||||
chunks.append({"sectionTitle": current_title, "text": current_text.strip()})
|
||||
current_title = line.strip()
|
||||
current_text = ""
|
||||
else:
|
||||
current_text += line + "\n"
|
||||
|
||||
if current_text.strip():
|
||||
chunks.append({"sectionTitle": current_title, "text": current_text.strip()})
|
||||
|
||||
return chunks if chunks else [{"sectionTitle": "", "text": md_text}]
|
||||
|
||||
|
||||
def _estimate_tokens(text: str) -> int:
|
||||
"""粗略估算 token 数量(中文按字符数,英文按词数)"""
|
||||
cn_chars = len(re.findall(r"[一-鿿]", text))
|
||||
en_words = len(re.findall(r"[a-zA-Z]+", text))
|
||||
# 中文约 1.5 字符/token,英文约 1 词/token
|
||||
return int(cn_chars / 1.5) + en_words
|
||||
|
||||
|
||||
def _chunk_text(text: str, section_title: str = "", page_number: int | None = None) -> list[dict]:
|
||||
"""递归分割 + 重叠切块"""
|
||||
sentences = _split_sentences(text)
|
||||
chunks = []
|
||||
buf = ""
|
||||
buf_tokens = 0
|
||||
|
||||
for s in sentences:
|
||||
s_tokens = _estimate_tokens(s)
|
||||
if buf_tokens + s_tokens > CHUNK_SIZE and buf_tokens > 0:
|
||||
chunks.append({"content": buf.strip(), "sectionTitle": section_title, "pageNumber": page_number})
|
||||
# 重叠:保留最后 overlap tokens
|
||||
if CHUNK_OVERLAP > 0:
|
||||
overlap_text = buf[-int(CHUNK_OVERLAP * 2):] # 粗略估算
|
||||
buf = overlap_text + s
|
||||
buf_tokens = _estimate_tokens(overlap_text) + s_tokens
|
||||
else:
|
||||
buf = s
|
||||
buf_tokens = s_tokens
|
||||
else:
|
||||
buf += s
|
||||
buf_tokens += s_tokens
|
||||
|
||||
if buf.strip():
|
||||
chunks.append({"content": buf.strip(), "sectionTitle": section_title, "pageNumber": page_number})
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def chunk_document(text: str, source_type: str = "text") -> list[dict]:
|
||||
"""
|
||||
对文档进行切片,返回 chunk 列表。
|
||||
每个 chunk: {content, sectionTitle, pageNumber, chunkType}
|
||||
"""
|
||||
if source_type in ("md", "markdown"):
|
||||
sections = _split_by_heading(text)
|
||||
else:
|
||||
sections = [{"sectionTitle": "", "text": text}]
|
||||
|
||||
all_chunks = []
|
||||
for sec in sections:
|
||||
sec_chunks = _chunk_text(sec["text"], section_title=sec.get("sectionTitle", ""))
|
||||
all_chunks.extend(sec_chunks)
|
||||
|
||||
# 添加 chunkType
|
||||
for i, c in enumerate(all_chunks):
|
||||
c["chunkIndex"] = i
|
||||
# 检测表格/代码块
|
||||
content = c["content"]
|
||||
if content.count("|") > 5 and "---" in content:
|
||||
c["chunkType"] = "table"
|
||||
elif content.strip().startswith("```") or "```" in content:
|
||||
c["chunkType"] = "code"
|
||||
else:
|
||||
c["chunkType"] = "text"
|
||||
|
||||
return all_chunks
|
||||
29
rag-worker/config.py
Normal file
29
rag-worker/config.py
Normal file
@ -0,0 +1,29 @@
|
||||
import os
|
||||
|
||||
# NestJS 内部 API
|
||||
API_BASE_URL = os.getenv("API_BASE_URL", "http://127.0.0.1:3000")
|
||||
RAG_WORKER_SECRET = os.getenv("RAG_WORKER_SECRET", "")
|
||||
|
||||
# SiliconFlow
|
||||
SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY", "")
|
||||
SILICONFLOW_BASE_URL = os.getenv("SILICONFLOW_BASE_URL", "https://api.siliconflow.cn/v1")
|
||||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "BAAI/bge-m3")
|
||||
EMBEDDING_DIM = int(os.getenv("EMBEDDING_DIM", "1024"))
|
||||
|
||||
# DeepSeek
|
||||
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY", "")
|
||||
DEEPSEEK_BASE_URL = os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com/v1")
|
||||
DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL", "deepseek-chat")
|
||||
|
||||
# Qdrant
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
||||
QDRANT_COLLECTION = os.getenv("QDRANT_COLLECTION", "zhixi_chunks")
|
||||
|
||||
# Chunking
|
||||
CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "512"))
|
||||
CHUNK_OVERLAP = int(os.getenv("CHUNK_OVERLAP", "64"))
|
||||
|
||||
# Worker
|
||||
WORKER_ID = os.getenv("WORKER_ID", f"worker-{os.getpid()}")
|
||||
POLL_INTERVAL = int(os.getenv("POLL_INTERVAL", "5"))
|
||||
HEARTBEAT_INTERVAL = int(os.getenv("HEARTBEAT_INTERVAL", "30"))
|
||||
63
rag-worker/embedder.py
Normal file
63
rag-worker/embedder.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""Embedding 服务:调用硅基流动 bge-m3"""
|
||||
|
||||
import asyncio
|
||||
import httpx
|
||||
from config import (
|
||||
SILICONFLOW_API_KEY,
|
||||
SILICONFLOW_BASE_URL,
|
||||
EMBEDDING_MODEL,
|
||||
EMBEDDING_DIM,
|
||||
)
|
||||
|
||||
BATCH_SIZE = 50
|
||||
MAX_RETRIES = 2
|
||||
|
||||
|
||||
async def embed_single(text: str) -> list[float]:
|
||||
"""单条文本 embedding"""
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
resp = await client.post(
|
||||
f"{SILICONFLOW_BASE_URL}/embeddings",
|
||||
headers={"Authorization": f"Bearer {SILICONFLOW_API_KEY}"},
|
||||
json={
|
||||
"model": EMBEDDING_MODEL,
|
||||
"input": [text],
|
||||
},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
raise RuntimeError(f"Embedding API error: {resp.status_code} {resp.text}")
|
||||
data = resp.json()
|
||||
return data["data"][0]["embedding"]
|
||||
|
||||
|
||||
async def embed_batch(texts: list[str]) -> list[list[float]]:
|
||||
"""批量 embedding,自动分批 + 重试"""
|
||||
all_embeddings = []
|
||||
|
||||
for i in range(0, len(texts), BATCH_SIZE):
|
||||
batch = texts[i:i + BATCH_SIZE]
|
||||
for attempt in range(MAX_RETRIES + 1):
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60) as client:
|
||||
resp = await client.post(
|
||||
f"{SILICONFLOW_BASE_URL}/embeddings",
|
||||
headers={"Authorization": f"Bearer {SILICONFLOW_API_KEY}"},
|
||||
json={
|
||||
"model": EMBEDDING_MODEL,
|
||||
"input": batch,
|
||||
},
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
all_embeddings.extend([d["embedding"] for d in data["data"]])
|
||||
break
|
||||
else:
|
||||
err = f"Status {resp.status_code}"
|
||||
if attempt == MAX_RETRIES:
|
||||
raise RuntimeError(f"Embedding batch failed after {MAX_RETRIES} retries: {err}")
|
||||
except Exception as e:
|
||||
if attempt == MAX_RETRIES:
|
||||
raise RuntimeError(f"Embedding batch failed: {e}")
|
||||
await asyncio.sleep(2 ** attempt)
|
||||
|
||||
return all_embeddings
|
||||
60
rag-worker/indexer.py
Normal file
60
rag-worker/indexer.py
Normal file
@ -0,0 +1,60 @@
|
||||
"""Qdrant 索引服务"""
|
||||
|
||||
import httpx
|
||||
from config import QDRANT_URL, QDRANT_COLLECTION
|
||||
|
||||
|
||||
async def upsert_points(points: list[dict]):
|
||||
"""批量写入 Qdrant points"""
|
||||
async with httpx.AsyncClient(timeout=60) as client:
|
||||
resp = await client.put(
|
||||
f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}/points",
|
||||
params={"wait": "true"},
|
||||
json={"points": points},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
raise RuntimeError(f"Qdrant upsert failed: {resp.text}")
|
||||
|
||||
|
||||
async def search(
|
||||
vector: list[float],
|
||||
user_id: str,
|
||||
knowledge_base_id: str,
|
||||
top_k: int = 5,
|
||||
) -> list[dict]:
|
||||
"""语义检索"""
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
resp = await client.post(
|
||||
f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}/points/search",
|
||||
json={
|
||||
"vector": vector,
|
||||
"filter": {
|
||||
"must": [
|
||||
{"key": "userId", "match": {"value": user_id}},
|
||||
{"key": "knowledgeBaseId", "match": {"value": knowledge_base_id}},
|
||||
{"key": "deleted", "match": {"value": False}},
|
||||
],
|
||||
},
|
||||
"limit": top_k,
|
||||
"with_payload": True,
|
||||
},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
raise RuntimeError(f"Qdrant search failed: {resp.text}")
|
||||
return resp.json()["result"]
|
||||
|
||||
|
||||
async def mark_deleted(source_id: str):
|
||||
"""将指定 source 的所有 points 标记为 deleted=true"""
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
await client.post(
|
||||
f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}/points/update",
|
||||
json={
|
||||
"filter": {
|
||||
"must": [
|
||||
{"key": "sourceId", "match": {"value": source_id}},
|
||||
]
|
||||
},
|
||||
"set": {"deleted": True},
|
||||
},
|
||||
)
|
||||
84
rag-worker/main.py
Normal file
84
rag-worker/main.py
Normal file
@ -0,0 +1,84 @@
|
||||
"""知习 RAG Worker — 文档导入主进程"""
|
||||
|
||||
import asyncio
|
||||
import signal
|
||||
import sys
|
||||
from config import WORKER_ID, POLL_INTERVAL, HEARTBEAT_INTERVAL
|
||||
from api_client import get_next_job, claim_job, heartbeat, update_job_status
|
||||
from pipelines.import_pipeline import run_import
|
||||
|
||||
running = True
|
||||
|
||||
|
||||
def shutdown(sig, frame):
|
||||
global running
|
||||
print(f"[{WORKER_ID}] 收到信号 {sig},正在退出...")
|
||||
running = False
|
||||
|
||||
|
||||
signal.signal(signal.SIGINT, shutdown)
|
||||
signal.signal(signal.SIGTERM, shutdown)
|
||||
|
||||
|
||||
async def heartbeat_loop():
|
||||
"""心跳循环(所有活跃任务)"""
|
||||
# 简化实现:worker 级心跳,后续可扩展到 per-job 心跳
|
||||
while running:
|
||||
await asyncio.sleep(HEARTBEAT_INTERVAL)
|
||||
|
||||
|
||||
async def work_loop():
|
||||
"""主工作循环:轮询 → 认领 → 执行"""
|
||||
print(f"[{WORKER_ID}] RAG Worker 已启动")
|
||||
|
||||
while running:
|
||||
try:
|
||||
job = await get_next_job()
|
||||
if not job:
|
||||
await asyncio.sleep(POLL_INTERVAL)
|
||||
continue
|
||||
|
||||
job_id = job.get("id") or job.get("jobId")
|
||||
if not job_id:
|
||||
continue
|
||||
|
||||
# 认领任务
|
||||
claimed = await claim_job(job_id)
|
||||
if not claimed:
|
||||
continue
|
||||
|
||||
print(f"[{WORKER_ID}] 开始处理任务 {job_id}")
|
||||
|
||||
# 启动心跳(后台任务)
|
||||
hb_task = asyncio.create_task(_per_job_heartbeat(job_id))
|
||||
|
||||
try:
|
||||
await run_import(job)
|
||||
print(f"[{WORKER_ID}] 任务 {job_id} 完成")
|
||||
except Exception as e:
|
||||
print(f"[{WORKER_ID}] 任务 {job_id} 失败: {e}")
|
||||
await update_job_status(job_id, "FAILED_RETRYABLE", {
|
||||
"errorMessage": str(e)[:500],
|
||||
})
|
||||
finally:
|
||||
hb_task.cancel()
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{WORKER_ID}] 轮询异常: {e}")
|
||||
await asyncio.sleep(POLL_INTERVAL)
|
||||
|
||||
print(f"[{WORKER_ID}] Worker 已停止")
|
||||
|
||||
|
||||
async def _per_job_heartbeat(job_id: str):
|
||||
"""单个任务的心跳上报"""
|
||||
while running:
|
||||
try:
|
||||
await heartbeat(job_id)
|
||||
except Exception:
|
||||
pass
|
||||
await asyncio.sleep(HEARTBEAT_INTERVAL)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(work_loop())
|
||||
137
rag-worker/parser.py
Normal file
137
rag-worker/parser.py
Normal file
@ -0,0 +1,137 @@
|
||||
"""文档解析:PDF / DOCX / TXT / MD / CSV / XLSX"""
|
||||
|
||||
import os
|
||||
import io
|
||||
import base64
|
||||
import httpx
|
||||
from config import SILICONFLOW_API_KEY, SILICONFLOW_BASE_URL
|
||||
|
||||
|
||||
async def download_file(url: str, local_path: str) -> str:
|
||||
"""从 COS 预签名 URL 下载文件到本地"""
|
||||
async with httpx.AsyncClient(timeout=120, follow_redirects=True) as client:
|
||||
resp = await client.get(url)
|
||||
resp.raise_for_status()
|
||||
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||||
with open(local_path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
return local_path
|
||||
|
||||
|
||||
def parse_txt(file_path: str) -> str:
|
||||
with open(file_path, "r", encoding="utf-8", errors="replace") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def parse_markdown(file_path: str) -> str:
|
||||
return parse_txt(file_path)
|
||||
|
||||
|
||||
def parse_docx(file_path: str) -> str:
|
||||
from docx import Document
|
||||
doc = Document(file_path)
|
||||
return "\n\n".join(p.text for p in doc.paragraphs if p.text.strip())
|
||||
|
||||
|
||||
def parse_pdf_text(file_path: str) -> str:
|
||||
"""用 PyMuPDF 提取 PDF 文本层"""
|
||||
import fitz
|
||||
doc = fitz.open(file_path)
|
||||
pages = []
|
||||
for page in doc:
|
||||
text = page.get_text()
|
||||
if text.strip():
|
||||
pages.append(text)
|
||||
doc.close()
|
||||
return "\n\n".join(pages)
|
||||
|
||||
|
||||
def pdf_needs_ocr(file_path: str) -> bool:
|
||||
"""判断 PDF 是否需要 OCR(文本层为空或极少文字)"""
|
||||
import fitz
|
||||
doc = fitz.open(file_path)
|
||||
total_len = sum(len(page.get_text().strip()) for page in doc)
|
||||
doc.close()
|
||||
# 平均每页少于 50 字符 → 扫描件
|
||||
page_count = max(doc.page_count if hasattr(doc, 'page_count') else len(doc), 1)
|
||||
return (total_len / page_count) < 50
|
||||
|
||||
|
||||
async def ocr_with_siliconflow(image_bytes: bytes) -> str:
|
||||
"""用硅基流动多模态模型做 OCR / 图文识别"""
|
||||
b64 = base64.b64encode(image_bytes).decode()
|
||||
async with httpx.AsyncClient(timeout=60) as client:
|
||||
resp = await client.post(
|
||||
f"{SILICONFLOW_BASE_URL}/chat/completions",
|
||||
headers={"Authorization": f"Bearer {SILICONFLOW_API_KEY}"},
|
||||
json={
|
||||
"model": "Qwen/Qwen3-VL-32B-Instruct",
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "请识别并提取这张图片中的所有文字内容。如果有表格,请用 Markdown 表格格式输出。不要添加任何解释。"},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64}"}},
|
||||
],
|
||||
}],
|
||||
"max_tokens": 4096,
|
||||
},
|
||||
)
|
||||
data = resp.json()
|
||||
return data["choices"][0]["message"]["content"]
|
||||
|
||||
|
||||
async def parse_image_with_ocr(file_path: str) -> str:
|
||||
"""对图片进行 OCR"""
|
||||
with open(file_path, "rb") as f:
|
||||
image_bytes = f.read()
|
||||
return await ocr_with_siliconflow(image_bytes)
|
||||
|
||||
|
||||
def parse_csv(file_path: str) -> str:
|
||||
import pandas as pd
|
||||
df = pd.read_csv(file_path)
|
||||
return df.to_markdown(index=False)
|
||||
|
||||
|
||||
def parse_xlsx(file_path: str) -> str:
|
||||
import pandas as pd
|
||||
df = pd.read_excel(file_path)
|
||||
return df.to_markdown(index=False)
|
||||
|
||||
|
||||
async def parse_document(file_path: str, mime_type: str) -> str:
|
||||
"""根据文件类型路由到合适的解析器"""
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
|
||||
if ext in (".txt",):
|
||||
return parse_txt(file_path)
|
||||
elif ext in (".md", ".markdown"):
|
||||
return parse_markdown(file_path)
|
||||
elif ext in (".docx",):
|
||||
return parse_docx(file_path)
|
||||
elif ext in (".csv",):
|
||||
return parse_csv(file_path)
|
||||
elif ext in (".xlsx",):
|
||||
return parse_xlsx(file_path)
|
||||
elif ext in (".pdf",):
|
||||
if pdf_needs_ocr(file_path):
|
||||
# 扫描件——先尝试文本提取,空则走多模态
|
||||
text = parse_pdf_text(file_path)
|
||||
if len(text.strip()) < 100:
|
||||
# 全扫描件,逐页 OCR
|
||||
import fitz
|
||||
doc = fitz.open(file_path)
|
||||
results = []
|
||||
for i, page in enumerate(doc):
|
||||
pix = page.get_pixmap(dpi=150)
|
||||
img_bytes = pix.tobytes("png")
|
||||
page_text = await ocr_with_siliconflow(img_bytes)
|
||||
results.append(page_text)
|
||||
doc.close()
|
||||
return "\n\n".join(results)
|
||||
return text
|
||||
return parse_pdf_text(file_path)
|
||||
elif ext in (".png", ".jpg", ".jpeg", ".webp", ".heic", ".bmp"):
|
||||
return await parse_image_with_ocr(file_path)
|
||||
else:
|
||||
raise ValueError(f"不支持的文件类型: {ext}")
|
||||
127
rag-worker/pipelines/import_pipeline.py
Normal file
127
rag-worker/pipelines/import_pipeline.py
Normal file
@ -0,0 +1,127 @@
|
||||
"""导入主流程:下载 → 解析 → 清洗 → 切片 → embedding → Qdrant → AI 候选"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from parser import download_file, parse_document
|
||||
from chunker import chunk_document
|
||||
from embedder import embed_batch
|
||||
from indexer import upsert_points
|
||||
from candidate_generator import generate_candidates
|
||||
from api_client import (
|
||||
heartbeat as send_heartbeat,
|
||||
update_job_status,
|
||||
save_chunks,
|
||||
save_candidates,
|
||||
get_job_detail,
|
||||
)
|
||||
|
||||
|
||||
async def run_import(job: dict):
|
||||
"""执行完整的文档导入流程"""
|
||||
job_id = job["id"]
|
||||
source_id = job.get("sourceId") or job.get("source_id")
|
||||
user_id = job["userId"] or job.get("user_id")
|
||||
kb_id = job["knowledgeBaseId"] or job.get("knowledge_base_id")
|
||||
file_id = job.get("fileId") or job.get("file_id")
|
||||
|
||||
if not source_id:
|
||||
raise ValueError(f"任务 {job_id} 缺少 sourceId")
|
||||
|
||||
# 获取 source 详情(从 NestJS)
|
||||
detail = await get_job_detail(job_id)
|
||||
source = (detail or {}).get("source", {}) if detail else {}
|
||||
mime_type = source.get("mimeType") or source.get("mime_type") or "text/plain"
|
||||
original_filename = source.get("originalFilename") or source.get("original_filename") or "unknown"
|
||||
|
||||
tmp_dir = f"/data/tmp/imports/{job_id}"
|
||||
file_path = os.path.join(tmp_dir, original_filename)
|
||||
|
||||
try:
|
||||
# 1. 下载文件
|
||||
await update_job_status(job_id, "DOWNLOADING", {"progress": 5})
|
||||
file_url = source.get("downloadUrl") or (detail or {}).get("downloadUrl", "")
|
||||
if file_url:
|
||||
await download_file(file_url, file_path)
|
||||
|
||||
# 2. 解析
|
||||
await update_job_status(job_id, "PARSING", {"progress": 20})
|
||||
text = await parse_document(file_path, mime_type)
|
||||
|
||||
# 如果文件不在本地(纯文本导入),直接从 source/import 中取文本
|
||||
if not text and (job.get("rawText") or source.get("rawText")):
|
||||
text = job.get("rawText", "") or source.get("rawText", "")
|
||||
|
||||
if not text or len(text.strip()) < 10:
|
||||
raise ValueError("文档解析后内容过少,可能为空白或损坏文件")
|
||||
|
||||
# 3. 清洗
|
||||
await update_job_status(job_id, "CLEANING", {"progress": 40, "textLength": len(text)})
|
||||
|
||||
# 4. 切片
|
||||
await update_job_status(job_id, "CHUNKING", {"progress": 50})
|
||||
source_type = source.get("type") or "text"
|
||||
chunks = chunk_document(text, source_type)
|
||||
|
||||
# 5. Embedding
|
||||
await update_job_status(job_id, "EMBEDDING", {"progress": 60})
|
||||
texts = [c["content"] for c in chunks]
|
||||
vectors = await embed_batch(texts)
|
||||
|
||||
# 6. Qdrant 索引
|
||||
await update_job_status(job_id, "INDEXING", {"progress": 80})
|
||||
points = []
|
||||
chunk_records = []
|
||||
for i, (chunk, vec) in enumerate(zip(chunks, vectors)):
|
||||
chunk_id = f"chunk_{source_id}_{i}"
|
||||
points.append({
|
||||
"id": chunk_id,
|
||||
"vector": vec,
|
||||
"payload": {
|
||||
"userId": user_id,
|
||||
"knowledgeBaseId": kb_id,
|
||||
"sourceId": source_id,
|
||||
"chunkId": chunk_id,
|
||||
"pageNumber": chunk.get("pageNumber"),
|
||||
"sectionTitle": chunk.get("sectionTitle", ""),
|
||||
"deleted": False,
|
||||
},
|
||||
})
|
||||
chunk_records.append({
|
||||
"userId": user_id,
|
||||
"knowledgeBaseId": kb_id,
|
||||
"sourceId": source_id,
|
||||
"content": chunk["content"],
|
||||
"chunkIndex": chunk["chunkIndex"],
|
||||
"pageNumber": chunk.get("pageNumber"),
|
||||
"sectionTitle": chunk.get("sectionTitle", ""),
|
||||
"tokenCount": len(chunk["content"]),
|
||||
"externalVectorId": chunk_id,
|
||||
"embeddingModel": "bge-m3",
|
||||
"embeddingStatus": "COMPLETED",
|
||||
"metadataJson": {"chunkType": chunk.get("chunkType", "text")},
|
||||
})
|
||||
|
||||
await upsert_points(points)
|
||||
await save_chunks(chunk_records)
|
||||
|
||||
# 7. 生成候选知识点
|
||||
await update_job_status(job_id, "GENERATING_CANDIDATES", {"progress": 90})
|
||||
candidates = await generate_candidates(text)
|
||||
if candidates:
|
||||
await save_candidates(user_id, kb_id, source_id, job_id, candidates)
|
||||
|
||||
# 8. 完成
|
||||
await update_job_status(job_id, "COMPLETED", {"progress": 100})
|
||||
|
||||
except Exception as e:
|
||||
await update_job_status(job_id, "FAILED_RETRYABLE", {
|
||||
"errorCode": "WORKER_ERROR",
|
||||
"errorMessage": str(e)[:500],
|
||||
})
|
||||
raise
|
||||
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if os.path.exists(tmp_dir):
|
||||
import shutil
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
8
rag-worker/requirements.txt
Normal file
8
rag-worker/requirements.txt
Normal file
@ -0,0 +1,8 @@
|
||||
httpx>=0.27
|
||||
pydantic>=2.0
|
||||
pymupdf>=1.24
|
||||
python-docx>=1.1
|
||||
markdown>=3.5
|
||||
pandas>=2.0
|
||||
openpyxl>=3.1
|
||||
Pillow>=10.0
|
||||
@ -28,6 +28,7 @@ import { FilesModule } from './modules/files/files.module';
|
||||
import { WaitlistModule } from './modules/waitlist/waitlist.module';
|
||||
import { KnowledgeSourceModule } from './modules/knowledge-source/knowledge-source.module';
|
||||
import { ImportCandidateModule } from './modules/import-candidate/import-candidate.module';
|
||||
import { RagModule } from './modules/rag/rag.module';
|
||||
|
||||
import { JwtAuthGuard } from './common/guards/jwt-auth.guard';
|
||||
import { RolesGuard } from './common/guards/roles.guard';
|
||||
@ -85,6 +86,7 @@ import appleConfig from './config/apple.config';
|
||||
KnowledgeSourceModule,
|
||||
ImportCandidateModule,
|
||||
DocumentImportModule,
|
||||
RagModule,
|
||||
LearningSessionModule,
|
||||
ActiveRecallModule,
|
||||
AiAnalysisModule,
|
||||
|
||||
124
src/modules/rag/internal-rag.controller.ts
Normal file
124
src/modules/rag/internal-rag.controller.ts
Normal file
@ -0,0 +1,124 @@
|
||||
import { Controller, Get, Post, Body, Param } from '@nestjs/common';
|
||||
import { ApiTags } from '@nestjs/swagger';
|
||||
import { DocumentImportRepository } from '../document-import/document-import.repository';
|
||||
import { KnowledgeSourceRepository } from '../knowledge-source/knowledge-source.repository';
|
||||
import { ImportCandidateRepository } from '../import-candidate/import-candidate.repository';
|
||||
import { PrismaService } from '../../infrastructure/database/prisma.service';
|
||||
|
||||
@ApiTags('internal-rag')
|
||||
@Controller('internal/rag')
|
||||
export class InternalRagController {
|
||||
constructor(
|
||||
private readonly importRepo: DocumentImportRepository,
|
||||
private readonly sourceRepo: KnowledgeSourceRepository,
|
||||
private readonly candidateRepo: ImportCandidateRepository,
|
||||
private readonly prisma: PrismaService,
|
||||
) {}
|
||||
|
||||
@Get('jobs/next')
|
||||
async getNextJob() {
|
||||
const job = await this.importRepo.claimNext(''); // 先查询,不认领
|
||||
if (!job) return { job: null };
|
||||
return {
|
||||
job: {
|
||||
id: job.id,
|
||||
userId: job.userId,
|
||||
knowledgeBaseId: job.knowledgeBaseId,
|
||||
sourceId: job.sourceId,
|
||||
fileId: job.fileId,
|
||||
sourceType: job.sourceType,
|
||||
sourceName: job.sourceName,
|
||||
rawText: job.rawText,
|
||||
status: job.status,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
@Get('jobs/:id')
|
||||
async getJobDetail(@Param('id') id: string) {
|
||||
const job = await this.importRepo.findById(id);
|
||||
if (!job) return { job: null };
|
||||
|
||||
let source = null;
|
||||
let downloadUrl = null;
|
||||
if (job.sourceId) {
|
||||
source = await this.sourceRepo.findById(job.sourceId);
|
||||
}
|
||||
|
||||
return {
|
||||
job: {
|
||||
id: job.id,
|
||||
userId: job.userId,
|
||||
knowledgeBaseId: job.knowledgeBaseId,
|
||||
sourceId: job.sourceId,
|
||||
fileId: job.fileId,
|
||||
rawText: job.rawText,
|
||||
status: job.status,
|
||||
},
|
||||
source: source ? {
|
||||
id: source.id,
|
||||
type: source.type,
|
||||
originalFilename: source.originalFilename,
|
||||
mimeType: source.mimeType,
|
||||
sizeBytes: Number(source.sizeBytes),
|
||||
originalObjectKey: source.originalObjectKey,
|
||||
} : null,
|
||||
};
|
||||
}
|
||||
|
||||
@Post('jobs/:id/claim')
|
||||
async claimJob(@Param('id') id: string, @Body() body: { workerId?: string }) {
|
||||
const workerId = body.workerId || 'unknown';
|
||||
const result = await this.importRepo.claim(id, workerId);
|
||||
return { success: result.count > 0 };
|
||||
}
|
||||
|
||||
@Post('jobs/:id/heartbeat')
|
||||
async heartbeat(@Param('id') id: string) {
|
||||
await this.importRepo.heartbeat(id);
|
||||
return { success: true };
|
||||
}
|
||||
|
||||
@Post('jobs/:id/status')
|
||||
async updateStatus(
|
||||
@Param('id') id: string,
|
||||
@Body() body: { status: string; progress?: number; errorCode?: string; errorMessage?: string },
|
||||
) {
|
||||
await this.importRepo.updateStatus(id, body.status, {
|
||||
step: body.status,
|
||||
progress: body.progress,
|
||||
errorCode: body.errorCode,
|
||||
errorMessage: body.errorMessage,
|
||||
});
|
||||
return { success: true };
|
||||
}
|
||||
|
||||
@Post('chunks')
|
||||
async saveChunks(@Body() body: { chunks: any[] }) {
|
||||
const chunks = body.chunks || [];
|
||||
if (chunks.length > 0) {
|
||||
await this.prisma.knowledgeChunk.createMany({ data: chunks });
|
||||
}
|
||||
return { success: true, count: chunks.length };
|
||||
}
|
||||
|
||||
@Post('candidates')
|
||||
async saveCandidates(
|
||||
@Body() body: {
|
||||
userId: string;
|
||||
knowledgeBaseId: string;
|
||||
sourceId: string;
|
||||
importId: string;
|
||||
candidates: any[];
|
||||
},
|
||||
) {
|
||||
await this.candidateRepo.createMany(
|
||||
body.userId,
|
||||
body.knowledgeBaseId,
|
||||
body.sourceId,
|
||||
body.importId,
|
||||
body.candidates || [],
|
||||
);
|
||||
return { success: true, count: body.candidates?.length || 0 };
|
||||
}
|
||||
}
|
||||
11
src/modules/rag/rag.module.ts
Normal file
11
src/modules/rag/rag.module.ts
Normal file
@ -0,0 +1,11 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
import { InternalRagController } from './internal-rag.controller';
|
||||
import { DocumentImportModule } from '../document-import/document-import.module';
|
||||
import { KnowledgeSourceModule } from '../knowledge-source/knowledge-source.module';
|
||||
import { ImportCandidateModule } from '../import-candidate/import-candidate.module';
|
||||
|
||||
@Module({
|
||||
imports: [DocumentImportModule, KnowledgeSourceModule, ImportCandidateModule],
|
||||
controllers: [InternalRagController],
|
||||
})
|
||||
export class RagModule {}
|
||||
Loading…
x
Reference in New Issue
Block a user