api-server/rag-worker/api_client.py
WangDL fbdae9078f
Some checks failed
Deploy API Server / build-and-deploy (push) Failing after 22s
feat: Python RAG Worker + NestJS 内部 API(文档解析/切片/embedding/Qdrant/候选生成)
2026-05-19 22:35:12 +08:00

95 lines
2.8 KiB
Python

import httpx
from config import API_BASE_URL, RAG_WORKER_SECRET, WORKER_ID
_auth_headers = {
"Authorization": f"Bearer {RAG_WORKER_SECRET}",
"X-Worker-Id": WORKER_ID,
}
async def get_next_job() -> dict | None:
"""获取下一个 QUEUED 导入任务"""
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.get(
f"{API_BASE_URL}/internal/rag/jobs/next",
headers=_auth_headers,
)
if resp.status_code == 200:
data = resp.json()
return data.get("data") or data.get("job")
return None
async def claim_job(job_id: str) -> bool:
"""认领任务"""
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(
f"{API_BASE_URL}/internal/rag/jobs/{job_id}/claim",
headers=_auth_headers,
)
return resp.status_code == 200
async def heartbeat(job_id: str) -> bool:
"""发送心跳"""
async with httpx.AsyncClient(timeout=10) as client:
resp = await client.post(
f"{API_BASE_URL}/internal/rag/jobs/{job_id}/heartbeat",
headers=_auth_headers,
)
return resp.status_code == 200
async def update_job_status(job_id: str, status: str, data: dict | None = None):
"""更新导入任务状态"""
async with httpx.AsyncClient(timeout=30) as client:
await client.post(
f"{API_BASE_URL}/internal/rag/jobs/{job_id}/status",
headers=_auth_headers,
json={"status": status, **(data or {})},
)
async def save_chunks(chunks: list[dict]):
"""批量保存 KnowledgeChunk"""
async with httpx.AsyncClient(timeout=60) as client:
await client.post(
f"{API_BASE_URL}/internal/rag/chunks",
headers=_auth_headers,
json={"chunks": chunks},
)
async def save_candidates(
user_id: str,
kb_id: str,
source_id: str,
import_id: str,
candidates: list[dict],
):
"""保存候选知识点"""
async with httpx.AsyncClient(timeout=60) as client:
await client.post(
f"{API_BASE_URL}/internal/rag/candidates",
headers=_auth_headers,
json={
"userId": user_id,
"knowledgeBaseId": kb_id,
"sourceId": source_id,
"importId": import_id,
"candidates": candidates,
},
)
async def get_job_detail(job_id: str) -> dict | None:
"""获取任务详情(含 source 信息)"""
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.get(
f"{API_BASE_URL}/internal/rag/jobs/{job_id}",
headers=_auth_headers,
)
if resp.status_code == 200:
return resp.json()
return None