agent_deply/pam_deploy_graph/llm/openai_compatible.py
dark a11904b7c5 docs/build: 补齐中文注释、流程图和 Linux 解压即用打包脚本
- 为 pam_deploy_graph 生产代码补充中文模块、类、函数/方法文档字符串
- 将原有英文说明和主要英文异常提示改为中文
- 新增当前整体逻辑结构流程图文档,覆盖模块结构、执行链路、action 路由、人工确认和 checkpoint 续跑
- 新增 Linux 自带运行环境打包脚本,使用 PyInstaller 生成解压即用目录和 tar.gz
- 新增 Linux 打包说明,包含构建命令、运行方式、依赖说明和包大小评估
- 同步 README,补充流程图、打包方式、产物路径和大小预估
- 更新相关测试断言以匹配中文错误提示
2026-06-01 11:21:42 +08:00

258 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""OpenAI-compatible HTTP LLM client。
该 client 面向暴露 `/chat/completions` 的模型服务,并使用 OpenAI 风格的
请求/响应结构。实现只依赖 Python 标准库,便于控制运行时依赖体积。
"""
from __future__ import annotations
import json
import urllib.request
from collections.abc import Callable
from typing import Any
from pam_deploy_graph.constants import (
ALLOWED_ACTIONS,
DEFAULT_PARAMS,
GLOBAL_ACTION_SEQUENCE,
IP_ACTION_SEQUENCE,
REQUIRED_PARAMS,
SENSITIVE_KEYS,
)
from pam_deploy_graph.models import ExecutionStrategy, LlmDeployPlan, LlmIntentResult, LlmParamResult
from .prompts import INTENT_PROMPT, PARAM_PROMPT, PLAN_PROMPT, SYSTEM_PROMPT
JsonTransport = Callable[[str, dict[str, str], dict[str, Any], float], dict[str, Any]]
class OpenAICompatibleLlmClient:
"""通过 OpenAI-compatible HTTP 接口获取结构化 LLM 输出。"""
def __init__(
self,
*,
base_url: str,
api_key: str,
model: str,
timeout_sec: float = 30,
temperature: float = 0,
transport: JsonTransport | None = None,
) -> None:
"""保存连接参数、模型参数和可替换的 HTTP transport。"""
if not base_url:
raise ValueError("必须配置 LLM base_url")
if not api_key:
raise ValueError("必须配置 LLM api_key")
if not model:
raise ValueError("必须配置 LLM model")
self.base_url = base_url.rstrip("/")
self.api_key = api_key
self.model = model
self.timeout_sec = timeout_sec
self.temperature = temperature
self.transport = transport or _default_transport
def understand_request(self, text: str) -> LlmIntentResult:
"""调用 LLM 识别用户意图。"""
payload = self._complete_json(INTENT_PROMPT, {"user_text": text})
return LlmIntentResult(
intent=_string(payload, "intent", "deploy"), # type: ignore[arg-type]
mode_preference=_string(payload, "mode_preference", "未指定"), # type: ignore[arg-type]
strategy_preference=_string(payload, "strategy_preference", "未指定"), # type: ignore[arg-type]
confidence=_float(payload, "confidence", 0.0),
reasons=_string_list(payload.get("reasons")),
needs_clarification=bool(payload.get("needs_clarification", False)),
clarification_questions=_string_list(payload.get("clarification_questions")),
)
def extract_params(self, text: str, base_params: dict[str, Any] | None = None) -> LlmParamResult:
"""调用 LLM 抽取参数,并避免把敏感值发送进 prompt。"""
original_base = dict(base_params or {})
safe_base = _redact_sensitive(original_base)
payload = self._complete_json(
PARAM_PROMPT,
{
"user_text": text,
"base_params": safe_base,
"required_params": list(REQUIRED_PARAMS),
"default_params": DEFAULT_PARAMS,
},
)
extracted = _dict(payload.get("extracted_params"))
merged = original_base.copy()
for key, value in extracted.items():
if key in SENSITIVE_KEYS and value == "***":
continue
merged[key] = value
control = _dict(payload.get("extracted_control"))
missing = [key for key in REQUIRED_PARAMS if not merged.get(key)]
sensitive = [key for key in ("CLIENT_SECRET", "CLIENT_ID") if merged.get(key)]
llm_ambiguous = _string_list(payload.get("ambiguous_fields"))
return LlmParamResult(
extracted_params=merged,
extracted_control=control,
missing_required_params=missing,
ambiguous_fields=llm_ambiguous,
sensitive_fields_present=sensitive,
)
def generate_plan(
self,
*,
params: dict[str, Any],
intent: str,
strategy: ExecutionStrategy,
) -> LlmDeployPlan:
"""调用 LLM 生成部署计划。"""
payload = self._complete_json(
PLAN_PROMPT,
{
"params": _redact_sensitive(params),
"intent": intent,
"execution_strategy": strategy,
"allowed_actions": list(ALLOWED_ACTIONS),
"global_action_sequence": list(GLOBAL_ACTION_SEQUENCE),
"ip_action_sequence": list(IP_ACTION_SEQUENCE),
},
)
planned_actions = _string_list(payload.get("planned_actions")) or list(GLOBAL_ACTION_SEQUENCE)
return LlmDeployPlan(
summary=_string(payload, "summary", "PAM 部署计划"),
risk_notes=_string_list(payload.get("risk_notes")),
planned_actions=planned_actions,
requires_confirmation=bool(payload.get("requires_confirmation", True)),
execution_strategy=_string(payload, "execution_strategy", strategy), # type: ignore[arg-type]
)
def _complete_json(self, instruction: str, input_payload: dict[str, Any]) -> dict[str, Any]:
"""发送 chat/completions 请求,并解析 JSON 对象响应。"""
request_payload = {
"model": self.model,
"temperature": self.temperature,
"response_format": {"type": "json_object"},
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": instruction
+ "\n\n输入 JSON:\n"
+ json.dumps(input_payload, ensure_ascii=False, sort_keys=True),
},
],
}
response = self.transport(
_chat_completions_url(self.base_url),
{
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
request_payload,
self.timeout_sec,
)
content = _message_content(response)
parsed = _loads_json_object(content)
if not isinstance(parsed, dict):
raise ValueError("LLM 响应必须是 JSON object")
return parsed
def _default_transport(
url: str,
headers: dict[str, str],
payload: dict[str, Any],
timeout_sec: float,
) -> dict[str, Any]:
"""使用标准库 urllib 发送 JSON POST 请求。"""
request = urllib.request.Request(
url,
data=json.dumps(payload).encode("utf-8"),
headers=headers,
method="POST",
)
with urllib.request.urlopen(request, timeout=timeout_sec) as response:
raw = response.read().decode("utf-8")
decoded = json.loads(raw)
if not isinstance(decoded, dict):
raise ValueError("LLM HTTP 响应必须是 JSON object")
return decoded
def _chat_completions_url(base_url: str) -> str:
"""把 base_url 规范化为 chat/completions endpoint。"""
clean = base_url.rstrip("/")
if clean.endswith("/chat/completions"):
return clean
return f"{clean}/chat/completions"
def _message_content(response: dict[str, Any]) -> Any:
"""从 OpenAI-compatible 响应中提取 message.content。"""
try:
content = response["choices"][0]["message"]["content"]
except (KeyError, IndexError, TypeError) as exc:
raise ValueError("LLM 响应缺少 choices[0].message.content") from exc
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
parts.append(str(item.get("text", "")))
elif isinstance(item, str):
parts.append(item)
return "".join(parts)
return content
def _loads_json_object(content: Any) -> Any:
"""把 message.content 解析为 JSON 对象。"""
if isinstance(content, dict):
return content
if not isinstance(content, str):
raise ValueError("LLM message content 必须是 JSON 文本")
return json.loads(content)
def _redact_sensitive(value: Any) -> Any:
"""递归脱敏 prompt 输入中的敏感字段。"""
if isinstance(value, dict):
redacted: dict[str, Any] = {}
for key, item in value.items():
if str(key) in SENSITIVE_KEYS:
redacted[str(key)] = "***"
else:
redacted[str(key)] = _redact_sensitive(item)
return redacted
if isinstance(value, list):
return [_redact_sensitive(item) for item in value]
return value
def _string(payload: dict[str, Any], key: str, default: str) -> str:
"""安全读取字符串字段。"""
value = payload.get(key, default)
return str(value) if value is not None else default
def _float(payload: dict[str, Any], key: str, default: float) -> float:
"""安全读取浮点数字段。"""
try:
return float(payload.get(key, default))
except (TypeError, ValueError):
return default
def _dict(value: Any) -> dict[str, Any]:
"""确保返回 dict非法值降级为空 dict。"""
return value if isinstance(value, dict) else {}
def _string_list(value: Any) -> list[str]:
"""确保返回字符串列表。"""
if isinstance(value, list):
return [str(item) for item in value]
if value in (None, ""):
return []
return [str(value)]