用户修复外部问题后输入 resume,会从失败 action 重新执行,而不是结束整个流程。 回滚从 workflow 中拆出,新增显式命令: chat:rollback [IP] CLI:rollback --checkpoint ... [--ip ...] [--stop-first|--no-stop-first] 旧 confirm approve/reject 只保留为旧 checkpoint 兼容入口,新流程不再推荐使用。 LangGraph workflow 已移除回滚确认 interrupt 节点,失败暂停和续跑走业务 checkpoint。 README、打包 README、run.sh --help、流程图、todo、提示词基线和测试都已同步。
267 lines
11 KiB
Python
267 lines
11 KiB
Python
"""PAM 部署 Agent 的 action 级 LangGraph 运行器。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
from dataclasses import dataclass, field
|
||
from typing import Any, Literal
|
||
from uuid import uuid4
|
||
|
||
from .agent import PamDeployAgent
|
||
from .logging_utils import json_for_log
|
||
from .models import AgentState
|
||
|
||
GraphFlow = Literal["global", "deploy"]
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass(slots=True)
|
||
class LangGraphRunResult:
|
||
"""一次 LangGraph 执行或恢复后的结果摘要。"""
|
||
|
||
state: AgentState | None = None
|
||
report: str = ""
|
||
confirmation: dict[str, Any] = field(default_factory=dict)
|
||
interrupted: bool = False
|
||
chunks: list[dict[str, Any]] = field(default_factory=list)
|
||
|
||
|
||
class LangGraphDeploymentRuntime:
|
||
"""用 LangGraph 节点调度部署 action。"""
|
||
|
||
def __init__(
|
||
self,
|
||
*,
|
||
agent: PamDeployAgent,
|
||
thread_id: str | None = None,
|
||
flow: GraphFlow = "deploy",
|
||
) -> None:
|
||
"""初始化图实例和会话线程 ID。"""
|
||
self.agent = agent
|
||
self.thread_id = thread_id or str(uuid4())
|
||
self.flow = flow
|
||
self._waiting_confirmation = False
|
||
self._graph = build_deployment_graph(agent=self.agent, flow=self.flow)
|
||
logger.info(
|
||
"LangGraph runtime 初始化 thread_id=%s flow=%s agent=%s",
|
||
self.thread_id,
|
||
self.flow,
|
||
type(self.agent).__name__,
|
||
)
|
||
|
||
@property
|
||
def waiting_confirmation(self) -> bool:
|
||
"""返回当前 LangGraph 会话是否停在旧版 interrupt 确认点。"""
|
||
return self._waiting_confirmation
|
||
|
||
def start(self, state: AgentState) -> LangGraphRunResult:
|
||
"""从给定 AgentState 开始执行,直到结束或业务状态暂停。"""
|
||
self._waiting_confirmation = False
|
||
logger.info(
|
||
"LangGraph start run_id=%s thread_id=%s flow=%s paused=%s pending=%s",
|
||
state.run_id,
|
||
self.thread_id,
|
||
self.flow,
|
||
state.paused,
|
||
state.pending_confirmation,
|
||
)
|
||
return self._consume(self._graph.stream({"agent_state": state}, self._config()))
|
||
|
||
def resume(self, *, approved: bool, note: str = "") -> LangGraphRunResult:
|
||
"""兼容旧版 LangGraph interrupt 确认恢复;新流程通常不使用。"""
|
||
try:
|
||
from langgraph.types import Command
|
||
except ImportError as exc: # pragma: no cover - 依赖缺失时由调用方降级
|
||
raise RuntimeError("未安装 langgraph,无法恢复 interrupt。") from exc
|
||
|
||
decision = {"approved": approved, "note": note}
|
||
logger.info("LangGraph resume thread_id=%s decision=%s note_len=%s", self.thread_id, approved, len(note))
|
||
return self._consume(self._graph.stream(Command(resume=decision), self._config()))
|
||
|
||
def _config(self) -> dict[str, Any]:
|
||
"""生成 LangGraph checkpointer 使用的线程配置。"""
|
||
return {"configurable": {"thread_id": self.thread_id}}
|
||
|
||
def _consume(self, chunks: Any) -> LangGraphRunResult:
|
||
"""消费 LangGraph stream 输出,提取状态、报告和旧版 interrupt 请求。"""
|
||
result = LangGraphRunResult()
|
||
for chunk in chunks:
|
||
result.chunks.append(chunk)
|
||
logger.info("LangGraph chunk=%s", json_for_log(chunk, max_text_len=1600))
|
||
if "__interrupt__" in chunk:
|
||
result.interrupted = True
|
||
result.confirmation = _extract_interrupt_value(chunk["__interrupt__"])
|
||
logger.info("LangGraph interrupt thread_id=%s confirmation=%s", self.thread_id, json_for_log(result.confirmation))
|
||
continue
|
||
|
||
for value in chunk.values():
|
||
if not isinstance(value, dict):
|
||
continue
|
||
if isinstance(value.get("agent_state"), AgentState):
|
||
result.state = value["agent_state"]
|
||
if isinstance(value.get("report"), str):
|
||
result.report = value["report"]
|
||
|
||
self._waiting_confirmation = result.interrupted
|
||
logger.info(
|
||
"LangGraph consume 完成 thread_id=%s interrupted=%s waiting=%s state_run_id=%s report_len=%s",
|
||
self.thread_id,
|
||
result.interrupted,
|
||
self._waiting_confirmation,
|
||
result.state.run_id if result.state else "",
|
||
len(result.report),
|
||
)
|
||
return result
|
||
|
||
|
||
def build_deployment_graph(*, agent: PamDeployAgent, flow: GraphFlow = "deploy"):
|
||
"""构建 action 级别的 LangGraph 部署图。"""
|
||
logger.info("开始构建 LangGraph 部署图 flow=%s", flow)
|
||
try:
|
||
from langgraph.checkpoint.memory import InMemorySaver
|
||
from langgraph.graph import END, START, StateGraph
|
||
except ImportError as exc: # pragma: no cover - 依赖缺失时由调用方降级
|
||
raise RuntimeError("未安装 langgraph,无法启用部署图。") from exc
|
||
|
||
def entry_node(state: dict[str, Any]) -> dict[str, Any]:
|
||
"""保留入口节点,便于统一路由已有 state 或恢复 state。"""
|
||
agent_state = state["agent_state"]
|
||
logger.info("LangGraph entry_node run_id=%s pending=%s paused=%s", agent_state.run_id, agent_state.pending_confirmation, agent_state.paused)
|
||
return {"agent_state": state["agent_state"]}
|
||
|
||
def global_action_node(state: dict[str, Any]) -> dict[str, Any]:
|
||
"""执行一个全局 action。"""
|
||
agent_state = state["agent_state"]
|
||
action = agent.next_global_action(agent_state)
|
||
if action:
|
||
logger.info("LangGraph global_action_node run_id=%s action=%s", agent_state.run_id, action)
|
||
agent.run_global_action(agent_state, action)
|
||
return {"agent_state": agent_state}
|
||
|
||
def prepare_ip_node(state: dict[str, Any]) -> dict[str, Any]:
|
||
"""选择下一个 IP action,并写入图状态。"""
|
||
agent_state = state["agent_state"]
|
||
work = agent.next_ip_action(agent_state)
|
||
if work is None:
|
||
logger.info("LangGraph prepare_ip_node 无待执行 IP action run_id=%s", agent_state.run_id)
|
||
return {"agent_state": agent_state, "current_ip": "", "current_ip_action": ""}
|
||
ip, action = work
|
||
logger.info("LangGraph prepare_ip_node run_id=%s ip=%s action=%s", agent_state.run_id, ip, action)
|
||
return {"agent_state": agent_state, "current_ip": ip, "current_ip_action": action}
|
||
|
||
def ip_action_node(state: dict[str, Any]) -> dict[str, Any]:
|
||
"""执行一个单 IP action。"""
|
||
agent_state = state["agent_state"]
|
||
ip = str(state.get("current_ip", ""))
|
||
action = str(state.get("current_ip_action", ""))
|
||
if ip and action:
|
||
logger.info("LangGraph ip_action_node run_id=%s ip=%s action=%s", agent_state.run_id, ip, action)
|
||
agent.run_ip_action(agent_state, ip, action)
|
||
return {"agent_state": agent_state, "current_ip": "", "current_ip_action": ""}
|
||
|
||
def report_node(state: dict[str, Any]) -> dict[str, Any]:
|
||
"""渲染当前状态报告。"""
|
||
agent_state = state["agent_state"]
|
||
logger.info("LangGraph report_node run_id=%s pending=%s paused=%s", agent_state.run_id, agent_state.pending_confirmation, agent_state.paused)
|
||
return {
|
||
"agent_state": state["agent_state"],
|
||
"report": agent.render_report(state["agent_state"]),
|
||
}
|
||
|
||
def route_entry(state: dict[str, Any]) -> str:
|
||
"""从入口决定进入全局、IP 或报告节点。"""
|
||
agent_state = state["agent_state"]
|
||
if agent_state.pending_confirmation:
|
||
logger.info("LangGraph route_entry -> report legacy_pending run_id=%s", agent_state.run_id)
|
||
return "report"
|
||
if agent.next_global_action(agent_state):
|
||
logger.info("LangGraph route_entry -> global_action run_id=%s", agent_state.run_id)
|
||
return "global_action"
|
||
if flow == "global":
|
||
logger.info("LangGraph route_entry -> report run_id=%s", agent_state.run_id)
|
||
return "report"
|
||
logger.info("LangGraph route_entry -> prepare_ip run_id=%s", agent_state.run_id)
|
||
return "prepare_ip"
|
||
|
||
def route_after_global(state: dict[str, Any]) -> str:
|
||
"""全局 action 后继续全局循环或进入 IP 阶段。"""
|
||
agent_state = state["agent_state"]
|
||
if agent.next_global_action(agent_state):
|
||
logger.info("LangGraph route_after_global -> global_action run_id=%s", agent_state.run_id)
|
||
return "global_action"
|
||
if flow == "global":
|
||
logger.info("LangGraph route_after_global -> report run_id=%s", agent_state.run_id)
|
||
return "report"
|
||
logger.info("LangGraph route_after_global -> prepare_ip run_id=%s", agent_state.run_id)
|
||
return "prepare_ip"
|
||
|
||
def route_after_prepare_ip(state: dict[str, Any]) -> str:
|
||
"""IP 准备节点后进入单 IP action 或报告。"""
|
||
agent_state = state["agent_state"]
|
||
if agent_state.pending_confirmation:
|
||
logger.info("LangGraph route_after_prepare_ip -> report legacy_pending run_id=%s", agent_state.run_id)
|
||
return "report"
|
||
if state.get("current_ip_action"):
|
||
logger.info("LangGraph route_after_prepare_ip -> ip_action run_id=%s ip=%s action=%s", agent_state.run_id, state.get("current_ip"), state.get("current_ip_action"))
|
||
return "ip_action"
|
||
logger.info("LangGraph route_after_prepare_ip -> report run_id=%s", agent_state.run_id)
|
||
return "report"
|
||
|
||
graph = StateGraph(dict)
|
||
graph.add_node("entry", entry_node)
|
||
graph.add_node("global_action", global_action_node)
|
||
graph.add_node("prepare_ip", prepare_ip_node)
|
||
graph.add_node("ip_action", ip_action_node)
|
||
graph.add_node("report", report_node)
|
||
|
||
graph.add_edge(START, "entry")
|
||
graph.add_conditional_edges(
|
||
"entry",
|
||
route_entry,
|
||
{
|
||
"global_action": "global_action",
|
||
"prepare_ip": "prepare_ip",
|
||
"report": "report",
|
||
},
|
||
)
|
||
graph.add_conditional_edges(
|
||
"global_action",
|
||
route_after_global,
|
||
{
|
||
"global_action": "global_action",
|
||
"prepare_ip": "prepare_ip",
|
||
"report": "report",
|
||
},
|
||
)
|
||
graph.add_conditional_edges(
|
||
"prepare_ip",
|
||
route_after_prepare_ip,
|
||
{"ip_action": "ip_action", "report": "report"},
|
||
)
|
||
graph.add_edge("ip_action", "prepare_ip")
|
||
graph.add_edge("report", END)
|
||
compiled = graph.compile(checkpointer=InMemorySaver())
|
||
logger.info("LangGraph 部署图构建完成 flow=%s", flow)
|
||
return compiled
|
||
|
||
|
||
def _extract_interrupt_value(interrupts: Any) -> dict[str, Any]:
|
||
"""从 LangGraph interrupt 对象中提取确认请求字典。"""
|
||
if not interrupts:
|
||
return {}
|
||
first = interrupts[0]
|
||
value = getattr(first, "value", first)
|
||
return value if isinstance(value, dict) else {"value": value}
|
||
|
||
|
||
def _parse_confirmation_decision(value: Any) -> tuple[bool, str]:
|
||
"""把 interrupt resume 值解析为 approved/note。"""
|
||
if isinstance(value, dict):
|
||
return bool(value.get("approved", False)), str(value.get("note", ""))
|
||
if isinstance(value, bool):
|
||
return value, ""
|
||
if isinstance(value, str):
|
||
normalized = value.strip().lower()
|
||
return normalized in ("approve", "approved", "yes", "y", "true"), value
|
||
return False, str(value)
|