1147 lines
48 KiB
Python
1147 lines
48 KiB
Python
"""PAM 部署 Agent 的常驻式交互 CLI 会话。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import time
|
||
import json
|
||
import shlex
|
||
import builtins
|
||
import logging
|
||
import os
|
||
from dataclasses import asdict
|
||
from pathlib import Path
|
||
from typing import Any, Callable
|
||
|
||
from .agent import PamDeployAgent
|
||
from .checkpoint_store import load_agent_state, redact_mapping
|
||
from .langgraph_runtime import LangGraphDeploymentRuntime, LangGraphRunResult
|
||
from .llm import build_llm_client
|
||
from .llm.rule_based import RuleBasedLlmClient
|
||
from .logging_utils import configure_logging, json_for_log, redact_for_log
|
||
from .mcp_factory import build_mcp_runner_from_config
|
||
from .models import AgentState, ExecutionStrategy
|
||
from .params_loader import load_params_file
|
||
|
||
InputFunc = Callable[[str], str]
|
||
OutputFunc = Callable[[str], None]
|
||
logger = logging.getLogger(__name__)
|
||
|
||
COMMAND_HELP = """可用命令:
|
||
help 显示帮助
|
||
preview 查看当前参数和执行策略
|
||
analyze <需求> 只做理解和计划,不执行
|
||
params 脱敏展示当前会话参数
|
||
events [数量] 查看最近 action 事件,默认 10 条
|
||
set KEY=VALUE 修改当前会话参数
|
||
llm config KEY=VALUE 配置真实 LLM,支持 base_url/api_key/model/action_analysis_prompt_file
|
||
llm test [文本] 测试当前 LLM client 是否可正常调用
|
||
llm fallback 切回本地规则 fallback
|
||
llm action-analysis on|off 开关 action 审核详情写入 events
|
||
mcp config <路径> 加载 MCP client JSON 配置
|
||
run 创建部署任务并执行
|
||
status 查看当前运行状态
|
||
resume 从当前 checkpoint 续跑
|
||
rollback [IP] 显式回滚指定 IP;不传 IP 时回滚当前失败 IP
|
||
list checkpoints 列出 checkpoint 目录下的 JSON 文件
|
||
load params <路径> 加载并热更新参数文件
|
||
load checkpoint <路径> 加载指定 checkpoint
|
||
checkpoint 显示 checkpoint 路径
|
||
exit 退出
|
||
|
||
也可以直接输入自然语言需求,Agent 会先分析并更新会话参数;执行仍需输入 run。
|
||
执行中可按 Ctrl+C 中断,保存 checkpoint 后再用 resume 继续。
|
||
"""
|
||
|
||
|
||
class InteractiveCliSession:
|
||
"""维护一次交互式 CLI 会话的参数、状态和命令处理逻辑。"""
|
||
|
||
def __init__(
|
||
self,
|
||
*,
|
||
agent: PamDeployAgent,
|
||
params: dict[str, Any],
|
||
strategy: ExecutionStrategy = "hybrid_node_mcp",
|
||
checkpoint_path: str | None = None,
|
||
target_ips: list[str] | None = None,
|
||
input_func: InputFunc = input,
|
||
output_func: OutputFunc = print,
|
||
) -> None:
|
||
"""初始化会话上下文和输入输出函数。"""
|
||
self.agent = agent
|
||
self.params = dict(params)
|
||
self.strategy = strategy
|
||
self.checkpoint_path = checkpoint_path or _default_checkpoint_path()
|
||
self.target_ips = list(target_ips or [])
|
||
self.input = _build_prompt_input(input_func)
|
||
self.output = _build_output_func(output_func)
|
||
self.state: AgentState | None = None
|
||
self.last_analysis: dict[str, Any] | None = None
|
||
self.llm_config: dict[str, str] = {}
|
||
self.mcp_config_path: str = ""
|
||
self.graph_runtime: LangGraphDeploymentRuntime | None = None
|
||
self.agent.progress_callback = self._on_progress
|
||
self.log_path = configure_logging()
|
||
logger.info(
|
||
"chat 会话已初始化 strategy=%s checkpoint=%s target_ips=%s llm_client=%s log_path=%s params=%s",
|
||
self.strategy,
|
||
self.checkpoint_path,
|
||
self.target_ips,
|
||
type(self.agent.llm_client).__name__,
|
||
self.log_path,
|
||
json_for_log(self.params),
|
||
)
|
||
|
||
def run(self) -> None:
|
||
"""启动 REPL 循环,直到用户 exit 或输入流结束。"""
|
||
logger.info("chat REPL 启动 checkpoint=%s", self.checkpoint_path)
|
||
self.output("PAM 部署 Agent 交互式会话")
|
||
self.output("输入 help 查看命令,输入 exit 退出。")
|
||
self._load_existing_checkpoint_if_any()
|
||
while True:
|
||
try:
|
||
line = self.input("pam-deploy-agent> ")
|
||
except KeyboardInterrupt:
|
||
logger.info("chat 输入被用户中断")
|
||
self.output("已取消当前输入。输入 exit 退出,或继续输入命令。")
|
||
continue
|
||
except EOFError:
|
||
logger.info("chat 输入流结束")
|
||
self.output("bye")
|
||
return
|
||
if not self.handle_line(line):
|
||
return
|
||
|
||
def handle_line(self, line: str) -> bool:
|
||
"""处理用户输入的一行命令;返回 False 表示退出会话。"""
|
||
text = line.strip()
|
||
if not text:
|
||
return True
|
||
|
||
command, _, rest = text.partition(" ")
|
||
normalized = command.lower()
|
||
logger.info(
|
||
"chat 收到输入 command=%s text=%s",
|
||
normalized,
|
||
redact_for_log(text, max_text_len=500),
|
||
)
|
||
|
||
if normalized in ("exit", "quit", "q"):
|
||
logger.info("chat 会话退出")
|
||
self.output("bye")
|
||
return False
|
||
if normalized in ("help", "?"):
|
||
self.output(COMMAND_HELP.rstrip())
|
||
return True
|
||
if normalized == "preview":
|
||
self.output(self.agent.preview(self.params, self.strategy))
|
||
return True
|
||
if normalized == "params":
|
||
self._show_params()
|
||
return True
|
||
if normalized == "events":
|
||
self._show_events(rest.strip())
|
||
return True
|
||
if normalized == "analyze":
|
||
self._analyze(rest.strip())
|
||
return True
|
||
if normalized == "set":
|
||
self._set_param(rest.strip())
|
||
return True
|
||
if normalized == "llm":
|
||
self._configure_llm(rest.strip())
|
||
return True
|
||
if normalized == "mcp":
|
||
self._configure_mcp(rest.strip())
|
||
return True
|
||
if normalized in ("run", "deploy", "execute"):
|
||
self._run_deploy()
|
||
return True
|
||
if normalized == "resume":
|
||
self._resume()
|
||
return True
|
||
if normalized == "rollback":
|
||
self._rollback(rest.strip())
|
||
return True
|
||
if normalized == "status":
|
||
self._status()
|
||
return True
|
||
if normalized == "approve":
|
||
self._confirm(approved=True, note=rest.strip())
|
||
return True
|
||
if normalized == "reject":
|
||
self._confirm(approved=False, note=rest.strip())
|
||
return True
|
||
if normalized == "checkpoint":
|
||
self.output(f"checkpoint: {self.checkpoint_path}")
|
||
return True
|
||
if normalized == "list" and rest.strip().lower() == "checkpoints":
|
||
self._list_checkpoints()
|
||
return True
|
||
if normalized == "load" and rest.strip().lower().startswith("params"):
|
||
self._load_params(rest.strip()[len("params") :].strip())
|
||
return True
|
||
if normalized == "load" and rest.strip().lower().startswith("checkpoint"):
|
||
self._load_checkpoint(rest.strip()[len("checkpoint") :].strip())
|
||
return True
|
||
|
||
if _is_small_talk(text):
|
||
logger.info("chat 输入识别为寒暄,跳过结构化分析")
|
||
self.output("你好。可以输入 help 查看命令,或直接描述部署需求;执行前仍需输入 run 并确认。")
|
||
return True
|
||
if not _looks_like_deploy_request(text):
|
||
logger.info("chat 输入未命中部署需求粗筛,跳过结构化分析")
|
||
self.output("我没有识别到明确的部署需求。可以输入 help 查看命令,或用 analyze <需求> 明确触发需求分析。")
|
||
return True
|
||
|
||
self.output("正在分析需求...")
|
||
self._analyze(text)
|
||
return True
|
||
|
||
def _analyze(self, text: str) -> None:
|
||
"""分析自然语言需求,并更新会话中的参数、策略和目标 IP。"""
|
||
if not text:
|
||
self.output("请输入要分析的自然语言需求,例如:analyze 请用 MCP 预演部署 HET。")
|
||
return
|
||
|
||
try:
|
||
logger.info("chat 开始需求分析 text_len=%s base_params=%s", len(text), json_for_log(self.params))
|
||
result = self.agent.analyze_request(text, self.params)
|
||
except Exception as exc:
|
||
logger.exception("chat 需求分析失败 text=%s", redact_for_log(text, max_text_len=500))
|
||
self.output(f"需求分析失败: {exc}")
|
||
return
|
||
self.last_analysis = result
|
||
param_result = result["params"]
|
||
intent_result = result["intent"]
|
||
plan = result["plan"]
|
||
self.params = dict(param_result.extracted_params)
|
||
self.strategy = _choose_strategy(intent_result.strategy_preference, self.strategy)
|
||
|
||
user_ips = param_result.extracted_control.get("user_specified_ips")
|
||
if isinstance(user_ips, list):
|
||
self.target_ips = [str(item) for item in user_ips]
|
||
self._sync_params_to_state()
|
||
|
||
safe_payload = redact_mapping({key: asdict(value) for key, value in result.items()})
|
||
self.output("已生成结构化理解:")
|
||
self.output(f"- intent: {intent_result.intent}")
|
||
self.output(f"- strategy: {self.strategy}")
|
||
self.output(f"- summary: {plan.summary}")
|
||
if param_result.missing_required_params:
|
||
self.output("- missing: " + ", ".join(param_result.missing_required_params))
|
||
if self.target_ips:
|
||
self.output("- target_ips: " + ", ".join(self.target_ips))
|
||
self.output("执行请输 run;查看完整 JSON 可用一次性 analyze 命令。")
|
||
self.output(_format_redacted_params(safe_payload["params"]["extracted_params"]))
|
||
logger.info(
|
||
"chat 需求分析完成 intent=%s strategy=%s target_ips=%s result=%s",
|
||
intent_result.intent,
|
||
self.strategy,
|
||
self.target_ips,
|
||
json_for_log(safe_payload),
|
||
)
|
||
|
||
def _set_param(self, assignment: str) -> None:
|
||
"""处理 `set KEY=VALUE` 命令,更新当前会话参数。"""
|
||
if "=" not in assignment:
|
||
self.output("格式:set KEY=VALUE")
|
||
return
|
||
key, value = assignment.split("=", 1)
|
||
key = key.strip()
|
||
if not key:
|
||
self.output("参数名不能为空。")
|
||
return
|
||
self.params[key] = value.strip()
|
||
self._sync_params_to_state()
|
||
self.output(f"已设置 {key}")
|
||
logger.info("chat 参数已设置 key=%s params=%s", key, json_for_log(self.params))
|
||
|
||
def _show_params(self) -> None:
|
||
"""脱敏展示当前会话参数。"""
|
||
self.output(_format_redacted_params(redact_mapping(self.params)))
|
||
|
||
def _show_events(self, count_text: str) -> None:
|
||
"""展示最近若干条事件。"""
|
||
if self.state is None or not self.state.events:
|
||
self.output("当前没有事件。")
|
||
return
|
||
try:
|
||
count = int(count_text) if count_text else 10
|
||
except ValueError:
|
||
self.output("格式:events [数量]")
|
||
return
|
||
events = self.state.events[-max(count, 1) :]
|
||
self.output(json.dumps(redact_mapping(events), ensure_ascii=False, indent=2, default=str))
|
||
|
||
def _configure_llm(self, text: str) -> None:
|
||
"""热加载 LLM 配置,或开关 action 后诊断。"""
|
||
if not text:
|
||
self.output("格式:llm config base_url=... api_key=... model=... action_analysis_prompt_file=... | llm test [文本] | llm fallback | llm action-analysis on|off")
|
||
return
|
||
try:
|
||
parts = shlex.split(text)
|
||
except ValueError as exc:
|
||
logger.exception("chat LLM 命令解析失败 text=%s", redact_for_log(text, max_text_len=500))
|
||
self.output(f"LLM 命令解析失败: {exc}")
|
||
return
|
||
if parts[0] == "fallback":
|
||
self.agent.llm_client = RuleBasedLlmClient()
|
||
self.llm_config = {}
|
||
self.output("已切回本地规则 LLM fallback。")
|
||
logger.info("chat LLM 已切回 fallback")
|
||
return
|
||
if parts[0] == "test":
|
||
test_text = text.partition("test")[2].strip()
|
||
self._test_llm(test_text)
|
||
return
|
||
if parts[0] == "action-analysis":
|
||
if len(parts) < 2 or parts[1] not in ("on", "off"):
|
||
self.output("格式:llm action-analysis on|off")
|
||
return
|
||
self.agent.action_analysis_enabled = parts[1] == "on"
|
||
self.output(f"action 审核详情写入 events 已{'开启' if self.agent.action_analysis_enabled else '关闭'}。")
|
||
logger.info("chat action 审核事件写入开关=%s", self.agent.action_analysis_enabled)
|
||
return
|
||
if parts[0] != "config":
|
||
self.output("未知 llm 命令。")
|
||
return
|
||
updates = _parse_key_values(parts[1:])
|
||
self.llm_config.update(updates)
|
||
try:
|
||
logger.info("chat 开始加载 LLM 配置 updates=%s merged=%s", json_for_log(updates), json_for_log(self.llm_config))
|
||
self.agent.llm_client = build_llm_client(
|
||
base_url=self.llm_config.get("base_url"),
|
||
api_key=self.llm_config.get("api_key"),
|
||
model=self.llm_config.get("model"),
|
||
action_analysis_prompt_path=self.llm_config.get("action_analysis_prompt_file"),
|
||
)
|
||
except Exception as exc:
|
||
logger.exception("chat LLM 配置失败 config=%s", json_for_log(self.llm_config))
|
||
self.output(f"LLM 配置失败: {exc}")
|
||
return
|
||
safe = {**self.llm_config}
|
||
if safe.get("api_key"):
|
||
safe["api_key"] = "***"
|
||
self.output("LLM 配置已加载: " + json.dumps(safe, ensure_ascii=False))
|
||
logger.info("chat LLM 配置已加载 client=%s config=%s", type(self.agent.llm_client).__name__, json_for_log(self.llm_config))
|
||
|
||
def _test_llm(self, text: str) -> None:
|
||
"""使用当前 LLM client 做一次轻量调用,便于确认配置是否生效。"""
|
||
prompt = text or "请判断这是一次 LLM 连通性测试,并返回结构化意图。"
|
||
client_name = type(self.agent.llm_client).__name__
|
||
self.output(f"正在测试 LLM: {client_name}")
|
||
logger.info("chat LLM 测试开始 client=%s prompt=%s", client_name, redact_for_log(prompt, max_text_len=500))
|
||
try:
|
||
result = self.agent.understand_request(prompt)
|
||
except Exception as exc:
|
||
logger.exception("chat LLM 测试失败 client=%s", client_name)
|
||
self.output(f"LLM 测试失败: {exc}")
|
||
return
|
||
self.output("LLM 测试通过")
|
||
self.output(f"- client: {client_name}")
|
||
self.output(f"- intent: {result.intent}")
|
||
self.output(f"- strategy: {result.strategy_preference}")
|
||
self.output(f"- confidence: {result.confidence}")
|
||
if result.reasons:
|
||
self.output("- reasons: " + "; ".join(result.reasons))
|
||
logger.info("chat LLM 测试通过 client=%s result=%s", client_name, json_for_log(asdict(result)))
|
||
|
||
def _configure_mcp(self, text: str) -> None:
|
||
"""热加载 MCP client 配置。"""
|
||
command, _, path = text.partition(" ")
|
||
if command != "config" or not path.strip():
|
||
self.output("格式:mcp config <mcp_client.json>")
|
||
return
|
||
path = path.strip().strip('"')
|
||
try:
|
||
logger.info("chat 开始加载 MCP 配置 path=%s", path)
|
||
runner = build_mcp_runner_from_config(path)
|
||
except Exception as exc:
|
||
logger.exception("chat MCP 配置失败 path=%s", path)
|
||
self.output(f"MCP 配置失败: {exc}")
|
||
return
|
||
self.agent.mcp_runner = runner
|
||
self.agent.router.mcp_runner = runner
|
||
self.mcp_config_path = path
|
||
self.output(f"MCP 配置已加载: {path}")
|
||
logger.info("chat MCP 配置已加载 path=%s runner=%s", path, type(runner).__name__)
|
||
|
||
def _list_checkpoints(self) -> None:
|
||
"""列出当前 checkpoint 目录下的 JSON 文件。"""
|
||
checkpoint_dir = Path(self.checkpoint_path).parent
|
||
logger.info("chat 查询 checkpoint 列表 dir=%s", checkpoint_dir)
|
||
if not checkpoint_dir.exists():
|
||
self.output(f"checkpoint 目录不存在: {checkpoint_dir}")
|
||
return
|
||
files = sorted(checkpoint_dir.glob("*.json"), key=lambda item: item.stat().st_mtime, reverse=True)
|
||
if not files:
|
||
self.output(f"checkpoint 目录没有 JSON 文件: {checkpoint_dir}")
|
||
return
|
||
lines = ["checkpoint 列表:"]
|
||
for file in files[:20]:
|
||
lines.append(f"- {file}")
|
||
self.output("\n".join(lines))
|
||
|
||
def _load_checkpoint(self, path_text: str) -> None:
|
||
"""加载指定 checkpoint 文件。"""
|
||
if not path_text:
|
||
self.output("格式:load checkpoint <路径>")
|
||
return
|
||
checkpoint = Path(path_text)
|
||
logger.info("chat 开始加载 checkpoint path=%s", checkpoint)
|
||
if not checkpoint.exists():
|
||
self.output(f"checkpoint 不存在: {checkpoint}")
|
||
return
|
||
self.state = load_agent_state(checkpoint)
|
||
self.state.checkpoint_path = str(checkpoint)
|
||
self.checkpoint_path = str(checkpoint)
|
||
self.params = dict(self.state.params)
|
||
self.strategy = self.state.execution_strategy
|
||
self.target_ips = list(self.state.target_ips)
|
||
self.graph_runtime = None
|
||
self.output(f"已加载 checkpoint: {checkpoint}")
|
||
logger.info(
|
||
"chat checkpoint 已加载 path=%s run_id=%s strategy=%s paused=%s pending=%s",
|
||
checkpoint,
|
||
self.state.run_id,
|
||
self.strategy,
|
||
self.state.paused,
|
||
self.state.pending_confirmation,
|
||
)
|
||
if self.state.pending_confirmation:
|
||
self._print_confirmation()
|
||
self._print_pause_context()
|
||
|
||
def _load_params(self, path_text: str) -> None:
|
||
"""从参数文件热更新当前会话参数,并同步到已暂停 state。"""
|
||
if not path_text:
|
||
self.output("格式:load params <路径>")
|
||
return
|
||
path = Path(path_text)
|
||
logger.info("chat 开始加载参数文件 path=%s", path)
|
||
if not path.exists():
|
||
self.output(f"参数文件不存在: {path}")
|
||
return
|
||
try:
|
||
updates = load_params_file(path)
|
||
except Exception as exc:
|
||
logger.exception("chat 参数文件加载失败 path=%s", path)
|
||
self.output(f"参数文件加载失败: {exc}")
|
||
return
|
||
self.params.update(updates)
|
||
try:
|
||
self.params = self.agent.normalize_params(self.params)
|
||
except ValueError as exc:
|
||
logger.exception("chat 参数热更新归一化失败 path=%s updates=%s", path, json_for_log(updates))
|
||
self.output(f"参数热更新失败: {exc}")
|
||
return
|
||
self._sync_params_to_state()
|
||
self.output(f"已加载参数文件: {path}")
|
||
self.output(_format_redacted_params(redact_mapping(self.params)))
|
||
logger.info("chat 参数文件已加载 path=%s updates=%s params=%s", path, json_for_log(updates), json_for_log(self.params))
|
||
|
||
def _run_deploy(self) -> None:
|
||
"""在用户确认后创建状态并执行完整部署流程。"""
|
||
logger.info("chat run 请求开始 strategy=%s checkpoint=%s target_ips=%s", self.strategy, self.checkpoint_path, self.target_ips)
|
||
if self.state and self.state.pending_confirmation:
|
||
logger.info("chat run 命中待确认事项 pending=%s", self.state.pending_confirmation)
|
||
self._print_confirmation()
|
||
return
|
||
|
||
if not self._prepare_params_for_run():
|
||
logger.info("chat run 参数准备失败")
|
||
return
|
||
|
||
problems = self._validate_run_prerequisites(self.params)
|
||
if problems:
|
||
logger.info("chat run 前置检查失败 problems=%s", problems)
|
||
self.output("执行前检查未通过:")
|
||
for problem in problems:
|
||
self.output(f"- {problem}")
|
||
self.output("请修正参数或配置后再输入 run。")
|
||
return
|
||
|
||
if not self._confirm_params_and_scope():
|
||
logger.info("chat run 用户取消参数或目标范围确认")
|
||
self.output("已取消执行。")
|
||
return
|
||
|
||
if not self._ask_yes_no("即将执行真实 action;确认执行请输入 yes: "):
|
||
logger.info("chat run 用户取消最终执行确认")
|
||
self.output("已取消执行。")
|
||
return
|
||
|
||
self.state = self.agent.create_state(
|
||
params=self.params,
|
||
execution_strategy=self.strategy,
|
||
checkpoint_path=self.checkpoint_path,
|
||
target_ips=self.target_ips,
|
||
)
|
||
self.graph_runtime = None
|
||
logger.info(
|
||
"chat run state 已创建 run_id=%s strategy=%s checkpoint=%s config=%s target_ips=%s",
|
||
self.state.run_id,
|
||
self.state.execution_strategy,
|
||
self.state.checkpoint_path,
|
||
self.state.config_path,
|
||
self.state.target_ips,
|
||
)
|
||
self._execute_current_state()
|
||
|
||
def _confirm_params_and_scope(self) -> bool:
|
||
"""执行前确认参数和目标 IP 范围。"""
|
||
self.output(_format_redacted_params(redact_mapping(self.params)))
|
||
if not self._ask_yes_no("确认以上参数请输入 yes: "):
|
||
return False
|
||
if self.target_ips:
|
||
self.output("目标 IP: " + ", ".join(self.target_ips))
|
||
else:
|
||
self.output("目标 IP: 未指定,将在 get-online-ips 后使用全部在线 IP。")
|
||
return self._ask_yes_no("确认目标范围请输入 yes: ")
|
||
|
||
def _resume(self) -> None:
|
||
"""从内存状态或 checkpoint 文件继续执行部署流程。"""
|
||
if self.state is None:
|
||
checkpoint = Path(self.checkpoint_path)
|
||
if not checkpoint.exists():
|
||
self.output("当前没有可续跑的 checkpoint。")
|
||
logger.info("chat resume 未找到 checkpoint path=%s", checkpoint)
|
||
return
|
||
logger.info("chat resume 从 checkpoint 加载 path=%s", checkpoint)
|
||
self.state = load_agent_state(checkpoint)
|
||
self.state.checkpoint_path = self.state.checkpoint_path or str(checkpoint)
|
||
if self.state.paused:
|
||
logger.info("chat resume 清理暂停状态 run_id=%s reason=%s", self.state.run_id, self.state.pause_reason)
|
||
self.state = self.agent.resume_state(self.state)
|
||
if self.graph_runtime and self.graph_runtime.waiting_confirmation:
|
||
logger.info("chat resume 停在 LangGraph 确认点 pending=%s", self.state.pending_confirmation if self.state else "")
|
||
self._print_confirmation()
|
||
return
|
||
logger.info("chat resume 开始执行 run_id=%s checkpoint=%s", self.state.run_id if self.state else "", self.checkpoint_path)
|
||
self._execute_current_state()
|
||
|
||
def _execute_current_state(self) -> None:
|
||
"""执行当前 state,并输出报告、确认提示和 checkpoint 路径。"""
|
||
if self.state is None:
|
||
self.output("当前没有运行状态。")
|
||
return
|
||
logger.info(
|
||
"chat 开始执行当前 state run_id=%s strategy=%s checkpoint=%s graph_runtime=%s waiting_confirmation=%s",
|
||
self.state.run_id,
|
||
self.state.execution_strategy,
|
||
self.state.checkpoint_path,
|
||
type(self.graph_runtime).__name__ if self.graph_runtime else "",
|
||
self.graph_runtime.waiting_confirmation if self.graph_runtime else False,
|
||
)
|
||
if self.graph_runtime is None or not self.graph_runtime.waiting_confirmation:
|
||
try:
|
||
self.graph_runtime = LangGraphDeploymentRuntime(agent=self.agent)
|
||
except RuntimeError as exc:
|
||
logger.exception("chat LangGraph runtime 不可用,降级本地执行 run_id=%s", self.state.run_id)
|
||
self.output(f"LangGraph 确认运行器不可用,降级为本地执行: {exc}")
|
||
self.graph_runtime = None
|
||
try:
|
||
self.state = self.agent.run_deploy_flow(self.state)
|
||
except KeyboardInterrupt:
|
||
self._handle_execution_interrupt()
|
||
return
|
||
except Exception as fallback_exc:
|
||
self._handle_execution_error(fallback_exc)
|
||
return
|
||
self._print_state_report_and_checkpoint()
|
||
logger.info("chat 本地执行完成 run_id=%s checkpoint=%s", self.state.run_id, self.state.checkpoint_path)
|
||
return
|
||
try:
|
||
result = self.graph_runtime.start(self.state)
|
||
except KeyboardInterrupt:
|
||
self._handle_execution_interrupt()
|
||
return
|
||
except Exception as exc:
|
||
self._handle_execution_error(exc)
|
||
return
|
||
self._apply_graph_result(result)
|
||
logger.info(
|
||
"chat LangGraph 执行返回 interrupted=%s pending=%s checkpoint=%s",
|
||
result.interrupted,
|
||
self.state.pending_confirmation if self.state else "",
|
||
self.state.checkpoint_path if self.state else self.checkpoint_path,
|
||
)
|
||
|
||
def _prepare_params_for_run(self) -> bool:
|
||
"""执行前归一化参数,确保确认值和实际写入脚本配置一致。"""
|
||
try:
|
||
self.params = self.agent.normalize_params(self.params)
|
||
except ValueError as exc:
|
||
logger.exception("chat 参数检查失败 params=%s", json_for_log(self.params))
|
||
self.output(f"参数检查失败: {exc}")
|
||
return False
|
||
logger.info("chat 参数检查通过 params=%s", json_for_log(self.params))
|
||
return True
|
||
|
||
def _validate_run_prerequisites(self, params: dict[str, Any]) -> list[str]:
|
||
"""在创建 state 前检查可提前发现的运行问题。"""
|
||
problems: list[str] = []
|
||
if self.strategy != "fake":
|
||
zip_path = str(params.get("ZIP_FILE_PATH", "")).strip()
|
||
if not _path_exists(zip_path):
|
||
problems.append(f"ZIP_FILE_PATH 不存在: {zip_path}")
|
||
if self.strategy in ("script_only", "hybrid_node_mcp"):
|
||
script_entry = self.agent.script_base_dir / "deploy.sh"
|
||
ps_entry = self.agent.script_base_dir / "deploy.ps1"
|
||
if not script_entry.exists() and not ps_entry.exists():
|
||
problems.append(f"脚本入口不存在: {script_entry} 或 {ps_entry}")
|
||
if self.strategy == "hybrid_node_mcp" and self.agent.mcp_runner is None:
|
||
problems.append("当前策略需要 MCP runner,请启动时传 --mcp-config 或在 chat 内执行 mcp config <路径>。")
|
||
return problems
|
||
|
||
def _handle_execution_error(self, exc: Exception) -> None:
|
||
"""输出 action 执行失败后的可恢复提示,不再误报 LangGraph 不可用。"""
|
||
self.output(f"执行已停止: {exc}")
|
||
logger.exception(
|
||
"chat 执行停止 run_id=%s checkpoint=%s",
|
||
self.state.run_id if self.state else "",
|
||
self.state.checkpoint_path if self.state else self.checkpoint_path,
|
||
)
|
||
if self.state is None:
|
||
return
|
||
if self.state.last_failed_step:
|
||
self.output(f"最后失败步骤: {self.state.last_failed_step}")
|
||
self._print_pause_context()
|
||
if self.state.pending_confirmation:
|
||
self._print_confirmation()
|
||
self.output(f"checkpoint: {self.state.checkpoint_path or self.checkpoint_path}")
|
||
self.output("请修正参数或外部环境后,使用 load checkpoint <路径> / resume 继续,或重新 run。")
|
||
|
||
def _handle_execution_interrupt(self) -> None:
|
||
"""处理执行中的用户中断,并保留断点。"""
|
||
if self.state is None:
|
||
self.output("执行已中断。")
|
||
logger.info("chat 执行中断时没有 state")
|
||
return
|
||
self.graph_runtime = None
|
||
logger.info("chat 执行被用户中断 run_id=%s checkpoint=%s", self.state.run_id, self.state.checkpoint_path)
|
||
self.state = self.agent.pause_state(
|
||
self.state,
|
||
reason="user_interrupted",
|
||
review_context={"type": "user_interrupt", "message": "用户手动中断执行"},
|
||
)
|
||
self.output("执行已由用户中断,当前 checkpoint 已保存。")
|
||
self._print_pause_context()
|
||
self.output(f"checkpoint: {self.state.checkpoint_path or self.checkpoint_path}")
|
||
|
||
def _apply_graph_result(self, result: LangGraphRunResult) -> None:
|
||
"""把 LangGraph 运行结果同步回 chat 会话并输出用户可见状态。"""
|
||
if result.state is not None:
|
||
self.state = result.state
|
||
if self.state is None:
|
||
self.output("当前没有运行状态。")
|
||
return
|
||
logger.info(
|
||
"chat 应用 LangGraph 结果 run_id=%s interrupted=%s pending=%s paused=%s",
|
||
self.state.run_id,
|
||
result.interrupted,
|
||
self.state.pending_confirmation,
|
||
self.state.paused,
|
||
)
|
||
self.output(result.report or self.agent.render_report(self.state))
|
||
if result.interrupted and result.confirmation:
|
||
self._print_confirmation_request(result.confirmation)
|
||
elif self.state.pending_confirmation:
|
||
self._print_confirmation()
|
||
self._print_pause_context()
|
||
self.output(f"checkpoint: {self.state.checkpoint_path or self.checkpoint_path}")
|
||
|
||
def _print_state_report_and_checkpoint(self) -> None:
|
||
"""输出本地执行路径的状态报告和 checkpoint。"""
|
||
if self.state is None:
|
||
return
|
||
self.output(self.agent.render_report(self.state))
|
||
self._print_pause_context()
|
||
if self.state.pending_confirmation:
|
||
self._print_confirmation()
|
||
self.output(f"checkpoint: {self.state.checkpoint_path or self.checkpoint_path}")
|
||
|
||
def _status(self) -> None:
|
||
"""输出当前运行状态;没有 state 时输出 checkpoint 路径。"""
|
||
if self.state is None:
|
||
self.output("当前还没有运行状态。")
|
||
self.output(f"checkpoint: {self.checkpoint_path}")
|
||
return
|
||
self.output(self.agent.render_report(self.state))
|
||
self._print_pause_context()
|
||
if self.state.pending_confirmation:
|
||
self._print_confirmation()
|
||
|
||
def _confirm(self, *, approved: bool, note: str = "") -> None:
|
||
"""处理 approve/reject 命令。"""
|
||
if self.state is None:
|
||
checkpoint = Path(self.checkpoint_path)
|
||
if checkpoint.exists():
|
||
logger.info("chat confirm 从 checkpoint 加载 path=%s", checkpoint)
|
||
self.state = load_agent_state(checkpoint)
|
||
self.state.checkpoint_path = self.state.checkpoint_path or str(checkpoint)
|
||
else:
|
||
self.output("当前没有待确认任务。")
|
||
logger.info("chat confirm 无 state 且 checkpoint 不存在 path=%s", checkpoint)
|
||
return
|
||
if not self.state.pending_confirmation:
|
||
self.output("当前没有待确认任务。")
|
||
logger.info("chat confirm 无待确认事项 run_id=%s", self.state.run_id)
|
||
return
|
||
|
||
logger.info(
|
||
"chat confirm 开始 run_id=%s approved=%s pending=%s note_len=%s",
|
||
self.state.run_id,
|
||
approved,
|
||
self.state.pending_confirmation,
|
||
len(note),
|
||
)
|
||
if self.graph_runtime and self.graph_runtime.waiting_confirmation:
|
||
try:
|
||
result = self.graph_runtime.resume(approved=approved, note=note)
|
||
except RuntimeError as exc:
|
||
logger.exception("chat LangGraph 确认恢复失败,降级本地确认 run_id=%s", self.state.run_id)
|
||
self.output(f"LangGraph 确认恢复失败,降级为本地确认: {exc}")
|
||
else:
|
||
self._apply_graph_result(result)
|
||
return
|
||
|
||
self.state = self.agent.confirm_pending(self.state, approved=approved, operator_note=note)
|
||
logger.info(
|
||
"chat confirm 完成 run_id=%s pending=%s paused=%s",
|
||
self.state.run_id,
|
||
self.state.pending_confirmation,
|
||
self.state.paused,
|
||
)
|
||
self.output(self.agent.render_report(self.state))
|
||
self._print_pause_context()
|
||
if self.state.pending_confirmation:
|
||
self._print_confirmation()
|
||
|
||
def _rollback(self, text: str) -> None:
|
||
"""显式执行单 IP 回滚;主 workflow 不再自动触发回滚。"""
|
||
if self.state is None:
|
||
checkpoint = Path(self.checkpoint_path)
|
||
if checkpoint.exists():
|
||
logger.info("chat rollback 从 checkpoint 加载 path=%s", checkpoint)
|
||
self.state = load_agent_state(checkpoint)
|
||
self.state.checkpoint_path = self.state.checkpoint_path or str(checkpoint)
|
||
else:
|
||
self.output("当前没有可回滚的运行状态。")
|
||
logger.info("chat rollback 无 state 且 checkpoint 不存在 path=%s", checkpoint)
|
||
return
|
||
try:
|
||
ip, stop_first, note = _parse_rollback_args(text)
|
||
except ValueError as exc:
|
||
self.output(f"rollback 参数错误: {exc}")
|
||
return
|
||
ip = ip or _find_current_failed_ip(self.state)
|
||
if not ip:
|
||
self.output("未找到当前失败 IP,请使用 rollback <IP> 指定。")
|
||
logger.info("chat rollback 未找到可回滚 IP run_id=%s", self.state.run_id)
|
||
return
|
||
logger.info(
|
||
"chat rollback 开始 run_id=%s ip=%s stop_first=%s note_len=%s",
|
||
self.state.run_id,
|
||
ip,
|
||
stop_first,
|
||
len(note),
|
||
)
|
||
self.graph_runtime = None
|
||
try:
|
||
self.state = self.agent.rollback_ip(
|
||
self.state,
|
||
ip,
|
||
stop_first=stop_first,
|
||
operator_note=note,
|
||
)
|
||
except Exception as exc:
|
||
logger.exception("chat rollback 执行失败 run_id=%s ip=%s", self.state.run_id, ip)
|
||
self.output(f"rollback 执行失败: {exc}")
|
||
self._print_pause_context()
|
||
return
|
||
self.output(self.agent.render_report(self.state))
|
||
self._print_pause_context()
|
||
self.output(f"checkpoint: {self.state.checkpoint_path or self.checkpoint_path}")
|
||
if not self.state.paused:
|
||
self.output("回滚已完成;如需继续主流程,输入 resume。")
|
||
logger.info(
|
||
"chat rollback 完成 run_id=%s ip=%s status=%s paused=%s",
|
||
self.state.run_id,
|
||
ip,
|
||
self.state.ip_states.get(ip, {}).get("rollback_status"),
|
||
self.state.paused,
|
||
)
|
||
|
||
def _sync_params_to_state(self) -> None:
|
||
"""若当前已有 state,则把热更新参数同步到 checkpoint/config。"""
|
||
if self.state is None:
|
||
return
|
||
try:
|
||
self.state = self.agent.update_state_params(self.state, self.params)
|
||
except ValueError as exc:
|
||
logger.exception("chat 参数同步到 state 失败 run_id=%s params=%s", self.state.run_id, json_for_log(self.params))
|
||
self.output(f"参数同步到当前任务失败: {exc}")
|
||
return
|
||
self.params = dict(self.state.params)
|
||
if self.target_ips:
|
||
self.state.target_ips = list(self.target_ips)
|
||
logger.info(
|
||
"chat 参数已同步到 state run_id=%s checkpoint=%s params=%s target_ips=%s",
|
||
self.state.run_id,
|
||
self.state.checkpoint_path,
|
||
json_for_log(self.params),
|
||
self.target_ips,
|
||
)
|
||
|
||
def _print_pause_context(self) -> None:
|
||
"""输出暂停原因和审核建议,避免黑盒暂停。"""
|
||
if self.state is None or not self.state.paused:
|
||
return
|
||
context = self.state.review_context or {}
|
||
reason = self.state.pause_reason or "unknown"
|
||
self.output(f"当前流程已暂停: {reason}")
|
||
if context.get("stage"):
|
||
self.output(f"- stage: {context.get('stage')}")
|
||
if context.get("ip"):
|
||
self.output(f"- ip: {context.get('ip')}")
|
||
if context.get("possible_reason"):
|
||
self.output(f"- reason: {context.get('possible_reason')}")
|
||
elif context.get("error_summary"):
|
||
self.output(f"- reason: {context.get('error_summary')}")
|
||
if context.get("suggested_action"):
|
||
self.output(f"- suggestion: {context.get('suggested_action')}")
|
||
if context.get("severity"):
|
||
self.output(f"- severity: {context.get('severity')}")
|
||
if context.get("notes"):
|
||
self.output("- notes: " + "; ".join(str(item) for item in context.get("notes", [])))
|
||
if reason == "user_interrupted":
|
||
self.output("输入 resume 可从当前 checkpoint 继续。")
|
||
elif reason == "llm_review_blocked":
|
||
self.output("请根据以上建议判断后续;如需继续,输入 resume。")
|
||
elif reason == "action_failed":
|
||
ip = context.get("ip")
|
||
rollback_hint = f"rollback {ip}" if ip else "rollback <IP>"
|
||
self.output(f"请修复失败原因后输入 resume 重试当前 action;如需回滚,输入 {rollback_hint}。")
|
||
elif reason == "rollback_failed":
|
||
self.output("请检查回滚失败原因;修复后可再次输入 rollback 重试,或人工处理后再 resume。")
|
||
|
||
def _on_progress(self, payload: dict[str, Any]) -> None:
|
||
"""把 Agent action 进度转成 chat 可见输出。"""
|
||
event_type = str(payload.get("type", ""))
|
||
stage = str(payload.get("stage", ""))
|
||
backend = str(payload.get("backend", ""))
|
||
ip = str(payload.get("ip", ""))
|
||
message = str(payload.get("message", ""))
|
||
suffix_parts = []
|
||
if backend:
|
||
suffix_parts.append(f"backend={backend}")
|
||
if ip:
|
||
suffix_parts.append(f"ip={ip}")
|
||
suffix = f" [{', '.join(suffix_parts)}]" if suffix_parts else ""
|
||
if event_type == "ACTION_START":
|
||
self.output(f"开始执行 action: {stage}{suffix}")
|
||
elif event_type == "ACTION_DONE":
|
||
detail = f": {message}" if message and message != "ok" else ""
|
||
self.output(f"完成 action: {stage}{suffix}{detail}")
|
||
elif event_type == "ACTION_FAIL":
|
||
detail = f": {message}" if message else ""
|
||
self.output(f"失败 action: {stage}{suffix}{detail}")
|
||
elif event_type == "ACTION_REVIEW_START":
|
||
self.output(f"开始分析 action 结果: {stage}{suffix}")
|
||
elif event_type == "ACTION_REVIEW_DONE":
|
||
detail = f": {message}" if message else ""
|
||
self.output(f"分析完成: {stage}{suffix}{detail}")
|
||
elif event_type == "ACTION_REVIEW_FAIL":
|
||
detail = f": {message}" if message else ""
|
||
self.output(f"分析失败: {stage}{suffix}{detail}")
|
||
logger.info("chat progress event=%s", json_for_log(payload))
|
||
|
||
def _print_confirmation(self) -> None:
|
||
"""输出当前待人工确认事项。"""
|
||
if self.state is None:
|
||
return
|
||
request = self.agent.build_confirmation_request(self.state)
|
||
if not request:
|
||
return
|
||
self._print_confirmation_request(request)
|
||
|
||
def _print_confirmation_request(self, request: dict[str, Any]) -> None:
|
||
"""输出指定的人工确认请求。"""
|
||
self.output("需要人工确认:")
|
||
self.output(f"- type: {request.get('type')}")
|
||
if request.get("ip"):
|
||
self.output(f"- ip: {request['ip']}")
|
||
if request.get("failed_stage"):
|
||
self.output(f"- failed_stage: {request['failed_stage']}")
|
||
if request.get("failure_reason"):
|
||
self.output(f"- reason: {request['failure_reason']}")
|
||
self.output("输入 approve 执行回滚,或 reject [原因] 拒绝回滚。")
|
||
|
||
def _ask_yes_no(self, prompt: str) -> bool:
|
||
"""读取一次 yes/no 确认,只有 yes/y 视为确认。"""
|
||
try:
|
||
answer = self.input(prompt).strip().lower()
|
||
except EOFError:
|
||
return False
|
||
return answer in ("yes", "y")
|
||
|
||
def _load_existing_checkpoint_if_any(self) -> None:
|
||
"""会话启动时自动加载已存在的 checkpoint。"""
|
||
checkpoint = Path(self.checkpoint_path)
|
||
if not checkpoint.exists():
|
||
return
|
||
logger.info("chat 启动时自动加载已有 checkpoint path=%s", checkpoint)
|
||
self.state = load_agent_state(checkpoint)
|
||
self.state.checkpoint_path = self.state.checkpoint_path or str(checkpoint)
|
||
self.output(f"已加载 checkpoint: {checkpoint}")
|
||
if self.state.pending_confirmation:
|
||
self._print_confirmation()
|
||
self._print_pause_context()
|
||
|
||
|
||
def run_interactive_chat(
|
||
*,
|
||
agent: PamDeployAgent,
|
||
params: dict[str, Any],
|
||
strategy: ExecutionStrategy,
|
||
checkpoint_path: str | None = None,
|
||
target_ips: list[str] | None = None,
|
||
input_func: InputFunc = input,
|
||
output_func: OutputFunc = print,
|
||
) -> InteractiveCliSession:
|
||
"""创建并运行交互式 CLI 会话,返回会话对象便于测试。"""
|
||
session = InteractiveCliSession(
|
||
agent=agent,
|
||
params=params,
|
||
strategy=strategy,
|
||
checkpoint_path=checkpoint_path,
|
||
target_ips=target_ips,
|
||
input_func=input_func,
|
||
output_func=output_func,
|
||
)
|
||
session.run()
|
||
return session
|
||
|
||
|
||
def _default_checkpoint_path() -> str:
|
||
"""生成默认 chat checkpoint 路径。"""
|
||
return str(Path("runtime") / "checkpoints" / f"chat_{time.strftime('%Y%m%d_%H%M%S')}.json")
|
||
|
||
|
||
def _choose_strategy(preference: str, default: ExecutionStrategy) -> ExecutionStrategy:
|
||
"""根据 LLM 偏好更新执行策略,非法值保留默认策略。"""
|
||
if preference in ("hybrid_node_mcp", "script_only", "fake"):
|
||
return preference # type: ignore[return-value]
|
||
return default
|
||
|
||
|
||
def _format_redacted_params(params: dict[str, Any]) -> str:
|
||
"""把脱敏后的参数字典格式化为多行文本。"""
|
||
lines = ["当前参数:"]
|
||
for key in sorted(params):
|
||
lines.append(f"- {key}: {params[key]}")
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _parse_key_values(parts: list[str]) -> dict[str, str]:
|
||
"""解析 KEY=VALUE 参数列表。"""
|
||
values: dict[str, str] = {}
|
||
for part in parts:
|
||
if "=" not in part:
|
||
continue
|
||
key, value = part.split("=", 1)
|
||
if key:
|
||
values[key] = value
|
||
return values
|
||
|
||
|
||
def _parse_rollback_args(text: str) -> tuple[str, bool | None, str]:
|
||
"""解析 chat rollback 命令参数,返回 IP、停机覆盖值和备注。"""
|
||
try:
|
||
parts = shlex.split(text)
|
||
except ValueError as exc:
|
||
raise ValueError(str(exc)) from exc
|
||
ip = ""
|
||
stop_first: bool | None = None
|
||
note_parts: list[str] = []
|
||
index = 0
|
||
while index < len(parts):
|
||
part = parts[index]
|
||
if part == "--stop-first":
|
||
stop_first = True
|
||
elif part == "--no-stop-first":
|
||
stop_first = False
|
||
elif part in ("--note", "-n"):
|
||
index += 1
|
||
if index >= len(parts):
|
||
raise ValueError("--note 需要提供备注内容")
|
||
note_parts.append(parts[index])
|
||
elif not ip:
|
||
ip = part
|
||
else:
|
||
note_parts.append(part)
|
||
index += 1
|
||
return ip, stop_first, " ".join(note_parts)
|
||
|
||
|
||
def _find_current_failed_ip(state: AgentState) -> str:
|
||
"""从当前 state 中找一个适合显式回滚的失败 IP。"""
|
||
context_ip = str((state.review_context or {}).get("ip", ""))
|
||
if context_ip and context_ip in state.ip_states:
|
||
return context_ip
|
||
for ip, ip_state in state.ip_states.items():
|
||
if ip_state.get("status") == "FAILED":
|
||
return ip
|
||
return ""
|
||
|
||
|
||
def _is_small_talk(text: str) -> bool:
|
||
"""识别不应触发 LLM/结构化分析的简单寒暄。"""
|
||
normalized = text.strip().lower()
|
||
return normalized in {
|
||
"你好",
|
||
"您好",
|
||
"hello",
|
||
"hi",
|
||
"hey",
|
||
"在吗",
|
||
"谢谢",
|
||
"thanks",
|
||
"thank you",
|
||
}
|
||
|
||
|
||
def _looks_like_deploy_request(text: str) -> bool:
|
||
"""粗筛自然语言部署需求,避免任意闲聊都触发耗时分析。"""
|
||
lowered = text.lower()
|
||
deploy_keywords = (
|
||
"部署",
|
||
"发布",
|
||
"升级",
|
||
"回滚",
|
||
"预演",
|
||
"执行",
|
||
"pam",
|
||
"mcp",
|
||
"node",
|
||
"版本",
|
||
"机场",
|
||
"deploy",
|
||
"release",
|
||
"upgrade",
|
||
"rollback",
|
||
"preview",
|
||
)
|
||
param_markers = (
|
||
"HOME_BASE_URL",
|
||
"CLIENT_ID",
|
||
"AIRPORT_CODE",
|
||
"APP_NAME",
|
||
"MODULE_NAME",
|
||
"VERSION_NUMBER",
|
||
"ZIP_FILE_PATH",
|
||
)
|
||
return any(keyword in lowered for keyword in deploy_keywords) or any(marker in text for marker in param_markers)
|
||
|
||
|
||
def _path_exists(path: str) -> bool:
|
||
"""检查本地路径是否存在,兼容打包到 Linux 后的绝对路径。"""
|
||
if not path:
|
||
return False
|
||
return Path(path).expanduser().exists()
|
||
|
||
|
||
def _build_prompt_input(input_func: InputFunc) -> InputFunc:
|
||
"""如果安装了 prompt_toolkit,则启用历史记录和命令补全。"""
|
||
if input_func is not builtins.input:
|
||
return input_func
|
||
try:
|
||
from prompt_toolkit import PromptSession
|
||
from prompt_toolkit.completion import WordCompleter
|
||
from prompt_toolkit.history import FileHistory
|
||
except ImportError:
|
||
return _build_readline_input(input_func)
|
||
|
||
commands = [
|
||
"help",
|
||
"preview",
|
||
"analyze",
|
||
"params",
|
||
"events",
|
||
"set",
|
||
"llm config",
|
||
"llm test",
|
||
"llm fallback",
|
||
"llm action-analysis on",
|
||
"llm action-analysis off",
|
||
"mcp config",
|
||
"run",
|
||
"status",
|
||
"resume",
|
||
"rollback",
|
||
"rollback --stop-first",
|
||
"list checkpoints",
|
||
"load params",
|
||
"load checkpoint",
|
||
"checkpoint",
|
||
"exit",
|
||
]
|
||
history = None
|
||
try:
|
||
history_path = Path("runtime") / "chat_history.txt"
|
||
history_path.parent.mkdir(parents=True, exist_ok=True)
|
||
history = FileHistory(str(history_path))
|
||
except OSError:
|
||
history = None
|
||
|
||
try:
|
||
session = PromptSession(
|
||
history=history,
|
||
completer=WordCompleter(commands, ignore_case=True, sentence=True),
|
||
)
|
||
except Exception:
|
||
logger.exception("chat prompt_toolkit 初始化失败,降级为普通 input")
|
||
return _build_readline_input(input_func)
|
||
return session.prompt
|
||
|
||
|
||
def _build_readline_input(input_func: InputFunc) -> InputFunc:
|
||
"""在没有 prompt_toolkit 时尽量启用 GNU readline,改善 Linux 终端退格键兼容。"""
|
||
if input_func is not builtins.input or os.name == "nt":
|
||
return input_func
|
||
try:
|
||
import readline # noqa: F401
|
||
except ImportError:
|
||
logger.debug("chat readline 不可用,使用普通 input")
|
||
return input_func
|
||
|
||
|
||
def _build_output_func(output_func: OutputFunc) -> OutputFunc:
|
||
"""如果安装了 rich,则使用 rich 输出;否则保持原输出函数。"""
|
||
if output_func is not builtins.print:
|
||
return output_func
|
||
try:
|
||
from rich.console import Console
|
||
from rich.markdown import Markdown
|
||
except ImportError:
|
||
return output_func
|
||
console = Console()
|
||
|
||
def rich_print(value: str) -> None:
|
||
text = str(value)
|
||
stripped = text.lstrip()
|
||
if stripped.startswith("{") or stripped.startswith("["):
|
||
try:
|
||
console.print_json(text)
|
||
return
|
||
except Exception:
|
||
pass
|
||
if text.startswith("## ") or "\n| ---" in text:
|
||
console.print(Markdown(text))
|
||
return
|
||
console.print(text)
|
||
|
||
return rich_print
|