From c9882c8d04108dcff805d8e61421850bcf328a71 Mon Sep 17 00:00:00 2001 From: WangDL Date: Wed, 20 May 2026 16:05:09 +0800 Subject: [PATCH] add rerank module + bug fixes from e2e test - New reranker.py: SiliconFlow bge-reranker-v2-m3 integration - config.py: add RERANK_MODEL - api_client.py: fix get_next_job/claim_job/get_job_detail unwrapping - candidate_generator.py: fix .format() conflict with JSON braces - import_pipeline.py: fix file existence check + UUID point IDs - Add .gitignore for __pycache__ Co-Authored-By: Claude Opus 4.7 --- rag-worker/.gitignore | 1 + rag-worker/api_client.py | 10 +++-- rag-worker/candidate_generator.py | 49 +++++++++++++------------ rag-worker/config.py | 11 ++++++ rag-worker/pipelines/import_pipeline.py | 22 +++++++---- rag-worker/requirements.txt | 2 + rag-worker/reranker.py | 46 +++++++++++++++++++++++ 7 files changed, 107 insertions(+), 34 deletions(-) create mode 100644 rag-worker/.gitignore create mode 100644 rag-worker/reranker.py diff --git a/rag-worker/.gitignore b/rag-worker/.gitignore new file mode 100644 index 0000000..c18dd8d --- /dev/null +++ b/rag-worker/.gitignore @@ -0,0 +1 @@ +__pycache__/ diff --git a/rag-worker/api_client.py b/rag-worker/api_client.py index b078e68..23ac97f 100644 --- a/rag-worker/api_client.py +++ b/rag-worker/api_client.py @@ -16,7 +16,9 @@ async def get_next_job() -> dict | None: ) if resp.status_code == 200: data = resp.json() - return data.get("data") or data.get("job") + result = data.get("data") or data + if isinstance(result, dict): + return result.get("job") return None @@ -26,8 +28,9 @@ async def claim_job(job_id: str) -> bool: resp = await client.post( f"{API_BASE_URL}/internal/rag/jobs/{job_id}/claim", headers=_auth_headers, + json={"workerId": WORKER_ID}, ) - return resp.status_code == 200 + return resp.status_code in (200, 201) async def heartbeat(job_id: str) -> bool: @@ -90,5 +93,6 @@ async def get_job_detail(job_id: str) -> dict | None: headers=_auth_headers, ) if resp.status_code == 200: - return resp.json() + data = resp.json() + return data.get("data") or data return None diff --git a/rag-worker/candidate_generator.py b/rag-worker/candidate_generator.py index ed4d218..f5e27c2 100644 --- a/rag-worker/candidate_generator.py +++ b/rag-worker/candidate_generator.py @@ -43,7 +43,7 @@ async def generate_candidates(text: str) -> list[dict]: text_len = len(text) expected_count = max(MIN_CANDIDATES, min(MAX_CANDIDATES, text_len // CHARS_PER_CANDIDATE)) - prompt = _PROMPT.format(text=text[:16000]) # 限制上下文长度 + prompt = _PROMPT.replace("{text}", text[:16000]) # 限制上下文长度 async with httpx.AsyncClient(timeout=120) as client: resp = await client.post( @@ -71,7 +71,31 @@ async def generate_candidates(text: str) -> list[dict]: def _parse_json_response(raw: str, expected_count: int) -> list[dict]: """从 AI 回复中提取 JSON 数组""" - # 尝试直接解析 + import re + + # 1. 提取 ```json ... ``` 块 + m = re.search(r"```(?:json)?\s*(.*?)\s*```", raw, re.DOTALL) + if m: + inner = m.group(1).strip() + try: + candidates = json.loads(inner) + if isinstance(candidates, list): + return candidates[:MAX_CANDIDATES] + except json.JSONDecodeError: + pass + + # 2. 提取 [ ... ] 块(从第一个 [ 到最后一个 ]) + start = raw.find("[") + end = raw.rfind("]") + if start != -1 and end != -1 and end > start: + try: + candidates = json.loads(raw[start:end + 1]) + if isinstance(candidates, list): + return candidates[:MAX_CANDIDATES] + except json.JSONDecodeError: + pass + + # 3. 直接解析整个回复 try: candidates = json.loads(raw) if isinstance(candidates, list): @@ -79,25 +103,4 @@ def _parse_json_response(raw: str, expected_count: int) -> list[dict]: except json.JSONDecodeError: pass - # 提取 ```json ... ``` 块 - import re - m = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", raw, re.DOTALL) - if m: - try: - candidates = json.loads(m.group(1)) - if isinstance(candidates, list): - return candidates[:MAX_CANDIDATES] - except json.JSONDecodeError: - pass - - # 提取 [ ... ] 块 - m = re.search(r"\[.*\]", raw, re.DOTALL) - if m: - try: - candidates = json.loads(m.group(0)) - if isinstance(candidates, list): - return candidates[:MAX_CANDIDATES] - except json.JSONDecodeError: - pass - raise ValueError(f"无法解析 AI 候选知识点回复: {raw[:500]}") diff --git a/rag-worker/config.py b/rag-worker/config.py index 2dd54bc..9c1c632 100644 --- a/rag-worker/config.py +++ b/rag-worker/config.py @@ -1,4 +1,14 @@ import os +from pathlib import Path + +# 加载 .env 文件(systemd 可通过 EnvironmentFile 设置,此处作为手动运行的兜底) +_dotenv_path = Path(__file__).resolve().parent / ".env" +if _dotenv_path.exists(): + try: + from dotenv import load_dotenv + load_dotenv(_dotenv_path) + except ImportError: + pass # NestJS 内部 API API_BASE_URL = os.getenv("API_BASE_URL", "http://127.0.0.1:3000") @@ -9,6 +19,7 @@ SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY", "") SILICONFLOW_BASE_URL = os.getenv("SILICONFLOW_BASE_URL", "https://api.siliconflow.cn/v1") EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "BAAI/bge-m3") EMBEDDING_DIM = int(os.getenv("EMBEDDING_DIM", "1024")) +RERANK_MODEL = os.getenv("RERANK_MODEL", "BAAI/bge-reranker-v2-m3") # DeepSeek DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY", "") diff --git a/rag-worker/pipelines/import_pipeline.py b/rag-worker/pipelines/import_pipeline.py index d23058e..3c59ec2 100644 --- a/rag-worker/pipelines/import_pipeline.py +++ b/rag-worker/pipelines/import_pipeline.py @@ -45,11 +45,14 @@ async def run_import(job: dict): # 2. 解析 await update_job_status(job_id, "PARSING", {"progress": 20}) - text = await parse_document(file_path, mime_type) + text = "" + if os.path.exists(file_path): + text = await parse_document(file_path, mime_type) # 如果文件不在本地(纯文本导入),直接从 source/import 中取文本 - if not text and (job.get("rawText") or source.get("rawText")): - text = job.get("rawText", "") or source.get("rawText", "") + detail_job = (detail or {}).get("job", {}) + if not text: + text = job.get("rawText") or source.get("rawText") or detail_job.get("rawText") or "" if not text or len(text.strip()) < 10: raise ValueError("文档解析后内容过少,可能为空白或损坏文件") @@ -72,7 +75,7 @@ async def run_import(job: dict): points = [] chunk_records = [] for i, (chunk, vec) in enumerate(zip(chunks, vectors)): - chunk_id = f"chunk_{source_id}_{i}" + chunk_id = str(uuid.uuid4()) points.append({ "id": chunk_id, "vector": vec, @@ -104,11 +107,14 @@ async def run_import(job: dict): await upsert_points(points) await save_chunks(chunk_records) - # 7. 生成候选知识点 + # 7. 生成候选知识点(非致命:失败不影响导入完成) await update_job_status(job_id, "GENERATING_CANDIDATES", {"progress": 90}) - candidates = await generate_candidates(text) - if candidates: - await save_candidates(user_id, kb_id, source_id, job_id, candidates) + try: + candidates = await generate_candidates(text) + if candidates: + await save_candidates(user_id, kb_id, source_id, job_id, candidates) + except Exception as e: + print(f"[worker] 候选知识点生成失败(非致命): {e}") # 8. 完成 await update_job_status(job_id, "COMPLETED", {"progress": 100}) diff --git a/rag-worker/requirements.txt b/rag-worker/requirements.txt index 2e426d4..8d6402c 100644 --- a/rag-worker/requirements.txt +++ b/rag-worker/requirements.txt @@ -6,3 +6,5 @@ markdown>=3.5 pandas>=2.0 openpyxl>=3.1 Pillow>=10.0 +qdrant-client>=1.9 +python-dotenv>=1.0 diff --git a/rag-worker/reranker.py b/rag-worker/reranker.py new file mode 100644 index 0000000..dd2edf4 --- /dev/null +++ b/rag-worker/reranker.py @@ -0,0 +1,46 @@ +"""Rerank 服务:调用硅基流动 bge-reranker-v2-m3 对检索结果精排""" + +import httpx +from config import SILICONFLOW_API_KEY, SILICONFLOW_BASE_URL, RERANK_MODEL + + +async def rerank( + query: str, + documents: list[str], + top_n: int = 5, +) -> list[dict]: + """对候选文档重新打分排序,返回 top_n 结果。 + + 每个结果包含: + - index: 原始 documents 数组中的位置 + - score: 相关性分数 (0.0-1.0) + - text: 文档原文 + """ + if not documents: + return [] + + async with httpx.AsyncClient(timeout=30) as client: + resp = await client.post( + f"{SILICONFLOW_BASE_URL}/rerank", + headers={"Authorization": f"Bearer {SILICONFLOW_API_KEY}"}, + json={ + "model": RERANK_MODEL, + "query": query, + "documents": documents, + "top_n": min(top_n, len(documents)), + "return_documents": True, + "max_chunks_per_doc": 1024, + }, + ) + if resp.status_code != 200: + raise RuntimeError(f"Rerank API error: {resp.status_code} {resp.text}") + + data = resp.json() + return [ + { + "index": r["index"], + "score": r["relevance_score"], + "text": r.get("document", {}).get("text", documents[r["index"]]), + } + for r in data["results"] + ]