- 新增统一日志工具,支持日志文件路径和级别配置 - 记录 CLI/chat、Agent、LLM、action、MCP、LangGraph、checkpoint 等关键流程 - 对日志中的 token、secret、api_key、Authorization 等敏感信息做脱敏 - chat 新增 llm test 命令,用于验证当前 LLM client 是否正常加载 - 同步 README、打包文档和 run.sh 帮助说明 - 补充日志脱敏和 llm test 相关测试
514 lines
20 KiB
Python
514 lines
20 KiB
Python
"""MCP client 适配器。
|
||
|
||
Agent 只依赖同步的 `call_tool(name, arguments)` 接口。本模块把普通
|
||
callable 或 SDK session 适配成这个接口,避免业务代码绑定具体 MCP SDK。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import time
|
||
import urllib.parse
|
||
import urllib.request
|
||
from datetime import timedelta
|
||
from collections.abc import Callable
|
||
from dataclasses import dataclass, field
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
from .logging_utils import json_for_log
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class McpAuthConfig:
|
||
"""MCP server 鉴权 token 配置。"""
|
||
|
||
token_url: str = ""
|
||
client_id: str = ""
|
||
client_secret: str = ""
|
||
grant_type: str = "client_credentials"
|
||
header_name: str = "Authorization"
|
||
header_prefix: str = "Bearer"
|
||
token_field: str = "access_token"
|
||
expires_in_field: str = "expires_in"
|
||
extra_form: dict[str, str] = field(default_factory=dict)
|
||
|
||
@classmethod
|
||
def from_mapping(cls, payload: dict[str, Any] | None) -> "McpAuthConfig | None":
|
||
"""从 JSON auth 字典构造 MCP 鉴权配置。"""
|
||
if not payload:
|
||
return None
|
||
if not isinstance(payload, dict):
|
||
raise ValueError("MCP auth 必须是 JSON object")
|
||
token_url = str(payload.get("token_url", ""))
|
||
base_url = str(payload.get("base_url", ""))
|
||
if not token_url and base_url:
|
||
token_url = base_url.rstrip("/") + "/oauth/token"
|
||
extra_form = payload.get("extra_form") or {}
|
||
if not isinstance(extra_form, dict):
|
||
raise ValueError("MCP auth.extra_form 必须是 JSON object")
|
||
return cls(
|
||
token_url=token_url,
|
||
client_id=str(payload.get("client_id", "")),
|
||
client_secret=str(payload.get("client_secret", "")),
|
||
grant_type=str(payload.get("grant_type", "client_credentials")),
|
||
header_name=str(payload.get("header_name", "Authorization")),
|
||
header_prefix=str(payload.get("header_prefix", "Bearer")),
|
||
token_field=str(payload.get("token_field", "access_token")),
|
||
expires_in_field=str(payload.get("expires_in_field", "expires_in")),
|
||
extra_form={str(key): str(value) for key, value in extra_form.items()},
|
||
)
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class McpClientConfig:
|
||
"""真实 MCP session 建立后需要传给 runner 的配置。"""
|
||
|
||
server_name: str = "pam-node"
|
||
transport: str = "streamable_http"
|
||
server_url: str = ""
|
||
command: str = ""
|
||
args: list[str] = field(default_factory=list)
|
||
env: dict[str, str] | None = None
|
||
cwd: str = ""
|
||
headers: dict[str, str] = field(default_factory=dict)
|
||
auth: McpAuthConfig | None = None
|
||
timeout_seconds: float = 60
|
||
sse_read_timeout_seconds: float = 300
|
||
tool_names: dict[str, str] = field(default_factory=dict)
|
||
|
||
@classmethod
|
||
def from_mapping(cls, payload: dict[str, Any]) -> "McpClientConfig":
|
||
"""从 JSON 字典构造 MCP client 配置。"""
|
||
tool_names = payload.get("tool_names") or payload.get("action_tools") or payload.get("tools") or {}
|
||
if not isinstance(tool_names, dict):
|
||
raise ValueError("MCP tool_names 必须是 JSON object")
|
||
args = payload.get("args") or []
|
||
if not isinstance(args, list):
|
||
raise ValueError("MCP args 必须是数组")
|
||
env = payload.get("env")
|
||
if env is not None and not isinstance(env, dict):
|
||
raise ValueError("MCP env 必须是 JSON object")
|
||
headers = payload.get("headers") or {}
|
||
if not isinstance(headers, dict):
|
||
raise ValueError("MCP headers 必须是 JSON object")
|
||
server_url = str(payload.get("server_url") or payload.get("url") or "")
|
||
command = str(payload.get("command", ""))
|
||
transport = str(payload.get("transport") or ("stdio" if command else "streamable_http"))
|
||
return cls(
|
||
server_name=str(payload.get("server_name", "pam-node")),
|
||
transport=transport,
|
||
server_url=server_url,
|
||
command=command,
|
||
args=[str(item) for item in args],
|
||
env={str(key): str(value) for key, value in env.items()} if env else None,
|
||
cwd=str(payload.get("cwd", "")),
|
||
headers={str(key): str(value) for key, value in headers.items()},
|
||
auth=McpAuthConfig.from_mapping(payload.get("auth")),
|
||
timeout_seconds=float(payload.get("timeout_seconds", 60)),
|
||
sse_read_timeout_seconds=float(payload.get("sse_read_timeout_seconds", 300)),
|
||
tool_names={str(key): str(value) for key, value in tool_names.items()},
|
||
)
|
||
|
||
|
||
def load_mcp_client_config(path: str | Path) -> McpClientConfig:
|
||
"""读取 MCP client JSON 配置文件。"""
|
||
logger.info("读取 MCP client 配置 path=%s", path)
|
||
payload = json.loads(Path(path).read_text(encoding="utf-8"))
|
||
if not isinstance(payload, dict):
|
||
raise ValueError("MCP client 配置必须是 JSON object")
|
||
config = McpClientConfig.from_mapping(payload)
|
||
logger.info(
|
||
"MCP client 配置读取完成 path=%s transport=%s server_url=%s command=%s has_auth=%s tool_names=%s",
|
||
path,
|
||
config.transport,
|
||
config.server_url,
|
||
config.command,
|
||
config.auth is not None,
|
||
json_for_log(config.tool_names),
|
||
)
|
||
return config
|
||
|
||
|
||
class FunctionMcpToolClient:
|
||
"""把普通 Python callable 包装为 MCP tool client。"""
|
||
|
||
def __init__(self, caller: Callable[[str, dict[str, Any]], Any]) -> None:
|
||
"""保存实际执行工具调用的函数。"""
|
||
self.caller = caller
|
||
|
||
def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||
"""调用底层函数并返回原始结果。"""
|
||
logger.info("Function MCP tool 调用 tool=%s arguments=%s", tool_name, json_for_log(arguments))
|
||
return self.caller(tool_name, arguments)
|
||
|
||
|
||
class SessionMcpToolClient:
|
||
"""适配暴露 `call_tool` 的 MCP SDK session。
|
||
|
||
适配器接受常见返回形态:
|
||
|
||
- 原始 dict/list/string
|
||
- 带有 `structuredContent` 的对象
|
||
- 带有 `content` 的对象,其中 text 内容可能是 JSON
|
||
"""
|
||
|
||
def __init__(self, session: Any) -> None:
|
||
"""校验并保存 MCP SDK session。"""
|
||
if not hasattr(session, "call_tool"):
|
||
raise TypeError("MCP session 必须暴露 call_tool 方法")
|
||
self.session = session
|
||
|
||
def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||
"""调用 SDK session,并把 SDK 返回值归一化。"""
|
||
logger.info("Session MCP tool 调用开始 tool=%s arguments=%s", tool_name, json_for_log(arguments))
|
||
result = self.session.call_tool(tool_name, arguments)
|
||
normalized = normalize_mcp_sdk_result(result)
|
||
logger.info("Session MCP tool 调用完成 tool=%s result=%s", tool_name, json_for_log(normalized, max_text_len=1600))
|
||
return normalized
|
||
|
||
def list_tools(self) -> list[str]:
|
||
"""从 SDK session 获取 tool 名称列表。"""
|
||
logger.info("Session MCP list_tools 开始")
|
||
result = self.session.list_tools()
|
||
tools = normalize_mcp_tool_list(result)
|
||
logger.info("Session MCP list_tools 完成 tools=%s", tools)
|
||
return tools
|
||
|
||
|
||
class StdioMcpToolClient:
|
||
"""通过 MCP Python SDK 启动 stdio server 并调用 tool。"""
|
||
|
||
def __init__(
|
||
self,
|
||
*,
|
||
command: str,
|
||
args: list[str] | None = None,
|
||
env: dict[str, str] | None = None,
|
||
cwd: str | None = None,
|
||
timeout_seconds: float = 60,
|
||
) -> None:
|
||
"""保存 stdio server 启动参数。"""
|
||
if not command:
|
||
raise ValueError("stdio MCP 配置必须提供 command")
|
||
self.command = command
|
||
self.args = list(args or [])
|
||
self.env = env
|
||
self.cwd = cwd or None
|
||
self.timeout_seconds = timeout_seconds
|
||
logger.info(
|
||
"stdio MCP client 初始化 command=%s args=%s cwd=%s env_keys=%s timeout=%s",
|
||
self.command,
|
||
self.args,
|
||
self.cwd or "",
|
||
sorted((self.env or {}).keys()),
|
||
self.timeout_seconds,
|
||
)
|
||
|
||
def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||
"""创建一次 MCP stdio session,调用 tool 后关闭 session。"""
|
||
started_at = time.perf_counter()
|
||
logger.info("stdio MCP tool 调用开始 tool=%s arguments=%s", tool_name, json_for_log(arguments))
|
||
try:
|
||
import anyio
|
||
from mcp import ClientSession
|
||
from mcp.client.stdio import StdioServerParameters, stdio_client
|
||
except ImportError as exc: # pragma: no cover - 依赖安装状态
|
||
raise RuntimeError("未安装 MCP Python SDK,请安装项目的 mcp 可选依赖") from exc
|
||
|
||
async def call_once() -> Any:
|
||
server = StdioServerParameters(
|
||
command=self.command,
|
||
args=self.args,
|
||
env=self.env,
|
||
cwd=self.cwd,
|
||
)
|
||
async with stdio_client(server) as (read_stream, write_stream):
|
||
async with ClientSession(read_stream, write_stream) as session:
|
||
await session.initialize()
|
||
result = await session.call_tool(
|
||
tool_name,
|
||
arguments,
|
||
read_timeout_seconds=timedelta(seconds=self.timeout_seconds),
|
||
)
|
||
return normalize_mcp_sdk_result(result)
|
||
|
||
try:
|
||
result = anyio.run(call_once)
|
||
except Exception:
|
||
logger.exception("stdio MCP tool 调用失败 tool=%s duration_ms=%s", tool_name, int((time.perf_counter() - started_at) * 1000))
|
||
raise
|
||
logger.info(
|
||
"stdio MCP tool 调用完成 tool=%s duration_ms=%s result=%s",
|
||
tool_name,
|
||
int((time.perf_counter() - started_at) * 1000),
|
||
json_for_log(result, max_text_len=1600),
|
||
)
|
||
return result
|
||
|
||
def list_tools(self) -> list[str]:
|
||
"""创建一次 MCP stdio session,读取 server 暴露的 tool 列表。"""
|
||
started_at = time.perf_counter()
|
||
logger.info("stdio MCP list_tools 开始")
|
||
try:
|
||
import anyio
|
||
from mcp import ClientSession
|
||
from mcp.client.stdio import StdioServerParameters, stdio_client
|
||
except ImportError as exc: # pragma: no cover - 依赖安装状态
|
||
raise RuntimeError("未安装 MCP Python SDK,请安装项目的 mcp 可选依赖") from exc
|
||
|
||
async def list_once() -> list[str]:
|
||
server = StdioServerParameters(
|
||
command=self.command,
|
||
args=self.args,
|
||
env=self.env,
|
||
cwd=self.cwd,
|
||
)
|
||
async with stdio_client(server) as (read_stream, write_stream):
|
||
async with ClientSession(read_stream, write_stream) as session:
|
||
await session.initialize()
|
||
result = await session.list_tools()
|
||
return normalize_mcp_tool_list(result)
|
||
|
||
try:
|
||
tools = anyio.run(list_once)
|
||
except Exception:
|
||
logger.exception("stdio MCP list_tools 失败 duration_ms=%s", int((time.perf_counter() - started_at) * 1000))
|
||
raise
|
||
logger.info("stdio MCP list_tools 完成 duration_ms=%s tools=%s", int((time.perf_counter() - started_at) * 1000), tools)
|
||
return tools
|
||
|
||
|
||
class OAuthTokenProvider:
|
||
"""按 HOME 相同的 client_credentials 方式获取 MCP 鉴权 token。"""
|
||
|
||
def __init__(self, config: McpAuthConfig, *, timeout_seconds: float = 30) -> None:
|
||
"""保存鉴权配置和 token 缓存。"""
|
||
if not config.token_url:
|
||
raise ValueError("MCP auth 必须提供 token_url 或 auth.base_url")
|
||
if not config.client_id or not config.client_secret:
|
||
raise ValueError("MCP auth 必须提供独立的 client_id 和 client_secret")
|
||
self.config = config
|
||
self.timeout_seconds = timeout_seconds
|
||
self._token = ""
|
||
self._expires_at = 0.0
|
||
logger.info(
|
||
"MCP OAuth token provider 初始化 token_url=%s client_id=%s timeout=%s",
|
||
self.config.token_url,
|
||
self.config.client_id,
|
||
self.timeout_seconds,
|
||
)
|
||
|
||
def authorization_headers(self) -> dict[str, str]:
|
||
"""返回带 token 的请求头。"""
|
||
token = self.get_token()
|
||
prefix = self.config.header_prefix.strip()
|
||
value = f"{prefix} {token}" if prefix else token
|
||
return {self.config.header_name: value}
|
||
|
||
def get_token(self) -> str:
|
||
"""获取可用 token,未过期时复用缓存。"""
|
||
now = time.time()
|
||
if self._token and now < self._expires_at:
|
||
logger.info("MCP auth token 使用缓存 expires_in_sec=%s", int(self._expires_at - now))
|
||
return self._token
|
||
logger.info("MCP auth token 开始刷新 token_url=%s client_id=%s", self.config.token_url, self.config.client_id)
|
||
payload = {
|
||
"grant_type": self.config.grant_type,
|
||
"client_id": self.config.client_id,
|
||
"client_secret": self.config.client_secret,
|
||
**self.config.extra_form,
|
||
}
|
||
data = urllib.parse.urlencode(payload).encode("utf-8")
|
||
request = urllib.request.Request(
|
||
self.config.token_url,
|
||
data=data,
|
||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||
method="POST",
|
||
)
|
||
with urllib.request.urlopen(request, timeout=self.timeout_seconds) as response:
|
||
raw = response.read().decode("utf-8")
|
||
result = json.loads(raw)
|
||
token = str(result.get(self.config.token_field, ""))
|
||
if not token:
|
||
raise RuntimeError("MCP auth token 响应缺少 access_token")
|
||
expires_in = _safe_float(result.get(self.config.expires_in_field), 3600)
|
||
self._token = token
|
||
self._expires_at = now + max(expires_in - 60, 1)
|
||
logger.info("MCP auth token 刷新完成 expires_in=%s cached_until=%s", expires_in, int(self._expires_at))
|
||
return token
|
||
|
||
|
||
class HttpMcpToolClient:
|
||
"""通过 MCP HTTP/SSE server URL 调用 tool。"""
|
||
|
||
def __init__(
|
||
self,
|
||
*,
|
||
url: str,
|
||
transport: str = "streamable_http",
|
||
headers: dict[str, str] | None = None,
|
||
auth_provider: OAuthTokenProvider | None = None,
|
||
timeout_seconds: float = 60,
|
||
sse_read_timeout_seconds: float = 300,
|
||
) -> None:
|
||
"""保存 HTTP/SSE MCP server 连接参数。"""
|
||
if not url:
|
||
raise ValueError("HTTP/SSE MCP 配置必须提供 server_url")
|
||
if transport not in ("streamable_http", "sse"):
|
||
raise ValueError(f"不支持的 HTTP MCP transport: {transport}")
|
||
self.url = url
|
||
self.transport = transport
|
||
self.headers = dict(headers or {})
|
||
self.auth_provider = auth_provider
|
||
self.timeout_seconds = timeout_seconds
|
||
self.sse_read_timeout_seconds = sse_read_timeout_seconds
|
||
logger.info(
|
||
"HTTP MCP client 初始化 url=%s transport=%s has_auth=%s headers=%s timeout=%s sse_read_timeout=%s",
|
||
self.url,
|
||
self.transport,
|
||
self.auth_provider is not None,
|
||
json_for_log(self.headers),
|
||
self.timeout_seconds,
|
||
self.sse_read_timeout_seconds,
|
||
)
|
||
|
||
def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||
"""连接 MCP server,调用 tool 后关闭 session。"""
|
||
return self._run_session(lambda session: session.call_tool(tool_name, arguments), operation_name=f"call_tool:{tool_name}", arguments=arguments)
|
||
|
||
def list_tools(self) -> list[str]:
|
||
"""连接 MCP server,读取 server 暴露的 tool 名称。"""
|
||
return self._run_session(lambda session: session.list_tools(), normalize_tools=True, operation_name="list_tools")
|
||
|
||
def _build_headers(self) -> dict[str, str]:
|
||
"""合并静态 headers 和动态鉴权 token。"""
|
||
headers = dict(self.headers)
|
||
if self.auth_provider is not None:
|
||
headers.update(self.auth_provider.authorization_headers())
|
||
return headers
|
||
|
||
def _run_session(
|
||
self,
|
||
operation: Callable[[Any], Any],
|
||
*,
|
||
normalize_tools: bool = False,
|
||
operation_name: str = "operation",
|
||
arguments: dict[str, Any] | None = None,
|
||
) -> Any:
|
||
"""创建一次 HTTP/SSE MCP session 并执行指定操作。"""
|
||
started_at = time.perf_counter()
|
||
logger.info(
|
||
"HTTP MCP session 开始 operation=%s url=%s transport=%s arguments=%s",
|
||
operation_name,
|
||
self.url,
|
||
self.transport,
|
||
json_for_log(arguments or {}),
|
||
)
|
||
try:
|
||
import anyio
|
||
from mcp import ClientSession
|
||
from mcp.client.sse import sse_client
|
||
from mcp.client.streamable_http import streamablehttp_client
|
||
except ImportError as exc: # pragma: no cover - 依赖安装状态
|
||
raise RuntimeError("未安装 MCP Python SDK,请安装项目的 mcp 可选依赖") from exc
|
||
|
||
async def call_once() -> Any:
|
||
headers = self._build_headers()
|
||
if self.transport == "sse":
|
||
async with sse_client(
|
||
self.url,
|
||
headers=headers,
|
||
timeout=self.timeout_seconds,
|
||
sse_read_timeout=self.sse_read_timeout_seconds,
|
||
) as (read_stream, write_stream):
|
||
async with ClientSession(read_stream, write_stream) as session:
|
||
await session.initialize()
|
||
result = await operation(session)
|
||
return normalize_mcp_tool_list(result) if normalize_tools else normalize_mcp_sdk_result(result)
|
||
|
||
async with streamablehttp_client(
|
||
self.url,
|
||
headers=headers,
|
||
timeout=timedelta(seconds=self.timeout_seconds),
|
||
sse_read_timeout=timedelta(seconds=self.sse_read_timeout_seconds),
|
||
) as streams:
|
||
read_stream, write_stream = streams[0], streams[1]
|
||
async with ClientSession(read_stream, write_stream) as session:
|
||
await session.initialize()
|
||
result = await operation(session)
|
||
return normalize_mcp_tool_list(result) if normalize_tools else normalize_mcp_sdk_result(result)
|
||
|
||
try:
|
||
result = anyio.run(call_once)
|
||
except Exception:
|
||
logger.exception(
|
||
"HTTP MCP session 失败 operation=%s url=%s transport=%s duration_ms=%s",
|
||
operation_name,
|
||
self.url,
|
||
self.transport,
|
||
int((time.perf_counter() - started_at) * 1000),
|
||
)
|
||
raise
|
||
logger.info(
|
||
"HTTP MCP session 完成 operation=%s duration_ms=%s result=%s",
|
||
operation_name,
|
||
int((time.perf_counter() - started_at) * 1000),
|
||
json_for_log(result, max_text_len=1600),
|
||
)
|
||
return result
|
||
|
||
|
||
def normalize_mcp_sdk_result(result: Any) -> Any:
|
||
"""把常见 MCP SDK 返回结构归一化成 dict/list/string。"""
|
||
if hasattr(result, "structuredContent"):
|
||
structured = getattr(result, "structuredContent")
|
||
if structured is not None:
|
||
return structured
|
||
|
||
if hasattr(result, "content"):
|
||
content = getattr(result, "content")
|
||
text_parts: list[str] = []
|
||
for item in content or []:
|
||
text = getattr(item, "text", None)
|
||
if text is not None:
|
||
text_parts.append(text)
|
||
if text_parts:
|
||
joined = "\n".join(text_parts)
|
||
try:
|
||
return json.loads(joined)
|
||
except json.JSONDecodeError:
|
||
return joined
|
||
|
||
return result
|
||
|
||
|
||
def normalize_mcp_tool_list(result: Any) -> list[str]:
|
||
"""把 MCP list_tools 返回值归一化为 tool name 列表。"""
|
||
tools = getattr(result, "tools", None)
|
||
if tools is None and isinstance(result, dict):
|
||
tools = result.get("tools")
|
||
names: list[str] = []
|
||
for item in tools or []:
|
||
if isinstance(item, str):
|
||
names.append(item)
|
||
continue
|
||
if isinstance(item, dict) and item.get("name"):
|
||
names.append(str(item["name"]))
|
||
continue
|
||
name = getattr(item, "name", None)
|
||
if name:
|
||
names.append(str(name))
|
||
return names
|
||
|
||
|
||
def _safe_float(value: Any, default: float) -> float:
|
||
"""把值安全转换为 float。"""
|
||
try:
|
||
return float(value)
|
||
except (TypeError, ValueError):
|
||
return default
|