agent_deply/pam_deploy_graph/llm/openai_compatible.py
2026-06-04 10:04:23 +08:00

343 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""OpenAI-compatible HTTP LLM client。
该 client 面向暴露 `/chat/completions` 的模型服务,并使用 OpenAI 风格的
请求/响应结构。实现只依赖 Python 标准库,便于控制运行时依赖体积。
"""
from __future__ import annotations
import json
from pathlib import Path
import urllib.request
from collections.abc import Callable
from typing import Any
from pam_deploy_graph.constants import (
ALLOWED_ACTIONS,
DEFAULT_PARAMS,
GLOBAL_ACTION_SEQUENCE,
IP_ACTION_SEQUENCE,
REQUIRED_PARAMS,
SENSITIVE_KEYS,
)
from pam_deploy_graph.models import ExecutionStrategy, LlmDeployPlan, LlmIntentResult, LlmModeDecision, LlmParamResult
from pam_deploy_graph.models import ActionResult, LlmActionAnalysis
from .prompts import ACTION_ANALYSIS_PROMPT, INTENT_PROMPT, MODE_PROMPT, PARAM_PROMPT, PLAN_PROMPT, SYSTEM_PROMPT
JsonTransport = Callable[[str, dict[str, str], dict[str, Any], float], dict[str, Any]]
class OpenAICompatibleLlmClient:
"""通过 OpenAI-compatible HTTP 接口获取结构化 LLM 输出。"""
def __init__(
self,
*,
base_url: str,
api_key: str,
model: str,
action_analysis_prompt: str | None = None,
timeout_sec: float = 30,
temperature: float = 0,
transport: JsonTransport | None = None,
) -> None:
"""保存连接参数、模型参数和可替换的 HTTP transport。"""
if not base_url:
raise ValueError("必须配置 LLM base_url")
if not model:
raise ValueError("必须配置 LLM model")
self.base_url = base_url.rstrip("/")
self.api_key = api_key
self.model = model
self.action_analysis_prompt = action_analysis_prompt or ACTION_ANALYSIS_PROMPT
self.timeout_sec = timeout_sec
self.temperature = temperature
self.transport = transport or _default_transport
def understand_request(self, text: str) -> LlmIntentResult:
"""调用 LLM 识别用户意图。"""
payload = self._complete_json(INTENT_PROMPT, {"user_text": text})
return LlmIntentResult(
intent=_string(payload, "intent", "deploy"), # type: ignore[arg-type]
mode_preference=_string(payload, "mode_preference", "未指定"), # type: ignore[arg-type]
strategy_preference=_string(payload, "strategy_preference", "未指定"), # type: ignore[arg-type]
confidence=_float(payload, "confidence", 0.0),
reasons=_string_list(payload.get("reasons")),
needs_clarification=bool(payload.get("needs_clarification", False)),
clarification_questions=_string_list(payload.get("clarification_questions")),
)
def extract_params(self, text: str, base_params: dict[str, Any] | None = None) -> LlmParamResult:
"""调用 LLM 抽取参数,并避免把敏感值发送进 prompt。"""
original_base = dict(base_params or {})
safe_base = _redact_sensitive(original_base)
payload = self._complete_json(
PARAM_PROMPT,
{
"user_text": text,
"base_params": safe_base,
"required_params": list(REQUIRED_PARAMS),
"default_params": DEFAULT_PARAMS,
},
)
extracted = _dict(payload.get("extracted_params"))
merged = original_base.copy()
for key, value in extracted.items():
if key in SENSITIVE_KEYS and value == "***":
continue
merged[key] = value
control = _dict(payload.get("extracted_control"))
missing = [key for key in REQUIRED_PARAMS if not merged.get(key)]
sensitive = [key for key in ("CLIENT_SECRET", "CLIENT_ID") if merged.get(key)]
llm_ambiguous = _string_list(payload.get("ambiguous_fields"))
return LlmParamResult(
extracted_params=merged,
extracted_control=control,
missing_required_params=missing,
ambiguous_fields=llm_ambiguous,
sensitive_fields_present=sensitive,
)
def generate_plan(
self,
*,
params: dict[str, Any],
intent: str,
strategy: ExecutionStrategy,
skill_policy: dict[str, Any],
tool_summaries: list[dict[str, Any]],
) -> LlmDeployPlan:
"""调用 LLM 生成部署计划。"""
payload = self._complete_json(
PLAN_PROMPT,
{
"params": _redact_sensitive(params),
"intent": intent,
"execution_strategy": strategy,
"skill_policy": skill_policy,
"tool_summaries": tool_summaries,
"allowed_actions": list(ALLOWED_ACTIONS),
"global_action_sequence": list(GLOBAL_ACTION_SEQUENCE),
"ip_action_sequence": list(IP_ACTION_SEQUENCE),
},
)
planned_actions = _string_list(payload.get("planned_actions")) or list(GLOBAL_ACTION_SEQUENCE)
return LlmDeployPlan(
summary=_string(payload, "summary", "PAM 部署计划"),
risk_notes=_string_list(payload.get("risk_notes")),
planned_actions=planned_actions,
requires_confirmation=bool(payload.get("requires_confirmation", True)),
execution_strategy=_string(payload, "execution_strategy", strategy), # type: ignore[arg-type]
)
def decide_execution_mode(
self,
*,
text: str,
params: dict[str, Any],
intent: str,
strategy: ExecutionStrategy,
allowed_modes: list[str],
tool_summaries: list[dict[str, Any]],
) -> LlmModeDecision:
"""调用 LLM 决定本次任务进入固定 runtime 或 agentic skill。"""
payload = self._complete_json(
MODE_PROMPT,
{
"user_text": text,
"params": _redact_sensitive(params),
"intent": intent,
"execution_strategy": strategy,
"allowed_modes": allowed_modes,
"tool_summaries": tool_summaries,
},
)
return LlmModeDecision(
mode=_string(payload, "mode", "fixed_runtime"), # type: ignore[arg-type]
reason=_string(payload, "reason", ""),
risk_level=_string(payload, "risk_level", "medium"), # type: ignore[arg-type]
requires_confirmation=bool(payload.get("requires_confirmation", True)),
)
def analyze_action_result(
self,
*,
action: str,
result: ActionResult,
state_summary: dict[str, Any],
) -> LlmActionAnalysis:
"""调用 LLM 分析 action 结果,返回结构化诊断建议。"""
payload = self._complete_json(
self.action_analysis_prompt,
{
"action": action,
"result": {
"backend": result.backend,
"ok": result.ok,
"exit_code": result.exit_code,
"tool_name": result.tool_name,
"values": _redact_sensitive(result.values),
"stderr": _truncate_text(result.stderr),
"error_summary": result.error_summary,
},
"state_summary": _redact_sensitive(state_summary),
},
)
return LlmActionAnalysis(
action=_string(payload, "action", action),
has_anomaly=bool(payload.get("has_anomaly", False)),
severity=_string(payload, "severity", "info"), # type: ignore[arg-type]
possible_reason=_string(payload, "possible_reason", ""),
suggested_action=_string(payload, "suggested_action", ""),
requires_confirmation=bool(payload.get("requires_confirmation", False)),
should_continue=bool(payload.get("should_continue", True)),
notes=_string_list(payload.get("notes")),
)
def _complete_json(self, instruction: str, input_payload: dict[str, Any]) -> dict[str, Any]:
"""发送 chat/completions 请求,并解析 JSON 对象响应。"""
request_payload = {
"model": self.model,
"temperature": self.temperature,
"response_format": {"type": "json_object"},
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": instruction
+ "\n\n输入 JSON:\n"
+ json.dumps(input_payload, ensure_ascii=False, sort_keys=True),
},
],
}
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
response = self.transport(
_chat_completions_url(self.base_url),
headers,
request_payload,
self.timeout_sec,
)
content = _message_content(response)
parsed = _loads_json_object(content)
if not isinstance(parsed, dict):
raise ValueError("LLM 响应必须是 JSON object")
return parsed
def _default_transport(
url: str,
headers: dict[str, str],
payload: dict[str, Any],
timeout_sec: float,
) -> dict[str, Any]:
"""使用标准库 urllib 发送 JSON POST 请求。"""
request = urllib.request.Request(
url,
data=json.dumps(payload).encode("utf-8"),
headers=headers,
method="POST",
)
with urllib.request.urlopen(request, timeout=timeout_sec) as response:
raw = response.read().decode("utf-8")
decoded = json.loads(raw)
if not isinstance(decoded, dict):
raise ValueError("LLM HTTP 响应必须是 JSON object")
return decoded
def load_prompt_text(path: str | None) -> str:
"""读取自定义提示词文件。"""
if not path:
return ACTION_ANALYSIS_PROMPT
prompt_path = Path(path)
return prompt_path.read_text(encoding="utf-8").strip() or ACTION_ANALYSIS_PROMPT
def _chat_completions_url(base_url: str) -> str:
"""把 base_url 规范化为 chat/completions endpoint。"""
clean = base_url.rstrip("/")
if clean.endswith("/chat/completions"):
return clean
return f"{clean}/chat/completions"
def _message_content(response: dict[str, Any]) -> Any:
"""从 OpenAI-compatible 响应中提取 message.content。"""
try:
content = response["choices"][0]["message"]["content"]
except (KeyError, IndexError, TypeError) as exc:
raise ValueError("LLM 响应缺少 choices[0].message.content") from exc
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
parts.append(str(item.get("text", "")))
elif isinstance(item, str):
parts.append(item)
return "".join(parts)
return content
def _loads_json_object(content: Any) -> Any:
"""把 message.content 解析为 JSON 对象。"""
if isinstance(content, dict):
return content
if not isinstance(content, str):
raise ValueError("LLM message content 必须是 JSON 文本")
return json.loads(content)
def _redact_sensitive(value: Any) -> Any:
"""递归脱敏 prompt 输入中的敏感字段。"""
if isinstance(value, dict):
redacted: dict[str, Any] = {}
for key, item in value.items():
if str(key) in SENSITIVE_KEYS:
redacted[str(key)] = "***"
else:
redacted[str(key)] = _redact_sensitive(item)
return redacted
if isinstance(value, list):
return [_redact_sensitive(item) for item in value]
return value
def _truncate_text(value: str, limit: int = 1000) -> str:
"""截断发送给 LLM 的长文本,避免传入完整日志。"""
if len(value) <= limit:
return value
return value[:limit] + "...[已截断]"
def _string(payload: dict[str, Any], key: str, default: str) -> str:
"""安全读取字符串字段。"""
value = payload.get(key, default)
return str(value) if value is not None else default
def _float(payload: dict[str, Any], key: str, default: float) -> float:
"""安全读取浮点数字段。"""
try:
return float(payload.get(key, default))
except (TypeError, ValueError):
return default
def _dict(value: Any) -> dict[str, Any]:
"""确保返回 dict非法值降级为空 dict。"""
return value if isinstance(value, dict) else {}
def _string_list(value: Any) -> list[str]:
"""确保返回字符串列表。"""
if isinstance(value, list):
return [str(item) for item in value]
if value in (None, ""):
return []
return [str(value)]