dark d01c4d3d06 feat: 完善交互式部署与 MCP/LLM 配置能力
- 新增 MCP client 配置加载,支持 CLI/chat 通过配置文件接入 MCP
- 完善 chat 交互命令,支持参数查看、事件查看、checkpoint 列表与加载
- 增加 LLM action 后诊断能力,支持真实 LLM 和本地规则兜底
- 将 chat 人工确认点接入 LangGraph interrupt/checkpointer
- 更新 README、流程图、待办文档和打包说明
- 补充相关单元测试
2026-06-01 16:45:52 +08:00

168 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""MCP client 适配器。
Agent 只依赖同步的 `call_tool(name, arguments)` 接口。本模块把普通
callable 或 SDK session 适配成这个接口,避免业务代码绑定具体 MCP SDK。
"""
from __future__ import annotations
import json
from datetime import timedelta
from collections.abc import Callable
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
@dataclass(frozen=True)
class McpClientConfig:
"""真实 MCP session 建立后需要传给 runner 的配置。"""
server_name: str = "pam-node"
transport: str = "stdio"
command: str = ""
args: list[str] = field(default_factory=list)
env: dict[str, str] | None = None
cwd: str = ""
timeout_seconds: float = 60
tool_names: dict[str, str] = field(default_factory=dict)
@classmethod
def from_mapping(cls, payload: dict[str, Any]) -> "McpClientConfig":
"""从 JSON 字典构造 MCP client 配置。"""
tool_names = payload.get("tool_names") or payload.get("tools") or {}
if not isinstance(tool_names, dict):
raise ValueError("MCP tool_names 必须是 JSON object")
args = payload.get("args") or []
if not isinstance(args, list):
raise ValueError("MCP args 必须是数组")
env = payload.get("env")
if env is not None and not isinstance(env, dict):
raise ValueError("MCP env 必须是 JSON object")
return cls(
server_name=str(payload.get("server_name", "pam-node")),
transport=str(payload.get("transport", "stdio")),
command=str(payload.get("command", "")),
args=[str(item) for item in args],
env={str(key): str(value) for key, value in env.items()} if env else None,
cwd=str(payload.get("cwd", "")),
timeout_seconds=float(payload.get("timeout_seconds", 60)),
tool_names={str(key): str(value) for key, value in tool_names.items()},
)
def load_mcp_client_config(path: str | Path) -> McpClientConfig:
"""读取 MCP client JSON 配置文件。"""
payload = json.loads(Path(path).read_text(encoding="utf-8"))
if not isinstance(payload, dict):
raise ValueError("MCP client 配置必须是 JSON object")
return McpClientConfig.from_mapping(payload)
class FunctionMcpToolClient:
"""把普通 Python callable 包装为 MCP tool client。"""
def __init__(self, caller: Callable[[str, dict[str, Any]], Any]) -> None:
"""保存实际执行工具调用的函数。"""
self.caller = caller
def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""调用底层函数并返回原始结果。"""
return self.caller(tool_name, arguments)
class SessionMcpToolClient:
"""适配暴露 `call_tool` 的 MCP SDK session。
适配器接受常见返回形态:
- 原始 dict/list/string
- 带有 `structuredContent` 的对象
- 带有 `content` 的对象,其中 text 内容可能是 JSON
"""
def __init__(self, session: Any) -> None:
"""校验并保存 MCP SDK session。"""
if not hasattr(session, "call_tool"):
raise TypeError("MCP session 必须暴露 call_tool 方法")
self.session = session
def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""调用 SDK session并把 SDK 返回值归一化。"""
result = self.session.call_tool(tool_name, arguments)
return normalize_mcp_sdk_result(result)
class StdioMcpToolClient:
"""通过 MCP Python SDK 启动 stdio server 并调用 tool。"""
def __init__(
self,
*,
command: str,
args: list[str] | None = None,
env: dict[str, str] | None = None,
cwd: str | None = None,
timeout_seconds: float = 60,
) -> None:
"""保存 stdio server 启动参数。"""
if not command:
raise ValueError("stdio MCP 配置必须提供 command")
self.command = command
self.args = list(args or [])
self.env = env
self.cwd = cwd or None
self.timeout_seconds = timeout_seconds
def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""创建一次 MCP stdio session调用 tool 后关闭 session。"""
try:
import anyio
from mcp import ClientSession
from mcp.client.stdio import StdioServerParameters, stdio_client
except ImportError as exc: # pragma: no cover - 依赖安装状态
raise RuntimeError("未安装 MCP Python SDK请安装项目的 mcp 可选依赖") from exc
async def call_once() -> Any:
server = StdioServerParameters(
command=self.command,
args=self.args,
env=self.env,
cwd=self.cwd,
)
async with stdio_client(server) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
result = await session.call_tool(
tool_name,
arguments,
read_timeout_seconds=timedelta(seconds=self.timeout_seconds),
)
return normalize_mcp_sdk_result(result)
return anyio.run(call_once)
def normalize_mcp_sdk_result(result: Any) -> Any:
"""把常见 MCP SDK 返回结构归一化成 dict/list/string。"""
if hasattr(result, "structuredContent"):
structured = getattr(result, "structuredContent")
if structured is not None:
return structured
if hasattr(result, "content"):
content = getattr(result, "content")
text_parts: list[str] = []
for item in content or []:
text = getattr(item, "text", None)
if text is not None:
text_parts.append(text)
if text_parts:
joined = "\n".join(text_parts)
try:
return json.loads(joined)
except json.JSONDecodeError:
return joined
return result