agent_deply/pam_deploy_graph/llm/openai_compatible.py
dark 4250a7b221 LLM action 结果分析不再传 state_summary
调整了 agent.py 和 LLM client 协议/实现。
现在只传当前 action 的结构化结果和必要诊断日志,避免历史运行态影响判断。
提示词和文档也已同步说明。

verify-ip 增加健康检查重试
默认 VERIFY_INTERVAL_SEC=10、VERIFY_MAX_ATTEMPTS=12,约 2 分钟。
verify-ip 未通过但未达到最大次数时,会播报进度、保存 checkpoint,并继续从当前 verify-ip 重试,不会进入 download-log。
参数已加入 config.txt.example、脚本配置读取、README、打包 README、Skill 文档和流程图。
2026-06-04 16:57:16 +08:00

416 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""OpenAI-compatible HTTP LLM client。
该 client 面向暴露 `/chat/completions` 的模型服务,并使用 OpenAI 风格的
请求/响应结构。实现只依赖 Python 标准库,便于控制运行时依赖体积。
"""
from __future__ import annotations
import json
import logging
import time
from pathlib import Path
import urllib.request
from collections.abc import Callable
from typing import Any
from pam_deploy_graph.constants import (
ALLOWED_ACTIONS,
DEFAULT_PARAMS,
GLOBAL_ACTION_SEQUENCE,
IP_ACTION_SEQUENCE,
REQUIRED_PARAMS,
SENSITIVE_KEYS,
)
from pam_deploy_graph.logging_utils import json_for_log, redact_for_log
from pam_deploy_graph.models import ExecutionStrategy, LlmDeployPlan, LlmIntentResult, LlmParamResult
from pam_deploy_graph.models import ActionResult, LlmActionAnalysis
from .prompts import ACTION_ANALYSIS_PROMPT, INTENT_PROMPT, PARAM_PROMPT, PLAN_PROMPT, SYSTEM_PROMPT
JsonTransport = Callable[[str, dict[str, str], dict[str, Any], float], dict[str, Any]]
logger = logging.getLogger(__name__)
class OpenAICompatibleLlmClient:
"""通过 OpenAI-compatible HTTP 接口获取结构化 LLM 输出。"""
def __init__(
self,
*,
base_url: str,
api_key: str,
model: str,
action_analysis_prompt: str | None = None,
timeout_sec: float = 30,
temperature: float = 0,
transport: JsonTransport | None = None,
) -> None:
"""保存连接参数、模型参数和可替换的 HTTP transport。"""
if not base_url:
raise ValueError("必须配置 LLM base_url")
if not model:
raise ValueError("必须配置 LLM model")
self.base_url = base_url.rstrip("/")
self.api_key = api_key
self.model = model
self.action_analysis_prompt = action_analysis_prompt or ACTION_ANALYSIS_PROMPT
self.timeout_sec = timeout_sec
self.temperature = temperature
self.transport = transport or _default_transport
logger.info(
"OpenAI-compatible LLM client 初始化 base_url=%s endpoint=%s model=%s has_api_key=%s timeout=%s temperature=%s custom_transport=%s",
self.base_url,
_chat_completions_url(self.base_url),
self.model,
bool(self.api_key),
self.timeout_sec,
self.temperature,
transport is not None,
)
def understand_request(self, text: str) -> LlmIntentResult:
"""调用 LLM 识别用户意图。"""
payload = self._complete_json("understand_request", INTENT_PROMPT, {"user_text": text})
return LlmIntentResult(
intent=_string(payload, "intent", "deploy"), # type: ignore[arg-type]
mode_preference=_string(payload, "mode_preference", "未指定"), # type: ignore[arg-type]
strategy_preference=_string(payload, "strategy_preference", "未指定"), # type: ignore[arg-type]
confidence=_float(payload, "confidence", 0.0),
reasons=_string_list(payload.get("reasons")),
needs_clarification=bool(payload.get("needs_clarification", False)),
clarification_questions=_string_list(payload.get("clarification_questions")),
)
def extract_params(self, text: str, base_params: dict[str, Any] | None = None) -> LlmParamResult:
"""调用 LLM 抽取参数,并避免把敏感值发送进 prompt。"""
original_base = dict(base_params or {})
safe_base = _redact_sensitive(original_base)
payload = self._complete_json(
"extract_params",
PARAM_PROMPT,
{
"user_text": text,
"base_params": safe_base,
"required_params": list(REQUIRED_PARAMS),
"default_params": DEFAULT_PARAMS,
},
)
extracted = _dict(payload.get("extracted_params"))
merged = original_base.copy()
for key, value in extracted.items():
if key in SENSITIVE_KEYS and value == "***":
continue
merged[key] = value
control = _dict(payload.get("extracted_control"))
missing = [key for key in REQUIRED_PARAMS if not merged.get(key)]
sensitive = [key for key in ("CLIENT_SECRET", "CLIENT_ID") if merged.get(key)]
llm_ambiguous = _string_list(payload.get("ambiguous_fields"))
return LlmParamResult(
extracted_params=merged,
extracted_control=control,
missing_required_params=missing,
ambiguous_fields=llm_ambiguous,
sensitive_fields_present=sensitive,
)
def generate_plan(
self,
*,
params: dict[str, Any],
intent: str,
strategy: ExecutionStrategy,
) -> LlmDeployPlan:
"""调用 LLM 生成部署计划。"""
payload = self._complete_json(
"generate_plan",
PLAN_PROMPT,
{
"params": _redact_sensitive(params),
"intent": intent,
"execution_strategy": strategy,
"allowed_actions": list(ALLOWED_ACTIONS),
"global_action_sequence": list(GLOBAL_ACTION_SEQUENCE),
"ip_action_sequence": list(IP_ACTION_SEQUENCE),
},
)
planned_actions = _string_list(payload.get("planned_actions")) or list(GLOBAL_ACTION_SEQUENCE)
return LlmDeployPlan(
summary=_string(payload, "summary", "PAM 部署计划"),
risk_notes=_string_list(payload.get("risk_notes")),
planned_actions=planned_actions,
requires_confirmation=bool(payload.get("requires_confirmation", True)),
execution_strategy=_string(payload, "execution_strategy", strategy), # type: ignore[arg-type]
)
def analyze_action_result(
self,
*,
action: str,
result: ActionResult,
) -> LlmActionAnalysis:
"""调用 LLM 分析 action 结果,返回结构化诊断建议。"""
payload = self._complete_json(
"analyze_action_result",
self.action_analysis_prompt,
{
"action": action,
"result": _action_review_result_payload(action, result),
},
)
return LlmActionAnalysis(
action=_string(payload, "action", action),
has_anomaly=bool(payload.get("has_anomaly", False)),
severity=_string(payload, "severity", "info"), # type: ignore[arg-type]
possible_reason=_string(payload, "possible_reason", ""),
suggested_action=_string(payload, "suggested_action", ""),
requires_confirmation=bool(payload.get("requires_confirmation", False)),
should_continue=bool(payload.get("should_continue", True)),
progress_complete=_optional_bool(payload.get("progress_complete")),
notes=_string_list(payload.get("notes")),
)
def _complete_json(self, operation: str, instruction: str, input_payload: dict[str, Any]) -> dict[str, Any]:
"""发送 chat/completions 请求,并解析 JSON 对象响应。"""
started_at = time.perf_counter()
endpoint = _chat_completions_url(self.base_url)
request_payload = {
"model": self.model,
"temperature": self.temperature,
"response_format": {"type": "json_object"},
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": instruction
+ "\n\n输入 JSON:\n"
+ json.dumps(input_payload, ensure_ascii=False, sort_keys=True),
},
],
}
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
logger.info(
"LLM 请求开始 operation=%s endpoint=%s model=%s timeout=%s has_api_key=%s input=%s",
operation,
endpoint,
self.model,
self.timeout_sec,
bool(self.api_key),
json_for_log(input_payload, max_text_len=1600),
)
try:
response = self.transport(
endpoint,
headers,
request_payload,
self.timeout_sec,
)
content = _message_content(response)
logger.info(
"LLM 原始响应 operation=%s duration_ms=%s content=%s",
operation,
int((time.perf_counter() - started_at) * 1000),
redact_for_log(content, max_text_len=1600),
)
parsed = _loads_json_object(content)
if not isinstance(parsed, dict):
raise ValueError("LLM 响应必须是 JSON object")
except Exception:
logger.exception(
"LLM 请求失败 operation=%s endpoint=%s duration_ms=%s input=%s",
operation,
endpoint,
int((time.perf_counter() - started_at) * 1000),
json_for_log(input_payload, max_text_len=1600),
)
raise
logger.info(
"LLM 请求完成 operation=%s duration_ms=%s response_keys=%s response=%s",
operation,
int((time.perf_counter() - started_at) * 1000),
sorted(parsed.keys()),
json_for_log(parsed, max_text_len=1600),
)
return parsed
def _default_transport(
url: str,
headers: dict[str, str],
payload: dict[str, Any],
timeout_sec: float,
) -> dict[str, Any]:
"""使用标准库 urllib 发送 JSON POST 请求。"""
request = urllib.request.Request(
url,
data=json.dumps(payload).encode("utf-8"),
headers=headers,
method="POST",
)
with urllib.request.urlopen(request, timeout=timeout_sec) as response:
raw = response.read().decode("utf-8")
decoded = json.loads(raw)
if not isinstance(decoded, dict):
raise ValueError("LLM HTTP 响应必须是 JSON object")
return decoded
def load_prompt_text(path: str | None) -> str:
"""读取自定义提示词文件。"""
if not path:
return ACTION_ANALYSIS_PROMPT
prompt_path = Path(path)
return prompt_path.read_text(encoding="utf-8").strip() or ACTION_ANALYSIS_PROMPT
def _chat_completions_url(base_url: str) -> str:
"""把 base_url 规范化为 chat/completions endpoint。"""
clean = base_url.rstrip("/")
if clean.endswith("/chat/completions"):
return clean
return f"{clean}/chat/completions"
def _message_content(response: dict[str, Any]) -> Any:
"""从 OpenAI-compatible 响应中提取 message.content。"""
try:
content = response["choices"][0]["message"]["content"]
except (KeyError, IndexError, TypeError) as exc:
raise ValueError("LLM 响应缺少 choices[0].message.content") from exc
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
parts.append(str(item.get("text", "")))
elif isinstance(item, str):
parts.append(item)
return "".join(parts)
return content
def _loads_json_object(content: Any) -> Any:
"""把 message.content 解析为 JSON 对象。"""
if isinstance(content, dict):
return content
if not isinstance(content, str):
raise ValueError("LLM message content 必须是 JSON 文本")
return json.loads(content)
def _redact_sensitive(value: Any) -> Any:
"""递归脱敏 prompt 输入中的敏感字段。"""
if isinstance(value, dict):
redacted: dict[str, Any] = {}
for key, item in value.items():
if str(key) in SENSITIVE_KEYS:
redacted[str(key)] = "***"
else:
redacted[str(key)] = _redact_sensitive(item)
return redacted
if isinstance(value, list):
return [_redact_sensitive(item) for item in value]
return value
def _action_review_result_payload(action: str, result: ActionResult) -> dict[str, Any]:
"""构造 action 审核输入,避免把正常脚本日志当作错误喂给 LLM。"""
payload: dict[str, Any] = {
"backend": result.backend,
"ok": result.ok,
"exit_code": result.exit_code,
"tool_name": result.tool_name,
"values": _redact_sensitive(result.values),
"error_summary": result.error_summary,
}
if _needs_diagnostic_log(action, result):
diagnostic = _diagnostic_log_text(result)
if diagnostic:
payload["diagnostic_log"] = diagnostic
return payload
def _needs_diagnostic_log(action: str, result: ActionResult) -> bool:
"""仅在失败或业务异常时把少量诊断日志交给 LLM。"""
if not result.ok or result.error_summary or result.values.get("PENDING_AGENT_CONFIRMATION"):
return True
if action == "verify-ip":
success = result.values.get("SUCCESS")
if success is not None and str(success).lower() not in ("true", "1", "yes"):
return True
return False
def _diagnostic_log_text(result: ActionResult) -> str:
"""优先使用错误摘要;必要时取 stderr/stdout/raw_output 的尾部作为诊断上下文。"""
if result.error_summary:
return _truncate_text(result.error_summary)
for text in (result.stderr, result.stdout, result.raw_output):
stripped = text.strip()
if stripped:
return _tail_text(stripped)
return ""
def _truncate_text(value: str, limit: int = 1000) -> str:
"""截断发送给 LLM 的长文本,避免传入完整日志。"""
if len(value) <= limit:
return value
return value[:limit] + "...[已截断]"
def _tail_text(value: str, limit: int = 1000) -> str:
"""保留长诊断日志尾部,通常错误原因更靠近末尾。"""
if len(value) <= limit:
return value
marker = "[已截断]..."
return marker + value[-(limit - len(marker)) :]
def _string(payload: dict[str, Any], key: str, default: str) -> str:
"""安全读取字符串字段。"""
value = payload.get(key, default)
return str(value) if value is not None else default
def _float(payload: dict[str, Any], key: str, default: float) -> float:
"""安全读取浮点数字段。"""
try:
return float(payload.get(key, default))
except (TypeError, ValueError):
return default
def _optional_bool(value: Any) -> bool | None:
"""解析可选布尔值,字段缺失时保留 None。"""
if value is None:
return None
if isinstance(value, bool):
return value
if isinstance(value, str):
lowered = value.strip().lower()
if lowered in ("", "null", "none"):
return None
if lowered in ("true", "1", "yes", "y"):
return True
if lowered in ("false", "0", "no", "n"):
return False
return bool(value)
def _dict(value: Any) -> dict[str, Any]:
"""确保返回 dict非法值降级为空 dict。"""
return value if isinstance(value, dict) else {}
def _string_list(value: Any) -> list[str]:
"""确保返回字符串列表。"""
if isinstance(value, list):
return [str(item) for item in value]
if value in (None, ""):
return []
return [str(value)]