"""MCP client 适配器。 Agent 只依赖同步的 `call_tool(name, arguments)` 接口。本模块把普通 callable 或 SDK session 适配成这个接口,避免业务代码绑定具体 MCP SDK。 """ from __future__ import annotations import json import logging import time import urllib.parse import urllib.request from datetime import timedelta from collections.abc import Callable from dataclasses import dataclass, field from pathlib import Path from typing import Any from .logging_utils import json_for_log logger = logging.getLogger(__name__) @dataclass(frozen=True) class McpAuthConfig: """MCP server 鉴权 token 配置。""" token_url: str = "" client_id: str = "" client_secret: str = "" grant_type: str = "client_credentials" header_name: str = "Authorization" header_prefix: str = "Bearer" token_field: str = "access_token" expires_in_field: str = "expires_in" extra_form: dict[str, str] = field(default_factory=dict) @classmethod def from_mapping(cls, payload: dict[str, Any] | None) -> "McpAuthConfig | None": """从 JSON auth 字典构造 MCP 鉴权配置。""" if not payload: return None if not isinstance(payload, dict): raise ValueError("MCP auth 必须是 JSON object") token_url = str(payload.get("token_url", "")) base_url = str(payload.get("base_url", "")) if not token_url and base_url: token_url = base_url.rstrip("/") + "/oauth/token" extra_form = payload.get("extra_form") or {} if not isinstance(extra_form, dict): raise ValueError("MCP auth.extra_form 必须是 JSON object") return cls( token_url=token_url, client_id=str(payload.get("client_id", "")), client_secret=str(payload.get("client_secret", "")), grant_type=str(payload.get("grant_type", "client_credentials")), header_name=str(payload.get("header_name", "Authorization")), header_prefix=str(payload.get("header_prefix", "Bearer")), token_field=str(payload.get("token_field", "access_token")), expires_in_field=str(payload.get("expires_in_field", "expires_in")), extra_form={str(key): str(value) for key, value in extra_form.items()}, ) @dataclass(frozen=True) class McpClientConfig: """真实 MCP session 建立后需要传给 runner 的配置。""" server_name: str = "pam-node" transport: str = "streamable_http" server_url: str = "" command: str = "" args: list[str] = field(default_factory=list) env: dict[str, str] | None = None cwd: str = "" headers: dict[str, str] = field(default_factory=dict) auth: McpAuthConfig | None = None timeout_seconds: float = 60 sse_read_timeout_seconds: float = 300 tool_names: dict[str, str] = field(default_factory=dict) @classmethod def from_mapping(cls, payload: dict[str, Any]) -> "McpClientConfig": """从 JSON 字典构造 MCP client 配置。""" tool_names = payload.get("tool_names") or payload.get("action_tools") or payload.get("tools") or {} if not isinstance(tool_names, dict): raise ValueError("MCP tool_names 必须是 JSON object") args = payload.get("args") or [] if not isinstance(args, list): raise ValueError("MCP args 必须是数组") env = payload.get("env") if env is not None and not isinstance(env, dict): raise ValueError("MCP env 必须是 JSON object") headers = payload.get("headers") or {} if not isinstance(headers, dict): raise ValueError("MCP headers 必须是 JSON object") server_url = str(payload.get("server_url") or payload.get("url") or "") command = str(payload.get("command", "")) transport = str(payload.get("transport") or ("stdio" if command else "streamable_http")) return cls( server_name=str(payload.get("server_name", "pam-node")), transport=transport, server_url=server_url, command=command, args=[str(item) for item in args], env={str(key): str(value) for key, value in env.items()} if env else None, cwd=str(payload.get("cwd", "")), headers={str(key): str(value) for key, value in headers.items()}, auth=McpAuthConfig.from_mapping(payload.get("auth")), timeout_seconds=float(payload.get("timeout_seconds", 60)), sse_read_timeout_seconds=float(payload.get("sse_read_timeout_seconds", 300)), tool_names={str(key): str(value) for key, value in tool_names.items()}, ) def load_mcp_client_config(path: str | Path) -> McpClientConfig: """读取 MCP client JSON 配置文件。""" logger.info("读取 MCP client 配置 path=%s", path) payload = json.loads(Path(path).read_text(encoding="utf-8")) if not isinstance(payload, dict): raise ValueError("MCP client 配置必须是 JSON object") config = McpClientConfig.from_mapping(payload) logger.info( "MCP client 配置读取完成 path=%s transport=%s server_url=%s command=%s has_auth=%s tool_names=%s", path, config.transport, config.server_url, config.command, config.auth is not None, json_for_log(config.tool_names), ) return config class FunctionMcpToolClient: """把普通 Python callable 包装为 MCP tool client。""" def __init__(self, caller: Callable[[str, dict[str, Any]], Any]) -> None: """保存实际执行工具调用的函数。""" self.caller = caller def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any: """调用底层函数并返回原始结果。""" logger.info("Function MCP tool 调用 tool=%s arguments=%s", tool_name, json_for_log(arguments)) return self.caller(tool_name, arguments) class SessionMcpToolClient: """适配暴露 `call_tool` 的 MCP SDK session。 适配器接受常见返回形态: - 原始 dict/list/string - 带有 `structuredContent` 的对象 - 带有 `content` 的对象,其中 text 内容可能是 JSON """ def __init__(self, session: Any) -> None: """校验并保存 MCP SDK session。""" if not hasattr(session, "call_tool"): raise TypeError("MCP session 必须暴露 call_tool 方法") self.session = session def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any: """调用 SDK session,并把 SDK 返回值归一化。""" logger.info("Session MCP tool 调用开始 tool=%s arguments=%s", tool_name, json_for_log(arguments)) result = self.session.call_tool(tool_name, arguments) normalized = normalize_mcp_sdk_result(result) logger.info("Session MCP tool 调用完成 tool=%s result=%s", tool_name, json_for_log(normalized, max_text_len=1600)) return normalized def list_tools(self) -> list[str]: """从 SDK session 获取 tool 名称列表。""" logger.info("Session MCP list_tools 开始") result = self.session.list_tools() tools = normalize_mcp_tool_list(result) logger.info("Session MCP list_tools 完成 tools=%s", tools) return tools class StdioMcpToolClient: """通过 MCP Python SDK 启动 stdio server 并调用 tool。""" def __init__( self, *, command: str, args: list[str] | None = None, env: dict[str, str] | None = None, cwd: str | None = None, timeout_seconds: float = 60, ) -> None: """保存 stdio server 启动参数。""" if not command: raise ValueError("stdio MCP 配置必须提供 command") self.command = command self.args = list(args or []) self.env = env self.cwd = cwd or None self.timeout_seconds = timeout_seconds logger.info( "stdio MCP client 初始化 command=%s args=%s cwd=%s env_keys=%s timeout=%s", self.command, self.args, self.cwd or "", sorted((self.env or {}).keys()), self.timeout_seconds, ) def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any: """创建一次 MCP stdio session,调用 tool 后关闭 session。""" started_at = time.perf_counter() logger.info("stdio MCP tool 调用开始 tool=%s arguments=%s", tool_name, json_for_log(arguments)) try: import anyio from mcp import ClientSession from mcp.client.stdio import StdioServerParameters, stdio_client except ImportError as exc: # pragma: no cover - 依赖安装状态 raise RuntimeError("未安装 MCP Python SDK,请安装项目的 mcp 可选依赖") from exc async def call_once() -> Any: server = StdioServerParameters( command=self.command, args=self.args, env=self.env, cwd=self.cwd, ) async with stdio_client(server) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: await session.initialize() result = await session.call_tool( tool_name, arguments, read_timeout_seconds=timedelta(seconds=self.timeout_seconds), ) return normalize_mcp_sdk_result(result) try: result = anyio.run(call_once) except Exception: logger.exception("stdio MCP tool 调用失败 tool=%s duration_ms=%s", tool_name, int((time.perf_counter() - started_at) * 1000)) raise logger.info( "stdio MCP tool 调用完成 tool=%s duration_ms=%s result=%s", tool_name, int((time.perf_counter() - started_at) * 1000), json_for_log(result, max_text_len=1600), ) return result def list_tools(self) -> list[str]: """创建一次 MCP stdio session,读取 server 暴露的 tool 列表。""" started_at = time.perf_counter() logger.info("stdio MCP list_tools 开始") try: import anyio from mcp import ClientSession from mcp.client.stdio import StdioServerParameters, stdio_client except ImportError as exc: # pragma: no cover - 依赖安装状态 raise RuntimeError("未安装 MCP Python SDK,请安装项目的 mcp 可选依赖") from exc async def list_once() -> list[str]: server = StdioServerParameters( command=self.command, args=self.args, env=self.env, cwd=self.cwd, ) async with stdio_client(server) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: await session.initialize() result = await session.list_tools() return normalize_mcp_tool_list(result) try: tools = anyio.run(list_once) except Exception: logger.exception("stdio MCP list_tools 失败 duration_ms=%s", int((time.perf_counter() - started_at) * 1000)) raise logger.info("stdio MCP list_tools 完成 duration_ms=%s tools=%s", int((time.perf_counter() - started_at) * 1000), tools) return tools class OAuthTokenProvider: """按 HOME 相同的 client_credentials 方式获取 MCP 鉴权 token。""" def __init__(self, config: McpAuthConfig, *, timeout_seconds: float = 30) -> None: """保存鉴权配置和 token 缓存。""" if not config.token_url: raise ValueError("MCP auth 必须提供 token_url 或 auth.base_url") if not config.client_id or not config.client_secret: raise ValueError("MCP auth 必须提供独立的 client_id 和 client_secret") self.config = config self.timeout_seconds = timeout_seconds self._token = "" self._expires_at = 0.0 logger.info( "MCP OAuth token provider 初始化 token_url=%s client_id=%s timeout=%s", self.config.token_url, self.config.client_id, self.timeout_seconds, ) def authorization_headers(self) -> dict[str, str]: """返回带 token 的请求头。""" token = self.get_token() prefix = self.config.header_prefix.strip() value = f"{prefix} {token}" if prefix else token return {self.config.header_name: value} def get_token(self) -> str: """获取可用 token,未过期时复用缓存。""" now = time.time() if self._token and now < self._expires_at: logger.info("MCP auth token 使用缓存 expires_in_sec=%s", int(self._expires_at - now)) return self._token logger.info("MCP auth token 开始刷新 token_url=%s client_id=%s", self.config.token_url, self.config.client_id) payload = { "grant_type": self.config.grant_type, "client_id": self.config.client_id, "client_secret": self.config.client_secret, **self.config.extra_form, } data = urllib.parse.urlencode(payload).encode("utf-8") request = urllib.request.Request( self.config.token_url, data=data, headers={"Content-Type": "application/x-www-form-urlencoded"}, method="POST", ) with urllib.request.urlopen(request, timeout=self.timeout_seconds) as response: raw = response.read().decode("utf-8") result = json.loads(raw) token = str(result.get(self.config.token_field, "")) if not token: raise RuntimeError("MCP auth token 响应缺少 access_token") expires_in = _safe_float(result.get(self.config.expires_in_field), 3600) self._token = token self._expires_at = now + max(expires_in - 60, 1) logger.info("MCP auth token 刷新完成 expires_in=%s cached_until=%s", expires_in, int(self._expires_at)) return token class HttpMcpToolClient: """通过 MCP HTTP/SSE server URL 调用 tool。""" def __init__( self, *, url: str, transport: str = "streamable_http", headers: dict[str, str] | None = None, auth_provider: OAuthTokenProvider | None = None, timeout_seconds: float = 60, sse_read_timeout_seconds: float = 300, ) -> None: """保存 HTTP/SSE MCP server 连接参数。""" if not url: raise ValueError("HTTP/SSE MCP 配置必须提供 server_url") if transport not in ("streamable_http", "sse"): raise ValueError(f"不支持的 HTTP MCP transport: {transport}") self.url = url self.transport = transport self.headers = dict(headers or {}) self.auth_provider = auth_provider self.timeout_seconds = timeout_seconds self.sse_read_timeout_seconds = sse_read_timeout_seconds logger.info( "HTTP MCP client 初始化 url=%s transport=%s has_auth=%s headers=%s timeout=%s sse_read_timeout=%s", self.url, self.transport, self.auth_provider is not None, json_for_log(self.headers), self.timeout_seconds, self.sse_read_timeout_seconds, ) def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any: """连接 MCP server,调用 tool 后关闭 session。""" return self._run_session(lambda session: session.call_tool(tool_name, arguments), operation_name=f"call_tool:{tool_name}", arguments=arguments) def list_tools(self) -> list[str]: """连接 MCP server,读取 server 暴露的 tool 名称。""" return self._run_session(lambda session: session.list_tools(), normalize_tools=True, operation_name="list_tools") def _build_headers(self) -> dict[str, str]: """合并静态 headers 和动态鉴权 token。""" headers = dict(self.headers) if self.auth_provider is not None: headers.update(self.auth_provider.authorization_headers()) return headers def _run_session( self, operation: Callable[[Any], Any], *, normalize_tools: bool = False, operation_name: str = "operation", arguments: dict[str, Any] | None = None, ) -> Any: """创建一次 HTTP/SSE MCP session 并执行指定操作。""" started_at = time.perf_counter() logger.info( "HTTP MCP session 开始 operation=%s url=%s transport=%s arguments=%s", operation_name, self.url, self.transport, json_for_log(arguments or {}), ) try: import anyio from mcp import ClientSession from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client except ImportError as exc: # pragma: no cover - 依赖安装状态 raise RuntimeError("未安装 MCP Python SDK,请安装项目的 mcp 可选依赖") from exc async def call_once() -> Any: headers = self._build_headers() if self.transport == "sse": async with sse_client( self.url, headers=headers, timeout=self.timeout_seconds, sse_read_timeout=self.sse_read_timeout_seconds, ) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: await session.initialize() result = await operation(session) return normalize_mcp_tool_list(result) if normalize_tools else normalize_mcp_sdk_result(result) async with streamablehttp_client( self.url, headers=headers, timeout=timedelta(seconds=self.timeout_seconds), sse_read_timeout=timedelta(seconds=self.sse_read_timeout_seconds), ) as streams: read_stream, write_stream = streams[0], streams[1] async with ClientSession(read_stream, write_stream) as session: await session.initialize() result = await operation(session) return normalize_mcp_tool_list(result) if normalize_tools else normalize_mcp_sdk_result(result) try: result = anyio.run(call_once) except Exception: logger.exception( "HTTP MCP session 失败 operation=%s url=%s transport=%s duration_ms=%s", operation_name, self.url, self.transport, int((time.perf_counter() - started_at) * 1000), ) raise logger.info( "HTTP MCP session 完成 operation=%s duration_ms=%s result=%s", operation_name, int((time.perf_counter() - started_at) * 1000), json_for_log(result, max_text_len=1600), ) return result def normalize_mcp_sdk_result(result: Any) -> Any: """把常见 MCP SDK 返回结构归一化成 dict/list/string。""" if hasattr(result, "structuredContent"): structured = getattr(result, "structuredContent") if structured is not None: return structured if hasattr(result, "content"): content = getattr(result, "content") text_parts: list[str] = [] for item in content or []: text = getattr(item, "text", None) if text is not None: text_parts.append(text) if text_parts: joined = "\n".join(text_parts) try: return json.loads(joined) except json.JSONDecodeError: return joined return result def normalize_mcp_tool_list(result: Any) -> list[str]: """把 MCP list_tools 返回值归一化为 tool name 列表。""" tools = getattr(result, "tools", None) if tools is None and isinstance(result, dict): tools = result.get("tools") names: list[str] = [] for item in tools or []: if isinstance(item, str): names.append(item) continue if isinstance(item, dict) and item.get("name"): names.append(str(item["name"])) continue name = getattr(item, "name", None) if name: names.append(str(name)) return names def _safe_float(value: Any, default: float) -> float: """把值安全转换为 float。""" try: return float(value) except (TypeError, ValueError): return default