diff --git a/README.md b/README.md index 72500b1..2dcd13e 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ pam_deploy_graph/ llm/ # LLM structured output 接口、真实 HTTP client、提示词、规则 fallback 和 guardrails graph.py # LangGraph StateGraph 集成入口 langgraph_runtime.py # chat 人工确认点的 LangGraph interrupt 运行器 - mcp_client.py # MCP session/callable adapter 与 client 配置读取 + mcp_client.py # MCP stdio/HTTP/SSE client、鉴权 token 和配置读取 interactive.py # 常驻式 CLI 对话框,会话命令、确认和续跑 cli.py # CLI 入口 @@ -48,7 +48,7 @@ docs/ packaging/ build_linux_self_contained.sh # Linux 解压即用包构建脚本 README_linux_package.md # Linux 打包说明和包大小评估 - mcp_client.example.json # MCP stdio 配置示例 + mcp_client.example.json # MCP server URL + 鉴权配置示例 ``` ## 当前进度 @@ -72,20 +72,20 @@ packaging/ - 固化真实 LLM 提示词:意图识别、参数抽取、部署计划生成均要求 JSON structured output。 - 增加规则 fallback `RuleBasedLlmClient`,用于本地开发和测试。 - 增加 LLM 输出 guardrails,禁止计划中出现可执行脚本命令和非法 action。 -- 引入 `langgraph` 依赖,并提供 `build_langgraph()` 图工厂。 -- chat 人工确认点已接入 LangGraph interrupt/checkpointer:`run` 到待回滚确认时暂停,`approve/reject` 通过 `Command(resume=...)` 恢复。 -- 引入 MCP client adapter,可包装 SDK session 或普通 callable,并提供 JSON client 配置读取。 -- CLI/chat 支持 `--mcp-config` 直接加载 stdio MCP 配置并构造 MCP runner。 +- 引入 `langgraph` 依赖,CLI/chat 执行流程统一通过 action 级 LangGraph runtime 调度。 +- chat/CLI 人工确认点已接入 LangGraph interrupt/checkpointer:运行到待回滚确认时暂停,`approve/reject` 通过 `Command(resume=...)` 恢复。 +- 引入 MCP client adapter,可包装 SDK session、普通 callable、stdio server、HTTP/SSE server,并提供 JSON client 配置读取。 +- CLI/chat 支持 `--mcp-config` 直接加载 MCP server URL、鉴权和可选 tool 覆盖配置。 - 本地已安装 `langgraph` 和 `mcp`,并完成 LangGraph fake 全局流程 smoke。 - CLI `analyze` 输出已做敏感字段脱敏。 - 增加 `chat` 常驻式 CLI 对话框,支持自然语言分析、参数设置、执行确认、回滚确认、状态查看、事件查看、checkpoint 选择和续跑。 - chat 可选启用 `rich` / `prompt_toolkit`,支持更清晰输出、命令补全和输入历史。 - 增加 action 后 LLM/规则诊断,可通过 `--analyze-actions` 或 `llm action-analysis on` 显式开启。 -- 添加基础测试,当前本地结果为 `37 passed, 1 skipped`。 +- 添加基础测试,当前本地结果为 `42 passed, 1 skipped`。 未完成: -- 尚未接入真实 MCP session;当前已把 client adapter、tool 映射和配置格式准备好。 +- 尚未执行真实 PAM_NODE MCP 调用;当前已把 MCP HTTP/SSE/stdio client、鉴权和 tools 自动发现准备好。 - 尚未执行真实脚本 action 或真实 PAM_NODE MCP 调用。 ## LLM 配置 @@ -115,7 +115,9 @@ python -m pam_deploy_graph.cli analyze \ ## MCP Client 配置 -CLI/chat 已支持通过 `--mcp-config` 直接加载 MCP 配置。当前内置支持 stdio transport;配置文件里提供 MCP server 启动命令后,Agent 会在调用 PAM_NODE action 时创建 MCP stdio session。 +CLI/chat 已支持通过 `--mcp-config` 直接加载 MCP 配置。常用场景只需要配置 MCP `server_url` 和独立鉴权信息;Agent 会连接 MCP server,调用 `list_tools` 自动发现 server 暴露的 tools,再按 action 名自动匹配。 + +MCP 鉴权 token 获取方式与 HOME 一致,默认按 `client_credentials` POST 到 `/oauth/token` 风格接口;但 MCP 使用独立的 `token_url`、`client_id`、`client_secret`,不会复用 HOME 的账号密码。 CLI 示例: @@ -142,31 +144,43 @@ agent = PamDeployAgent(mcp_runner=runner) ```json { "server_name": "pam-node-prod", + "transport": "streamable_http", + "server_url": "https://pam-node-mcp.example.com/mcp", + "auth": { + "token_url": "https://pam-node-auth.example.com/oauth/token", + "client_id": "mcp_client_id", + "client_secret": "mcp_client_secret", + "grant_type": "client_credentials" + }, + "timeout_seconds": 60, + "sse_read_timeout_seconds": 300, + "headers": {} +} +``` + +字段说明: + +- `transport`:支持 `streamable_http`、`sse`、`stdio`。一般远端 MCP server 用 `streamable_http` 或 `sse`。 +- `server_url`:MCP server 地址,例如 `/mcp` 或 `/sse` endpoint。 +- `auth.token_url`:MCP 鉴权 token 地址,和 HOME 获取 token 的表单方式一致,但地址和账号密码独立。 +- `auth.client_id` / `auth.client_secret`:MCP 独立账号密码。 +- `headers`:除鉴权外需要额外带给 MCP server 的静态请求头。 +- `action_tools`:通常不用配置。只有 server 暴露的 tool 名称不符合 `get-online-ips`、`get_online_ips`、`pam_get_online_ips` 这类约定时,才用它覆盖 action -> tool,例如 `{ "get-online-ips": "custom_list_ips" }`。 + +如果是本地 stdio MCP server,也仍然支持: + +```json +{ "transport": "stdio", "command": "/opt/pam-node-mcp/server", "args": ["--stdio"], "cwd": "/opt/pam-node-mcp", "env": { "PAM_NODE_ENV": "prod" - }, - "timeout_seconds": 60, - "tool_names": { - "get-online-ips": "pam_get_online_ips", - "create-download-task": "pam_create_download_task", - "poll-download-progress": "pam_poll_download_progress", - "upgrade-ip": "pam_upgrade_ip", - "poll-upgrade-progress": "pam_poll_upgrade_progress", - "start-ip": "pam_start_ip", - "stop-ip": "pam_stop_ip", - "verify-ip": "pam_verify_ip", - "download-log": "pam_download_log", - "rollback-ip": "pam_rollback_ip" } } ``` -如果不传 `tool_names`,`McpActionRunner` 会使用上面的默认 action -> tool 映射。 - ## 使用方式 整体逻辑结构流程图: @@ -251,9 +265,10 @@ python -m pam_deploy_graph.cli run-deploy --config doc_scripts/config.txt.exampl ```bash python -m pam_deploy_graph.cli confirm --checkpoint runtime/checkpoints/demo.json --decision approve --confirm -python -m pam_deploy_graph.cli resume --checkpoint runtime/checkpoints/demo.json --confirm ``` +`confirm` 会通过 LangGraph interrupt resume 处理确认,并在确认后继续执行后续图节点;如果进程中断或需要再次续跑,再执行 `resume` 即可。 + 拒绝回滚: ```bash diff --git a/docs/current_logic_flow.md b/docs/current_logic_flow.md index e5bedbf..ec8d54a 100644 --- a/docs/current_logic_flow.md +++ b/docs/current_logic_flow.md @@ -20,13 +20,14 @@ flowchart TD CLI --> AGENT[PamDeployAgent] CHAT --> AGENT - CHAT --> LGR[langgraph_runtime.py chat interrupt 运行器] + CLI --> LGR[langgraph_runtime.py action 级 LangGraph runtime] + CHAT --> LGR PARAMS --> AGENT RULE --> AGENT REAL --> AGENT LGR --> AGENT - LGR --> LGCHECK[LangGraph InMemorySaver checkpointer] + LGR --> LGCHECK[LangGraph InMemorySaver checkpointer/interrupt] AGENT --> ROUTER[ActionRouter] ROUTER --> SCRIPT[ScriptActionRunner] ROUTER --> MCP[McpActionRunner] @@ -34,7 +35,7 @@ flowchart TD SCRIPT --> DEPLOY[doc_scripts/deploy.sh 或 deploy.ps1] MCP --> MCPFACTORY[mcp_factory.py 读取 --mcp-config] - MCPFACTORY --> MCPCLIENT[mcp_client.py: stdio/Session/Function adapter] + MCPFACTORY --> MCPCLIENT[mcp_client.py: stdio/HTTP/SSE adapter + token auth] FAKE --> FIXTURE[测试 fixture 或默认 fake 返回值] AGENT --> CHECKPOINT[checkpoint_store.py] @@ -65,11 +66,11 @@ flowchart TD A[create_state 创建运行状态] --> B[normalize_params 合并默认参数并校验必填项] B --> C[write_config 写脚本配置文件] C --> D[build_action_backends 生成 action 路由表] - D --> E[run_deploy_flow] + D --> E[LangGraph entry 节点] E --> F{是否存在 pending_confirmation} - F -- 是 --> P[暂停并保存 checkpoint] - F -- 否 --> G[run_global_flow 全局阶段] + F -- 是 --> P[confirm interrupt 节点] + F -- 否 --> G[global_action 节点循环] G --> G1[get-token] G1 --> G2[create-version] @@ -79,14 +80,14 @@ flowchart TD G5 --> G6[get-online-ips] G6 --> G7[create-download-task] G7 --> G8[poll-download-progress] - G8 --> H[run_ip_flow 逐 IP 阶段] + G8 --> H[prepare_ip 节点选择下一个 IP action] H --> I[resolve_target_ips 计算目标 IP] - I --> J[upgrade-ip] - J --> K[poll-upgrade-progress] - K --> L[start-ip] - L --> M[verify-ip] - M --> N[download-log] + I --> J[ip_action 节点执行 upgrade-ip] + J --> K[ip_action 节点执行 poll-upgrade-progress] + K --> L[ip_action 节点执行 start-ip] + L --> M[ip_action 节点执行 verify-ip] + M --> N[ip_action 节点执行 download-log] N --> O{还有下一个 IP} O -- 是 --> J O -- 否 --> R[render_report 输出报告] @@ -131,7 +132,7 @@ flowchart TD E --> F[设置 pending_confirmation=rollback-ip:IP] F --> G[保存 checkpoint 并暂停] - G --> LG{是否来自 chat} + G --> LG{是否来自 CLI/chat 图运行} LG -- 是 --> LGI[LangGraph interrupt 输出确认请求] LGI --> LGRS[approve/reject 通过 Command resume 恢复] LGRS --> H{用户决定} @@ -153,11 +154,12 @@ flowchart TD - `ip_states[ip].status == SUCCESS`:成功 IP 会跳过。 - `ip_states[ip].completed_steps`:同一个 IP 已完成的 action 会跳过。 - `pending_confirmation`:存在待确认事项时,部署流程不继续执行,必须先 `approve` 或 `reject`。 -- chat 会话内的确认点由 `langgraph_runtime.py` 通过 LangGraph interrupt 和 InMemorySaver 托管;命令行一次性 `confirm/resume` 仍读取业务 checkpoint JSON。 +- CLI/chat 的运行调度由 `langgraph_runtime.py` 通过 action 级 LangGraph 节点执行;chat 和 CLI confirm 的确认点使用 LangGraph interrupt 和 InMemorySaver。 +- 跨进程续跑仍读取业务 checkpoint JSON;LangGraph checkpointer 负责单进程图恢复和 interrupt resume。 - checkpoint 为了真实续跑会保存完整参数,请放在受控目录中。 ## 真实外部能力接入点 - 真实 LLM:`llm.openai_compatible.OpenAICompatibleLlmClient`,通过 `PAM_LLM_BASE_URL`、`PAM_LLM_API_KEY`、`PAM_LLM_MODEL` 或 CLI 参数配置。 -- 真实 MCP:CLI/chat 可通过 `--mcp-config` 加载 stdio MCP 配置,内部由 `mcp_factory.py` 构造 `McpActionRunner`。 +- 真实 MCP:CLI/chat 可通过 `--mcp-config` 加载 streamable_http、sse 或 stdio MCP 配置,HTTP/SSE 支持独立 token 鉴权,并通过 `list_tools` 自动发现 server tools。 - 真实脚本:PAM_HOME action 通过 `doc_scripts/deploy.sh` 或 `deploy.ps1` 调用。 diff --git a/packaging/README_linux_package.md b/packaging/README_linux_package.md index dc51069..7907ea6 100644 --- a/packaging/README_linux_package.md +++ b/packaging/README_linux_package.md @@ -51,7 +51,7 @@ pam-deploy-agent-linux-x86_64/ - `doc_scripts` 不会打入项目设计文档、测试脚本、Windows bat/PowerShell 脚本。 - 发布包内的 `README.md` 来自 `packaging/README_packaged_agent.md`,只说明打包后 Agent 的使用方式。 -- 发布包内的 `mcp_client.example.json` 是 MCP stdio 配置示例,需要按真实 MCP server 修改。 +- 发布包内的 `mcp_client.example.json` 是 MCP server URL + 独立鉴权配置示例,需要按真实 MCP server 和 token 地址修改。 - 项目开发用 README 不会复制到发布包内。 ## 解压后运行 diff --git a/packaging/README_packaged_agent.md b/packaging/README_packaged_agent.md index 0aa07a6..f9c0d9f 100644 --- a/packaging/README_packaged_agent.md +++ b/packaging/README_packaged_agent.md @@ -119,9 +119,10 @@ PAM> exit ```bash ./run.sh confirm --checkpoint runtime/checkpoints/demo.json --decision approve --confirm -./run.sh resume --checkpoint runtime/checkpoints/demo.json --confirm ``` +`confirm` 会通过 LangGraph interrupt resume 处理确认,并在确认后继续执行后续图节点;进程中断或需要再次续跑时,再使用 `resume`。 + 拒绝回滚: ```bash @@ -167,35 +168,33 @@ PAM> llm fallback ## MCP 配置 -`--mcp-config` 指向 MCP client JSON 配置文件。当前支持 stdio transport: +`--mcp-config` 指向 MCP client JSON 配置文件。一般只需要配置 MCP server 地址和独立鉴权信息;Agent 会从 MCP server `list_tools` 自动发现可用 tool,不需要手写所有 action。 + +MCP token 获取方式与 HOME 一致,默认按 `client_credentials` POST 到 token 地址;但 MCP 使用独立的 `token_url`、`client_id`、`client_secret`。 ```json { "server_name": "pam-node-prod", - "transport": "stdio", - "command": "/opt/pam-node-mcp/server", - "args": ["--stdio"], - "cwd": "/opt/pam-node-mcp", - "env": { - "PAM_NODE_ENV": "prod" + "transport": "streamable_http", + "server_url": "https://pam-node-mcp.example.com/mcp", + "auth": { + "token_url": "https://pam-node-auth.example.com/oauth/token", + "client_id": "mcp_client_id", + "client_secret": "mcp_client_secret", + "grant_type": "client_credentials" }, - "timeout_seconds": 60, - "tool_names": { - "get-online-ips": "pam_get_online_ips", - "verify-ip": "pam_verify_ip", - "rollback-ip": "pam_rollback_ip" - } + "timeout_seconds": 60 } ``` 字段说明: -- `command`:MCP server 启动命令。 -- `args`:MCP server 启动参数。 -- `cwd`:MCP server 工作目录,可为空。 -- `env`:传给 MCP server 的环境变量,可为空。 +- `transport`:支持 `streamable_http`、`sse`、`stdio`。 +- `server_url`:MCP server 地址。 +- `auth.token_url`:MCP token 获取地址。 +- `auth.client_id` / `auth.client_secret`:MCP 独立账号密码。 - `timeout_seconds`:单次 tool 调用超时时间。 -- `tool_names`:Agent action 到 MCP tool name 的映射。 +- `action_tools`:可选覆盖项。通常不需要配置;只有 server tool 名称不符合 `get-online-ips`、`get_online_ips`、`pam_get_online_ips` 这类约定时才需要。 ## 注意事项 diff --git a/packaging/build_linux_self_contained.sh b/packaging/build_linux_self_contained.sh index b2e7e94..550d0da 100644 --- a/packaging/build_linux_self_contained.sh +++ b/packaging/build_linux_self_contained.sh @@ -114,8 +114,9 @@ PAM 部署 Agent 解压即用包 指定目标工作站 IP。可重复传入多次。 --mcp-config <路径> - MCP client JSON 配置文件。hybrid_node_mcp 策略、resume 或 confirm - 需要执行 MCP action 时使用。 + MCP client JSON 配置文件。通常配置 server_url 和独立鉴权信息; + Agent 会从 server list_tools 自动发现 tools。hybrid_node_mcp 策略、 + resume 或 confirm 需要执行 MCP action 时使用。 示例:mcp_client.example.json --confirm @@ -151,6 +152,7 @@ LLM 环境变量: ./run.sh run-deploy --config doc_scripts/config.txt.example --strategy fake --checkpoint runtime/checkpoints/demo.json --confirm ./run.sh confirm --checkpoint runtime/checkpoints/demo.json --decision approve --confirm + # 如果进程中断或需要再次续跑: ./run.sh resume --checkpoint runtime/checkpoints/demo.json --confirm 查看子命令原始参数: @@ -160,9 +162,10 @@ LLM 环境变量: 说明: 1. 本包已包含 Python 运行时和 Python 依赖,目标机器不需要安装 Python 包。 2. doc_scripts 只包含运行必需文件:deploy.sh、config.txt.example、PAM_AUTO_DEPLY_SKILL.md。 - 3. mcp_client.example.json 是 MCP stdio 配置示例,需要按真实 MCP server 修改。 - 4. chat 内可使用 params、events、list checkpoints、load checkpoint、llm config、mcp config 等命令。 - 5. checkpoint 会保存完整运行参数,请放在受控目录。 + 3. mcp_client.example.json 是 MCP server URL + 独立鉴权配置示例,需要按真实 MCP server 修改。 + 4. confirm 会通过 LangGraph interrupt resume 处理确认,并继续后续图节点;进程中断时再使用 resume。 + 5. chat 内可使用 params、events、list checkpoints、load checkpoint、llm config、mcp config 等命令。 + 6. checkpoint 会保存完整运行参数,请放在受控目录。 HELP_TEXT } diff --git a/packaging/mcp_client.example.json b/packaging/mcp_client.example.json index 6a96c02..00ac4c1 100644 --- a/packaging/mcp_client.example.json +++ b/packaging/mcp_client.example.json @@ -1,23 +1,18 @@ { "server_name": "pam-node-prod", - "transport": "stdio", - "command": "/opt/pam-node-mcp/server", - "args": ["--stdio"], - "cwd": "/opt/pam-node-mcp", - "env": { - "PAM_NODE_ENV": "prod" + "transport": "streamable_http", + "server_url": "https://pam-node-mcp.example.com/mcp", + "auth": { + "token_url": "https://pam-node-auth.example.com/oauth/token", + "client_id": "mcp_client_id", + "client_secret": "mcp_client_secret", + "grant_type": "client_credentials" }, "timeout_seconds": 60, - "tool_names": { - "get-online-ips": "pam_get_online_ips", - "create-download-task": "pam_create_download_task", - "poll-download-progress": "pam_poll_download_progress", - "upgrade-ip": "pam_upgrade_ip", - "poll-upgrade-progress": "pam_poll_upgrade_progress", - "start-ip": "pam_start_ip", - "stop-ip": "pam_stop_ip", - "verify-ip": "pam_verify_ip", - "download-log": "pam_download_log", - "rollback-ip": "pam_rollback_ip" - } + "sse_read_timeout_seconds": 300, + "headers": { + "X-PAM-Env": "prod" + }, + "_comment_action_tools": "通常不需要配置 action_tools。Agent 会从 MCP server list_tools 自动发现 tool;只有 server tool 名称不符合 get-online-ips/get_online_ips/pam_get_online_ips 这类约定时,才配置 action_tools 覆盖。", + "action_tools": {} } diff --git a/pam_deploy_graph/agent.py b/pam_deploy_graph/agent.py index ce508b5..3ef45a8 100644 --- a/pam_deploy_graph/agent.py +++ b/pam_deploy_graph/agent.py @@ -168,30 +168,45 @@ class PamDeployAgent: def run_global_flow(self, state: AgentState) -> AgentState: """执行全局部署阶段,并跳过 checkpoint 中已完成的步骤。""" + while True: + action = self.next_global_action(state) + if action is None: + return state + self.run_global_action(state, action) + + def next_global_action(self, state: AgentState) -> str | None: + """返回下一个未完成的全局 action。""" for action in GLOBAL_ACTION_SEQUENCE: if action in state.completed_global_steps: continue - kwargs: dict[str, Any] = {} - if action == "publish-version": - kwargs["hash_code"] = state.hash_code - result = self.router.run_action(state, action, **kwargs) - state.events.append( - { - "type": "ACTION_DONE" if result.ok else "ACTION_FAIL", - "stage": action, - "backend": result.backend, - "message": result.error_summary or "ok", - } - ) - self._append_action_analysis(state, action, result) - if not result.ok: - state.last_failed_step = action - self._save_checkpoint(state) - raise RuntimeError(f"{action} 执行失败: {result.error_summary}") - self._apply_result(state, action, result.values) - state.completed_global_steps.append(action) - state.last_success_step = action + return action + return None + + def run_global_action(self, state: AgentState, action: str) -> AgentState: + """执行一个全局 action,并把结果写回 AgentState。""" + if action in state.completed_global_steps: + return state + kwargs: dict[str, Any] = {} + if action == "publish-version": + kwargs["hash_code"] = state.hash_code + result = self.router.run_action(state, action, **kwargs) + state.events.append( + { + "type": "ACTION_DONE" if result.ok else "ACTION_FAIL", + "stage": action, + "backend": result.backend, + "message": result.error_summary or "ok", + } + ) + self._append_action_analysis(state, action, result) + if not result.ok: + state.last_failed_step = action self._save_checkpoint(state) + raise RuntimeError(f"{action} 执行失败: {result.error_summary}") + self._apply_result(state, action, result.values) + state.completed_global_steps.append(action) + state.last_success_step = action + self._save_checkpoint(state) return state def run_deploy_flow(self, state: AgentState) -> AgentState: @@ -205,9 +220,18 @@ class PamDeployAgent: def run_ip_flow(self, state: AgentState) -> AgentState: """执行逐 IP 部署流程,失败时停在人工确认点。""" + while True: + work = self.next_ip_action(state) + if work is None: + return state + ip, action = work + self.run_ip_action(state, ip, action) + + def next_ip_action(self, state: AgentState) -> tuple[str, str] | None: + """返回下一个待执行的单 IP action,并按需初始化 IP 状态。""" if state.pending_confirmation: self._save_checkpoint(state) - return state + return None self._resolve_target_ips(state) for ip in state.target_ips: ip_state = state.ip_states.get(ip) @@ -217,7 +241,7 @@ class PamDeployAgent: if ip_state.get("rollback_status") == "PENDING_AGENT_CONFIRMATION": state.pending_confirmation = f"rollback-ip:{ip}" self._save_checkpoint(state) - return state + return None continue if not ip_state: state.events.append({"type": "IP_START", "ip": ip, "message": "start"}) @@ -232,38 +256,46 @@ class PamDeployAgent: } state.ip_states[ip] = ip_state + completed_steps = ip_state.setdefault("completed_steps", []) for action in IP_ACTION_SEQUENCE: - completed_steps = ip_state.setdefault("completed_steps", []) - if action in completed_steps: - continue - result = self.router.run_action(state, action, ip=ip) - failed = (not result.ok) or self._business_failed(action, result.values) - state.events.append( - { - "type": "ACTION_FAIL" if failed else "ACTION_DONE", - "stage": action, - "backend": result.backend, - "ip": ip, - "message": result.error_summary or result.values.get("MESSAGE", "ok"), - } - ) - self._append_action_analysis(state, action, result, ip=ip) - - if failed: - self._record_ip_failure(state, ip, action, result.error_summary or str(result.values)) - if action != "download-log": - self._download_log_best_effort(state, ip) - state.pending_confirmation = f"rollback-ip:{ip}" - self._save_checkpoint(state) - return state - - self._apply_ip_result(ip_state, action, result.values) - completed_steps.append(action) - self._save_checkpoint(state) + if action not in completed_steps: + return ip, action ip_state["status"] = "SUCCESS" state.events.append({"type": "IP_DONE", "ip": ip, "message": "success"}) self._save_checkpoint(state) + return None + + def run_ip_action(self, state: AgentState, ip: str, action: str) -> AgentState: + """执行一个单 IP action,并在失败时设置人工确认点。""" + ip_state = state.ip_states[ip] + completed_steps = ip_state.setdefault("completed_steps", []) + if action in completed_steps: + return state + result = self.router.run_action(state, action, ip=ip) + failed = (not result.ok) or self._business_failed(action, result.values) + state.events.append( + { + "type": "ACTION_FAIL" if failed else "ACTION_DONE", + "stage": action, + "backend": result.backend, + "ip": ip, + "message": result.error_summary or result.values.get("MESSAGE", "ok"), + } + ) + self._append_action_analysis(state, action, result, ip=ip) + + if failed: + self._record_ip_failure(state, ip, action, result.error_summary or str(result.values)) + if action != "download-log": + self._download_log_best_effort(state, ip) + state.pending_confirmation = f"rollback-ip:{ip}" + self._save_checkpoint(state) + return state + + self._apply_ip_result(ip_state, action, result.values) + completed_steps.append(action) + self._save_checkpoint(state) return state def build_confirmation_request(self, state: AgentState) -> dict[str, Any]: diff --git a/pam_deploy_graph/cli.py b/pam_deploy_graph/cli.py index 226733b..af8892d 100644 --- a/pam_deploy_graph/cli.py +++ b/pam_deploy_graph/cli.py @@ -9,6 +9,7 @@ from dataclasses import asdict from .agent import PamDeployAgent from .checkpoint_store import load_agent_state, redact_mapping from .interactive import run_interactive_chat +from .langgraph_runtime import LangGraphDeploymentRuntime, LangGraphRunResult from .llm import build_llm_client from .mcp_factory import build_mcp_runner_from_config from .params_loader import load_params_file @@ -45,6 +46,25 @@ def print_pause_payload(agent: PamDeployAgent, state) -> None: print(json.dumps({"checkpoint": state.checkpoint_path}, ensure_ascii=False, indent=2)) +def run_graph_once(agent: PamDeployAgent, state, *, flow: str = "deploy") -> LangGraphRunResult: + """用 LangGraph runtime 执行一次状态,返回图执行结果。""" + runtime = LangGraphDeploymentRuntime(agent=agent, flow=flow) # type: ignore[arg-type] + return runtime.start(state) + + +def print_graph_result(agent: PamDeployAgent, result: LangGraphRunResult) -> None: + """输出 LangGraph 执行结果、报告和暂停信息。""" + state = result.state + if result.report: + print(result.report) + elif state is not None: + print(agent.render_report(state)) + if result.interrupted and result.confirmation: + print(json.dumps({"confirmation": result.confirmation}, ensure_ascii=False, indent=2)) + if state is not None: + print_pause_payload(agent, state) + + def main() -> None: """解析 CLI 参数并分发到对应命令。""" parser = argparse.ArgumentParser(prog="pam-deploy-agent") @@ -144,29 +164,29 @@ def main() -> None: execution_strategy=args.strategy, checkpoint_path=args.checkpoint, ) - state = agent.run_global_flow(state) - print(json.dumps({"events": state.events}, ensure_ascii=False, indent=2)) - print_pause_payload(agent, state) + result = run_graph_once(agent, state, flow="global") + if result.state is not None: + print(json.dumps({"events": result.state.events}, ensure_ascii=False, indent=2)) + print_pause_payload(agent, result.state) return if args.command == "resume": state = load_agent_state(args.checkpoint) state.checkpoint_path = state.checkpoint_path or args.checkpoint - state = agent.run_deploy_flow(state) - print(agent.render_report(state)) - print_pause_payload(agent, state) + result = run_graph_once(agent, state, flow="deploy") + print_graph_result(agent, result) return if args.command == "confirm": state = load_agent_state(args.checkpoint) state.checkpoint_path = state.checkpoint_path or args.checkpoint - state = agent.confirm_pending( - state, - approved=args.decision == "approve", - operator_note=args.note, - ) - print(agent.render_report(state)) - print_pause_payload(agent, state) + runtime = LangGraphDeploymentRuntime(agent=agent, flow="deploy") + first = runtime.start(state) + if first.interrupted: + result = runtime.resume(approved=args.decision == "approve", note=args.note) + print_graph_result(agent, result) + return + print_graph_result(agent, first) return state = agent.create_state( @@ -175,9 +195,8 @@ def main() -> None: checkpoint_path=args.checkpoint, target_ips=args.target_ip, ) - state = agent.run_deploy_flow(state) - print(agent.render_report(state)) - print_pause_payload(agent, state) + result = run_graph_once(agent, state, flow="deploy") + print_graph_result(agent, result) if __name__ == "__main__": diff --git a/pam_deploy_graph/constants.py b/pam_deploy_graph/constants.py index 64485e8..8eba2d8 100644 --- a/pam_deploy_graph/constants.py +++ b/pam_deploy_graph/constants.py @@ -69,6 +69,8 @@ DEFAULT_PARAMS = { # 日志、报告和 LLM 输入中需要脱敏的字段。 SENSITIVE_KEYS = { "CLIENT_SECRET", + "MCP_CLIENT_SECRET", + "MCP_TOKEN", "TOKEN", "Authorization", "access_token", diff --git a/pam_deploy_graph/graph.py b/pam_deploy_graph/graph.py index 18ed3f9..f099521 100644 --- a/pam_deploy_graph/graph.py +++ b/pam_deploy_graph/graph.py @@ -1,27 +1,31 @@ -"""PAM 部署 Agent 的 LangGraph 集成入口。""" +"""PAM 部署 Agent 的 LangGraph 图工厂。""" from __future__ import annotations -from typing import Any, Literal +from typing import Any from .agent import PamDeployAgent - -GraphFlow = Literal["global", "deploy"] +from .langgraph_runtime import GraphFlow def build_langgraph(agent: PamDeployAgent | None = None, flow: GraphFlow = "deploy"): - """把现有 Agent 节点组装成 LangGraph StateGraph。""" + """构建兼容旧输入格式的 action 级 LangGraph 部署图。 + + 输入 state 支持直接传 `params`,图内会先调用 `create_state`;CLI/chat + 默认使用 `LangGraphDeploymentRuntime`,该 runtime 直接接收 `AgentState` + 并支持 interrupt/checkpointer。 + """ try: from langgraph.graph import END, START, StateGraph except ImportError as exc: # pragma: no cover - 依赖可选安装状态 - raise RuntimeError( - "未安装 langgraph。请先执行 `pip install -e .` 安装项目依赖。" - ) from exc + raise RuntimeError("未安装 langgraph。请先执行 `pip install -e .` 安装项目依赖。") from exc runtime = agent or PamDeployAgent() def create_state_node(state: dict[str, Any]) -> dict[str, Any]: """根据输入参数创建 AgentState。""" + if "agent_state" in state: + return {"agent_state": state["agent_state"]} agent_state = runtime.create_state( params=state["params"], execution_strategy=state.get("execution_strategy", "hybrid_node_mcp"), @@ -29,38 +33,98 @@ def build_langgraph(agent: PamDeployAgent | None = None, flow: GraphFlow = "depl script_entry=state.get("script_entry"), config_path=state.get("config_path"), trace_file_path=state.get("trace_file_path"), + checkpoint_path=state.get("checkpoint_path"), target_ips=state.get("target_ips"), ) return {"agent_state": agent_state} - def run_global_node(state: dict[str, Any]) -> dict[str, Any]: - """运行全局部署阶段。""" - agent_state = runtime.run_global_flow(state["agent_state"]) + def global_action_node(state: dict[str, Any]) -> dict[str, Any]: + """执行一个全局 action。""" + agent_state = state["agent_state"] + action = runtime.next_global_action(agent_state) + if action: + runtime.run_global_action(agent_state, action) return {"agent_state": agent_state} - def run_ip_node(state: dict[str, Any]) -> dict[str, Any]: - """运行逐 IP 部署阶段。""" - agent_state = runtime.run_ip_flow(state["agent_state"]) - return {"agent_state": agent_state} + def prepare_ip_node(state: dict[str, Any]) -> dict[str, Any]: + """选择下一个 IP action。""" + agent_state = state["agent_state"] + work = runtime.next_ip_action(agent_state) + if work is None: + return {"agent_state": agent_state, "current_ip": "", "current_ip_action": ""} + ip, action = work + return {"agent_state": agent_state, "current_ip": ip, "current_ip_action": action} + + def ip_action_node(state: dict[str, Any]) -> dict[str, Any]: + """执行一个 IP action。""" + agent_state = state["agent_state"] + ip = str(state.get("current_ip", "")) + action = str(state.get("current_ip_action", "")) + if ip and action: + runtime.run_ip_action(agent_state, ip, action) + return {"agent_state": agent_state, "current_ip": "", "current_ip_action": ""} def report_node(state: dict[str, Any]) -> dict[str, Any]: """渲染最终部署报告。""" - return {"report": runtime.render_report(state["agent_state"])} + return { + "agent_state": state["agent_state"], + "report": runtime.render_report(state["agent_state"]), + } + + def route_entry(state: dict[str, Any]) -> str: + """入口路由。""" + agent_state = state["agent_state"] + if agent_state.pending_confirmation: + return "report" + if runtime.next_global_action(agent_state): + return "global_action" + if flow == "global": + return "report" + return "prepare_ip" + + def route_after_global(state: dict[str, Any]) -> str: + """全局 action 后路由。""" + agent_state = state["agent_state"] + if runtime.next_global_action(agent_state): + return "global_action" + if flow == "global": + return "report" + return "prepare_ip" + + def route_after_prepare_ip(state: dict[str, Any]) -> str: + """IP 准备节点后路由。""" + agent_state = state["agent_state"] + if agent_state.pending_confirmation: + return "report" + if state.get("current_ip_action"): + return "ip_action" + return "report" graph = StateGraph(dict) graph.add_node("create_state", create_state_node) - graph.add_node("run_global", run_global_node) - graph.add_node("run_ip", run_ip_node) + graph.add_node("global_action", global_action_node) + graph.add_node("prepare_ip", prepare_ip_node) + graph.add_node("ip_action", ip_action_node) graph.add_node("report", report_node) graph.add_edge(START, "create_state") - graph.add_edge("create_state", "run_global") - if flow == "global": - graph.add_edge("run_global", END) - else: - graph.add_edge("run_global", "run_ip") - graph.add_edge("run_ip", "report") - graph.add_edge("report", END) + graph.add_conditional_edges( + "create_state", + route_entry, + {"global_action": "global_action", "prepare_ip": "prepare_ip", "report": "report"}, + ) + graph.add_conditional_edges( + "global_action", + route_after_global, + {"global_action": "global_action", "prepare_ip": "prepare_ip", "report": "report"}, + ) + graph.add_conditional_edges( + "prepare_ip", + route_after_prepare_ip, + {"ip_action": "ip_action", "report": "report"}, + ) + graph.add_edge("ip_action", "prepare_ip") + graph.add_edge("report", END) return graph.compile() diff --git a/pam_deploy_graph/interactive.py b/pam_deploy_graph/interactive.py index 22b9097..0b54156 100644 --- a/pam_deploy_graph/interactive.py +++ b/pam_deploy_graph/interactive.py @@ -81,7 +81,7 @@ class InteractiveCliSession: self._load_existing_checkpoint_if_any() while True: try: - line = self.input("PAM> ") + line = self.input("pam-deploy-agent> ") except EOFError: self.output("bye") return diff --git a/pam_deploy_graph/langgraph_runtime.py b/pam_deploy_graph/langgraph_runtime.py index 190e028..e5736d7 100644 --- a/pam_deploy_graph/langgraph_runtime.py +++ b/pam_deploy_graph/langgraph_runtime.py @@ -3,12 +3,14 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any +from typing import Any, Literal from uuid import uuid4 from .agent import PamDeployAgent from .models import AgentState +GraphFlow = Literal["global", "deploy"] + @dataclass(slots=True) class LangGraphRunResult: @@ -22,14 +24,21 @@ class LangGraphRunResult: class LangGraphDeploymentRuntime: - """用 LangGraph interrupt/checkpointer 托管 chat 中的人工确认流程。""" + """用 LangGraph 节点调度部署 action,并托管人工确认 interrupt。""" - def __init__(self, *, agent: PamDeployAgent, thread_id: str | None = None) -> None: + def __init__( + self, + *, + agent: PamDeployAgent, + thread_id: str | None = None, + flow: GraphFlow = "deploy", + ) -> None: """初始化图实例和会话线程 ID。""" self.agent = agent self.thread_id = thread_id or str(uuid4()) + self.flow = flow self._waiting_confirmation = False - self._graph = self._build_graph() + self._graph = build_deployment_graph(agent=self.agent, flow=self.flow) @property def waiting_confirmation(self) -> bool: @@ -51,56 +60,6 @@ class LangGraphDeploymentRuntime: decision = {"approved": approved, "note": note} return self._consume(self._graph.stream(Command(resume=decision), self._config())) - def _build_graph(self): - """构建 deploy -> confirm interrupt -> deploy 的循环图。""" - try: - from langgraph.checkpoint.memory import InMemorySaver - from langgraph.graph import END, START, StateGraph - from langgraph.types import interrupt - except ImportError as exc: # pragma: no cover - 依赖缺失时由调用方降级 - raise RuntimeError("未安装 langgraph,无法启用 chat interrupt。") from exc - - def deploy_node(state: dict[str, Any]) -> dict[str, Any]: - """执行部署流,遇到 pending_confirmation 时由路由转入确认节点。""" - agent_state = self.agent.run_deploy_flow(state["agent_state"]) - return {"agent_state": agent_state} - - def confirm_node(state: dict[str, Any]) -> dict[str, Any]: - """把确认请求交给 LangGraph interrupt,并在恢复后执行确认动作。""" - agent_state = state["agent_state"] - request = self.agent.build_confirmation_request(agent_state) - decision = interrupt(request) - approved, note = _parse_confirmation_decision(decision) - agent_state = self.agent.confirm_pending( - agent_state, - approved=approved, - operator_note=note, - ) - return {"agent_state": agent_state} - - def report_node(state: dict[str, Any]) -> dict[str, Any]: - """渲染当前状态报告。""" - return {"report": self.agent.render_report(state["agent_state"])} - - def route_after_deploy(state: dict[str, Any]) -> str: - """根据是否存在 pending_confirmation 决定下一步。""" - agent_state = state["agent_state"] - return "confirm" if agent_state.pending_confirmation else "report" - - graph = StateGraph(dict) - graph.add_node("deploy", deploy_node) - graph.add_node("confirm", confirm_node) - graph.add_node("report", report_node) - graph.add_edge(START, "deploy") - graph.add_conditional_edges( - "deploy", - route_after_deploy, - {"confirm": "confirm", "report": "report"}, - ) - graph.add_edge("confirm", "deploy") - graph.add_edge("report", END) - return graph.compile(checkpointer=InMemorySaver()) - def _config(self) -> dict[str, Any]: """生成 LangGraph checkpointer 使用的线程配置。""" return {"configurable": {"thread_id": self.thread_id}} @@ -127,6 +86,133 @@ class LangGraphDeploymentRuntime: return result +def build_deployment_graph(*, agent: PamDeployAgent, flow: GraphFlow = "deploy"): + """构建 action 级别的 LangGraph 部署图。""" + try: + from langgraph.checkpoint.memory import InMemorySaver + from langgraph.graph import END, START, StateGraph + from langgraph.types import interrupt + except ImportError as exc: # pragma: no cover - 依赖缺失时由调用方降级 + raise RuntimeError("未安装 langgraph,无法启用部署图。") from exc + + def entry_node(state: dict[str, Any]) -> dict[str, Any]: + """保留入口节点,便于统一路由已有 state 或恢复 state。""" + return {"agent_state": state["agent_state"]} + + def global_action_node(state: dict[str, Any]) -> dict[str, Any]: + """执行一个全局 action。""" + agent_state = state["agent_state"] + action = agent.next_global_action(agent_state) + if action: + agent.run_global_action(agent_state, action) + return {"agent_state": agent_state} + + def prepare_ip_node(state: dict[str, Any]) -> dict[str, Any]: + """选择下一个 IP action,并写入图状态。""" + agent_state = state["agent_state"] + work = agent.next_ip_action(agent_state) + if work is None: + return {"agent_state": agent_state, "current_ip": "", "current_ip_action": ""} + ip, action = work + return {"agent_state": agent_state, "current_ip": ip, "current_ip_action": action} + + def ip_action_node(state: dict[str, Any]) -> dict[str, Any]: + """执行一个单 IP action。""" + agent_state = state["agent_state"] + ip = str(state.get("current_ip", "")) + action = str(state.get("current_ip_action", "")) + if ip and action: + agent.run_ip_action(agent_state, ip, action) + return {"agent_state": agent_state, "current_ip": "", "current_ip_action": ""} + + def confirm_node(state: dict[str, Any]) -> dict[str, Any]: + """把确认请求交给 LangGraph interrupt,并在恢复后执行确认动作。""" + agent_state = state["agent_state"] + request = agent.build_confirmation_request(agent_state) + decision = interrupt(request) + approved, note = _parse_confirmation_decision(decision) + agent_state = agent.confirm_pending( + agent_state, + approved=approved, + operator_note=note, + ) + return {"agent_state": agent_state} + + def report_node(state: dict[str, Any]) -> dict[str, Any]: + """渲染当前状态报告。""" + return { + "agent_state": state["agent_state"], + "report": agent.render_report(state["agent_state"]), + } + + def route_entry(state: dict[str, Any]) -> str: + """从入口决定进入全局、IP、确认或报告节点。""" + agent_state = state["agent_state"] + if agent_state.pending_confirmation: + return "confirm" + if agent.next_global_action(agent_state): + return "global_action" + if flow == "global": + return "report" + return "prepare_ip" + + def route_after_global(state: dict[str, Any]) -> str: + """全局 action 后继续全局循环或进入 IP 阶段。""" + agent_state = state["agent_state"] + if agent.next_global_action(agent_state): + return "global_action" + if flow == "global": + return "report" + return "prepare_ip" + + def route_after_prepare_ip(state: dict[str, Any]) -> str: + """IP 准备节点后进入确认、单 IP action 或报告。""" + agent_state = state["agent_state"] + if agent_state.pending_confirmation: + return "confirm" + if state.get("current_ip_action"): + return "ip_action" + return "report" + + graph = StateGraph(dict) + graph.add_node("entry", entry_node) + graph.add_node("global_action", global_action_node) + graph.add_node("prepare_ip", prepare_ip_node) + graph.add_node("ip_action", ip_action_node) + graph.add_node("confirm", confirm_node) + graph.add_node("report", report_node) + + graph.add_edge(START, "entry") + graph.add_conditional_edges( + "entry", + route_entry, + { + "confirm": "confirm", + "global_action": "global_action", + "prepare_ip": "prepare_ip", + "report": "report", + }, + ) + graph.add_conditional_edges( + "global_action", + route_after_global, + { + "global_action": "global_action", + "prepare_ip": "prepare_ip", + "report": "report", + }, + ) + graph.add_conditional_edges( + "prepare_ip", + route_after_prepare_ip, + {"confirm": "confirm", "ip_action": "ip_action", "report": "report"}, + ) + graph.add_edge("ip_action", "prepare_ip") + graph.add_edge("confirm", "entry") + graph.add_edge("report", END) + return graph.compile(checkpointer=InMemorySaver()) + + def _extract_interrupt_value(interrupts: Any) -> dict[str, Any]: """从 LangGraph interrupt 对象中提取确认请求字典。""" if not interrupts: diff --git a/pam_deploy_graph/mcp_client.py b/pam_deploy_graph/mcp_client.py index 3bdddd5..04710e1 100644 --- a/pam_deploy_graph/mcp_client.py +++ b/pam_deploy_graph/mcp_client.py @@ -7,6 +7,9 @@ callable 或 SDK session 适配成这个接口,避免业务代码绑定具体 from __future__ import annotations import json +import time +import urllib.parse +import urllib.request from datetime import timedelta from collections.abc import Callable from dataclasses import dataclass, field @@ -14,23 +17,68 @@ from pathlib import Path from typing import Any +@dataclass(frozen=True) +class McpAuthConfig: + """MCP server 鉴权 token 配置。""" + + token_url: str = "" + client_id: str = "" + client_secret: str = "" + grant_type: str = "client_credentials" + header_name: str = "Authorization" + header_prefix: str = "Bearer" + token_field: str = "access_token" + expires_in_field: str = "expires_in" + extra_form: dict[str, str] = field(default_factory=dict) + + @classmethod + def from_mapping(cls, payload: dict[str, Any] | None) -> "McpAuthConfig | None": + """从 JSON auth 字典构造 MCP 鉴权配置。""" + if not payload: + return None + if not isinstance(payload, dict): + raise ValueError("MCP auth 必须是 JSON object") + token_url = str(payload.get("token_url", "")) + base_url = str(payload.get("base_url", "")) + if not token_url and base_url: + token_url = base_url.rstrip("/") + "/oauth/token" + extra_form = payload.get("extra_form") or {} + if not isinstance(extra_form, dict): + raise ValueError("MCP auth.extra_form 必须是 JSON object") + return cls( + token_url=token_url, + client_id=str(payload.get("client_id", "")), + client_secret=str(payload.get("client_secret", "")), + grant_type=str(payload.get("grant_type", "client_credentials")), + header_name=str(payload.get("header_name", "Authorization")), + header_prefix=str(payload.get("header_prefix", "Bearer")), + token_field=str(payload.get("token_field", "access_token")), + expires_in_field=str(payload.get("expires_in_field", "expires_in")), + extra_form={str(key): str(value) for key, value in extra_form.items()}, + ) + + @dataclass(frozen=True) class McpClientConfig: """真实 MCP session 建立后需要传给 runner 的配置。""" server_name: str = "pam-node" - transport: str = "stdio" + transport: str = "streamable_http" + server_url: str = "" command: str = "" args: list[str] = field(default_factory=list) env: dict[str, str] | None = None cwd: str = "" + headers: dict[str, str] = field(default_factory=dict) + auth: McpAuthConfig | None = None timeout_seconds: float = 60 + sse_read_timeout_seconds: float = 300 tool_names: dict[str, str] = field(default_factory=dict) @classmethod def from_mapping(cls, payload: dict[str, Any]) -> "McpClientConfig": """从 JSON 字典构造 MCP client 配置。""" - tool_names = payload.get("tool_names") or payload.get("tools") or {} + tool_names = payload.get("tool_names") or payload.get("action_tools") or payload.get("tools") or {} if not isinstance(tool_names, dict): raise ValueError("MCP tool_names 必须是 JSON object") args = payload.get("args") or [] @@ -39,14 +87,24 @@ class McpClientConfig: env = payload.get("env") if env is not None and not isinstance(env, dict): raise ValueError("MCP env 必须是 JSON object") + headers = payload.get("headers") or {} + if not isinstance(headers, dict): + raise ValueError("MCP headers 必须是 JSON object") + server_url = str(payload.get("server_url") or payload.get("url") or "") + command = str(payload.get("command", "")) + transport = str(payload.get("transport") or ("stdio" if command else "streamable_http")) return cls( server_name=str(payload.get("server_name", "pam-node")), - transport=str(payload.get("transport", "stdio")), - command=str(payload.get("command", "")), + transport=transport, + server_url=server_url, + command=command, args=[str(item) for item in args], env={str(key): str(value) for key, value in env.items()} if env else None, cwd=str(payload.get("cwd", "")), + headers={str(key): str(value) for key, value in headers.items()}, + auth=McpAuthConfig.from_mapping(payload.get("auth")), timeout_seconds=float(payload.get("timeout_seconds", 60)), + sse_read_timeout_seconds=float(payload.get("sse_read_timeout_seconds", 300)), tool_names={str(key): str(value) for key, value in tool_names.items()}, ) @@ -92,6 +150,11 @@ class SessionMcpToolClient: result = self.session.call_tool(tool_name, arguments) return normalize_mcp_sdk_result(result) + def list_tools(self) -> list[str]: + """从 SDK session 获取 tool 名称列表。""" + result = self.session.list_tools() + return normalize_mcp_tool_list(result) + class StdioMcpToolClient: """通过 MCP Python SDK 启动 stdio server 并调用 tool。""" @@ -142,6 +205,160 @@ class StdioMcpToolClient: return anyio.run(call_once) + def list_tools(self) -> list[str]: + """创建一次 MCP stdio session,读取 server 暴露的 tool 列表。""" + try: + import anyio + from mcp import ClientSession + from mcp.client.stdio import StdioServerParameters, stdio_client + except ImportError as exc: # pragma: no cover - 依赖安装状态 + raise RuntimeError("未安装 MCP Python SDK,请安装项目的 mcp 可选依赖") from exc + + async def list_once() -> list[str]: + server = StdioServerParameters( + command=self.command, + args=self.args, + env=self.env, + cwd=self.cwd, + ) + async with stdio_client(server) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + result = await session.list_tools() + return normalize_mcp_tool_list(result) + + return anyio.run(list_once) + + +class OAuthTokenProvider: + """按 HOME 相同的 client_credentials 方式获取 MCP 鉴权 token。""" + + def __init__(self, config: McpAuthConfig, *, timeout_seconds: float = 30) -> None: + """保存鉴权配置和 token 缓存。""" + if not config.token_url: + raise ValueError("MCP auth 必须提供 token_url 或 auth.base_url") + if not config.client_id or not config.client_secret: + raise ValueError("MCP auth 必须提供独立的 client_id 和 client_secret") + self.config = config + self.timeout_seconds = timeout_seconds + self._token = "" + self._expires_at = 0.0 + + def authorization_headers(self) -> dict[str, str]: + """返回带 token 的请求头。""" + token = self.get_token() + prefix = self.config.header_prefix.strip() + value = f"{prefix} {token}" if prefix else token + return {self.config.header_name: value} + + def get_token(self) -> str: + """获取可用 token,未过期时复用缓存。""" + now = time.time() + if self._token and now < self._expires_at: + return self._token + payload = { + "grant_type": self.config.grant_type, + "client_id": self.config.client_id, + "client_secret": self.config.client_secret, + **self.config.extra_form, + } + data = urllib.parse.urlencode(payload).encode("utf-8") + request = urllib.request.Request( + self.config.token_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + method="POST", + ) + with urllib.request.urlopen(request, timeout=self.timeout_seconds) as response: + raw = response.read().decode("utf-8") + result = json.loads(raw) + token = str(result.get(self.config.token_field, "")) + if not token: + raise RuntimeError("MCP auth token 响应缺少 access_token") + expires_in = _safe_float(result.get(self.config.expires_in_field), 3600) + self._token = token + self._expires_at = now + max(expires_in - 60, 1) + return token + + +class HttpMcpToolClient: + """通过 MCP HTTP/SSE server URL 调用 tool。""" + + def __init__( + self, + *, + url: str, + transport: str = "streamable_http", + headers: dict[str, str] | None = None, + auth_provider: OAuthTokenProvider | None = None, + timeout_seconds: float = 60, + sse_read_timeout_seconds: float = 300, + ) -> None: + """保存 HTTP/SSE MCP server 连接参数。""" + if not url: + raise ValueError("HTTP/SSE MCP 配置必须提供 server_url") + if transport not in ("streamable_http", "sse"): + raise ValueError(f"不支持的 HTTP MCP transport: {transport}") + self.url = url + self.transport = transport + self.headers = dict(headers or {}) + self.auth_provider = auth_provider + self.timeout_seconds = timeout_seconds + self.sse_read_timeout_seconds = sse_read_timeout_seconds + + def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any: + """连接 MCP server,调用 tool 后关闭 session。""" + return self._run_session(lambda session: session.call_tool(tool_name, arguments)) + + def list_tools(self) -> list[str]: + """连接 MCP server,读取 server 暴露的 tool 名称。""" + return self._run_session(lambda session: session.list_tools(), normalize_tools=True) + + def _build_headers(self) -> dict[str, str]: + """合并静态 headers 和动态鉴权 token。""" + headers = dict(self.headers) + if self.auth_provider is not None: + headers.update(self.auth_provider.authorization_headers()) + return headers + + def _run_session(self, operation: Callable[[Any], Any], *, normalize_tools: bool = False) -> Any: + """创建一次 HTTP/SSE MCP session 并执行指定操作。""" + try: + import anyio + from mcp import ClientSession + from mcp.client.sse import sse_client + from mcp.client.streamable_http import streamablehttp_client + except ImportError as exc: # pragma: no cover - 依赖安装状态 + raise RuntimeError("未安装 MCP Python SDK,请安装项目的 mcp 可选依赖") from exc + + async def call_once() -> Any: + headers = self._build_headers() + if self.transport == "sse": + async with sse_client( + self.url, + headers=headers, + timeout=self.timeout_seconds, + sse_read_timeout=self.sse_read_timeout_seconds, + ) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + result = await operation(session) + return normalize_mcp_tool_list(result) if normalize_tools else normalize_mcp_sdk_result(result) + + async with streamablehttp_client( + self.url, + headers=headers, + timeout=timedelta(seconds=self.timeout_seconds), + sse_read_timeout=timedelta(seconds=self.sse_read_timeout_seconds), + ) as streams: + read_stream, write_stream = streams[0], streams[1] + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + result = await operation(session) + return normalize_mcp_tool_list(result) if normalize_tools else normalize_mcp_sdk_result(result) + + return anyio.run(call_once) + def normalize_mcp_sdk_result(result: Any) -> Any: """把常见 MCP SDK 返回结构归一化成 dict/list/string。""" @@ -165,3 +382,30 @@ def normalize_mcp_sdk_result(result: Any) -> Any: return joined return result + + +def normalize_mcp_tool_list(result: Any) -> list[str]: + """把 MCP list_tools 返回值归一化为 tool name 列表。""" + tools = getattr(result, "tools", None) + if tools is None and isinstance(result, dict): + tools = result.get("tools") + names: list[str] = [] + for item in tools or []: + if isinstance(item, str): + names.append(item) + continue + if isinstance(item, dict) and item.get("name"): + names.append(str(item["name"])) + continue + name = getattr(item, "name", None) + if name: + names.append(str(name)) + return names + + +def _safe_float(value: Any, default: float) -> float: + """把值安全转换为 float。""" + try: + return float(value) + except (TypeError, ValueError): + return default diff --git a/pam_deploy_graph/mcp_factory.py b/pam_deploy_graph/mcp_factory.py index 9e3c7ce..dc29ba6 100644 --- a/pam_deploy_graph/mcp_factory.py +++ b/pam_deploy_graph/mcp_factory.py @@ -4,7 +4,13 @@ from __future__ import annotations from pathlib import Path -from .mcp_client import McpClientConfig, StdioMcpToolClient, load_mcp_client_config +from .mcp_client import ( + HttpMcpToolClient, + McpClientConfig, + OAuthTokenProvider, + StdioMcpToolClient, + load_mcp_client_config, +) from .mcp_runner import McpActionRunner @@ -25,4 +31,18 @@ def build_mcp_client(config: McpClientConfig): cwd=config.cwd or None, timeout_seconds=config.timeout_seconds, ) + if config.transport in ("streamable_http", "sse"): + auth_provider = ( + OAuthTokenProvider(config.auth, timeout_seconds=config.timeout_seconds) + if config.auth is not None + else None + ) + return HttpMcpToolClient( + url=config.server_url, + transport=config.transport, + headers=config.headers, + auth_provider=auth_provider, + timeout_seconds=config.timeout_seconds, + sse_read_timeout_seconds=config.sse_read_timeout_seconds, + ) raise ValueError(f"不支持的 MCP transport: {config.transport}") diff --git a/pam_deploy_graph/mcp_runner.py b/pam_deploy_graph/mcp_runner.py index a7058eb..65df751 100644 --- a/pam_deploy_graph/mcp_runner.py +++ b/pam_deploy_graph/mcp_runner.py @@ -15,6 +15,10 @@ class McpToolClient(Protocol): """调用指定 MCP tool,并返回工具原始输出。""" ... + def list_tools(self) -> list[str]: + """返回 MCP server 暴露的 tool 名称列表。""" + ... + DEFAULT_NODE_MCP_TOOLS = { "get-online-ips": "pam_get_online_ips", @@ -40,7 +44,8 @@ class McpActionRunner: ) -> None: """保存 MCP client 和 action 到 tool name 的映射。""" self.client = client - self.tool_names = tool_names or DEFAULT_NODE_MCP_TOOLS.copy() + self.tool_names = tool_names or {} + self._discovered_tools: list[str] | None = None def run( self, @@ -55,9 +60,7 @@ class McpActionRunner: """执行一个 PAM_NODE action,并归一化为 ActionResult。""" if self.client is None: raise RuntimeError("尚未配置 MCP client") - tool_name = self.tool_names.get(action) - if not tool_name: - raise ValueError(f"action 未映射 MCP tool: {action}") + tool_name = self._resolve_tool_name(action) arguments = self._build_arguments( action, params=params, @@ -71,6 +74,41 @@ class McpActionRunner: return parse_mcp_result(action, {}, ok=False, tool_name=tool_name, error=str(exc)) return parse_mcp_result(action, payload, ok=True, tool_name=tool_name) + def _resolve_tool_name(self, action: str) -> str: + """根据显式映射、server tools 自动发现和默认约定解析 tool name。""" + explicit = self.tool_names.get(action) + if explicit: + return explicit + + discovered = self._list_discovered_tools() + if discovered: + candidates = _tool_name_candidates(action) + by_lower = {name.lower(): name for name in discovered} + for candidate in candidates: + matched = by_lower.get(candidate.lower()) + if matched: + return matched + available = ", ".join(discovered) + raise ValueError(f"MCP server 未发现 action 对应 tool: {action}; 已发现: {available}") + + fallback = DEFAULT_NODE_MCP_TOOLS.get(action) + if fallback: + return fallback + raise ValueError(f"action 未映射 MCP tool: {action}") + + def _list_discovered_tools(self) -> list[str]: + """读取并缓存 MCP server 暴露的 tool 名称。""" + if self._discovered_tools is not None: + return self._discovered_tools + if self.client is None or not hasattr(self.client, "list_tools"): + self._discovered_tools = [] + return self._discovered_tools + try: + self._discovered_tools = list(self.client.list_tools()) + except Exception: + self._discovered_tools = [] + return self._discovered_tools + def _build_arguments( self, action: str, @@ -98,3 +136,16 @@ class McpActionRunner: if action == "rollback-ip": arguments["stopFirst"] = stop_first return {key: value for key, value in arguments.items() if value not in (None, "")} + + +def _tool_name_candidates(action: str) -> list[str]: + """生成 action 自动匹配 MCP tool 的候选名称。""" + snake = action.replace("-", "_") + return [ + action, + snake, + f"pam_{snake}", + f"pam_node_{snake}", + f"pam.node.{snake}", + f"pam-node.{action}", + ] diff --git a/tests/test_mcp_client.py b/tests/test_mcp_client.py index 3b64a20..bb88cc5 100644 --- a/tests/test_mcp_client.py +++ b/tests/test_mcp_client.py @@ -1,11 +1,15 @@ from pam_deploy_graph.mcp_client import ( FunctionMcpToolClient, + HttpMcpToolClient, load_mcp_client_config, + normalize_mcp_tool_list, + OAuthTokenProvider, SessionMcpToolClient, StdioMcpToolClient, normalize_mcp_sdk_result, ) from pam_deploy_graph.mcp_factory import build_mcp_runner_from_config +from pam_deploy_graph.mcp_runner import McpActionRunner def test_function_mcp_client_wraps_callable(): @@ -30,6 +34,16 @@ def test_session_mcp_client_normalizes_text_json_content(): assert client.call_tool("tool", {}) == {"ok": True} +def test_normalize_mcp_tool_list(): + result = type( + "Tools", + (), + {"tools": [type("Tool", (), {"name": "pam_get_online_ips"})(), {"name": "verify-ip"}]}, + )() + + assert normalize_mcp_tool_list(result) == ["pam_get_online_ips", "verify-ip"] + + def test_load_mcp_client_config(tmp_path): path = tmp_path / "mcp.json" path.write_text( @@ -54,6 +68,33 @@ def test_load_mcp_client_config(tmp_path): assert config.tool_names["get-online-ips"] == "custom_ips" +def test_load_http_mcp_client_config_with_auth(tmp_path): + path = tmp_path / "mcp.json" + path.write_text( + """ + { + "server_name": "pam-node-prod", + "transport": "streamable_http", + "server_url": "https://pam-node.example.com/mcp", + "auth": { + "token_url": "https://pam-node-auth.example.com/oauth/token", + "client_id": "mcp-client", + "client_secret": "mcp-secret" + } + } + """, + encoding="utf-8", + ) + + config = load_mcp_client_config(path) + + assert config.transport == "streamable_http" + assert config.server_url == "https://pam-node.example.com/mcp" + assert config.auth is not None + assert config.auth.client_id == "mcp-client" + assert config.auth.client_secret == "mcp-secret" + + def test_build_mcp_runner_from_stdio_config(tmp_path): path = tmp_path / "mcp.json" path.write_text( @@ -65,3 +106,92 @@ def test_build_mcp_runner_from_stdio_config(tmp_path): assert isinstance(runner.client, StdioMcpToolClient) assert runner.tool_names["verify-ip"] == "custom_verify" + + +def test_build_mcp_runner_from_http_config(tmp_path): + path = tmp_path / "mcp.json" + path.write_text( + """ + { + "transport": "sse", + "server_url": "https://pam-node.example.com/sse", + "auth": { + "token_url": "https://pam-node-auth.example.com/oauth/token", + "client_id": "mcp-client", + "client_secret": "mcp-secret" + } + } + """, + encoding="utf-8", + ) + + runner = build_mcp_runner_from_config(path) + + assert isinstance(runner.client, HttpMcpToolClient) + assert runner.client.transport == "sse" + + +def test_oauth_token_provider_uses_home_style_form(monkeypatch, tmp_path): + config = load_mcp_client_config( + _write_json_config( + tmp_path, + { + "transport": "streamable_http", + "server_url": "https://pam-node.example.com/mcp", + "auth": { + "token_url": "https://pam-node-auth.example.com/oauth/token", + "client_id": "mcp-client", + "client_secret": "mcp-secret", + }, + }, + ) + ) + assert config.auth is not None + calls = [] + + class Response: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def read(self): + return b'{"access_token": "token-1", "expires_in": 3600}' + + def fake_urlopen(request, timeout): + calls.append((request, timeout)) + return Response() + + monkeypatch.setattr("urllib.request.urlopen", fake_urlopen) + provider = OAuthTokenProvider(config.auth) + + headers = provider.authorization_headers() + + assert headers == {"Authorization": "Bearer token-1"} + body = calls[0][0].data.decode("utf-8") + assert "grant_type=client_credentials" in body + assert "client_id=mcp-client" in body + assert "client_secret=mcp-secret" in body + + +def test_mcp_runner_auto_discovers_tool_name(): + class Client: + def list_tools(self): + return ["pam_get_online_ips"] + + def call_tool(self, tool_name, arguments): + return {"IP": ["192.168.1.10"], "COUNT": 1, "TOOL": tool_name} + + runner = McpActionRunner(client=Client()) + + result = runner.run("get-online-ips", params={}) + + assert result.ok is True + assert result.tool_name == "pam_get_online_ips" + + +def _write_json_config(tmpdir, payload): + path = tmpdir / "mcp.json" + path.write_text(__import__("json").dumps(payload), encoding="utf-8") + return str(path)