agent_deply/pam_deploy_graph/llm/openai_compatible.py
dark 1e74ae3cd6 feat: 增加 PAM 部署 Agent 交互式 CLI 与真实 LLM 配置
- 新增 OpenAI-compatible LLM client,支持 base_url、api_key、model 配置
- 固化意图识别、参数抽取、部署计划生成的结构化 JSON 提示词
- 增加 MCP client 配置读取和真实 session 接入说明
- 实现 checkpoint 自动保存、resume 断点续跑和已完成步骤跳过
- 实现人工确认流程,支持失败 IP 回滚 approve/reject
- 新增 chat 常驻式 CLI 对话框,支持自然语言分析、参数设置、执行确认、状态查看、回滚确认和续跑
- 同步 README,补充 LLM、MCP、checkpoint、confirm/resume、chat 使用方式
- 增加相关单元测试,覆盖 LLM client、MCP 配置、确认/续跑和交互式 CLI
2026-06-01 10:26:40 +08:00

243 lines
8.4 KiB
Python

"""OpenAI-compatible HTTP LLM client.
The client targets providers exposing a `/chat/completions` endpoint with
OpenAI-style request and response shapes. It intentionally uses only the Python
standard library so the runtime can stay dependency-light.
"""
from __future__ import annotations
import json
import urllib.request
from collections.abc import Callable
from typing import Any
from pam_deploy_graph.constants import (
ALLOWED_ACTIONS,
DEFAULT_PARAMS,
GLOBAL_ACTION_SEQUENCE,
IP_ACTION_SEQUENCE,
REQUIRED_PARAMS,
SENSITIVE_KEYS,
)
from pam_deploy_graph.models import ExecutionStrategy, LlmDeployPlan, LlmIntentResult, LlmParamResult
from .prompts import INTENT_PROMPT, PARAM_PROMPT, PLAN_PROMPT, SYSTEM_PROMPT
JsonTransport = Callable[[str, dict[str, str], dict[str, Any], float], dict[str, Any]]
class OpenAICompatibleLlmClient:
def __init__(
self,
*,
base_url: str,
api_key: str,
model: str,
timeout_sec: float = 30,
temperature: float = 0,
transport: JsonTransport | None = None,
) -> None:
if not base_url:
raise ValueError("LLM base_url is required")
if not api_key:
raise ValueError("LLM api_key is required")
if not model:
raise ValueError("LLM model is required")
self.base_url = base_url.rstrip("/")
self.api_key = api_key
self.model = model
self.timeout_sec = timeout_sec
self.temperature = temperature
self.transport = transport or _default_transport
def understand_request(self, text: str) -> LlmIntentResult:
payload = self._complete_json(INTENT_PROMPT, {"user_text": text})
return LlmIntentResult(
intent=_string(payload, "intent", "deploy"), # type: ignore[arg-type]
mode_preference=_string(payload, "mode_preference", "未指定"), # type: ignore[arg-type]
strategy_preference=_string(payload, "strategy_preference", "未指定"), # type: ignore[arg-type]
confidence=_float(payload, "confidence", 0.0),
reasons=_string_list(payload.get("reasons")),
needs_clarification=bool(payload.get("needs_clarification", False)),
clarification_questions=_string_list(payload.get("clarification_questions")),
)
def extract_params(self, text: str, base_params: dict[str, Any] | None = None) -> LlmParamResult:
original_base = dict(base_params or {})
safe_base = _redact_sensitive(original_base)
payload = self._complete_json(
PARAM_PROMPT,
{
"user_text": text,
"base_params": safe_base,
"required_params": list(REQUIRED_PARAMS),
"default_params": DEFAULT_PARAMS,
},
)
extracted = _dict(payload.get("extracted_params"))
merged = original_base.copy()
for key, value in extracted.items():
if key in SENSITIVE_KEYS and value == "***":
continue
merged[key] = value
control = _dict(payload.get("extracted_control"))
missing = [key for key in REQUIRED_PARAMS if not merged.get(key)]
sensitive = [key for key in ("CLIENT_SECRET", "CLIENT_ID") if merged.get(key)]
llm_ambiguous = _string_list(payload.get("ambiguous_fields"))
return LlmParamResult(
extracted_params=merged,
extracted_control=control,
missing_required_params=missing,
ambiguous_fields=llm_ambiguous,
sensitive_fields_present=sensitive,
)
def generate_plan(
self,
*,
params: dict[str, Any],
intent: str,
strategy: ExecutionStrategy,
) -> LlmDeployPlan:
payload = self._complete_json(
PLAN_PROMPT,
{
"params": _redact_sensitive(params),
"intent": intent,
"execution_strategy": strategy,
"allowed_actions": list(ALLOWED_ACTIONS),
"global_action_sequence": list(GLOBAL_ACTION_SEQUENCE),
"ip_action_sequence": list(IP_ACTION_SEQUENCE),
},
)
planned_actions = _string_list(payload.get("planned_actions")) or list(GLOBAL_ACTION_SEQUENCE)
return LlmDeployPlan(
summary=_string(payload, "summary", "PAM deployment plan"),
risk_notes=_string_list(payload.get("risk_notes")),
planned_actions=planned_actions,
requires_confirmation=bool(payload.get("requires_confirmation", True)),
execution_strategy=_string(payload, "execution_strategy", strategy), # type: ignore[arg-type]
)
def _complete_json(self, instruction: str, input_payload: dict[str, Any]) -> dict[str, Any]:
request_payload = {
"model": self.model,
"temperature": self.temperature,
"response_format": {"type": "json_object"},
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": instruction
+ "\n\n输入 JSON:\n"
+ json.dumps(input_payload, ensure_ascii=False, sort_keys=True),
},
],
}
response = self.transport(
_chat_completions_url(self.base_url),
{
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
request_payload,
self.timeout_sec,
)
content = _message_content(response)
parsed = _loads_json_object(content)
if not isinstance(parsed, dict):
raise ValueError("LLM response must be a JSON object")
return parsed
def _default_transport(
url: str,
headers: dict[str, str],
payload: dict[str, Any],
timeout_sec: float,
) -> dict[str, Any]:
request = urllib.request.Request(
url,
data=json.dumps(payload).encode("utf-8"),
headers=headers,
method="POST",
)
with urllib.request.urlopen(request, timeout=timeout_sec) as response:
raw = response.read().decode("utf-8")
decoded = json.loads(raw)
if not isinstance(decoded, dict):
raise ValueError("LLM HTTP response must be a JSON object")
return decoded
def _chat_completions_url(base_url: str) -> str:
clean = base_url.rstrip("/")
if clean.endswith("/chat/completions"):
return clean
return f"{clean}/chat/completions"
def _message_content(response: dict[str, Any]) -> Any:
try:
content = response["choices"][0]["message"]["content"]
except (KeyError, IndexError, TypeError) as exc:
raise ValueError("LLM response does not contain choices[0].message.content") from exc
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
parts.append(str(item.get("text", "")))
elif isinstance(item, str):
parts.append(item)
return "".join(parts)
return content
def _loads_json_object(content: Any) -> Any:
if isinstance(content, dict):
return content
if not isinstance(content, str):
raise ValueError("LLM message content must be JSON text")
return json.loads(content)
def _redact_sensitive(value: Any) -> Any:
if isinstance(value, dict):
redacted: dict[str, Any] = {}
for key, item in value.items():
if str(key) in SENSITIVE_KEYS:
redacted[str(key)] = "***"
else:
redacted[str(key)] = _redact_sensitive(item)
return redacted
if isinstance(value, list):
return [_redact_sensitive(item) for item in value]
return value
def _string(payload: dict[str, Any], key: str, default: str) -> str:
value = payload.get(key, default)
return str(value) if value is not None else default
def _float(payload: dict[str, Any], key: str, default: float) -> float:
try:
return float(payload.get(key, default))
except (TypeError, ValueError):
return default
def _dict(value: Any) -> dict[str, Any]:
return value if isinstance(value, dict) else {}
def _string_list(value: Any) -> list[str]:
if isinstance(value, list):
return [str(item) for item in value]
if value in (None, ""):
return []
return [str(value)]