142 lines
5.5 KiB
Python
142 lines
5.5 KiB
Python
"""PAM 部署 Agent 的 LangGraph 图工厂。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from typing import Any
|
||
|
||
from .agent import PamDeployAgent
|
||
from .langgraph_runtime import GraphFlow
|
||
|
||
|
||
def build_langgraph(agent: PamDeployAgent | None = None, flow: GraphFlow = "deploy"):
|
||
"""构建兼容旧输入格式的 action 级 LangGraph 部署图。
|
||
|
||
输入 state 支持直接传 `params`,图内会先调用 `create_state`;CLI/chat
|
||
默认使用 `LangGraphDeploymentRuntime`,该 runtime 直接接收 `AgentState`
|
||
并支持 interrupt/checkpointer。
|
||
"""
|
||
try:
|
||
from langgraph.graph import END, START, StateGraph
|
||
except ImportError as exc: # pragma: no cover - 依赖可选安装状态
|
||
raise RuntimeError("未安装 langgraph。请先执行 `pip install -e .` 安装项目依赖。") from exc
|
||
|
||
runtime = agent or PamDeployAgent()
|
||
|
||
def create_state_node(state: dict[str, Any]) -> dict[str, Any]:
|
||
"""根据输入参数创建 AgentState。"""
|
||
if "agent_state" in state:
|
||
return {"agent_state": state["agent_state"]}
|
||
agent_state = runtime.create_state(
|
||
params=state["params"],
|
||
execution_strategy=state.get("execution_strategy", "hybrid_node_mcp"),
|
||
execution_mode=state.get("execution_mode", "fixed_runtime"),
|
||
run_id=state.get("run_id"),
|
||
script_entry=state.get("script_entry"),
|
||
config_path=state.get("config_path"),
|
||
trace_file_path=state.get("trace_file_path"),
|
||
checkpoint_path=state.get("checkpoint_path"),
|
||
target_ips=state.get("target_ips"),
|
||
planned_actions=state.get("planned_actions"),
|
||
mode_reason=state.get("mode_reason", ""),
|
||
mode_risk_level=state.get("mode_risk_level", "medium"),
|
||
mode_requires_confirmation=state.get("mode_requires_confirmation", True),
|
||
)
|
||
return {"agent_state": agent_state}
|
||
|
||
def global_action_node(state: dict[str, Any]) -> dict[str, Any]:
|
||
"""执行一个全局 action。"""
|
||
agent_state = state["agent_state"]
|
||
action = runtime.next_global_action(agent_state)
|
||
if action:
|
||
runtime.run_global_action(agent_state, action)
|
||
return {"agent_state": agent_state}
|
||
|
||
def prepare_ip_node(state: dict[str, Any]) -> dict[str, Any]:
|
||
"""选择下一个 IP action。"""
|
||
agent_state = state["agent_state"]
|
||
work = runtime.next_ip_action(agent_state)
|
||
if work is None:
|
||
return {"agent_state": agent_state, "current_ip": "", "current_ip_action": ""}
|
||
ip, action = work
|
||
return {"agent_state": agent_state, "current_ip": ip, "current_ip_action": action}
|
||
|
||
def ip_action_node(state: dict[str, Any]) -> dict[str, Any]:
|
||
"""执行一个 IP action。"""
|
||
agent_state = state["agent_state"]
|
||
ip = str(state.get("current_ip", ""))
|
||
action = str(state.get("current_ip_action", ""))
|
||
if ip and action:
|
||
runtime.run_ip_action(agent_state, ip, action)
|
||
return {"agent_state": agent_state, "current_ip": "", "current_ip_action": ""}
|
||
|
||
def report_node(state: dict[str, Any]) -> dict[str, Any]:
|
||
"""渲染最终部署报告。"""
|
||
return {
|
||
"agent_state": state["agent_state"],
|
||
"report": runtime.render_report(state["agent_state"]),
|
||
}
|
||
|
||
def route_entry(state: dict[str, Any]) -> str:
|
||
"""入口路由。"""
|
||
agent_state = state["agent_state"]
|
||
if agent_state.pending_confirmation:
|
||
return "report"
|
||
if runtime.next_global_action(agent_state):
|
||
return "global_action"
|
||
if flow == "global":
|
||
return "report"
|
||
return "prepare_ip"
|
||
|
||
def route_after_global(state: dict[str, Any]) -> str:
|
||
"""全局 action 后路由。"""
|
||
agent_state = state["agent_state"]
|
||
if runtime.next_global_action(agent_state):
|
||
return "global_action"
|
||
if flow == "global":
|
||
return "report"
|
||
return "prepare_ip"
|
||
|
||
def route_after_prepare_ip(state: dict[str, Any]) -> str:
|
||
"""IP 准备节点后路由。"""
|
||
agent_state = state["agent_state"]
|
||
if agent_state.pending_confirmation:
|
||
return "report"
|
||
if state.get("current_ip_action"):
|
||
return "ip_action"
|
||
return "report"
|
||
|
||
graph = StateGraph(dict)
|
||
graph.add_node("create_state", create_state_node)
|
||
graph.add_node("global_action", global_action_node)
|
||
graph.add_node("prepare_ip", prepare_ip_node)
|
||
graph.add_node("ip_action", ip_action_node)
|
||
graph.add_node("report", report_node)
|
||
|
||
graph.add_edge(START, "create_state")
|
||
graph.add_conditional_edges(
|
||
"create_state",
|
||
route_entry,
|
||
{"global_action": "global_action", "prepare_ip": "prepare_ip", "report": "report"},
|
||
)
|
||
graph.add_conditional_edges(
|
||
"global_action",
|
||
route_after_global,
|
||
{"global_action": "global_action", "prepare_ip": "prepare_ip", "report": "report"},
|
||
)
|
||
graph.add_conditional_edges(
|
||
"prepare_ip",
|
||
route_after_prepare_ip,
|
||
{"ip_action": "ip_action", "report": "report"},
|
||
)
|
||
graph.add_edge("ip_action", "prepare_ip")
|
||
graph.add_edge("report", END)
|
||
return graph.compile()
|
||
|
||
|
||
def build_graph_or_none(agent: PamDeployAgent | None = None, flow: GraphFlow = "deploy"):
|
||
"""在未安装 LangGraph 时返回 None,便于调用方降级。"""
|
||
try:
|
||
return build_langgraph(agent=agent, flow=flow)
|
||
except RuntimeError:
|
||
return None
|