2026-06-04 10:04:23 +08:00

142 lines
5.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""PAM 部署 Agent 的 LangGraph 图工厂。"""
from __future__ import annotations
from typing import Any
from .agent import PamDeployAgent
from .langgraph_runtime import GraphFlow
def build_langgraph(agent: PamDeployAgent | None = None, flow: GraphFlow = "deploy"):
"""构建兼容旧输入格式的 action 级 LangGraph 部署图。
输入 state 支持直接传 `params`,图内会先调用 `create_state`CLI/chat
默认使用 `LangGraphDeploymentRuntime`,该 runtime 直接接收 `AgentState`
并支持 interrupt/checkpointer。
"""
try:
from langgraph.graph import END, START, StateGraph
except ImportError as exc: # pragma: no cover - 依赖可选安装状态
raise RuntimeError("未安装 langgraph。请先执行 `pip install -e .` 安装项目依赖。") from exc
runtime = agent or PamDeployAgent()
def create_state_node(state: dict[str, Any]) -> dict[str, Any]:
"""根据输入参数创建 AgentState。"""
if "agent_state" in state:
return {"agent_state": state["agent_state"]}
agent_state = runtime.create_state(
params=state["params"],
execution_strategy=state.get("execution_strategy", "hybrid_node_mcp"),
execution_mode=state.get("execution_mode", "fixed_runtime"),
run_id=state.get("run_id"),
script_entry=state.get("script_entry"),
config_path=state.get("config_path"),
trace_file_path=state.get("trace_file_path"),
checkpoint_path=state.get("checkpoint_path"),
target_ips=state.get("target_ips"),
planned_actions=state.get("planned_actions"),
mode_reason=state.get("mode_reason", ""),
mode_risk_level=state.get("mode_risk_level", "medium"),
mode_requires_confirmation=state.get("mode_requires_confirmation", True),
)
return {"agent_state": agent_state}
def global_action_node(state: dict[str, Any]) -> dict[str, Any]:
"""执行一个全局 action。"""
agent_state = state["agent_state"]
action = runtime.next_global_action(agent_state)
if action:
runtime.run_global_action(agent_state, action)
return {"agent_state": agent_state}
def prepare_ip_node(state: dict[str, Any]) -> dict[str, Any]:
"""选择下一个 IP action。"""
agent_state = state["agent_state"]
work = runtime.next_ip_action(agent_state)
if work is None:
return {"agent_state": agent_state, "current_ip": "", "current_ip_action": ""}
ip, action = work
return {"agent_state": agent_state, "current_ip": ip, "current_ip_action": action}
def ip_action_node(state: dict[str, Any]) -> dict[str, Any]:
"""执行一个 IP action。"""
agent_state = state["agent_state"]
ip = str(state.get("current_ip", ""))
action = str(state.get("current_ip_action", ""))
if ip and action:
runtime.run_ip_action(agent_state, ip, action)
return {"agent_state": agent_state, "current_ip": "", "current_ip_action": ""}
def report_node(state: dict[str, Any]) -> dict[str, Any]:
"""渲染最终部署报告。"""
return {
"agent_state": state["agent_state"],
"report": runtime.render_report(state["agent_state"]),
}
def route_entry(state: dict[str, Any]) -> str:
"""入口路由。"""
agent_state = state["agent_state"]
if agent_state.pending_confirmation:
return "report"
if runtime.next_global_action(agent_state):
return "global_action"
if flow == "global":
return "report"
return "prepare_ip"
def route_after_global(state: dict[str, Any]) -> str:
"""全局 action 后路由。"""
agent_state = state["agent_state"]
if runtime.next_global_action(agent_state):
return "global_action"
if flow == "global":
return "report"
return "prepare_ip"
def route_after_prepare_ip(state: dict[str, Any]) -> str:
"""IP 准备节点后路由。"""
agent_state = state["agent_state"]
if agent_state.pending_confirmation:
return "report"
if state.get("current_ip_action"):
return "ip_action"
return "report"
graph = StateGraph(dict)
graph.add_node("create_state", create_state_node)
graph.add_node("global_action", global_action_node)
graph.add_node("prepare_ip", prepare_ip_node)
graph.add_node("ip_action", ip_action_node)
graph.add_node("report", report_node)
graph.add_edge(START, "create_state")
graph.add_conditional_edges(
"create_state",
route_entry,
{"global_action": "global_action", "prepare_ip": "prepare_ip", "report": "report"},
)
graph.add_conditional_edges(
"global_action",
route_after_global,
{"global_action": "global_action", "prepare_ip": "prepare_ip", "report": "report"},
)
graph.add_conditional_edges(
"prepare_ip",
route_after_prepare_ip,
{"ip_action": "ip_action", "report": "report"},
)
graph.add_edge("ip_action", "prepare_ip")
graph.add_edge("report", END)
return graph.compile()
def build_graph_or_none(agent: PamDeployAgent | None = None, flow: GraphFlow = "deploy"):
"""在未安装 LangGraph 时返回 None便于调用方降级。"""
try:
return build_langgraph(agent=agent, flow=flow)
except RuntimeError:
return None