agent_deply/pam_deploy_graph/llm/openai_compatible.py
dark d3f5c82d98 feat: 补充 Agent 运行日志并增加 LLM 测试命令
- 新增统一日志工具,支持日志文件路径和级别配置
- 记录 CLI/chat、Agent、LLM、action、MCP、LangGraph、checkpoint 等关键流程
- 对日志中的 token、secret、api_key、Authorization 等敏感信息做脱敏
- chat 新增 llm test 命令,用于验证当前 LLM client 是否正常加载
- 同步 README、打包文档和 run.sh 帮助说明
- 补充日志脱敏和 llm test 相关测试
2026-06-04 10:51:59 +08:00

361 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""OpenAI-compatible HTTP LLM client。
该 client 面向暴露 `/chat/completions` 的模型服务,并使用 OpenAI 风格的
请求/响应结构。实现只依赖 Python 标准库,便于控制运行时依赖体积。
"""
from __future__ import annotations
import json
import logging
import time
from pathlib import Path
import urllib.request
from collections.abc import Callable
from typing import Any
from pam_deploy_graph.constants import (
ALLOWED_ACTIONS,
DEFAULT_PARAMS,
GLOBAL_ACTION_SEQUENCE,
IP_ACTION_SEQUENCE,
REQUIRED_PARAMS,
SENSITIVE_KEYS,
)
from pam_deploy_graph.logging_utils import json_for_log, redact_for_log
from pam_deploy_graph.models import ExecutionStrategy, LlmDeployPlan, LlmIntentResult, LlmParamResult
from pam_deploy_graph.models import ActionResult, LlmActionAnalysis
from .prompts import ACTION_ANALYSIS_PROMPT, INTENT_PROMPT, PARAM_PROMPT, PLAN_PROMPT, SYSTEM_PROMPT
JsonTransport = Callable[[str, dict[str, str], dict[str, Any], float], dict[str, Any]]
logger = logging.getLogger(__name__)
class OpenAICompatibleLlmClient:
"""通过 OpenAI-compatible HTTP 接口获取结构化 LLM 输出。"""
def __init__(
self,
*,
base_url: str,
api_key: str,
model: str,
action_analysis_prompt: str | None = None,
timeout_sec: float = 30,
temperature: float = 0,
transport: JsonTransport | None = None,
) -> None:
"""保存连接参数、模型参数和可替换的 HTTP transport。"""
if not base_url:
raise ValueError("必须配置 LLM base_url")
if not model:
raise ValueError("必须配置 LLM model")
self.base_url = base_url.rstrip("/")
self.api_key = api_key
self.model = model
self.action_analysis_prompt = action_analysis_prompt or ACTION_ANALYSIS_PROMPT
self.timeout_sec = timeout_sec
self.temperature = temperature
self.transport = transport or _default_transport
logger.info(
"OpenAI-compatible LLM client 初始化 base_url=%s endpoint=%s model=%s has_api_key=%s timeout=%s temperature=%s custom_transport=%s",
self.base_url,
_chat_completions_url(self.base_url),
self.model,
bool(self.api_key),
self.timeout_sec,
self.temperature,
transport is not None,
)
def understand_request(self, text: str) -> LlmIntentResult:
"""调用 LLM 识别用户意图。"""
payload = self._complete_json("understand_request", INTENT_PROMPT, {"user_text": text})
return LlmIntentResult(
intent=_string(payload, "intent", "deploy"), # type: ignore[arg-type]
mode_preference=_string(payload, "mode_preference", "未指定"), # type: ignore[arg-type]
strategy_preference=_string(payload, "strategy_preference", "未指定"), # type: ignore[arg-type]
confidence=_float(payload, "confidence", 0.0),
reasons=_string_list(payload.get("reasons")),
needs_clarification=bool(payload.get("needs_clarification", False)),
clarification_questions=_string_list(payload.get("clarification_questions")),
)
def extract_params(self, text: str, base_params: dict[str, Any] | None = None) -> LlmParamResult:
"""调用 LLM 抽取参数,并避免把敏感值发送进 prompt。"""
original_base = dict(base_params or {})
safe_base = _redact_sensitive(original_base)
payload = self._complete_json(
"extract_params",
PARAM_PROMPT,
{
"user_text": text,
"base_params": safe_base,
"required_params": list(REQUIRED_PARAMS),
"default_params": DEFAULT_PARAMS,
},
)
extracted = _dict(payload.get("extracted_params"))
merged = original_base.copy()
for key, value in extracted.items():
if key in SENSITIVE_KEYS and value == "***":
continue
merged[key] = value
control = _dict(payload.get("extracted_control"))
missing = [key for key in REQUIRED_PARAMS if not merged.get(key)]
sensitive = [key for key in ("CLIENT_SECRET", "CLIENT_ID") if merged.get(key)]
llm_ambiguous = _string_list(payload.get("ambiguous_fields"))
return LlmParamResult(
extracted_params=merged,
extracted_control=control,
missing_required_params=missing,
ambiguous_fields=llm_ambiguous,
sensitive_fields_present=sensitive,
)
def generate_plan(
self,
*,
params: dict[str, Any],
intent: str,
strategy: ExecutionStrategy,
) -> LlmDeployPlan:
"""调用 LLM 生成部署计划。"""
payload = self._complete_json(
"generate_plan",
PLAN_PROMPT,
{
"params": _redact_sensitive(params),
"intent": intent,
"execution_strategy": strategy,
"allowed_actions": list(ALLOWED_ACTIONS),
"global_action_sequence": list(GLOBAL_ACTION_SEQUENCE),
"ip_action_sequence": list(IP_ACTION_SEQUENCE),
},
)
planned_actions = _string_list(payload.get("planned_actions")) or list(GLOBAL_ACTION_SEQUENCE)
return LlmDeployPlan(
summary=_string(payload, "summary", "PAM 部署计划"),
risk_notes=_string_list(payload.get("risk_notes")),
planned_actions=planned_actions,
requires_confirmation=bool(payload.get("requires_confirmation", True)),
execution_strategy=_string(payload, "execution_strategy", strategy), # type: ignore[arg-type]
)
def analyze_action_result(
self,
*,
action: str,
result: ActionResult,
state_summary: dict[str, Any],
) -> LlmActionAnalysis:
"""调用 LLM 分析 action 结果,返回结构化诊断建议。"""
payload = self._complete_json(
"analyze_action_result",
self.action_analysis_prompt,
{
"action": action,
"result": {
"backend": result.backend,
"ok": result.ok,
"exit_code": result.exit_code,
"tool_name": result.tool_name,
"values": _redact_sensitive(result.values),
"stderr": _truncate_text(result.stderr),
"error_summary": result.error_summary,
},
"state_summary": _redact_sensitive(state_summary),
},
)
return LlmActionAnalysis(
action=_string(payload, "action", action),
has_anomaly=bool(payload.get("has_anomaly", False)),
severity=_string(payload, "severity", "info"), # type: ignore[arg-type]
possible_reason=_string(payload, "possible_reason", ""),
suggested_action=_string(payload, "suggested_action", ""),
requires_confirmation=bool(payload.get("requires_confirmation", False)),
should_continue=bool(payload.get("should_continue", True)),
notes=_string_list(payload.get("notes")),
)
def _complete_json(self, operation: str, instruction: str, input_payload: dict[str, Any]) -> dict[str, Any]:
"""发送 chat/completions 请求,并解析 JSON 对象响应。"""
started_at = time.perf_counter()
endpoint = _chat_completions_url(self.base_url)
request_payload = {
"model": self.model,
"temperature": self.temperature,
"response_format": {"type": "json_object"},
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": instruction
+ "\n\n输入 JSON:\n"
+ json.dumps(input_payload, ensure_ascii=False, sort_keys=True),
},
],
}
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
logger.info(
"LLM 请求开始 operation=%s endpoint=%s model=%s timeout=%s has_api_key=%s input=%s",
operation,
endpoint,
self.model,
self.timeout_sec,
bool(self.api_key),
json_for_log(input_payload, max_text_len=1600),
)
try:
response = self.transport(
endpoint,
headers,
request_payload,
self.timeout_sec,
)
content = _message_content(response)
logger.info(
"LLM 原始响应 operation=%s duration_ms=%s content=%s",
operation,
int((time.perf_counter() - started_at) * 1000),
redact_for_log(content, max_text_len=1600),
)
parsed = _loads_json_object(content)
if not isinstance(parsed, dict):
raise ValueError("LLM 响应必须是 JSON object")
except Exception:
logger.exception(
"LLM 请求失败 operation=%s endpoint=%s duration_ms=%s input=%s",
operation,
endpoint,
int((time.perf_counter() - started_at) * 1000),
json_for_log(input_payload, max_text_len=1600),
)
raise
logger.info(
"LLM 请求完成 operation=%s duration_ms=%s response_keys=%s response=%s",
operation,
int((time.perf_counter() - started_at) * 1000),
sorted(parsed.keys()),
json_for_log(parsed, max_text_len=1600),
)
return parsed
def _default_transport(
url: str,
headers: dict[str, str],
payload: dict[str, Any],
timeout_sec: float,
) -> dict[str, Any]:
"""使用标准库 urllib 发送 JSON POST 请求。"""
request = urllib.request.Request(
url,
data=json.dumps(payload).encode("utf-8"),
headers=headers,
method="POST",
)
with urllib.request.urlopen(request, timeout=timeout_sec) as response:
raw = response.read().decode("utf-8")
decoded = json.loads(raw)
if not isinstance(decoded, dict):
raise ValueError("LLM HTTP 响应必须是 JSON object")
return decoded
def load_prompt_text(path: str | None) -> str:
"""读取自定义提示词文件。"""
if not path:
return ACTION_ANALYSIS_PROMPT
prompt_path = Path(path)
return prompt_path.read_text(encoding="utf-8").strip() or ACTION_ANALYSIS_PROMPT
def _chat_completions_url(base_url: str) -> str:
"""把 base_url 规范化为 chat/completions endpoint。"""
clean = base_url.rstrip("/")
if clean.endswith("/chat/completions"):
return clean
return f"{clean}/chat/completions"
def _message_content(response: dict[str, Any]) -> Any:
"""从 OpenAI-compatible 响应中提取 message.content。"""
try:
content = response["choices"][0]["message"]["content"]
except (KeyError, IndexError, TypeError) as exc:
raise ValueError("LLM 响应缺少 choices[0].message.content") from exc
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
parts.append(str(item.get("text", "")))
elif isinstance(item, str):
parts.append(item)
return "".join(parts)
return content
def _loads_json_object(content: Any) -> Any:
"""把 message.content 解析为 JSON 对象。"""
if isinstance(content, dict):
return content
if not isinstance(content, str):
raise ValueError("LLM message content 必须是 JSON 文本")
return json.loads(content)
def _redact_sensitive(value: Any) -> Any:
"""递归脱敏 prompt 输入中的敏感字段。"""
if isinstance(value, dict):
redacted: dict[str, Any] = {}
for key, item in value.items():
if str(key) in SENSITIVE_KEYS:
redacted[str(key)] = "***"
else:
redacted[str(key)] = _redact_sensitive(item)
return redacted
if isinstance(value, list):
return [_redact_sensitive(item) for item in value]
return value
def _truncate_text(value: str, limit: int = 1000) -> str:
"""截断发送给 LLM 的长文本,避免传入完整日志。"""
if len(value) <= limit:
return value
return value[:limit] + "...[已截断]"
def _string(payload: dict[str, Any], key: str, default: str) -> str:
"""安全读取字符串字段。"""
value = payload.get(key, default)
return str(value) if value is not None else default
def _float(payload: dict[str, Any], key: str, default: float) -> float:
"""安全读取浮点数字段。"""
try:
return float(payload.get(key, default))
except (TypeError, ValueError):
return default
def _dict(value: Any) -> dict[str, Any]:
"""确保返回 dict非法值降级为空 dict。"""
return value if isinstance(value, dict) else {}
def _string_list(value: Any) -> list[str]:
"""确保返回字符串列表。"""
if isinstance(value, list):
return [str(item) for item in value]
if value in (None, ""):
return []
return [str(value)]