"""业务 checkpoint 的 JSON 存储与恢复工具。""" from __future__ import annotations import json from dataclasses import asdict, fields, is_dataclass from pathlib import Path from typing import Any from .constants import SENSITIVE_KEYS from .models import AgentState def redact_mapping(value: Any) -> Any: """递归脱敏 dict/list 中的敏感字段。""" if isinstance(value, dict): result = {} for key, item in value.items(): if str(key) in SENSITIVE_KEYS: result[key] = "***" else: result[key] = redact_mapping(item) return result if isinstance(value, list): return [redact_mapping(item) for item in value] return value def save_checkpoint(state: Any, path: str | Path, *, redact: bool = True) -> Path: """保存 checkpoint;真实续跑场景可关闭脱敏以保留必要参数。""" checkpoint_path = Path(path) checkpoint_path.parent.mkdir(parents=True, exist_ok=True) payload = asdict(state) if is_dataclass(state) else state if redact: payload = redact_mapping(payload) checkpoint_path.write_text( json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8", ) return checkpoint_path def load_checkpoint(path: str | Path) -> dict[str, Any]: """从 JSON 文件读取原始 checkpoint 字典。""" return json.loads(Path(path).read_text(encoding="utf-8")) def agent_state_from_mapping(payload: dict[str, Any]) -> AgentState: """把 checkpoint 字典转换回 AgentState,忽略未知字段。""" allowed_fields = {item.name for item in fields(AgentState)} state_payload = {key: value for key, value in payload.items() if key in allowed_fields} return AgentState(**state_payload) def load_agent_state(path: str | Path) -> AgentState: """读取 checkpoint 文件并恢复 AgentState。""" return agent_state_from_mapping(load_checkpoint(path))