"""chat 人工确认点的 LangGraph interrupt 运行器。""" from __future__ import annotations from dataclasses import dataclass, field from typing import Any, Literal from uuid import uuid4 from .agent import PamDeployAgent from .models import AgentState GraphFlow = Literal["global", "deploy"] @dataclass(slots=True) class LangGraphRunResult: """一次 LangGraph 执行或恢复后的结果摘要。""" state: AgentState | None = None report: str = "" confirmation: dict[str, Any] = field(default_factory=dict) interrupted: bool = False chunks: list[dict[str, Any]] = field(default_factory=list) class LangGraphDeploymentRuntime: """用 LangGraph 节点调度部署 action,并托管人工确认 interrupt。""" def __init__( self, *, agent: PamDeployAgent, thread_id: str | None = None, flow: GraphFlow = "deploy", ) -> None: """初始化图实例和会话线程 ID。""" self.agent = agent self.thread_id = thread_id or str(uuid4()) self.flow = flow self._waiting_confirmation = False self._graph = build_deployment_graph(agent=self.agent, flow=self.flow) @property def waiting_confirmation(self) -> bool: """返回当前 LangGraph 会话是否停在 interrupt 确认点。""" return self._waiting_confirmation def start(self, state: AgentState) -> LangGraphRunResult: """从给定 AgentState 开始执行,直到结束或遇到人工确认点。""" self._waiting_confirmation = False return self._consume(self._graph.stream({"agent_state": state}, self._config())) def resume(self, *, approved: bool, note: str = "") -> LangGraphRunResult: """把人工确认结果交回 LangGraph,并继续执行。""" try: from langgraph.types import Command except ImportError as exc: # pragma: no cover - 依赖缺失时由调用方降级 raise RuntimeError("未安装 langgraph,无法恢复 interrupt。") from exc decision = {"approved": approved, "note": note} return self._consume(self._graph.stream(Command(resume=decision), self._config())) def _config(self) -> dict[str, Any]: """生成 LangGraph checkpointer 使用的线程配置。""" return {"configurable": {"thread_id": self.thread_id}} def _consume(self, chunks: Any) -> LangGraphRunResult: """消费 LangGraph stream 输出,提取状态、报告和 interrupt 请求。""" result = LangGraphRunResult() for chunk in chunks: result.chunks.append(chunk) if "__interrupt__" in chunk: result.interrupted = True result.confirmation = _extract_interrupt_value(chunk["__interrupt__"]) continue for value in chunk.values(): if not isinstance(value, dict): continue if isinstance(value.get("agent_state"), AgentState): result.state = value["agent_state"] if isinstance(value.get("report"), str): result.report = value["report"] self._waiting_confirmation = result.interrupted return result def build_deployment_graph(*, agent: PamDeployAgent, flow: GraphFlow = "deploy"): """构建 action 级别的 LangGraph 部署图。""" try: from langgraph.checkpoint.memory import InMemorySaver from langgraph.graph import END, START, StateGraph from langgraph.types import interrupt except ImportError as exc: # pragma: no cover - 依赖缺失时由调用方降级 raise RuntimeError("未安装 langgraph,无法启用部署图。") from exc def entry_node(state: dict[str, Any]) -> dict[str, Any]: """保留入口节点,便于统一路由已有 state 或恢复 state。""" return {"agent_state": state["agent_state"]} def global_action_node(state: dict[str, Any]) -> dict[str, Any]: """执行一个全局 action。""" agent_state = state["agent_state"] action = agent.next_global_action(agent_state) if action: agent.run_global_action(agent_state, action) return {"agent_state": agent_state} def prepare_ip_node(state: dict[str, Any]) -> dict[str, Any]: """选择下一个 IP action,并写入图状态。""" agent_state = state["agent_state"] work = agent.next_ip_action(agent_state) if work is None: return {"agent_state": agent_state, "current_ip": "", "current_ip_action": ""} ip, action = work return {"agent_state": agent_state, "current_ip": ip, "current_ip_action": action} def ip_action_node(state: dict[str, Any]) -> dict[str, Any]: """执行一个单 IP action。""" agent_state = state["agent_state"] ip = str(state.get("current_ip", "")) action = str(state.get("current_ip_action", "")) if ip and action: agent.run_ip_action(agent_state, ip, action) return {"agent_state": agent_state, "current_ip": "", "current_ip_action": ""} def confirm_node(state: dict[str, Any]) -> dict[str, Any]: """把确认请求交给 LangGraph interrupt,并在恢复后执行确认动作。""" agent_state = state["agent_state"] request = agent.build_confirmation_request(agent_state) decision = interrupt(request) approved, note = _parse_confirmation_decision(decision) agent_state = agent.confirm_pending( agent_state, approved=approved, operator_note=note, ) return {"agent_state": agent_state} def report_node(state: dict[str, Any]) -> dict[str, Any]: """渲染当前状态报告。""" return { "agent_state": state["agent_state"], "report": agent.render_report(state["agent_state"]), } def route_entry(state: dict[str, Any]) -> str: """从入口决定进入全局、IP、确认或报告节点。""" agent_state = state["agent_state"] if agent_state.pending_confirmation: return "confirm" if agent.next_global_action(agent_state): return "global_action" if flow == "global": return "report" return "prepare_ip" def route_after_global(state: dict[str, Any]) -> str: """全局 action 后继续全局循环或进入 IP 阶段。""" agent_state = state["agent_state"] if agent.next_global_action(agent_state): return "global_action" if flow == "global": return "report" return "prepare_ip" def route_after_prepare_ip(state: dict[str, Any]) -> str: """IP 准备节点后进入确认、单 IP action 或报告。""" agent_state = state["agent_state"] if agent_state.pending_confirmation: return "confirm" if state.get("current_ip_action"): return "ip_action" return "report" graph = StateGraph(dict) graph.add_node("entry", entry_node) graph.add_node("global_action", global_action_node) graph.add_node("prepare_ip", prepare_ip_node) graph.add_node("ip_action", ip_action_node) graph.add_node("confirm", confirm_node) graph.add_node("report", report_node) graph.add_edge(START, "entry") graph.add_conditional_edges( "entry", route_entry, { "confirm": "confirm", "global_action": "global_action", "prepare_ip": "prepare_ip", "report": "report", }, ) graph.add_conditional_edges( "global_action", route_after_global, { "global_action": "global_action", "prepare_ip": "prepare_ip", "report": "report", }, ) graph.add_conditional_edges( "prepare_ip", route_after_prepare_ip, {"confirm": "confirm", "ip_action": "ip_action", "report": "report"}, ) graph.add_edge("ip_action", "prepare_ip") graph.add_edge("confirm", "entry") graph.add_edge("report", END) return graph.compile(checkpointer=InMemorySaver()) def _extract_interrupt_value(interrupts: Any) -> dict[str, Any]: """从 LangGraph interrupt 对象中提取确认请求字典。""" if not interrupts: return {} first = interrupts[0] value = getattr(first, "value", first) return value if isinstance(value, dict) else {"value": value} def _parse_confirmation_decision(value: Any) -> tuple[bool, str]: """把 interrupt resume 值解析为 approved/note。""" if isinstance(value, dict): return bool(value.get("approved", False)), str(value.get("note", "")) if isinstance(value, bool): return value, "" if isinstance(value, str): normalized = value.strip().lower() return normalized in ("approve", "approved", "yes", "y", "true"), value return False, str(value)