- 新增 MCP client 配置加载,支持 CLI/chat 通过配置文件接入 MCP - 完善 chat 交互命令,支持参数查看、事件查看、checkpoint 列表与加载 - 增加 LLM action 后诊断能力,支持真实 LLM 和本地规则兜底 - 将 chat 人工确认点接入 LangGraph interrupt/checkpointer - 更新 README、流程图、待办文档和打包说明 - 补充相关单元测试
168 lines
6.1 KiB
Python
168 lines
6.1 KiB
Python
"""MCP client 适配器。
|
||
|
||
Agent 只依赖同步的 `call_tool(name, arguments)` 接口。本模块把普通
|
||
callable 或 SDK session 适配成这个接口,避免业务代码绑定具体 MCP SDK。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
from datetime import timedelta
|
||
from collections.abc import Callable
|
||
from dataclasses import dataclass, field
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class McpClientConfig:
|
||
"""真实 MCP session 建立后需要传给 runner 的配置。"""
|
||
|
||
server_name: str = "pam-node"
|
||
transport: str = "stdio"
|
||
command: str = ""
|
||
args: list[str] = field(default_factory=list)
|
||
env: dict[str, str] | None = None
|
||
cwd: str = ""
|
||
timeout_seconds: float = 60
|
||
tool_names: dict[str, str] = field(default_factory=dict)
|
||
|
||
@classmethod
|
||
def from_mapping(cls, payload: dict[str, Any]) -> "McpClientConfig":
|
||
"""从 JSON 字典构造 MCP client 配置。"""
|
||
tool_names = payload.get("tool_names") or payload.get("tools") or {}
|
||
if not isinstance(tool_names, dict):
|
||
raise ValueError("MCP tool_names 必须是 JSON object")
|
||
args = payload.get("args") or []
|
||
if not isinstance(args, list):
|
||
raise ValueError("MCP args 必须是数组")
|
||
env = payload.get("env")
|
||
if env is not None and not isinstance(env, dict):
|
||
raise ValueError("MCP env 必须是 JSON object")
|
||
return cls(
|
||
server_name=str(payload.get("server_name", "pam-node")),
|
||
transport=str(payload.get("transport", "stdio")),
|
||
command=str(payload.get("command", "")),
|
||
args=[str(item) for item in args],
|
||
env={str(key): str(value) for key, value in env.items()} if env else None,
|
||
cwd=str(payload.get("cwd", "")),
|
||
timeout_seconds=float(payload.get("timeout_seconds", 60)),
|
||
tool_names={str(key): str(value) for key, value in tool_names.items()},
|
||
)
|
||
|
||
|
||
def load_mcp_client_config(path: str | Path) -> McpClientConfig:
|
||
"""读取 MCP client JSON 配置文件。"""
|
||
payload = json.loads(Path(path).read_text(encoding="utf-8"))
|
||
if not isinstance(payload, dict):
|
||
raise ValueError("MCP client 配置必须是 JSON object")
|
||
return McpClientConfig.from_mapping(payload)
|
||
|
||
|
||
class FunctionMcpToolClient:
|
||
"""把普通 Python callable 包装为 MCP tool client。"""
|
||
|
||
def __init__(self, caller: Callable[[str, dict[str, Any]], Any]) -> None:
|
||
"""保存实际执行工具调用的函数。"""
|
||
self.caller = caller
|
||
|
||
def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||
"""调用底层函数并返回原始结果。"""
|
||
return self.caller(tool_name, arguments)
|
||
|
||
|
||
class SessionMcpToolClient:
|
||
"""适配暴露 `call_tool` 的 MCP SDK session。
|
||
|
||
适配器接受常见返回形态:
|
||
|
||
- 原始 dict/list/string
|
||
- 带有 `structuredContent` 的对象
|
||
- 带有 `content` 的对象,其中 text 内容可能是 JSON
|
||
"""
|
||
|
||
def __init__(self, session: Any) -> None:
|
||
"""校验并保存 MCP SDK session。"""
|
||
if not hasattr(session, "call_tool"):
|
||
raise TypeError("MCP session 必须暴露 call_tool 方法")
|
||
self.session = session
|
||
|
||
def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||
"""调用 SDK session,并把 SDK 返回值归一化。"""
|
||
result = self.session.call_tool(tool_name, arguments)
|
||
return normalize_mcp_sdk_result(result)
|
||
|
||
|
||
class StdioMcpToolClient:
|
||
"""通过 MCP Python SDK 启动 stdio server 并调用 tool。"""
|
||
|
||
def __init__(
|
||
self,
|
||
*,
|
||
command: str,
|
||
args: list[str] | None = None,
|
||
env: dict[str, str] | None = None,
|
||
cwd: str | None = None,
|
||
timeout_seconds: float = 60,
|
||
) -> None:
|
||
"""保存 stdio server 启动参数。"""
|
||
if not command:
|
||
raise ValueError("stdio MCP 配置必须提供 command")
|
||
self.command = command
|
||
self.args = list(args or [])
|
||
self.env = env
|
||
self.cwd = cwd or None
|
||
self.timeout_seconds = timeout_seconds
|
||
|
||
def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||
"""创建一次 MCP stdio session,调用 tool 后关闭 session。"""
|
||
try:
|
||
import anyio
|
||
from mcp import ClientSession
|
||
from mcp.client.stdio import StdioServerParameters, stdio_client
|
||
except ImportError as exc: # pragma: no cover - 依赖安装状态
|
||
raise RuntimeError("未安装 MCP Python SDK,请安装项目的 mcp 可选依赖") from exc
|
||
|
||
async def call_once() -> Any:
|
||
server = StdioServerParameters(
|
||
command=self.command,
|
||
args=self.args,
|
||
env=self.env,
|
||
cwd=self.cwd,
|
||
)
|
||
async with stdio_client(server) as (read_stream, write_stream):
|
||
async with ClientSession(read_stream, write_stream) as session:
|
||
await session.initialize()
|
||
result = await session.call_tool(
|
||
tool_name,
|
||
arguments,
|
||
read_timeout_seconds=timedelta(seconds=self.timeout_seconds),
|
||
)
|
||
return normalize_mcp_sdk_result(result)
|
||
|
||
return anyio.run(call_once)
|
||
|
||
|
||
def normalize_mcp_sdk_result(result: Any) -> Any:
|
||
"""把常见 MCP SDK 返回结构归一化成 dict/list/string。"""
|
||
if hasattr(result, "structuredContent"):
|
||
structured = getattr(result, "structuredContent")
|
||
if structured is not None:
|
||
return structured
|
||
|
||
if hasattr(result, "content"):
|
||
content = getattr(result, "content")
|
||
text_parts: list[str] = []
|
||
for item in content or []:
|
||
text = getattr(item, "text", None)
|
||
if text is not None:
|
||
text_parts.append(text)
|
||
if text_parts:
|
||
joined = "\n".join(text_parts)
|
||
try:
|
||
return json.loads(joined)
|
||
except json.JSONDecodeError:
|
||
return joined
|
||
|
||
return result
|