"""MCP client 适配器。 Agent 只依赖同步的 `call_tool(name, arguments)` 接口。本模块把普通 callable 或 SDK session 适配成这个接口,避免业务代码绑定具体 MCP SDK。 """ from __future__ import annotations import json from datetime import timedelta from collections.abc import Callable from dataclasses import dataclass, field from pathlib import Path from typing import Any @dataclass(frozen=True) class McpClientConfig: """真实 MCP session 建立后需要传给 runner 的配置。""" server_name: str = "pam-node" transport: str = "stdio" command: str = "" args: list[str] = field(default_factory=list) env: dict[str, str] | None = None cwd: str = "" timeout_seconds: float = 60 tool_names: dict[str, str] = field(default_factory=dict) @classmethod def from_mapping(cls, payload: dict[str, Any]) -> "McpClientConfig": """从 JSON 字典构造 MCP client 配置。""" tool_names = payload.get("tool_names") or payload.get("tools") or {} if not isinstance(tool_names, dict): raise ValueError("MCP tool_names 必须是 JSON object") args = payload.get("args") or [] if not isinstance(args, list): raise ValueError("MCP args 必须是数组") env = payload.get("env") if env is not None and not isinstance(env, dict): raise ValueError("MCP env 必须是 JSON object") return cls( server_name=str(payload.get("server_name", "pam-node")), transport=str(payload.get("transport", "stdio")), command=str(payload.get("command", "")), args=[str(item) for item in args], env={str(key): str(value) for key, value in env.items()} if env else None, cwd=str(payload.get("cwd", "")), timeout_seconds=float(payload.get("timeout_seconds", 60)), tool_names={str(key): str(value) for key, value in tool_names.items()}, ) def load_mcp_client_config(path: str | Path) -> McpClientConfig: """读取 MCP client JSON 配置文件。""" payload = json.loads(Path(path).read_text(encoding="utf-8")) if not isinstance(payload, dict): raise ValueError("MCP client 配置必须是 JSON object") return McpClientConfig.from_mapping(payload) class FunctionMcpToolClient: """把普通 Python callable 包装为 MCP tool client。""" def __init__(self, caller: Callable[[str, dict[str, Any]], Any]) -> None: """保存实际执行工具调用的函数。""" self.caller = caller def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any: """调用底层函数并返回原始结果。""" return self.caller(tool_name, arguments) class SessionMcpToolClient: """适配暴露 `call_tool` 的 MCP SDK session。 适配器接受常见返回形态: - 原始 dict/list/string - 带有 `structuredContent` 的对象 - 带有 `content` 的对象,其中 text 内容可能是 JSON """ def __init__(self, session: Any) -> None: """校验并保存 MCP SDK session。""" if not hasattr(session, "call_tool"): raise TypeError("MCP session 必须暴露 call_tool 方法") self.session = session def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any: """调用 SDK session,并把 SDK 返回值归一化。""" result = self.session.call_tool(tool_name, arguments) return normalize_mcp_sdk_result(result) class StdioMcpToolClient: """通过 MCP Python SDK 启动 stdio server 并调用 tool。""" def __init__( self, *, command: str, args: list[str] | None = None, env: dict[str, str] | None = None, cwd: str | None = None, timeout_seconds: float = 60, ) -> None: """保存 stdio server 启动参数。""" if not command: raise ValueError("stdio MCP 配置必须提供 command") self.command = command self.args = list(args or []) self.env = env self.cwd = cwd or None self.timeout_seconds = timeout_seconds def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any: """创建一次 MCP stdio session,调用 tool 后关闭 session。""" try: import anyio from mcp import ClientSession from mcp.client.stdio import StdioServerParameters, stdio_client except ImportError as exc: # pragma: no cover - 依赖安装状态 raise RuntimeError("未安装 MCP Python SDK,请安装项目的 mcp 可选依赖") from exc async def call_once() -> Any: server = StdioServerParameters( command=self.command, args=self.args, env=self.env, cwd=self.cwd, ) async with stdio_client(server) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: await session.initialize() result = await session.call_tool( tool_name, arguments, read_timeout_seconds=timedelta(seconds=self.timeout_seconds), ) return normalize_mcp_sdk_result(result) return anyio.run(call_once) def normalize_mcp_sdk_result(result: Any) -> Any: """把常见 MCP SDK 返回结构归一化成 dict/list/string。""" if hasattr(result, "structuredContent"): structured = getattr(result, "structuredContent") if structured is not None: return structured if hasattr(result, "content"): content = getattr(result, "content") text_parts: list[str] = [] for item in content or []: text = getattr(item, "text", None) if text is not None: text_parts.append(text) if text_parts: joined = "\n".join(text_parts) try: return json.loads(joined) except json.JSONDecodeError: return joined return result