diff --git a/README.md b/README.md index b14994b..2d06bd2 100644 --- a/README.md +++ b/README.md @@ -26,9 +26,10 @@ pam_deploy_graph/ config_writer.py # 生成脚本 action 所需 config 文件 checkpoint_store.py # 业务 checkpoint JSON 读写 params_loader.py # 读取 JSON 或 config.txt 风格参数文件 - llm/ # LLM structured output 接口、规则 fallback 和 guardrails + llm/ # LLM structured output 接口、真实 HTTP client、提示词、规则 fallback 和 guardrails graph.py # LangGraph StateGraph 集成入口 - mcp_client.py # MCP session/callable adapter + mcp_client.py # MCP session/callable adapter 与 client 配置读取 + interactive.py # 常驻式 CLI 对话框,会话命令、确认和续跑 cli.py # CLI 入口 tests/ @@ -37,6 +38,7 @@ tests/ test_params_loader.py test_script_runner.py test_skill_policy.py + test_interactive_cli.py ``` ## 当前进度 @@ -53,24 +55,111 @@ tests/ - 实现 fake 全局流程和完整部署流程,便于不触碰真实环境地验证 Agent 路由。 - 实现逐 IP 处理骨架:升级、轮询、启动、校验、日志下载。 - 实现单 IP 失败后的待回滚确认状态,不自动执行回滚。 +- 实现人工确认入口:`confirm --decision approve|reject` 只处理待确认回滚。 +- 实现 checkpoint 自动保存和 `resume` 续跑:全局步骤、成功 IP、单 IP 已完成 action 会跳过。 - 实现 LLM structured output 骨架:意图识别、参数抽取、部署计划生成。 +- 实现 OpenAI-compatible 真实 LLM client,支持 `base_url` / `api_key` / `model` 配置。 +- 固化真实 LLM 提示词:意图识别、参数抽取、部署计划生成均要求 JSON structured output。 - 增加规则 fallback `RuleBasedLlmClient`,用于本地开发和测试。 - 增加 LLM 输出 guardrails,禁止计划中出现可执行脚本命令和非法 action。 - 引入 `langgraph` 依赖,并提供 `build_langgraph()` 图工厂。 -- 引入 MCP client adapter,可包装 SDK session 或普通 callable。 +- 引入 MCP client adapter,可包装 SDK session 或普通 callable,并提供 JSON client 配置读取。 - 本地已安装 `langgraph` 和 `mcp`,并完成 LangGraph fake 全局流程 smoke。 - CLI `analyze` 输出已做敏感字段脱敏。 -- 添加基础测试,当前本地结果为 `22 passed, 1 skipped`。 +- 增加 `chat` 常驻式 CLI 对话框,支持自然语言分析、参数设置、执行确认、回滚确认、状态查看和续跑。 +- 添加基础测试,当前本地结果为 `31 passed, 1 skipped`。 未完成: -- 尚未接入真实 MCP client。 -- 尚未接入真实 LLM 服务,目前使用规则 fallback。 -- 尚未实现人工确认 interrupt、断点续跑完整图流程和单 IP 子流程。 +- 尚未接入真实 MCP session;当前已把 client adapter、tool 映射和配置格式准备好。 - 尚未执行真实脚本 action 或真实 PAM_NODE MCP 调用。 +## LLM 配置 + +默认不配置 LLM 时,`analyze` 使用本地规则 fallback。配置真实 LLM 后,会走 OpenAI-compatible `/chat/completions`: + +```powershell +$env:PAM_LLM_BASE_URL="https://your-llm.example.com/v1" +$env:PAM_LLM_API_KEY="your-api-key" +$env:PAM_LLM_MODEL="your-model-name" + +python -m pam_deploy_graph.cli analyze --config doc_scripts/config.txt.example --text "请用 MCP 预演部署 HET PAM Node 版本 2.0.5,不要动环境" +``` + +也可以直接用 CLI 参数覆盖环境变量: + +```bash +python -m pam_deploy_graph.cli analyze \ + --config doc_scripts/config.txt.example \ + --text "请用 MCP 预演部署 HET PAM Node 版本 2.0.5,不要动环境" \ + --llm-base-url https://your-llm.example.com/v1 \ + --llm-api-key your-api-key \ + --llm-model your-model-name +``` + +真实 LLM 调用位置在 `pam_deploy_graph/llm/openai_compatible.py`,提示词在 `pam_deploy_graph/llm/prompts.py`。发送给 LLM 的 `base_params` 会脱敏,`CLIENT_SECRET` 不会进入 prompt;本地生成计划后仍会执行 guardrails 校验。 + +## MCP Client 配置 + +真实 MCP session 由外部接入,Agent 只依赖同步 `call_tool(name, arguments)` 接口。接入方式: + +```python +from pam_deploy_graph.agent import PamDeployAgent +from pam_deploy_graph.mcp_client import SessionMcpToolClient, load_mcp_client_config +from pam_deploy_graph.mcp_runner import McpActionRunner + +config = load_mcp_client_config("mcp_client.json") +client = SessionMcpToolClient(session) # session 是你接入真实 MCP 后得到的 SDK session +runner = McpActionRunner(client=client, tool_names=config.tool_names or None) +agent = PamDeployAgent(mcp_runner=runner) +``` + +`mcp_client.json` 示例: + +```json +{ + "server_name": "pam-node-prod", + "tool_names": { + "get-online-ips": "pam_get_online_ips", + "create-download-task": "pam_create_download_task", + "poll-download-progress": "pam_poll_download_progress", + "upgrade-ip": "pam_upgrade_ip", + "poll-upgrade-progress": "pam_poll_upgrade_progress", + "start-ip": "pam_start_ip", + "stop-ip": "pam_stop_ip", + "verify-ip": "pam_verify_ip", + "download-log": "pam_download_log", + "rollback-ip": "pam_rollback_ip" + } +} +``` + +如果不传 `tool_names`,`McpActionRunner` 会使用上面的默认 action -> tool 映射。 + ## 使用方式 +交互式对话框: + +```bash +python -m pam_deploy_graph.cli chat --config doc_scripts/config.txt.example --strategy fake --checkpoint runtime/checkpoints/chat-demo.json +``` + +启动后可输入自然语言需求或会话命令: + +```text +PAM> 请用 MCP 预演部署 HET PAM Node 版本 2.0.5,不要动环境 +PAM> preview +PAM> set VERSION_NUMBER=2.0.6 +PAM> run +即将执行真实 action;确认执行请输入 yes: yes +PAM> status +PAM> approve +PAM> resume +PAM> exit +``` + +`chat` 默认仍要求在会话内显式输入 `run` 和 `yes` 才会执行 action;如果某个 IP 失败,会提示输入 `approve` 或 `reject [原因]`。`chat` 也支持 `--llm-base-url` / `--llm-api-key` / `--llm-model`,配置方式和 `analyze` 一致。 + 预演: ```bash @@ -86,9 +175,24 @@ python -m pam_deploy_graph.cli run-global --config doc_scripts/config.txt.exampl fake 完整部署流程验证: ```bash -python -m pam_deploy_graph.cli run-deploy --config doc_scripts/config.txt.example --strategy fake --confirm +python -m pam_deploy_graph.cli run-deploy --config doc_scripts/config.txt.example --strategy fake --checkpoint runtime/checkpoints/demo.json --confirm ``` +如果某个 IP 失败并进入待回滚确认,先查看输出中的 `confirmation`,再人工决定: + +```bash +python -m pam_deploy_graph.cli confirm --checkpoint runtime/checkpoints/demo.json --decision approve --confirm +python -m pam_deploy_graph.cli resume --checkpoint runtime/checkpoints/demo.json --confirm +``` + +拒绝回滚: + +```bash +python -m pam_deploy_graph.cli confirm --checkpoint runtime/checkpoints/demo.json --decision reject --note "人工决定暂不回滚" --confirm +``` + +checkpoint 用于断点续跑,会保存完整运行状态和参数。为了支持真实续跑,Agent 写入 checkpoint 时不会脱敏参数;请把 checkpoint 放在受控目录中。如果不传 `--checkpoint`,流程仍可运行,但不能跨进程 `resume`。 + 结构化理解和计划生成: ```bash @@ -104,9 +208,6 @@ pytest -q ## 下一步建议 1. 接入真实 PAM_NODE MCP session,并用 `SessionMcpToolClient` 包装。 -2. 用 fake runner 补齐完整部署主流程和单 IP 子流程测试。 -3. 引入 LangGraph,把当前 Agent 节点接入 `StateGraph`。 -4. 增加人工确认节点:参数确认、IP 范围确认、回滚确认。 -5. 接入真实 LLM 服务,实现 `RuleBasedLlmClient` 同协议替换。 -6. 完善 checkpoint 恢复:全局步骤跳过、成功 IP 跳过、pending rollback 恢复。 -7. 在测试环境中做 smoke:HOME 脚本 `get-token/get-node-url` + NODE MCP `get-online-ips`。 +2. 在测试环境中做 smoke:HOME 脚本 `get-token/get-node-url` + NODE MCP `get-online-ips`。 +3. 把当前 checkpoint/confirmation 语义继续接入 LangGraph interrupt/checkpointer。 +4. 继续细化参数确认、IP 范围确认的交互式 UI 或上层编排。 diff --git a/pam_deploy_graph/agent.py b/pam_deploy_graph/agent.py index b89912a..7bbb758 100644 --- a/pam_deploy_graph/agent.py +++ b/pam_deploy_graph/agent.py @@ -11,10 +11,11 @@ from pathlib import Path from typing import Any from .action_router import ActionRouter, build_action_backends +from .checkpoint_store import save_checkpoint from .config_writer import write_config from .constants import DEFAULT_PARAMS, GLOBAL_ACTION_SEQUENCE, IP_ACTION_SEQUENCE, REQUIRED_PARAMS from .fake_runner import FakeActionRunner -from .llm import RuleBasedLlmClient, validate_deploy_plan, validate_intent_result +from .llm import LlmClient, RuleBasedLlmClient, validate_deploy_plan, validate_intent_result from .mcp_runner import McpActionRunner from .models import AgentState, ExecutionStrategy, LlmDeployPlan, LlmIntentResult, LlmParamResult from .script_runner import ScriptActionRunner, select_script_entry @@ -29,7 +30,7 @@ class PamDeployAgent: script_base_dir: str | Path = "doc_scripts", mcp_runner: McpActionRunner | None = None, fake_runner: FakeActionRunner | None = None, - llm_client: RuleBasedLlmClient | None = None, + llm_client: LlmClient | None = None, ) -> None: self.skill_policy = load_skill_policy(skill_path) self.script_base_dir = Path(script_base_dir) @@ -98,6 +99,7 @@ class PamDeployAgent: script_entry: str | None = None, config_path: str | None = None, trace_file_path: str | None = None, + checkpoint_path: str | None = None, target_ips: list[str] | None = None, ) -> AgentState: normalized = self.normalize_params(params) @@ -116,6 +118,7 @@ class PamDeployAgent: script_base_dir=str(self.script_base_dir), config_path=actual_config_path, trace_file_path=actual_trace_path, + checkpoint_path=checkpoint_path or "", target_ips=target_ips or [], ) @@ -151,6 +154,8 @@ class PamDeployAgent: def run_global_flow(self, state: AgentState) -> AgentState: for action in GLOBAL_ACTION_SEQUENCE: + if action in state.completed_global_steps: + continue kwargs: dict[str, Any] = {} if action == "publish-version": kwargs["hash_code"] = state.hash_code @@ -165,33 +170,54 @@ class PamDeployAgent: ) if not result.ok: state.last_failed_step = action + self._save_checkpoint(state) raise RuntimeError(f"{action} failed: {result.error_summary}") self._apply_result(state, action, result.values) state.completed_global_steps.append(action) state.last_success_step = action + self._save_checkpoint(state) return state def run_deploy_flow(self, state: AgentState) -> AgentState: + if state.pending_confirmation: + self._save_checkpoint(state) + return state self.run_global_flow(state) self.run_ip_flow(state) return state def run_ip_flow(self, state: AgentState) -> AgentState: + if state.pending_confirmation: + self._save_checkpoint(state) + return state self._resolve_target_ips(state) for ip in state.target_ips: - state.events.append({"type": "IP_START", "ip": ip, "message": "start"}) - ip_state = { - "status": "RUNNING", - "completed_steps": [], - "failed_stage": "", - "failure_reason": "", - "rollback_status": "ROLLBACK_NOT_RUN", - "rollback_stop_first": False, - "log_file": "", - } - state.ip_states[ip] = ip_state + ip_state = state.ip_states.get(ip) + if ip_state and ip_state.get("status") == "SUCCESS": + continue + if ip_state and ip_state.get("status") == "FAILED": + if ip_state.get("rollback_status") == "PENDING_AGENT_CONFIRMATION": + state.pending_confirmation = f"rollback-ip:{ip}" + self._save_checkpoint(state) + return state + continue + if not ip_state: + state.events.append({"type": "IP_START", "ip": ip, "message": "start"}) + ip_state = { + "status": "RUNNING", + "completed_steps": [], + "failed_stage": "", + "failure_reason": "", + "rollback_status": "ROLLBACK_NOT_RUN", + "rollback_stop_first": False, + "log_file": "", + } + state.ip_states[ip] = ip_state for action in IP_ACTION_SEQUENCE: + completed_steps = ip_state.setdefault("completed_steps", []) + if action in completed_steps: + continue result = self.router.run_action(state, action, ip=ip) failed = (not result.ok) or self._business_failed(action, result.values) state.events.append( @@ -209,13 +235,85 @@ class PamDeployAgent: if action != "download-log": self._download_log_best_effort(state, ip) state.pending_confirmation = f"rollback-ip:{ip}" + self._save_checkpoint(state) return state self._apply_ip_result(ip_state, action, result.values) - ip_state["completed_steps"].append(action) + completed_steps.append(action) + self._save_checkpoint(state) ip_state["status"] = "SUCCESS" state.events.append({"type": "IP_DONE", "ip": ip, "message": "success"}) + self._save_checkpoint(state) + return state + + def build_confirmation_request(self, state: AgentState) -> dict[str, Any]: + if not state.pending_confirmation: + return {} + kind, _, value = state.pending_confirmation.partition(":") + if kind == "rollback-ip": + ip_state = state.ip_states.get(value, {}) + return { + "type": "rollback-ip", + "ip": value, + "failed_stage": ip_state.get("failed_stage", ""), + "failure_reason": ip_state.get("failure_reason", ""), + "rollback_stop_first": bool(ip_state.get("rollback_stop_first", False)), + "allowed_decisions": ["approve", "reject"], + } + return { + "type": kind, + "value": value, + "allowed_decisions": ["approve", "reject"], + } + + def confirm_pending(self, state: AgentState, *, approved: bool, operator_note: str = "") -> AgentState: + request = self.build_confirmation_request(state) + if not request: + raise ValueError("No pending confirmation") + if request["type"] != "rollback-ip": + raise ValueError(f"Unsupported confirmation type: {request['type']}") + + ip = request["ip"] + ip_state = state.ip_states[ip] + if not approved: + ip_state["rollback_status"] = "REJECTED_BY_OPERATOR" + state.events.append( + { + "type": "CONFIRMATION_REJECTED", + "stage": "rollback-ip", + "ip": ip, + "message": operator_note or "rollback rejected by operator", + } + ) + state.pending_confirmation = "" + self._save_checkpoint(state) + return state + + result = self.router.run_action( + state, + "rollback-ip", + ip=ip, + stop_first=bool(ip_state.get("rollback_stop_first", False)), + ) + ip_state["rollback_status"] = "ROLLBACK_DONE" if result.ok else "ROLLBACK_FAILED" + state.events.append( + { + "type": "ACTION_DONE" if result.ok else "ACTION_FAIL", + "stage": "rollback-ip", + "backend": result.backend, + "ip": ip, + "message": result.error_summary or result.values.get("MESSAGE", "ok"), + } + ) + if result.ok: + state.pending_confirmation = "" + state.last_success_step = "rollback-ip" + state.last_failed_step = "" + else: + state.pending_confirmation = f"rollback-ip:{ip}" + state.last_failed_step = "rollback-ip" + self._save_checkpoint(state) return state def _apply_result(self, state: AgentState, action: str, values: dict[str, Any]) -> None: @@ -308,6 +406,10 @@ class PamDeployAgent: } ) + def _save_checkpoint(self, state: AgentState) -> None: + if state.checkpoint_path: + save_checkpoint(state, state.checkpoint_path, redact=False) + def render_report(self, state: AgentState) -> str: success = sum(1 for item in state.ip_states.values() if item.get("status") == "SUCCESS") failed = sum(1 for item in state.ip_states.values() if item.get("status") == "FAILED") diff --git a/pam_deploy_graph/checkpoint_store.py b/pam_deploy_graph/checkpoint_store.py index aa540a0..a010c79 100644 --- a/pam_deploy_graph/checkpoint_store.py +++ b/pam_deploy_graph/checkpoint_store.py @@ -3,11 +3,12 @@ from __future__ import annotations import json -from dataclasses import asdict, is_dataclass +from dataclasses import asdict, fields, is_dataclass from pathlib import Path from typing import Any from .constants import SENSITIVE_KEYS +from .models import AgentState def redact_mapping(value: Any) -> Any: @@ -24,12 +25,14 @@ def redact_mapping(value: Any) -> Any: return value -def save_checkpoint(state: Any, path: str | Path) -> Path: +def save_checkpoint(state: Any, path: str | Path, *, redact: bool = True) -> Path: checkpoint_path = Path(path) checkpoint_path.parent.mkdir(parents=True, exist_ok=True) payload = asdict(state) if is_dataclass(state) else state + if redact: + payload = redact_mapping(payload) checkpoint_path.write_text( - json.dumps(redact_mapping(payload), ensure_ascii=False, indent=2), + json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8", ) return checkpoint_path @@ -38,3 +41,12 @@ def save_checkpoint(state: Any, path: str | Path) -> Path: def load_checkpoint(path: str | Path) -> dict[str, Any]: return json.loads(Path(path).read_text(encoding="utf-8")) + +def agent_state_from_mapping(payload: dict[str, Any]) -> AgentState: + allowed_fields = {item.name for item in fields(AgentState)} + state_payload = {key: value for key, value in payload.items() if key in allowed_fields} + return AgentState(**state_payload) + + +def load_agent_state(path: str | Path) -> AgentState: + return agent_state_from_mapping(load_checkpoint(path)) diff --git a/pam_deploy_graph/cli.py b/pam_deploy_graph/cli.py index 2c807f9..9ea22a2 100644 --- a/pam_deploy_graph/cli.py +++ b/pam_deploy_graph/cli.py @@ -7,10 +7,30 @@ import json from dataclasses import asdict from .agent import PamDeployAgent -from .checkpoint_store import redact_mapping +from .checkpoint_store import load_agent_state, redact_mapping +from .interactive import run_interactive_chat +from .llm import build_llm_client from .params_loader import load_params_file +def add_llm_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--llm-base-url") + parser.add_argument("--llm-api-key") + parser.add_argument("--llm-model") + + +def require_confirm(args: argparse.Namespace) -> None: + if not getattr(args, "confirm", False): + raise SystemExit("Refusing to execute actions without --confirm.") + + +def print_pause_payload(agent: PamDeployAgent, state) -> None: + if state.pending_confirmation: + print(json.dumps({"confirmation": agent.build_confirmation_request(state)}, ensure_ascii=False, indent=2)) + if state.checkpoint_path: + print(json.dumps({"checkpoint": state.checkpoint_path}, ensure_ascii=False, indent=2)) + + def main() -> None: parser = argparse.ArgumentParser(prog="pam-deploy-agent") sub = parser.add_subparsers(dest="command", required=True) @@ -22,21 +42,48 @@ def main() -> None: analyze = sub.add_parser("analyze") analyze.add_argument("--text", required=True) analyze.add_argument("--config") + add_llm_args(analyze) + + chat = sub.add_parser("chat") + chat.add_argument("--config", required=True) + chat.add_argument("--strategy", default="fake", choices=["hybrid_node_mcp", "script_only", "fake"]) + chat.add_argument("--target-ip", action="append", default=[]) + chat.add_argument("--checkpoint") + add_llm_args(chat) run = sub.add_parser("run-global") run.add_argument("--config", required=True) run.add_argument("--strategy", default="fake", choices=["hybrid_node_mcp", "script_only", "fake"]) + run.add_argument("--checkpoint") run.add_argument("--confirm", action="store_true") deploy = sub.add_parser("run-deploy") deploy.add_argument("--config", required=True) deploy.add_argument("--strategy", default="fake", choices=["hybrid_node_mcp", "script_only", "fake"]) deploy.add_argument("--target-ip", action="append", default=[]) + deploy.add_argument("--checkpoint") deploy.add_argument("--confirm", action="store_true") + resume = sub.add_parser("resume") + resume.add_argument("--checkpoint", required=True) + resume.add_argument("--confirm", action="store_true") + + confirm = sub.add_parser("confirm") + confirm.add_argument("--checkpoint", required=True) + confirm.add_argument("--decision", required=True, choices=["approve", "reject"]) + confirm.add_argument("--note", default="") + confirm.add_argument("--confirm", action="store_true") + args = parser.parse_args() params = load_params_file(args.config) if getattr(args, "config", None) else {} - agent = PamDeployAgent() + llm_client = None + if args.command in ("analyze", "chat"): + llm_client = build_llm_client( + base_url=args.llm_base_url, + api_key=args.llm_api_key, + model=args.llm_model, + ) + agent = PamDeployAgent(llm_client=llm_client) if args.command == "analyze": result = agent.analyze_request(args.text, params) @@ -44,25 +91,61 @@ def main() -> None: print(json.dumps(payload, ensure_ascii=False, indent=2)) return + if args.command == "chat": + run_interactive_chat( + agent=agent, + params=params, + strategy=args.strategy, + checkpoint_path=args.checkpoint, + target_ips=args.target_ip, + ) + return + if args.command == "preview": print(agent.preview(params, args.strategy)) return - if not args.confirm: - raise SystemExit("Refusing to execute actions without --confirm.") + require_confirm(args) if args.command == "run-global": - state = agent.create_state(params=params, execution_strategy=args.strategy) + state = agent.create_state( + params=params, + execution_strategy=args.strategy, + checkpoint_path=args.checkpoint, + ) state = agent.run_global_flow(state) print(json.dumps({"events": state.events}, ensure_ascii=False, indent=2)) + print_pause_payload(agent, state) + return + + if args.command == "resume": + state = load_agent_state(args.checkpoint) + state.checkpoint_path = state.checkpoint_path or args.checkpoint + state = agent.run_deploy_flow(state) + print(agent.render_report(state)) + print_pause_payload(agent, state) + return + + if args.command == "confirm": + state = load_agent_state(args.checkpoint) + state.checkpoint_path = state.checkpoint_path or args.checkpoint + state = agent.confirm_pending( + state, + approved=args.decision == "approve", + operator_note=args.note, + ) + print(agent.render_report(state)) + print_pause_payload(agent, state) return state = agent.create_state( params=params, execution_strategy=args.strategy, + checkpoint_path=args.checkpoint, target_ips=args.target_ip, ) state = agent.run_deploy_flow(state) print(agent.render_report(state)) + print_pause_payload(agent, state) if __name__ == "__main__": diff --git a/pam_deploy_graph/interactive.py b/pam_deploy_graph/interactive.py new file mode 100644 index 0000000..847e3f5 --- /dev/null +++ b/pam_deploy_graph/interactive.py @@ -0,0 +1,290 @@ +"""Interactive CLI session for the PAM deploy agent.""" + +from __future__ import annotations + +import time +from dataclasses import asdict +from pathlib import Path +from typing import Any, Callable + +from .agent import PamDeployAgent +from .checkpoint_store import load_agent_state, redact_mapping +from .models import AgentState, ExecutionStrategy + +InputFunc = Callable[[str], str] +OutputFunc = Callable[[str], None] + +COMMAND_HELP = """可用命令: + help 显示帮助 + preview 查看当前参数和执行策略 + analyze <需求> 只做理解和计划,不执行 + set KEY=VALUE 修改当前会话参数 + run 创建部署任务并执行 + status 查看当前运行状态 + approve 确认待处理回滚 + reject [原因] 拒绝待处理回滚 + resume 从当前 checkpoint 续跑 + checkpoint 显示 checkpoint 路径 + exit 退出 + +也可以直接输入自然语言需求,Agent 会先分析并更新会话参数;执行仍需输入 run。 +""" + + +class InteractiveCliSession: + def __init__( + self, + *, + agent: PamDeployAgent, + params: dict[str, Any], + strategy: ExecutionStrategy = "hybrid_node_mcp", + checkpoint_path: str | None = None, + target_ips: list[str] | None = None, + input_func: InputFunc = input, + output_func: OutputFunc = print, + ) -> None: + self.agent = agent + self.params = dict(params) + self.strategy = strategy + self.checkpoint_path = checkpoint_path or _default_checkpoint_path() + self.target_ips = list(target_ips or []) + self.input = input_func + self.output = output_func + self.state: AgentState | None = None + self.last_analysis: dict[str, Any] | None = None + + def run(self) -> None: + self.output("PAM Deploy Agent interactive session") + self.output("输入 help 查看命令,输入 exit 退出。") + self._load_existing_checkpoint_if_any() + while True: + try: + line = self.input("PAM> ") + except EOFError: + self.output("bye") + return + if not self.handle_line(line): + return + + def handle_line(self, line: str) -> bool: + text = line.strip() + if not text: + return True + + command, _, rest = text.partition(" ") + normalized = command.lower() + + if normalized in ("exit", "quit", "q"): + self.output("bye") + return False + if normalized in ("help", "?"): + self.output(COMMAND_HELP.rstrip()) + return True + if normalized == "preview": + self.output(self.agent.preview(self.params, self.strategy)) + return True + if normalized == "analyze": + self._analyze(rest.strip()) + return True + if normalized == "set": + self._set_param(rest.strip()) + return True + if normalized in ("run", "deploy", "execute"): + self._run_deploy() + return True + if normalized == "resume": + self._resume() + return True + if normalized == "status": + self._status() + return True + if normalized == "approve": + self._confirm(approved=True, note=rest.strip()) + return True + if normalized == "reject": + self._confirm(approved=False, note=rest.strip()) + return True + if normalized == "checkpoint": + self.output(f"checkpoint: {self.checkpoint_path}") + return True + + self._analyze(text) + return True + + def _analyze(self, text: str) -> None: + if not text: + self.output("请输入要分析的自然语言需求,例如:analyze 请用 MCP 预演部署 HET。") + return + + result = self.agent.analyze_request(text, self.params) + self.last_analysis = result + param_result = result["params"] + intent_result = result["intent"] + plan = result["plan"] + self.params = dict(param_result.extracted_params) + self.strategy = _choose_strategy(intent_result.strategy_preference, self.strategy) + + user_ips = param_result.extracted_control.get("user_specified_ips") + if isinstance(user_ips, list): + self.target_ips = [str(item) for item in user_ips] + + safe_payload = redact_mapping({key: asdict(value) for key, value in result.items()}) + self.output("已生成结构化理解:") + self.output(f"- intent: {intent_result.intent}") + self.output(f"- strategy: {self.strategy}") + self.output(f"- summary: {plan.summary}") + if param_result.missing_required_params: + self.output("- missing: " + ", ".join(param_result.missing_required_params)) + if self.target_ips: + self.output("- target_ips: " + ", ".join(self.target_ips)) + self.output("执行请输 run;查看完整 JSON 可用一次性 analyze 命令。") + self.output(_format_redacted_params(safe_payload["params"]["extracted_params"])) + + def _set_param(self, assignment: str) -> None: + if "=" not in assignment: + self.output("格式:set KEY=VALUE") + return + key, value = assignment.split("=", 1) + key = key.strip() + if not key: + self.output("参数名不能为空。") + return + self.params[key] = value.strip() + self.output(f"已设置 {key}") + + def _run_deploy(self) -> None: + if self.state and self.state.pending_confirmation: + self._print_confirmation() + return + + if not self._ask_yes_no("即将执行真实 action;确认执行请输入 yes: "): + self.output("已取消执行。") + return + + self.state = self.agent.create_state( + params=self.params, + execution_strategy=self.strategy, + checkpoint_path=self.checkpoint_path, + target_ips=self.target_ips, + ) + self._execute_current_state() + + def _resume(self) -> None: + if self.state is None: + checkpoint = Path(self.checkpoint_path) + if not checkpoint.exists(): + self.output("当前没有可续跑的 checkpoint。") + return + self.state = load_agent_state(checkpoint) + self.state.checkpoint_path = self.state.checkpoint_path or str(checkpoint) + self._execute_current_state() + + def _execute_current_state(self) -> None: + if self.state is None: + self.output("当前没有运行状态。") + return + self.state = self.agent.run_deploy_flow(self.state) + self.output(self.agent.render_report(self.state)) + if self.state.pending_confirmation: + self._print_confirmation() + self.output(f"checkpoint: {self.state.checkpoint_path or self.checkpoint_path}") + + def _status(self) -> None: + if self.state is None: + self.output("当前还没有运行状态。") + self.output(f"checkpoint: {self.checkpoint_path}") + return + self.output(self.agent.render_report(self.state)) + if self.state.pending_confirmation: + self._print_confirmation() + + def _confirm(self, *, approved: bool, note: str = "") -> None: + if self.state is None: + checkpoint = Path(self.checkpoint_path) + if checkpoint.exists(): + self.state = load_agent_state(checkpoint) + self.state.checkpoint_path = self.state.checkpoint_path or str(checkpoint) + else: + self.output("当前没有待确认任务。") + return + if not self.state.pending_confirmation: + self.output("当前没有待确认任务。") + return + + self.state = self.agent.confirm_pending(self.state, approved=approved, operator_note=note) + self.output(self.agent.render_report(self.state)) + if self.state.pending_confirmation: + self._print_confirmation() + + def _print_confirmation(self) -> None: + if self.state is None: + return + request = self.agent.build_confirmation_request(self.state) + if not request: + return + self.output("需要人工确认:") + self.output(f"- type: {request.get('type')}") + if request.get("ip"): + self.output(f"- ip: {request['ip']}") + if request.get("failed_stage"): + self.output(f"- failed_stage: {request['failed_stage']}") + if request.get("failure_reason"): + self.output(f"- reason: {request['failure_reason']}") + self.output("输入 approve 执行回滚,或 reject [原因] 拒绝回滚。") + + def _ask_yes_no(self, prompt: str) -> bool: + try: + answer = self.input(prompt).strip().lower() + except EOFError: + return False + return answer in ("yes", "y") + + def _load_existing_checkpoint_if_any(self) -> None: + checkpoint = Path(self.checkpoint_path) + if not checkpoint.exists(): + return + self.state = load_agent_state(checkpoint) + self.state.checkpoint_path = self.state.checkpoint_path or str(checkpoint) + self.output(f"已加载 checkpoint: {checkpoint}") + if self.state.pending_confirmation: + self._print_confirmation() + + +def run_interactive_chat( + *, + agent: PamDeployAgent, + params: dict[str, Any], + strategy: ExecutionStrategy, + checkpoint_path: str | None = None, + target_ips: list[str] | None = None, + input_func: InputFunc = input, + output_func: OutputFunc = print, +) -> InteractiveCliSession: + session = InteractiveCliSession( + agent=agent, + params=params, + strategy=strategy, + checkpoint_path=checkpoint_path, + target_ips=target_ips, + input_func=input_func, + output_func=output_func, + ) + session.run() + return session + + +def _default_checkpoint_path() -> str: + return str(Path("runtime") / "checkpoints" / f"chat_{time.strftime('%Y%m%d_%H%M%S')}.json") + + +def _choose_strategy(preference: str, default: ExecutionStrategy) -> ExecutionStrategy: + if preference in ("hybrid_node_mcp", "script_only", "fake"): + return preference # type: ignore[return-value] + return default + + +def _format_redacted_params(params: dict[str, Any]) -> str: + lines = ["当前参数:"] + for key in sorted(params): + lines.append(f"- {key}: {params[key]}") + return "\n".join(lines) diff --git a/pam_deploy_graph/llm/__init__.py b/pam_deploy_graph/llm/__init__.py index ef479b7..7e70597 100644 --- a/pam_deploy_graph/llm/__init__.py +++ b/pam_deploy_graph/llm/__init__.py @@ -1,7 +1,16 @@ """LLM integration surfaces for PAM deploy Agent.""" +from .base import LlmClient +from .factory import build_llm_client +from .openai_compatible import OpenAICompatibleLlmClient from .rule_based import RuleBasedLlmClient from .validators import validate_deploy_plan, validate_intent_result -__all__ = ["RuleBasedLlmClient", "validate_deploy_plan", "validate_intent_result"] - +__all__ = [ + "LlmClient", + "OpenAICompatibleLlmClient", + "RuleBasedLlmClient", + "build_llm_client", + "validate_deploy_plan", + "validate_intent_result", +] diff --git a/pam_deploy_graph/llm/base.py b/pam_deploy_graph/llm/base.py new file mode 100644 index 0000000..2716115 --- /dev/null +++ b/pam_deploy_graph/llm/base.py @@ -0,0 +1,24 @@ +"""Shared LLM client protocol.""" + +from __future__ import annotations + +from typing import Any, Protocol + +from pam_deploy_graph.models import ExecutionStrategy, LlmDeployPlan, LlmIntentResult, LlmParamResult + + +class LlmClient(Protocol): + def understand_request(self, text: str) -> LlmIntentResult: + ... + + def extract_params(self, text: str, base_params: dict[str, Any] | None = None) -> LlmParamResult: + ... + + def generate_plan( + self, + *, + params: dict[str, Any], + intent: str, + strategy: ExecutionStrategy, + ) -> LlmDeployPlan: + ... diff --git a/pam_deploy_graph/llm/factory.py b/pam_deploy_graph/llm/factory.py new file mode 100644 index 0000000..7654943 --- /dev/null +++ b/pam_deploy_graph/llm/factory.py @@ -0,0 +1,39 @@ +"""LLM client factory for CLI and embedding code.""" + +from __future__ import annotations + +import os + +from .base import LlmClient +from .openai_compatible import OpenAICompatibleLlmClient +from .rule_based import RuleBasedLlmClient + + +def build_llm_client( + *, + base_url: str | None = None, + api_key: str | None = None, + model: str | None = None, +) -> LlmClient: + actual_base_url = base_url or os.getenv("PAM_LLM_BASE_URL", "") + actual_api_key = api_key or os.getenv("PAM_LLM_API_KEY", "") + actual_model = model or os.getenv("PAM_LLM_MODEL", "") + + if not actual_base_url and not actual_api_key and not actual_model: + return RuleBasedLlmClient() + + missing = [] + if not actual_base_url: + missing.append("base_url") + if not actual_api_key: + missing.append("api_key") + if not actual_model: + missing.append("model") + if missing: + raise ValueError(f"Incomplete LLM config: missing {', '.join(missing)}") + + return OpenAICompatibleLlmClient( + base_url=actual_base_url, + api_key=actual_api_key, + model=actual_model, + ) diff --git a/pam_deploy_graph/llm/openai_compatible.py b/pam_deploy_graph/llm/openai_compatible.py new file mode 100644 index 0000000..fae48de --- /dev/null +++ b/pam_deploy_graph/llm/openai_compatible.py @@ -0,0 +1,242 @@ +"""OpenAI-compatible HTTP LLM client. + +The client targets providers exposing a `/chat/completions` endpoint with +OpenAI-style request and response shapes. It intentionally uses only the Python +standard library so the runtime can stay dependency-light. +""" + +from __future__ import annotations + +import json +import urllib.request +from collections.abc import Callable +from typing import Any + +from pam_deploy_graph.constants import ( + ALLOWED_ACTIONS, + DEFAULT_PARAMS, + GLOBAL_ACTION_SEQUENCE, + IP_ACTION_SEQUENCE, + REQUIRED_PARAMS, + SENSITIVE_KEYS, +) +from pam_deploy_graph.models import ExecutionStrategy, LlmDeployPlan, LlmIntentResult, LlmParamResult + +from .prompts import INTENT_PROMPT, PARAM_PROMPT, PLAN_PROMPT, SYSTEM_PROMPT + +JsonTransport = Callable[[str, dict[str, str], dict[str, Any], float], dict[str, Any]] + + +class OpenAICompatibleLlmClient: + def __init__( + self, + *, + base_url: str, + api_key: str, + model: str, + timeout_sec: float = 30, + temperature: float = 0, + transport: JsonTransport | None = None, + ) -> None: + if not base_url: + raise ValueError("LLM base_url is required") + if not api_key: + raise ValueError("LLM api_key is required") + if not model: + raise ValueError("LLM model is required") + self.base_url = base_url.rstrip("/") + self.api_key = api_key + self.model = model + self.timeout_sec = timeout_sec + self.temperature = temperature + self.transport = transport or _default_transport + + def understand_request(self, text: str) -> LlmIntentResult: + payload = self._complete_json(INTENT_PROMPT, {"user_text": text}) + return LlmIntentResult( + intent=_string(payload, "intent", "deploy"), # type: ignore[arg-type] + mode_preference=_string(payload, "mode_preference", "未指定"), # type: ignore[arg-type] + strategy_preference=_string(payload, "strategy_preference", "未指定"), # type: ignore[arg-type] + confidence=_float(payload, "confidence", 0.0), + reasons=_string_list(payload.get("reasons")), + needs_clarification=bool(payload.get("needs_clarification", False)), + clarification_questions=_string_list(payload.get("clarification_questions")), + ) + + def extract_params(self, text: str, base_params: dict[str, Any] | None = None) -> LlmParamResult: + original_base = dict(base_params or {}) + safe_base = _redact_sensitive(original_base) + payload = self._complete_json( + PARAM_PROMPT, + { + "user_text": text, + "base_params": safe_base, + "required_params": list(REQUIRED_PARAMS), + "default_params": DEFAULT_PARAMS, + }, + ) + + extracted = _dict(payload.get("extracted_params")) + merged = original_base.copy() + for key, value in extracted.items(): + if key in SENSITIVE_KEYS and value == "***": + continue + merged[key] = value + + control = _dict(payload.get("extracted_control")) + missing = [key for key in REQUIRED_PARAMS if not merged.get(key)] + sensitive = [key for key in ("CLIENT_SECRET", "CLIENT_ID") if merged.get(key)] + llm_ambiguous = _string_list(payload.get("ambiguous_fields")) + return LlmParamResult( + extracted_params=merged, + extracted_control=control, + missing_required_params=missing, + ambiguous_fields=llm_ambiguous, + sensitive_fields_present=sensitive, + ) + + def generate_plan( + self, + *, + params: dict[str, Any], + intent: str, + strategy: ExecutionStrategy, + ) -> LlmDeployPlan: + payload = self._complete_json( + PLAN_PROMPT, + { + "params": _redact_sensitive(params), + "intent": intent, + "execution_strategy": strategy, + "allowed_actions": list(ALLOWED_ACTIONS), + "global_action_sequence": list(GLOBAL_ACTION_SEQUENCE), + "ip_action_sequence": list(IP_ACTION_SEQUENCE), + }, + ) + planned_actions = _string_list(payload.get("planned_actions")) or list(GLOBAL_ACTION_SEQUENCE) + return LlmDeployPlan( + summary=_string(payload, "summary", "PAM deployment plan"), + risk_notes=_string_list(payload.get("risk_notes")), + planned_actions=planned_actions, + requires_confirmation=bool(payload.get("requires_confirmation", True)), + execution_strategy=_string(payload, "execution_strategy", strategy), # type: ignore[arg-type] + ) + + def _complete_json(self, instruction: str, input_payload: dict[str, Any]) -> dict[str, Any]: + request_payload = { + "model": self.model, + "temperature": self.temperature, + "response_format": {"type": "json_object"}, + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT}, + { + "role": "user", + "content": instruction + + "\n\n输入 JSON:\n" + + json.dumps(input_payload, ensure_ascii=False, sort_keys=True), + }, + ], + } + response = self.transport( + _chat_completions_url(self.base_url), + { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + request_payload, + self.timeout_sec, + ) + content = _message_content(response) + parsed = _loads_json_object(content) + if not isinstance(parsed, dict): + raise ValueError("LLM response must be a JSON object") + return parsed + + +def _default_transport( + url: str, + headers: dict[str, str], + payload: dict[str, Any], + timeout_sec: float, +) -> dict[str, Any]: + request = urllib.request.Request( + url, + data=json.dumps(payload).encode("utf-8"), + headers=headers, + method="POST", + ) + with urllib.request.urlopen(request, timeout=timeout_sec) as response: + raw = response.read().decode("utf-8") + decoded = json.loads(raw) + if not isinstance(decoded, dict): + raise ValueError("LLM HTTP response must be a JSON object") + return decoded + + +def _chat_completions_url(base_url: str) -> str: + clean = base_url.rstrip("/") + if clean.endswith("/chat/completions"): + return clean + return f"{clean}/chat/completions" + + +def _message_content(response: dict[str, Any]) -> Any: + try: + content = response["choices"][0]["message"]["content"] + except (KeyError, IndexError, TypeError) as exc: + raise ValueError("LLM response does not contain choices[0].message.content") from exc + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + parts.append(str(item.get("text", ""))) + elif isinstance(item, str): + parts.append(item) + return "".join(parts) + return content + + +def _loads_json_object(content: Any) -> Any: + if isinstance(content, dict): + return content + if not isinstance(content, str): + raise ValueError("LLM message content must be JSON text") + return json.loads(content) + + +def _redact_sensitive(value: Any) -> Any: + if isinstance(value, dict): + redacted: dict[str, Any] = {} + for key, item in value.items(): + if str(key) in SENSITIVE_KEYS: + redacted[str(key)] = "***" + else: + redacted[str(key)] = _redact_sensitive(item) + return redacted + if isinstance(value, list): + return [_redact_sensitive(item) for item in value] + return value + + +def _string(payload: dict[str, Any], key: str, default: str) -> str: + value = payload.get(key, default) + return str(value) if value is not None else default + + +def _float(payload: dict[str, Any], key: str, default: float) -> float: + try: + return float(payload.get(key, default)) + except (TypeError, ValueError): + return default + + +def _dict(value: Any) -> dict[str, Any]: + return value if isinstance(value, dict) else {} + + +def _string_list(value: Any) -> list[str]: + if isinstance(value, list): + return [str(item) for item in value] + if value in (None, ""): + return [] + return [str(value)] diff --git a/pam_deploy_graph/llm/prompts.py b/pam_deploy_graph/llm/prompts.py new file mode 100644 index 0000000..45a594c --- /dev/null +++ b/pam_deploy_graph/llm/prompts.py @@ -0,0 +1,67 @@ +"""Prompts for structured PAM deployment planning.""" + +SYSTEM_PROMPT = """你是 PAM 智能部署 Agent 的结构化理解与规划组件。 + +必须遵守: +- 只输出一个 JSON 对象,不输出 Markdown、解释文字或代码块。 +- 不生成 shell、PowerShell、bat、curl 等可执行命令。 +- 不回显密钥、token、CLIENT_SECRET、Authorization 等敏感值。 +- 只能在允许的 action 集合中选择部署动作。 +- 真实执行前必须保留人工确认点:参数确认、目标 IP 范围确认、失败回滚确认。 +""" + +INTENT_PROMPT = """根据用户输入识别意图和执行偏好。 + +输出 JSON schema: +{ + "intent": "deploy|show_usage|preview|query_node_ips|rollback", + "mode_preference": "MCP|API脚本|未指定", + "strategy_preference": "hybrid_node_mcp|script_only|fake|未指定", + "confidence": 0.0, + "reasons": ["..."], + "needs_clarification": false, + "clarification_questions": ["..."] +} +""" + +PARAM_PROMPT = """从用户输入中抽取 PAM 部署参数和控制信息。 + +输出 JSON schema: +{ + "extracted_params": { + "HOME_BASE_URL": "...", + "CLIENT_ID": "...", + "AIRPORT_CODE": "...", + "APP_NAME": "...", + "MODULE_NAME": "...", + "VERSION_NUMBER": "...", + "ZIP_FILE_PATH": "...", + "ACTION_TYPE": "...", + "TIMEOUT": "...", + "LOG_NAME": "..." + }, + "extracted_control": { + "user_specified_ips": ["..."] + }, + "missing_required_params": ["..."], + "ambiguous_fields": ["..."], + "sensitive_fields_present": ["..."] +} + +不要输出或猜测 CLIENT_SECRET 的真实值;如果输入里出现敏感字段,只标记字段名。 +""" + +PLAN_PROMPT = """生成 PAM 部署计划。 + +输出 JSON schema: +{ + "summary": "...", + "risk_notes": ["..."], + "planned_actions": ["get-token", "create-version"], + "requires_confirmation": true, + "execution_strategy": "hybrid_node_mcp|script_only|fake|未指定" +} + +计划只能使用允许 action;不要包含可执行脚本命令、命令行参数或密钥。 +PAM_HOME action 仍由脚本 action 执行;PAM_NODE action 在 hybrid_node_mcp 策略下走 MCP。 +""" diff --git a/pam_deploy_graph/mcp_client.py b/pam_deploy_graph/mcp_client.py index 180bf39..5104e80 100644 --- a/pam_deploy_graph/mcp_client.py +++ b/pam_deploy_graph/mcp_client.py @@ -9,9 +9,36 @@ from __future__ import annotations import json from collections.abc import Callable +from dataclasses import dataclass, field +from pathlib import Path from typing import Any +@dataclass(frozen=True) +class McpClientConfig: + """Configuration needed after a real MCP session has been created.""" + + server_name: str = "pam-node" + tool_names: dict[str, str] = field(default_factory=dict) + + @classmethod + def from_mapping(cls, payload: dict[str, Any]) -> "McpClientConfig": + tool_names = payload.get("tool_names") or payload.get("tools") or {} + if not isinstance(tool_names, dict): + raise ValueError("MCP tool_names must be an object") + return cls( + server_name=str(payload.get("server_name", "pam-node")), + tool_names={str(key): str(value) for key, value in tool_names.items()}, + ) + + +def load_mcp_client_config(path: str | Path) -> McpClientConfig: + payload = json.loads(Path(path).read_text(encoding="utf-8")) + if not isinstance(payload, dict): + raise ValueError("MCP client config must be a JSON object") + return McpClientConfig.from_mapping(payload) + + class FunctionMcpToolClient: """Wrap a plain Python callable as an MCP tool client.""" diff --git a/pam_deploy_graph/models.py b/pam_deploy_graph/models.py index 4e60be8..94c88bd 100644 --- a/pam_deploy_graph/models.py +++ b/pam_deploy_graph/models.py @@ -99,4 +99,5 @@ class AgentState: pending_confirmation: str = "" last_success_step: str = "" last_failed_step: str = "" + checkpoint_path: str = "" events: list[dict[str, Any]] = field(default_factory=list) diff --git a/tests/test_agent_flow.py b/tests/test_agent_flow.py index 1d0da52..a573b92 100644 --- a/tests/test_agent_flow.py +++ b/tests/test_agent_flow.py @@ -1,6 +1,8 @@ from pathlib import Path from pam_deploy_graph.agent import PamDeployAgent +from pam_deploy_graph.checkpoint_store import load_agent_state +from pam_deploy_graph.constants import GLOBAL_ACTION_SEQUENCE from pam_deploy_graph.fake_runner import FakeActionRunner @@ -56,3 +58,96 @@ def test_run_deploy_flow_stops_on_verify_failure(tmp_path: Path): assert state.ip_states["192.168.1.10"]["rollback_status"] == "PENDING_AGENT_CONFIRMATION" assert "192.168.1.11" not in state.ip_states assert any(event["type"] == "CONFIRMATION_REQUIRED" for event in state.events) + + +def test_confirm_pending_rollback_runs_rollback_and_resume_continues(tmp_path: Path): + fake = FakeActionRunner( + { + "verify-ip:192.168.1.10": { + "ACTION": "verify-ip", + "IP": "192.168.1.10", + "SUCCESS": "false", + "MESSAGE": "health check failed", + } + } + ) + agent = PamDeployAgent(fake_runner=fake) + state = agent.create_state( + params=PARAMS, + execution_strategy="fake", + config_path=str(tmp_path / "config.txt"), + ) + + agent.run_deploy_flow(state) + request = agent.build_confirmation_request(state) + agent.confirm_pending(state, approved=True) + agent.run_deploy_flow(state) + + assert request["type"] == "rollback-ip" + assert state.pending_confirmation == "" + assert state.ip_states["192.168.1.10"]["rollback_status"] == "ROLLBACK_DONE" + assert state.ip_states["192.168.1.11"]["status"] == "SUCCESS" + assert any(call[0] == "rollback-ip" for call in fake.calls) + + +def test_failed_rollback_keeps_confirmation_pending(tmp_path: Path): + fake = FakeActionRunner( + { + "verify-ip:192.168.1.10": { + "ACTION": "verify-ip", + "IP": "192.168.1.10", + "SUCCESS": "false", + "MESSAGE": "health check failed", + }, + "rollback-ip:192.168.1.10": { + "_fail": True, + "ACTION": "rollback-ip", + "IP": "192.168.1.10", + "MESSAGE": "rollback failed", + }, + } + ) + agent = PamDeployAgent(fake_runner=fake) + state = agent.create_state( + params=PARAMS, + execution_strategy="fake", + config_path=str(tmp_path / "config.txt"), + ) + + agent.run_deploy_flow(state) + agent.confirm_pending(state, approved=True) + + assert state.pending_confirmation == "rollback-ip:192.168.1.10" + assert state.ip_states["192.168.1.10"]["rollback_status"] == "ROLLBACK_FAILED" + + +def test_checkpoint_resume_skips_completed_global_and_success_ip(tmp_path: Path): + checkpoint = tmp_path / "checkpoint.json" + fake = FakeActionRunner() + agent = PamDeployAgent(fake_runner=fake) + state = agent.create_state( + params=PARAMS, + execution_strategy="fake", + config_path=str(tmp_path / "config.txt"), + checkpoint_path=str(checkpoint), + ) + state.completed_global_steps = list(GLOBAL_ACTION_SEQUENCE) + state.online_ips = ["192.168.1.10", "192.168.1.11"] + state.target_ips = ["192.168.1.10", "192.168.1.11"] + state.ip_states["192.168.1.10"] = { + "status": "SUCCESS", + "completed_steps": ["upgrade-ip", "poll-upgrade-progress", "start-ip", "verify-ip", "download-log"], + "failed_stage": "", + "failure_reason": "", + "rollback_status": "ROLLBACK_NOT_RUN", + "rollback_stop_first": False, + "log_file": "logs/fake.zip", + } + + agent.run_deploy_flow(state) + loaded = load_agent_state(checkpoint) + + called_actions = [call[0] for call in fake.calls] + assert "get-token" not in called_actions + assert all(call[1].get("ip") != "192.168.1.10" for call in fake.calls) + assert loaded.ip_states["192.168.1.11"]["status"] == "SUCCESS" diff --git a/tests/test_interactive_cli.py b/tests/test_interactive_cli.py new file mode 100644 index 0000000..092ffb5 --- /dev/null +++ b/tests/test_interactive_cli.py @@ -0,0 +1,84 @@ +from pathlib import Path + +from pam_deploy_graph.agent import PamDeployAgent +from pam_deploy_graph.fake_runner import FakeActionRunner +from pam_deploy_graph.interactive import InteractiveCliSession + + +PARAMS = { + "HOME_BASE_URL": "https://pam.home.example.com", + "CLIENT_ID": "client", + "CLIENT_SECRET": "secret", + "AIRPORT_CODE": "HET", + "APP_NAME": "PAM", + "MODULE_NAME": "Node", + "VERSION_NUMBER": "2.0.5", + "ZIP_FILE_PATH": "C:/pkg.zip", +} + + +def run_session(session: InteractiveCliSession, inputs: list[str]) -> list[str]: + output: list[str] = [] + iterator = iter(inputs) + session.input = lambda _prompt: next(iterator) + session.output = output.append + session.run() + return output + + +def test_chat_analyzes_natural_language_and_updates_context(tmp_path: Path): + session = InteractiveCliSession( + agent=PamDeployAgent(), + params=PARAMS, + strategy="fake", + checkpoint_path=str(tmp_path / "checkpoint.json"), + ) + + output = run_session(session, ["analyze please use MCP deploy 192.168.1.10", "exit"]) + + assert session.strategy == "hybrid_node_mcp" + assert session.target_ips == ["192.168.1.10"] + assert any("执行请输 run" in item for item in output) + + +def test_chat_run_executes_fake_deploy_and_writes_checkpoint(tmp_path: Path): + checkpoint = tmp_path / "checkpoint.json" + session = InteractiveCliSession( + agent=PamDeployAgent(fake_runner=FakeActionRunner()), + params=PARAMS, + strategy="fake", + checkpoint_path=str(checkpoint), + ) + + run_session(session, ["run", "yes", "exit"]) + + assert checkpoint.exists() + assert session.state is not None + assert session.state.pending_confirmation == "" + assert all(item["status"] == "SUCCESS" for item in session.state.ip_states.values()) + + +def test_chat_approve_then_resume_continues_after_failed_ip(tmp_path: Path): + fake = FakeActionRunner( + { + "verify-ip:192.168.1.10": { + "ACTION": "verify-ip", + "IP": "192.168.1.10", + "SUCCESS": "false", + "MESSAGE": "health check failed", + } + } + ) + session = InteractiveCliSession( + agent=PamDeployAgent(fake_runner=fake), + params=PARAMS, + strategy="fake", + checkpoint_path=str(tmp_path / "checkpoint.json"), + ) + + run_session(session, ["run", "yes", "approve", "resume", "exit"]) + + assert session.state is not None + assert session.state.pending_confirmation == "" + assert session.state.ip_states["192.168.1.10"]["rollback_status"] == "ROLLBACK_DONE" + assert session.state.ip_states["192.168.1.11"]["status"] == "SUCCESS" diff --git a/tests/test_llm_structured.py b/tests/test_llm_structured.py index 7cb96e8..873aa2f 100644 --- a/tests/test_llm_structured.py +++ b/tests/test_llm_structured.py @@ -2,6 +2,7 @@ from dataclasses import asdict from pam_deploy_graph.agent import PamDeployAgent from pam_deploy_graph.checkpoint_store import redact_mapping +from pam_deploy_graph.llm.openai_compatible import OpenAICompatibleLlmClient from pam_deploy_graph.llm.rule_based import RuleBasedLlmClient from pam_deploy_graph.llm.validators import validate_deploy_plan from pam_deploy_graph.models import LlmDeployPlan @@ -71,3 +72,72 @@ def test_plan_guardrails_reject_executable_text(): assert "forbidden" in str(exc) else: raise AssertionError("expected guardrail failure") + + +def test_openai_compatible_client_uses_base_url_api_key_and_prompt(): + calls = [] + + def transport(url, headers, payload, timeout_sec): + calls.append((url, headers, payload, timeout_sec)) + return { + "choices": [ + { + "message": { + "content": ( + '{"intent":"deploy","mode_preference":"MCP",' + '"strategy_preference":"hybrid_node_mcp","confidence":0.9,' + '"reasons":["ok"]}' + ) + } + } + ] + } + + client = OpenAICompatibleLlmClient( + base_url="https://llm.example/v1", + api_key="secret-key", + model="model-a", + transport=transport, + ) + + result = client.understand_request("请用 MCP 部署") + + assert result.intent == "deploy" + assert calls[0][0] == "https://llm.example/v1/chat/completions" + assert calls[0][1]["Authorization"] == "Bearer secret-key" + assert calls[0][2]["model"] == "model-a" + assert "只输出一个 JSON 对象" in calls[0][2]["messages"][0]["content"] + + +def test_openai_compatible_client_does_not_send_base_secret(): + calls = [] + + def transport(url, headers, payload, timeout_sec): + calls.append(payload) + return { + "choices": [ + { + "message": { + "content": ( + '{"extracted_params":{"AIRPORT_CODE":"HET"},' + '"extracted_control":{},' + '"missing_required_params":[],' + '"ambiguous_fields":[]}' + ) + } + } + ] + } + + client = OpenAICompatibleLlmClient( + base_url="https://llm.example/v1", + api_key="secret-key", + model="model-a", + transport=transport, + ) + + result = client.extract_params("机场 HET", {"CLIENT_SECRET": "real-secret", "CLIENT_ID": "id"}) + + serialized_prompt = str(calls[0]) + assert "real-secret" not in serialized_prompt + assert result.extracted_params["CLIENT_SECRET"] == "real-secret" diff --git a/tests/test_mcp_client.py b/tests/test_mcp_client.py index 7a9f209..4af77aa 100644 --- a/tests/test_mcp_client.py +++ b/tests/test_mcp_client.py @@ -1,5 +1,6 @@ from pam_deploy_graph.mcp_client import ( FunctionMcpToolClient, + load_mcp_client_config, SessionMcpToolClient, normalize_mcp_sdk_result, ) @@ -26,3 +27,15 @@ def test_session_mcp_client_normalizes_text_json_content(): client = SessionMcpToolClient(Session()) assert client.call_tool("tool", {}) == {"ok": True} + +def test_load_mcp_client_config(tmp_path): + path = tmp_path / "mcp.json" + path.write_text( + '{"server_name": "pam-node-prod", "tool_names": {"get-online-ips": "custom_ips"}}', + encoding="utf-8", + ) + + config = load_mcp_client_config(path) + + assert config.server_name == "pam-node-prod" + assert config.tool_names["get-online-ips"] == "custom_ips"