104 lines
3.2 KiB
Python
104 lines
3.2 KiB
Python
"""候选知识点生成:调用 DeepSeek 分析文本,生成 ImportCandidate"""
|
||
|
||
import json
|
||
import httpx
|
||
from config import DEEPSEEK_API_KEY, DEEPSEEK_BASE_URL, DEEPSEEK_MODEL
|
||
|
||
MAX_CANDIDATES = 30
|
||
MIN_CANDIDATES = 3
|
||
CHARS_PER_CANDIDATE = 2000
|
||
|
||
_PROMPT = """你是一个学习助手。请分析以下文档内容,提取关键知识点。
|
||
|
||
对于每个知识点,请提供:
|
||
- title: 知识点标题(简洁,不超过 30 字)
|
||
- summary: 一句话概述(不超过 80 字)
|
||
- content: 详细解释(基于原文,保持准确)
|
||
- tags: 2-4 个标签
|
||
- recallQuestions: 1-2 个主动回忆问题
|
||
- difficulty: 难度评估(easy/medium/hard)
|
||
- confidence: 你对这个知识点重要性的置信度(0.0-1.0)
|
||
|
||
请以 JSON 数组格式返回,每个元素是一个知识点:
|
||
```json
|
||
[{
|
||
"title": "知识点标题",
|
||
"summary": "一句话概述",
|
||
"content": "详细解释...",
|
||
"tags": ["标签1", "标签2"],
|
||
"recallQuestions": ["问题1?", "问题2?"],
|
||
"difficulty": "medium",
|
||
"confidence": 0.85
|
||
}]
|
||
```
|
||
|
||
文档内容:
|
||
{text}
|
||
"""
|
||
|
||
|
||
async def generate_candidates(text: str) -> list[dict]:
|
||
"""用 DeepSeek 生成候选知识点"""
|
||
# 估算生成数量
|
||
text_len = len(text)
|
||
expected_count = max(MIN_CANDIDATES, min(MAX_CANDIDATES, text_len // CHARS_PER_CANDIDATE))
|
||
|
||
prompt = _PROMPT.format(text=text[:16000]) # 限制上下文长度
|
||
|
||
async with httpx.AsyncClient(timeout=120) as client:
|
||
resp = await client.post(
|
||
f"{DEEPSEEK_BASE_URL}/chat/completions",
|
||
headers={"Authorization": f"Bearer {DEEPSEEK_API_KEY}"},
|
||
json={
|
||
"model": DEEPSEEK_MODEL,
|
||
"messages": [
|
||
{"role": "system", "content": "你是一个专业的学习内容分析师。请始终返回有效的 JSON 数组。"},
|
||
{"role": "user", "content": prompt},
|
||
],
|
||
"temperature": 0.3,
|
||
"max_tokens": 4096,
|
||
},
|
||
)
|
||
if resp.status_code != 200:
|
||
raise RuntimeError(f"DeepSeek API error: {resp.status_code} {resp.text}")
|
||
|
||
data = resp.json()
|
||
raw = data["choices"][0]["message"]["content"]
|
||
|
||
# 提取 JSON
|
||
return _parse_json_response(raw, expected_count)
|
||
|
||
|
||
def _parse_json_response(raw: str, expected_count: int) -> list[dict]:
|
||
"""从 AI 回复中提取 JSON 数组"""
|
||
# 尝试直接解析
|
||
try:
|
||
candidates = json.loads(raw)
|
||
if isinstance(candidates, list):
|
||
return candidates[:MAX_CANDIDATES]
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
# 提取 ```json ... ``` 块
|
||
import re
|
||
m = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", raw, re.DOTALL)
|
||
if m:
|
||
try:
|
||
candidates = json.loads(m.group(1))
|
||
if isinstance(candidates, list):
|
||
return candidates[:MAX_CANDIDATES]
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
# 提取 [ ... ] 块
|
||
m = re.search(r"\[.*\]", raw, re.DOTALL)
|
||
if m:
|
||
try:
|
||
candidates = json.loads(m.group(0))
|
||
if isinstance(candidates, list):
|
||
return candidates[:MAX_CANDIDATES]
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
raise ValueError(f"无法解析 AI 候选知识点回复: {raw[:500]}")
|