api-server/rag-worker/candidate_generator.py
WangDL c9882c8d04
All checks were successful
Deploy API Server / build-and-deploy (push) Successful in 15s
add rerank module + bug fixes from e2e test
- New reranker.py: SiliconFlow bge-reranker-v2-m3 integration
- config.py: add RERANK_MODEL
- api_client.py: fix get_next_job/claim_job/get_job_detail unwrapping
- candidate_generator.py: fix .format() conflict with JSON braces
- import_pipeline.py: fix file existence check + UUID point IDs
- Add .gitignore for __pycache__

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-20 16:05:09 +08:00

107 lines
3.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""候选知识点生成:调用 DeepSeek 分析文本,生成 ImportCandidate"""
import json
import httpx
from config import DEEPSEEK_API_KEY, DEEPSEEK_BASE_URL, DEEPSEEK_MODEL
MAX_CANDIDATES = 30
MIN_CANDIDATES = 3
CHARS_PER_CANDIDATE = 2000
_PROMPT = """你是一个学习助手。请分析以下文档内容,提取关键知识点。
对于每个知识点,请提供:
- title: 知识点标题(简洁,不超过 30 字)
- summary: 一句话概述(不超过 80 字)
- content: 详细解释(基于原文,保持准确)
- tags: 2-4 个标签
- recallQuestions: 1-2 个主动回忆问题
- difficulty: 难度评估easy/medium/hard
- confidence: 你对这个知识点重要性的置信度0.0-1.0
请以 JSON 数组格式返回,每个元素是一个知识点:
```json
[{
"title": "知识点标题",
"summary": "一句话概述",
"content": "详细解释...",
"tags": ["标签1", "标签2"],
"recallQuestions": ["问题1", "问题2"],
"difficulty": "medium",
"confidence": 0.85
}]
```
文档内容:
{text}
"""
async def generate_candidates(text: str) -> list[dict]:
"""用 DeepSeek 生成候选知识点"""
# 估算生成数量
text_len = len(text)
expected_count = max(MIN_CANDIDATES, min(MAX_CANDIDATES, text_len // CHARS_PER_CANDIDATE))
prompt = _PROMPT.replace("{text}", text[:16000]) # 限制上下文长度
async with httpx.AsyncClient(timeout=120) as client:
resp = await client.post(
f"{DEEPSEEK_BASE_URL}/chat/completions",
headers={"Authorization": f"Bearer {DEEPSEEK_API_KEY}"},
json={
"model": DEEPSEEK_MODEL,
"messages": [
{"role": "system", "content": "你是一个专业的学习内容分析师。请始终返回有效的 JSON 数组。"},
{"role": "user", "content": prompt},
],
"temperature": 0.3,
"max_tokens": 4096,
},
)
if resp.status_code != 200:
raise RuntimeError(f"DeepSeek API error: {resp.status_code} {resp.text}")
data = resp.json()
raw = data["choices"][0]["message"]["content"]
# 提取 JSON
return _parse_json_response(raw, expected_count)
def _parse_json_response(raw: str, expected_count: int) -> list[dict]:
"""从 AI 回复中提取 JSON 数组"""
import re
# 1. 提取 ```json ... ``` 块
m = re.search(r"```(?:json)?\s*(.*?)\s*```", raw, re.DOTALL)
if m:
inner = m.group(1).strip()
try:
candidates = json.loads(inner)
if isinstance(candidates, list):
return candidates[:MAX_CANDIDATES]
except json.JSONDecodeError:
pass
# 2. 提取 [ ... ] 块(从第一个 [ 到最后一个 ]
start = raw.find("[")
end = raw.rfind("]")
if start != -1 and end != -1 and end > start:
try:
candidates = json.loads(raw[start:end + 1])
if isinstance(candidates, list):
return candidates[:MAX_CANDIDATES]
except json.JSONDecodeError:
pass
# 3. 直接解析整个回复
try:
candidates = json.loads(raw)
if isinstance(candidates, list):
return candidates[:MAX_CANDIDATES]
except json.JSONDecodeError:
pass
raise ValueError(f"无法解析 AI 候选知识点回复: {raw[:500]}")