dark 05ece1bffc feat: 标准化 LangGraph 运行链路并完善 MCP 接入
- 将 CLI/chat 部署执行切换为 action 级 LangGraph runtime
- 接入 LangGraph interrupt/checkpointer 处理人工确认与恢复
- 保留业务 checkpoint JSON 用于跨进程断点续跑
- 增加 MCP HTTP/SSE server_url 配置支持
- 增加 MCP 独立 OAuth token 鉴权,复用 HOME 的 client_credentials 方式
- 支持从 MCP server list_tools 自动发现 tools,action_tools 仅作为可选覆盖
- 更新 MCP 配置示例、README、打包说明和整体流程图
- 补充 MCP 配置、鉴权和 tool 自动发现测试
2026-06-02 10:44:42 +08:00

412 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""MCP client 适配器。
Agent 只依赖同步的 `call_tool(name, arguments)` 接口。本模块把普通
callable 或 SDK session 适配成这个接口,避免业务代码绑定具体 MCP SDK。
"""
from __future__ import annotations
import json
import time
import urllib.parse
import urllib.request
from datetime import timedelta
from collections.abc import Callable
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
@dataclass(frozen=True)
class McpAuthConfig:
"""MCP server 鉴权 token 配置。"""
token_url: str = ""
client_id: str = ""
client_secret: str = ""
grant_type: str = "client_credentials"
header_name: str = "Authorization"
header_prefix: str = "Bearer"
token_field: str = "access_token"
expires_in_field: str = "expires_in"
extra_form: dict[str, str] = field(default_factory=dict)
@classmethod
def from_mapping(cls, payload: dict[str, Any] | None) -> "McpAuthConfig | None":
"""从 JSON auth 字典构造 MCP 鉴权配置。"""
if not payload:
return None
if not isinstance(payload, dict):
raise ValueError("MCP auth 必须是 JSON object")
token_url = str(payload.get("token_url", ""))
base_url = str(payload.get("base_url", ""))
if not token_url and base_url:
token_url = base_url.rstrip("/") + "/oauth/token"
extra_form = payload.get("extra_form") or {}
if not isinstance(extra_form, dict):
raise ValueError("MCP auth.extra_form 必须是 JSON object")
return cls(
token_url=token_url,
client_id=str(payload.get("client_id", "")),
client_secret=str(payload.get("client_secret", "")),
grant_type=str(payload.get("grant_type", "client_credentials")),
header_name=str(payload.get("header_name", "Authorization")),
header_prefix=str(payload.get("header_prefix", "Bearer")),
token_field=str(payload.get("token_field", "access_token")),
expires_in_field=str(payload.get("expires_in_field", "expires_in")),
extra_form={str(key): str(value) for key, value in extra_form.items()},
)
@dataclass(frozen=True)
class McpClientConfig:
"""真实 MCP session 建立后需要传给 runner 的配置。"""
server_name: str = "pam-node"
transport: str = "streamable_http"
server_url: str = ""
command: str = ""
args: list[str] = field(default_factory=list)
env: dict[str, str] | None = None
cwd: str = ""
headers: dict[str, str] = field(default_factory=dict)
auth: McpAuthConfig | None = None
timeout_seconds: float = 60
sse_read_timeout_seconds: float = 300
tool_names: dict[str, str] = field(default_factory=dict)
@classmethod
def from_mapping(cls, payload: dict[str, Any]) -> "McpClientConfig":
"""从 JSON 字典构造 MCP client 配置。"""
tool_names = payload.get("tool_names") or payload.get("action_tools") or payload.get("tools") or {}
if not isinstance(tool_names, dict):
raise ValueError("MCP tool_names 必须是 JSON object")
args = payload.get("args") or []
if not isinstance(args, list):
raise ValueError("MCP args 必须是数组")
env = payload.get("env")
if env is not None and not isinstance(env, dict):
raise ValueError("MCP env 必须是 JSON object")
headers = payload.get("headers") or {}
if not isinstance(headers, dict):
raise ValueError("MCP headers 必须是 JSON object")
server_url = str(payload.get("server_url") or payload.get("url") or "")
command = str(payload.get("command", ""))
transport = str(payload.get("transport") or ("stdio" if command else "streamable_http"))
return cls(
server_name=str(payload.get("server_name", "pam-node")),
transport=transport,
server_url=server_url,
command=command,
args=[str(item) for item in args],
env={str(key): str(value) for key, value in env.items()} if env else None,
cwd=str(payload.get("cwd", "")),
headers={str(key): str(value) for key, value in headers.items()},
auth=McpAuthConfig.from_mapping(payload.get("auth")),
timeout_seconds=float(payload.get("timeout_seconds", 60)),
sse_read_timeout_seconds=float(payload.get("sse_read_timeout_seconds", 300)),
tool_names={str(key): str(value) for key, value in tool_names.items()},
)
def load_mcp_client_config(path: str | Path) -> McpClientConfig:
"""读取 MCP client JSON 配置文件。"""
payload = json.loads(Path(path).read_text(encoding="utf-8"))
if not isinstance(payload, dict):
raise ValueError("MCP client 配置必须是 JSON object")
return McpClientConfig.from_mapping(payload)
class FunctionMcpToolClient:
"""把普通 Python callable 包装为 MCP tool client。"""
def __init__(self, caller: Callable[[str, dict[str, Any]], Any]) -> None:
"""保存实际执行工具调用的函数。"""
self.caller = caller
def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""调用底层函数并返回原始结果。"""
return self.caller(tool_name, arguments)
class SessionMcpToolClient:
"""适配暴露 `call_tool` 的 MCP SDK session。
适配器接受常见返回形态:
- 原始 dict/list/string
- 带有 `structuredContent` 的对象
- 带有 `content` 的对象,其中 text 内容可能是 JSON
"""
def __init__(self, session: Any) -> None:
"""校验并保存 MCP SDK session。"""
if not hasattr(session, "call_tool"):
raise TypeError("MCP session 必须暴露 call_tool 方法")
self.session = session
def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""调用 SDK session并把 SDK 返回值归一化。"""
result = self.session.call_tool(tool_name, arguments)
return normalize_mcp_sdk_result(result)
def list_tools(self) -> list[str]:
"""从 SDK session 获取 tool 名称列表。"""
result = self.session.list_tools()
return normalize_mcp_tool_list(result)
class StdioMcpToolClient:
"""通过 MCP Python SDK 启动 stdio server 并调用 tool。"""
def __init__(
self,
*,
command: str,
args: list[str] | None = None,
env: dict[str, str] | None = None,
cwd: str | None = None,
timeout_seconds: float = 60,
) -> None:
"""保存 stdio server 启动参数。"""
if not command:
raise ValueError("stdio MCP 配置必须提供 command")
self.command = command
self.args = list(args or [])
self.env = env
self.cwd = cwd or None
self.timeout_seconds = timeout_seconds
def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""创建一次 MCP stdio session调用 tool 后关闭 session。"""
try:
import anyio
from mcp import ClientSession
from mcp.client.stdio import StdioServerParameters, stdio_client
except ImportError as exc: # pragma: no cover - 依赖安装状态
raise RuntimeError("未安装 MCP Python SDK请安装项目的 mcp 可选依赖") from exc
async def call_once() -> Any:
server = StdioServerParameters(
command=self.command,
args=self.args,
env=self.env,
cwd=self.cwd,
)
async with stdio_client(server) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
result = await session.call_tool(
tool_name,
arguments,
read_timeout_seconds=timedelta(seconds=self.timeout_seconds),
)
return normalize_mcp_sdk_result(result)
return anyio.run(call_once)
def list_tools(self) -> list[str]:
"""创建一次 MCP stdio session读取 server 暴露的 tool 列表。"""
try:
import anyio
from mcp import ClientSession
from mcp.client.stdio import StdioServerParameters, stdio_client
except ImportError as exc: # pragma: no cover - 依赖安装状态
raise RuntimeError("未安装 MCP Python SDK请安装项目的 mcp 可选依赖") from exc
async def list_once() -> list[str]:
server = StdioServerParameters(
command=self.command,
args=self.args,
env=self.env,
cwd=self.cwd,
)
async with stdio_client(server) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
result = await session.list_tools()
return normalize_mcp_tool_list(result)
return anyio.run(list_once)
class OAuthTokenProvider:
"""按 HOME 相同的 client_credentials 方式获取 MCP 鉴权 token。"""
def __init__(self, config: McpAuthConfig, *, timeout_seconds: float = 30) -> None:
"""保存鉴权配置和 token 缓存。"""
if not config.token_url:
raise ValueError("MCP auth 必须提供 token_url 或 auth.base_url")
if not config.client_id or not config.client_secret:
raise ValueError("MCP auth 必须提供独立的 client_id 和 client_secret")
self.config = config
self.timeout_seconds = timeout_seconds
self._token = ""
self._expires_at = 0.0
def authorization_headers(self) -> dict[str, str]:
"""返回带 token 的请求头。"""
token = self.get_token()
prefix = self.config.header_prefix.strip()
value = f"{prefix} {token}" if prefix else token
return {self.config.header_name: value}
def get_token(self) -> str:
"""获取可用 token未过期时复用缓存。"""
now = time.time()
if self._token and now < self._expires_at:
return self._token
payload = {
"grant_type": self.config.grant_type,
"client_id": self.config.client_id,
"client_secret": self.config.client_secret,
**self.config.extra_form,
}
data = urllib.parse.urlencode(payload).encode("utf-8")
request = urllib.request.Request(
self.config.token_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
method="POST",
)
with urllib.request.urlopen(request, timeout=self.timeout_seconds) as response:
raw = response.read().decode("utf-8")
result = json.loads(raw)
token = str(result.get(self.config.token_field, ""))
if not token:
raise RuntimeError("MCP auth token 响应缺少 access_token")
expires_in = _safe_float(result.get(self.config.expires_in_field), 3600)
self._token = token
self._expires_at = now + max(expires_in - 60, 1)
return token
class HttpMcpToolClient:
"""通过 MCP HTTP/SSE server URL 调用 tool。"""
def __init__(
self,
*,
url: str,
transport: str = "streamable_http",
headers: dict[str, str] | None = None,
auth_provider: OAuthTokenProvider | None = None,
timeout_seconds: float = 60,
sse_read_timeout_seconds: float = 300,
) -> None:
"""保存 HTTP/SSE MCP server 连接参数。"""
if not url:
raise ValueError("HTTP/SSE MCP 配置必须提供 server_url")
if transport not in ("streamable_http", "sse"):
raise ValueError(f"不支持的 HTTP MCP transport: {transport}")
self.url = url
self.transport = transport
self.headers = dict(headers or {})
self.auth_provider = auth_provider
self.timeout_seconds = timeout_seconds
self.sse_read_timeout_seconds = sse_read_timeout_seconds
def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""连接 MCP server调用 tool 后关闭 session。"""
return self._run_session(lambda session: session.call_tool(tool_name, arguments))
def list_tools(self) -> list[str]:
"""连接 MCP server读取 server 暴露的 tool 名称。"""
return self._run_session(lambda session: session.list_tools(), normalize_tools=True)
def _build_headers(self) -> dict[str, str]:
"""合并静态 headers 和动态鉴权 token。"""
headers = dict(self.headers)
if self.auth_provider is not None:
headers.update(self.auth_provider.authorization_headers())
return headers
def _run_session(self, operation: Callable[[Any], Any], *, normalize_tools: bool = False) -> Any:
"""创建一次 HTTP/SSE MCP session 并执行指定操作。"""
try:
import anyio
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
except ImportError as exc: # pragma: no cover - 依赖安装状态
raise RuntimeError("未安装 MCP Python SDK请安装项目的 mcp 可选依赖") from exc
async def call_once() -> Any:
headers = self._build_headers()
if self.transport == "sse":
async with sse_client(
self.url,
headers=headers,
timeout=self.timeout_seconds,
sse_read_timeout=self.sse_read_timeout_seconds,
) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
result = await operation(session)
return normalize_mcp_tool_list(result) if normalize_tools else normalize_mcp_sdk_result(result)
async with streamablehttp_client(
self.url,
headers=headers,
timeout=timedelta(seconds=self.timeout_seconds),
sse_read_timeout=timedelta(seconds=self.sse_read_timeout_seconds),
) as streams:
read_stream, write_stream = streams[0], streams[1]
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
result = await operation(session)
return normalize_mcp_tool_list(result) if normalize_tools else normalize_mcp_sdk_result(result)
return anyio.run(call_once)
def normalize_mcp_sdk_result(result: Any) -> Any:
"""把常见 MCP SDK 返回结构归一化成 dict/list/string。"""
if hasattr(result, "structuredContent"):
structured = getattr(result, "structuredContent")
if structured is not None:
return structured
if hasattr(result, "content"):
content = getattr(result, "content")
text_parts: list[str] = []
for item in content or []:
text = getattr(item, "text", None)
if text is not None:
text_parts.append(text)
if text_parts:
joined = "\n".join(text_parts)
try:
return json.loads(joined)
except json.JSONDecodeError:
return joined
return result
def normalize_mcp_tool_list(result: Any) -> list[str]:
"""把 MCP list_tools 返回值归一化为 tool name 列表。"""
tools = getattr(result, "tools", None)
if tools is None and isinstance(result, dict):
tools = result.get("tools")
names: list[str] = []
for item in tools or []:
if isinstance(item, str):
names.append(item)
continue
if isinstance(item, dict) and item.get("name"):
names.append(str(item["name"]))
continue
name = getattr(item, "name", None)
if name:
names.append(str(name))
return names
def _safe_float(value: Any, default: float) -> float:
"""把值安全转换为 float。"""
try:
return float(value)
except (TypeError, ValueError):
return default