add rerank module + bug fixes from e2e test
All checks were successful
Deploy API Server / build-and-deploy (push) Successful in 15s

- New reranker.py: SiliconFlow bge-reranker-v2-m3 integration
- config.py: add RERANK_MODEL
- api_client.py: fix get_next_job/claim_job/get_job_detail unwrapping
- candidate_generator.py: fix .format() conflict with JSON braces
- import_pipeline.py: fix file existence check + UUID point IDs
- Add .gitignore for __pycache__

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
WangDL 2026-05-20 16:05:09 +08:00
parent 1947a0c0d5
commit c9882c8d04
7 changed files with 107 additions and 34 deletions

1
rag-worker/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
__pycache__/

View File

@ -16,7 +16,9 @@ async def get_next_job() -> dict | None:
) )
if resp.status_code == 200: if resp.status_code == 200:
data = resp.json() data = resp.json()
return data.get("data") or data.get("job") result = data.get("data") or data
if isinstance(result, dict):
return result.get("job")
return None return None
@ -26,8 +28,9 @@ async def claim_job(job_id: str) -> bool:
resp = await client.post( resp = await client.post(
f"{API_BASE_URL}/internal/rag/jobs/{job_id}/claim", f"{API_BASE_URL}/internal/rag/jobs/{job_id}/claim",
headers=_auth_headers, headers=_auth_headers,
json={"workerId": WORKER_ID},
) )
return resp.status_code == 200 return resp.status_code in (200, 201)
async def heartbeat(job_id: str) -> bool: async def heartbeat(job_id: str) -> bool:
@ -90,5 +93,6 @@ async def get_job_detail(job_id: str) -> dict | None:
headers=_auth_headers, headers=_auth_headers,
) )
if resp.status_code == 200: if resp.status_code == 200:
return resp.json() data = resp.json()
return data.get("data") or data
return None return None

View File

@ -43,7 +43,7 @@ async def generate_candidates(text: str) -> list[dict]:
text_len = len(text) text_len = len(text)
expected_count = max(MIN_CANDIDATES, min(MAX_CANDIDATES, text_len // CHARS_PER_CANDIDATE)) expected_count = max(MIN_CANDIDATES, min(MAX_CANDIDATES, text_len // CHARS_PER_CANDIDATE))
prompt = _PROMPT.format(text=text[:16000]) # 限制上下文长度 prompt = _PROMPT.replace("{text}", text[:16000]) # 限制上下文长度
async with httpx.AsyncClient(timeout=120) as client: async with httpx.AsyncClient(timeout=120) as client:
resp = await client.post( resp = await client.post(
@ -71,7 +71,31 @@ async def generate_candidates(text: str) -> list[dict]:
def _parse_json_response(raw: str, expected_count: int) -> list[dict]: def _parse_json_response(raw: str, expected_count: int) -> list[dict]:
"""从 AI 回复中提取 JSON 数组""" """从 AI 回复中提取 JSON 数组"""
# 尝试直接解析 import re
# 1. 提取 ```json ... ``` 块
m = re.search(r"```(?:json)?\s*(.*?)\s*```", raw, re.DOTALL)
if m:
inner = m.group(1).strip()
try:
candidates = json.loads(inner)
if isinstance(candidates, list):
return candidates[:MAX_CANDIDATES]
except json.JSONDecodeError:
pass
# 2. 提取 [ ... ] 块(从第一个 [ 到最后一个 ]
start = raw.find("[")
end = raw.rfind("]")
if start != -1 and end != -1 and end > start:
try:
candidates = json.loads(raw[start:end + 1])
if isinstance(candidates, list):
return candidates[:MAX_CANDIDATES]
except json.JSONDecodeError:
pass
# 3. 直接解析整个回复
try: try:
candidates = json.loads(raw) candidates = json.loads(raw)
if isinstance(candidates, list): if isinstance(candidates, list):
@ -79,25 +103,4 @@ def _parse_json_response(raw: str, expected_count: int) -> list[dict]:
except json.JSONDecodeError: except json.JSONDecodeError:
pass pass
# 提取 ```json ... ``` 块
import re
m = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", raw, re.DOTALL)
if m:
try:
candidates = json.loads(m.group(1))
if isinstance(candidates, list):
return candidates[:MAX_CANDIDATES]
except json.JSONDecodeError:
pass
# 提取 [ ... ] 块
m = re.search(r"\[.*\]", raw, re.DOTALL)
if m:
try:
candidates = json.loads(m.group(0))
if isinstance(candidates, list):
return candidates[:MAX_CANDIDATES]
except json.JSONDecodeError:
pass
raise ValueError(f"无法解析 AI 候选知识点回复: {raw[:500]}") raise ValueError(f"无法解析 AI 候选知识点回复: {raw[:500]}")

View File

@ -1,4 +1,14 @@
import os import os
from pathlib import Path
# 加载 .env 文件systemd 可通过 EnvironmentFile 设置,此处作为手动运行的兜底)
_dotenv_path = Path(__file__).resolve().parent / ".env"
if _dotenv_path.exists():
try:
from dotenv import load_dotenv
load_dotenv(_dotenv_path)
except ImportError:
pass
# NestJS 内部 API # NestJS 内部 API
API_BASE_URL = os.getenv("API_BASE_URL", "http://127.0.0.1:3000") API_BASE_URL = os.getenv("API_BASE_URL", "http://127.0.0.1:3000")
@ -9,6 +19,7 @@ SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY", "")
SILICONFLOW_BASE_URL = os.getenv("SILICONFLOW_BASE_URL", "https://api.siliconflow.cn/v1") SILICONFLOW_BASE_URL = os.getenv("SILICONFLOW_BASE_URL", "https://api.siliconflow.cn/v1")
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "BAAI/bge-m3") EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "BAAI/bge-m3")
EMBEDDING_DIM = int(os.getenv("EMBEDDING_DIM", "1024")) EMBEDDING_DIM = int(os.getenv("EMBEDDING_DIM", "1024"))
RERANK_MODEL = os.getenv("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
# DeepSeek # DeepSeek
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY", "") DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY", "")

View File

@ -45,11 +45,14 @@ async def run_import(job: dict):
# 2. 解析 # 2. 解析
await update_job_status(job_id, "PARSING", {"progress": 20}) await update_job_status(job_id, "PARSING", {"progress": 20})
text = await parse_document(file_path, mime_type) text = ""
if os.path.exists(file_path):
text = await parse_document(file_path, mime_type)
# 如果文件不在本地(纯文本导入),直接从 source/import 中取文本 # 如果文件不在本地(纯文本导入),直接从 source/import 中取文本
if not text and (job.get("rawText") or source.get("rawText")): detail_job = (detail or {}).get("job", {})
text = job.get("rawText", "") or source.get("rawText", "") if not text:
text = job.get("rawText") or source.get("rawText") or detail_job.get("rawText") or ""
if not text or len(text.strip()) < 10: if not text or len(text.strip()) < 10:
raise ValueError("文档解析后内容过少,可能为空白或损坏文件") raise ValueError("文档解析后内容过少,可能为空白或损坏文件")
@ -72,7 +75,7 @@ async def run_import(job: dict):
points = [] points = []
chunk_records = [] chunk_records = []
for i, (chunk, vec) in enumerate(zip(chunks, vectors)): for i, (chunk, vec) in enumerate(zip(chunks, vectors)):
chunk_id = f"chunk_{source_id}_{i}" chunk_id = str(uuid.uuid4())
points.append({ points.append({
"id": chunk_id, "id": chunk_id,
"vector": vec, "vector": vec,
@ -104,11 +107,14 @@ async def run_import(job: dict):
await upsert_points(points) await upsert_points(points)
await save_chunks(chunk_records) await save_chunks(chunk_records)
# 7. 生成候选知识点 # 7. 生成候选知识点(非致命:失败不影响导入完成)
await update_job_status(job_id, "GENERATING_CANDIDATES", {"progress": 90}) await update_job_status(job_id, "GENERATING_CANDIDATES", {"progress": 90})
candidates = await generate_candidates(text) try:
if candidates: candidates = await generate_candidates(text)
await save_candidates(user_id, kb_id, source_id, job_id, candidates) if candidates:
await save_candidates(user_id, kb_id, source_id, job_id, candidates)
except Exception as e:
print(f"[worker] 候选知识点生成失败(非致命): {e}")
# 8. 完成 # 8. 完成
await update_job_status(job_id, "COMPLETED", {"progress": 100}) await update_job_status(job_id, "COMPLETED", {"progress": 100})

View File

@ -6,3 +6,5 @@ markdown>=3.5
pandas>=2.0 pandas>=2.0
openpyxl>=3.1 openpyxl>=3.1
Pillow>=10.0 Pillow>=10.0
qdrant-client>=1.9
python-dotenv>=1.0

46
rag-worker/reranker.py Normal file
View File

@ -0,0 +1,46 @@
"""Rerank 服务:调用硅基流动 bge-reranker-v2-m3 对检索结果精排"""
import httpx
from config import SILICONFLOW_API_KEY, SILICONFLOW_BASE_URL, RERANK_MODEL
async def rerank(
query: str,
documents: list[str],
top_n: int = 5,
) -> list[dict]:
"""对候选文档重新打分排序,返回 top_n 结果。
每个结果包含:
- index: 原始 documents 数组中的位置
- score: 相关性分数 (0.0-1.0)
- text: 文档原文
"""
if not documents:
return []
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(
f"{SILICONFLOW_BASE_URL}/rerank",
headers={"Authorization": f"Bearer {SILICONFLOW_API_KEY}"},
json={
"model": RERANK_MODEL,
"query": query,
"documents": documents,
"top_n": min(top_n, len(documents)),
"return_documents": True,
"max_chunks_per_doc": 1024,
},
)
if resp.status_code != 200:
raise RuntimeError(f"Rerank API error: {resp.status_code} {resp.text}")
data = resp.json()
return [
{
"index": r["index"],
"score": r["relevance_score"],
"text": r.get("document", {}).get("text", documents[r["index"]]),
}
for r in data["results"]
]