api-server/rag-worker/candidate_generator.py
WangDL fbdae9078f
Some checks failed
Deploy API Server / build-and-deploy (push) Failing after 22s
feat: Python RAG Worker + NestJS 内部 API(文档解析/切片/embedding/Qdrant/候选生成)
2026-05-19 22:35:12 +08:00

104 lines
3.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""候选知识点生成:调用 DeepSeek 分析文本,生成 ImportCandidate"""
import json
import httpx
from config import DEEPSEEK_API_KEY, DEEPSEEK_BASE_URL, DEEPSEEK_MODEL
MAX_CANDIDATES = 30
MIN_CANDIDATES = 3
CHARS_PER_CANDIDATE = 2000
_PROMPT = """你是一个学习助手。请分析以下文档内容,提取关键知识点。
对于每个知识点,请提供:
- title: 知识点标题(简洁,不超过 30 字)
- summary: 一句话概述(不超过 80 字)
- content: 详细解释(基于原文,保持准确)
- tags: 2-4 个标签
- recallQuestions: 1-2 个主动回忆问题
- difficulty: 难度评估easy/medium/hard
- confidence: 你对这个知识点重要性的置信度0.0-1.0
请以 JSON 数组格式返回,每个元素是一个知识点:
```json
[{
"title": "知识点标题",
"summary": "一句话概述",
"content": "详细解释...",
"tags": ["标签1", "标签2"],
"recallQuestions": ["问题1", "问题2"],
"difficulty": "medium",
"confidence": 0.85
}]
```
文档内容:
{text}
"""
async def generate_candidates(text: str) -> list[dict]:
"""用 DeepSeek 生成候选知识点"""
# 估算生成数量
text_len = len(text)
expected_count = max(MIN_CANDIDATES, min(MAX_CANDIDATES, text_len // CHARS_PER_CANDIDATE))
prompt = _PROMPT.format(text=text[:16000]) # 限制上下文长度
async with httpx.AsyncClient(timeout=120) as client:
resp = await client.post(
f"{DEEPSEEK_BASE_URL}/chat/completions",
headers={"Authorization": f"Bearer {DEEPSEEK_API_KEY}"},
json={
"model": DEEPSEEK_MODEL,
"messages": [
{"role": "system", "content": "你是一个专业的学习内容分析师。请始终返回有效的 JSON 数组。"},
{"role": "user", "content": prompt},
],
"temperature": 0.3,
"max_tokens": 4096,
},
)
if resp.status_code != 200:
raise RuntimeError(f"DeepSeek API error: {resp.status_code} {resp.text}")
data = resp.json()
raw = data["choices"][0]["message"]["content"]
# 提取 JSON
return _parse_json_response(raw, expected_count)
def _parse_json_response(raw: str, expected_count: int) -> list[dict]:
"""从 AI 回复中提取 JSON 数组"""
# 尝试直接解析
try:
candidates = json.loads(raw)
if isinstance(candidates, list):
return candidates[:MAX_CANDIDATES]
except json.JSONDecodeError:
pass
# 提取 ```json ... ``` 块
import re
m = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", raw, re.DOTALL)
if m:
try:
candidates = json.loads(m.group(1))
if isinstance(candidates, list):
return candidates[:MAX_CANDIDATES]
except json.JSONDecodeError:
pass
# 提取 [ ... ] 块
m = re.search(r"\[.*\]", raw, re.DOTALL)
if m:
try:
candidates = json.loads(m.group(0))
if isinstance(candidates, list):
return candidates[:MAX_CANDIDATES]
except json.JSONDecodeError:
pass
raise ValueError(f"无法解析 AI 候选知识点回复: {raw[:500]}")