- 扩展 LLM client 协议,支持普通对话、日志分析和单 action 解析 - chat 非内置输入默认进入 LLM 普通对话,不再本地拦截问候 - 新增 ask、log analyze、action propose、action run 等交互命令 - 单 action 执行前强制人工确认,并复用现有 ActionRouter、审核、事件和 checkpoint - 日志分析默认读取尾部内容并脱敏后再提交给 LLM - 更新 README、发布包 README 和 run.sh help - 补充 LLM 与 chat 交互相关测试
533 lines
19 KiB
Python
533 lines
19 KiB
Python
"""OpenAI-compatible HTTP LLM client。
|
||
|
||
该 client 面向暴露 `/chat/completions` 的模型服务,并使用 OpenAI 风格的
|
||
请求/响应结构。实现只依赖 Python 标准库,便于控制运行时依赖体积。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import time
|
||
from pathlib import Path
|
||
import urllib.request
|
||
from collections.abc import Callable
|
||
from typing import Any
|
||
|
||
from pam_deploy_graph.constants import (
|
||
ALLOWED_ACTIONS,
|
||
DEFAULT_PARAMS,
|
||
GLOBAL_ACTION_SEQUENCE,
|
||
IP_ACTION_SEQUENCE,
|
||
REQUIRED_PARAMS,
|
||
SENSITIVE_KEYS,
|
||
)
|
||
from pam_deploy_graph.logging_utils import json_for_log, redact_for_log
|
||
from pam_deploy_graph.models import ExecutionStrategy, LlmDeployPlan, LlmIntentResult, LlmParamResult, LlmSingleActionProposal
|
||
from pam_deploy_graph.models import ActionResult, LlmActionAnalysis
|
||
|
||
from .prompts import (
|
||
ACTION_ANALYSIS_PROMPT,
|
||
CHAT_PROMPT,
|
||
INTENT_PROMPT,
|
||
LOG_ANALYSIS_PROMPT,
|
||
PARAM_PROMPT,
|
||
PLAN_PROMPT,
|
||
SINGLE_ACTION_PROMPT,
|
||
SYSTEM_PROMPT,
|
||
)
|
||
|
||
JsonTransport = Callable[[str, dict[str, str], dict[str, Any], float], dict[str, Any]]
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class OpenAICompatibleLlmClient:
|
||
"""通过 OpenAI-compatible HTTP 接口获取结构化 LLM 输出。"""
|
||
|
||
def __init__(
|
||
self,
|
||
*,
|
||
base_url: str,
|
||
api_key: str,
|
||
model: str,
|
||
action_analysis_prompt: str | None = None,
|
||
timeout_sec: float = 30,
|
||
temperature: float = 0,
|
||
transport: JsonTransport | None = None,
|
||
) -> None:
|
||
"""保存连接参数、模型参数和可替换的 HTTP transport。"""
|
||
if not base_url:
|
||
raise ValueError("必须配置 LLM base_url")
|
||
if not model:
|
||
raise ValueError("必须配置 LLM model")
|
||
self.base_url = base_url.rstrip("/")
|
||
self.api_key = api_key
|
||
self.model = model
|
||
self.action_analysis_prompt = action_analysis_prompt or ACTION_ANALYSIS_PROMPT
|
||
self.timeout_sec = timeout_sec
|
||
self.temperature = temperature
|
||
self.transport = transport or _default_transport
|
||
logger.info(
|
||
"OpenAI-compatible LLM client 初始化 base_url=%s endpoint=%s model=%s has_api_key=%s timeout=%s temperature=%s custom_transport=%s",
|
||
self.base_url,
|
||
_chat_completions_url(self.base_url),
|
||
self.model,
|
||
bool(self.api_key),
|
||
self.timeout_sec,
|
||
self.temperature,
|
||
transport is not None,
|
||
)
|
||
|
||
def understand_request(self, text: str) -> LlmIntentResult:
|
||
"""调用 LLM 识别用户意图。"""
|
||
payload = self._complete_json("understand_request", INTENT_PROMPT, {"user_text": text})
|
||
return LlmIntentResult(
|
||
intent=_string(payload, "intent", "deploy"), # type: ignore[arg-type]
|
||
mode_preference=_string(payload, "mode_preference", "未指定"), # type: ignore[arg-type]
|
||
strategy_preference=_string(payload, "strategy_preference", "未指定"), # type: ignore[arg-type]
|
||
confidence=_float(payload, "confidence", 0.0),
|
||
reasons=_string_list(payload.get("reasons")),
|
||
needs_clarification=bool(payload.get("needs_clarification", False)),
|
||
clarification_questions=_string_list(payload.get("clarification_questions")),
|
||
)
|
||
|
||
def extract_params(self, text: str, base_params: dict[str, Any] | None = None) -> LlmParamResult:
|
||
"""调用 LLM 抽取参数,并避免把敏感值发送进 prompt。"""
|
||
original_base = dict(base_params or {})
|
||
safe_base = _redact_sensitive(original_base)
|
||
payload = self._complete_json(
|
||
"extract_params",
|
||
PARAM_PROMPT,
|
||
{
|
||
"user_text": text,
|
||
"base_params": safe_base,
|
||
"required_params": list(REQUIRED_PARAMS),
|
||
"default_params": DEFAULT_PARAMS,
|
||
},
|
||
)
|
||
|
||
extracted = _dict(payload.get("extracted_params"))
|
||
merged = original_base.copy()
|
||
for key, value in extracted.items():
|
||
if key in SENSITIVE_KEYS and value == "***":
|
||
continue
|
||
merged[key] = value
|
||
|
||
control = _dict(payload.get("extracted_control"))
|
||
missing = [key for key in REQUIRED_PARAMS if not merged.get(key)]
|
||
sensitive = [key for key in ("CLIENT_SECRET", "CLIENT_ID") if merged.get(key)]
|
||
llm_ambiguous = _string_list(payload.get("ambiguous_fields"))
|
||
return LlmParamResult(
|
||
extracted_params=merged,
|
||
extracted_control=control,
|
||
missing_required_params=missing,
|
||
ambiguous_fields=llm_ambiguous,
|
||
sensitive_fields_present=sensitive,
|
||
)
|
||
|
||
def generate_plan(
|
||
self,
|
||
*,
|
||
params: dict[str, Any],
|
||
intent: str,
|
||
strategy: ExecutionStrategy,
|
||
) -> LlmDeployPlan:
|
||
"""调用 LLM 生成部署计划。"""
|
||
payload = self._complete_json(
|
||
"generate_plan",
|
||
PLAN_PROMPT,
|
||
{
|
||
"params": _redact_sensitive(params),
|
||
"intent": intent,
|
||
"execution_strategy": strategy,
|
||
"allowed_actions": list(ALLOWED_ACTIONS),
|
||
"global_action_sequence": list(GLOBAL_ACTION_SEQUENCE),
|
||
"ip_action_sequence": list(IP_ACTION_SEQUENCE),
|
||
},
|
||
)
|
||
planned_actions = _string_list(payload.get("planned_actions")) or list(GLOBAL_ACTION_SEQUENCE)
|
||
return LlmDeployPlan(
|
||
summary=_string(payload, "summary", "PAM 部署计划"),
|
||
risk_notes=_string_list(payload.get("risk_notes")),
|
||
planned_actions=planned_actions,
|
||
requires_confirmation=bool(payload.get("requires_confirmation", True)),
|
||
execution_strategy=_string(payload, "execution_strategy", strategy), # type: ignore[arg-type]
|
||
)
|
||
|
||
def analyze_action_result(
|
||
self,
|
||
*,
|
||
action: str,
|
||
result: ActionResult,
|
||
) -> LlmActionAnalysis:
|
||
"""调用 LLM 分析 action 结果,返回结构化诊断建议。"""
|
||
payload = self._complete_json(
|
||
"analyze_action_result",
|
||
self.action_analysis_prompt,
|
||
{
|
||
"action": action,
|
||
"result": _action_review_result_payload(action, result),
|
||
},
|
||
)
|
||
return LlmActionAnalysis(
|
||
action=_string(payload, "action", action),
|
||
has_anomaly=bool(payload.get("has_anomaly", False)),
|
||
severity=_string(payload, "severity", "info"), # type: ignore[arg-type]
|
||
possible_reason=_string(payload, "possible_reason", ""),
|
||
suggested_action=_string(payload, "suggested_action", ""),
|
||
requires_confirmation=bool(payload.get("requires_confirmation", False)),
|
||
should_continue=bool(payload.get("should_continue", True)),
|
||
progress_complete=_optional_bool(payload.get("progress_complete")),
|
||
notes=_string_list(payload.get("notes")),
|
||
)
|
||
|
||
def chat(self, text: str, context: dict[str, Any] | None = None) -> str:
|
||
"""调用 LLM 做普通对话,不要求 JSON 响应。"""
|
||
return self._complete_text(
|
||
"chat",
|
||
CHAT_PROMPT,
|
||
{
|
||
"user_text": text,
|
||
"context": _redact_sensitive(context or {}),
|
||
},
|
||
)
|
||
|
||
def analyze_log(self, log_text: str, question: str | None = None, source_path: str = "") -> str:
|
||
"""调用 LLM 分析日志尾部摘要。"""
|
||
return self._complete_text(
|
||
"analyze_log",
|
||
LOG_ANALYSIS_PROMPT,
|
||
{
|
||
"source_path": source_path,
|
||
"question": question or "请分析日志中的异常、可能原因和下一步建议。",
|
||
"log_tail": redact_for_log(log_text, max_text_len=64000),
|
||
},
|
||
)
|
||
|
||
def propose_action(
|
||
self,
|
||
text: str,
|
||
allowed_actions: list[str],
|
||
params: dict[str, Any],
|
||
state_summary: dict[str, Any] | None = None,
|
||
) -> LlmSingleActionProposal:
|
||
"""调用 LLM 把自然语言解析为单 action 调用建议。"""
|
||
payload = self._complete_json(
|
||
"propose_action",
|
||
SINGLE_ACTION_PROMPT,
|
||
{
|
||
"user_text": text,
|
||
"allowed_actions": allowed_actions,
|
||
"params": _redact_sensitive(params),
|
||
"state_summary": _redact_sensitive(state_summary or {}),
|
||
},
|
||
)
|
||
action = _string(payload, "action", "")
|
||
if action not in allowed_actions:
|
||
action = ""
|
||
return LlmSingleActionProposal(
|
||
action=action,
|
||
ip=_string(payload, "ip", ""),
|
||
kwargs=_dict(payload.get("kwargs")),
|
||
reason=_string(payload, "reason", ""),
|
||
risk_level=_risk_level(payload.get("risk_level")),
|
||
requires_confirmation=True,
|
||
)
|
||
|
||
def _complete_json(self, operation: str, instruction: str, input_payload: dict[str, Any]) -> dict[str, Any]:
|
||
"""发送 chat/completions 请求,并解析 JSON 对象响应。"""
|
||
started_at = time.perf_counter()
|
||
endpoint = _chat_completions_url(self.base_url)
|
||
request_payload = {
|
||
"model": self.model,
|
||
"temperature": self.temperature,
|
||
"response_format": {"type": "json_object"},
|
||
"messages": [
|
||
{"role": "system", "content": SYSTEM_PROMPT},
|
||
{
|
||
"role": "user",
|
||
"content": instruction
|
||
+ "\n\n输入 JSON:\n"
|
||
+ json.dumps(input_payload, ensure_ascii=False, sort_keys=True),
|
||
},
|
||
],
|
||
}
|
||
headers = {"Content-Type": "application/json"}
|
||
if self.api_key:
|
||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||
logger.info(
|
||
"LLM 请求开始 operation=%s endpoint=%s model=%s timeout=%s has_api_key=%s input=%s",
|
||
operation,
|
||
endpoint,
|
||
self.model,
|
||
self.timeout_sec,
|
||
bool(self.api_key),
|
||
json_for_log(input_payload, max_text_len=1600),
|
||
)
|
||
try:
|
||
response = self.transport(
|
||
endpoint,
|
||
headers,
|
||
request_payload,
|
||
self.timeout_sec,
|
||
)
|
||
content = _message_content(response)
|
||
logger.info(
|
||
"LLM 原始响应 operation=%s duration_ms=%s content=%s",
|
||
operation,
|
||
int((time.perf_counter() - started_at) * 1000),
|
||
redact_for_log(content, max_text_len=1600),
|
||
)
|
||
parsed = _loads_json_object(content)
|
||
if not isinstance(parsed, dict):
|
||
raise ValueError("LLM 响应必须是 JSON object")
|
||
except Exception:
|
||
logger.exception(
|
||
"LLM 请求失败 operation=%s endpoint=%s duration_ms=%s input=%s",
|
||
operation,
|
||
endpoint,
|
||
int((time.perf_counter() - started_at) * 1000),
|
||
json_for_log(input_payload, max_text_len=1600),
|
||
)
|
||
raise
|
||
logger.info(
|
||
"LLM 请求完成 operation=%s duration_ms=%s response_keys=%s response=%s",
|
||
operation,
|
||
int((time.perf_counter() - started_at) * 1000),
|
||
sorted(parsed.keys()),
|
||
json_for_log(parsed, max_text_len=1600),
|
||
)
|
||
return parsed
|
||
|
||
def _complete_text(self, operation: str, instruction: str, input_payload: dict[str, Any]) -> str:
|
||
"""发送 chat/completions 请求,并返回普通文本响应。"""
|
||
started_at = time.perf_counter()
|
||
endpoint = _chat_completions_url(self.base_url)
|
||
request_payload = {
|
||
"model": self.model,
|
||
"temperature": self.temperature,
|
||
"messages": [
|
||
{"role": "system", "content": instruction},
|
||
{
|
||
"role": "user",
|
||
"content": "输入 JSON:\n" + json.dumps(input_payload, ensure_ascii=False, sort_keys=True),
|
||
},
|
||
],
|
||
}
|
||
headers = {"Content-Type": "application/json"}
|
||
if self.api_key:
|
||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||
logger.info(
|
||
"LLM 文本请求开始 operation=%s endpoint=%s model=%s timeout=%s has_api_key=%s input=%s",
|
||
operation,
|
||
endpoint,
|
||
self.model,
|
||
self.timeout_sec,
|
||
bool(self.api_key),
|
||
json_for_log(input_payload, max_text_len=1600),
|
||
)
|
||
try:
|
||
response = self.transport(endpoint, headers, request_payload, self.timeout_sec)
|
||
content = str(_message_content(response))
|
||
except Exception:
|
||
logger.exception(
|
||
"LLM 文本请求失败 operation=%s endpoint=%s duration_ms=%s input=%s",
|
||
operation,
|
||
endpoint,
|
||
int((time.perf_counter() - started_at) * 1000),
|
||
json_for_log(input_payload, max_text_len=1600),
|
||
)
|
||
raise
|
||
logger.info(
|
||
"LLM 文本请求完成 operation=%s duration_ms=%s content=%s",
|
||
operation,
|
||
int((time.perf_counter() - started_at) * 1000),
|
||
redact_for_log(content, max_text_len=1600),
|
||
)
|
||
return content.strip()
|
||
|
||
|
||
def _default_transport(
|
||
url: str,
|
||
headers: dict[str, str],
|
||
payload: dict[str, Any],
|
||
timeout_sec: float,
|
||
) -> dict[str, Any]:
|
||
"""使用标准库 urllib 发送 JSON POST 请求。"""
|
||
request = urllib.request.Request(
|
||
url,
|
||
data=json.dumps(payload).encode("utf-8"),
|
||
headers=headers,
|
||
method="POST",
|
||
)
|
||
with urllib.request.urlopen(request, timeout=timeout_sec) as response:
|
||
raw = response.read().decode("utf-8")
|
||
decoded = json.loads(raw)
|
||
if not isinstance(decoded, dict):
|
||
raise ValueError("LLM HTTP 响应必须是 JSON object")
|
||
return decoded
|
||
|
||
|
||
def load_prompt_text(path: str | None) -> str:
|
||
"""读取自定义提示词文件。"""
|
||
if not path:
|
||
return ACTION_ANALYSIS_PROMPT
|
||
prompt_path = Path(path)
|
||
return prompt_path.read_text(encoding="utf-8").strip() or ACTION_ANALYSIS_PROMPT
|
||
|
||
|
||
def _chat_completions_url(base_url: str) -> str:
|
||
"""把 base_url 规范化为 chat/completions endpoint。"""
|
||
clean = base_url.rstrip("/")
|
||
if clean.endswith("/chat/completions"):
|
||
return clean
|
||
return f"{clean}/chat/completions"
|
||
|
||
|
||
def _message_content(response: dict[str, Any]) -> Any:
|
||
"""从 OpenAI-compatible 响应中提取 message.content。"""
|
||
try:
|
||
content = response["choices"][0]["message"]["content"]
|
||
except (KeyError, IndexError, TypeError) as exc:
|
||
raise ValueError("LLM 响应缺少 choices[0].message.content") from exc
|
||
if isinstance(content, list):
|
||
parts: list[str] = []
|
||
for item in content:
|
||
if isinstance(item, dict) and item.get("type") == "text":
|
||
parts.append(str(item.get("text", "")))
|
||
elif isinstance(item, str):
|
||
parts.append(item)
|
||
return "".join(parts)
|
||
return content
|
||
|
||
|
||
def _loads_json_object(content: Any) -> Any:
|
||
"""把 message.content 解析为 JSON 对象。"""
|
||
if isinstance(content, dict):
|
||
return content
|
||
if not isinstance(content, str):
|
||
raise ValueError("LLM message content 必须是 JSON 文本")
|
||
return json.loads(content)
|
||
|
||
|
||
def _redact_sensitive(value: Any) -> Any:
|
||
"""递归脱敏 prompt 输入中的敏感字段。"""
|
||
if isinstance(value, dict):
|
||
redacted: dict[str, Any] = {}
|
||
for key, item in value.items():
|
||
if str(key) in SENSITIVE_KEYS:
|
||
redacted[str(key)] = "***"
|
||
else:
|
||
redacted[str(key)] = _redact_sensitive(item)
|
||
return redacted
|
||
if isinstance(value, list):
|
||
return [_redact_sensitive(item) for item in value]
|
||
return value
|
||
|
||
|
||
def _action_review_result_payload(action: str, result: ActionResult) -> dict[str, Any]:
|
||
"""构造 action 审核输入,避免把正常脚本日志当作错误喂给 LLM。"""
|
||
payload: dict[str, Any] = {
|
||
"backend": result.backend,
|
||
"ok": result.ok,
|
||
"exit_code": result.exit_code,
|
||
"tool_name": result.tool_name,
|
||
"values": _redact_sensitive(result.values),
|
||
"error_summary": result.error_summary,
|
||
}
|
||
if _needs_diagnostic_log(action, result):
|
||
diagnostic = _diagnostic_log_text(result)
|
||
if diagnostic:
|
||
payload["diagnostic_log"] = diagnostic
|
||
return payload
|
||
|
||
|
||
def _needs_diagnostic_log(action: str, result: ActionResult) -> bool:
|
||
"""仅在失败或业务异常时把少量诊断日志交给 LLM。"""
|
||
if not result.ok or result.error_summary or result.values.get("PENDING_AGENT_CONFIRMATION"):
|
||
return True
|
||
if action == "verify-ip":
|
||
success = result.values.get("SUCCESS")
|
||
if success is not None and str(success).lower() not in ("true", "1", "yes"):
|
||
return True
|
||
return False
|
||
|
||
|
||
def _diagnostic_log_text(result: ActionResult) -> str:
|
||
"""优先使用错误摘要;必要时取 stderr/stdout/raw_output 的尾部作为诊断上下文。"""
|
||
if result.error_summary:
|
||
return _truncate_text(result.error_summary)
|
||
for text in (result.stderr, result.stdout, result.raw_output):
|
||
stripped = text.strip()
|
||
if stripped:
|
||
return _tail_text(stripped)
|
||
return ""
|
||
|
||
|
||
def _truncate_text(value: str, limit: int = 1000) -> str:
|
||
"""截断发送给 LLM 的长文本,避免传入完整日志。"""
|
||
if len(value) <= limit:
|
||
return value
|
||
return value[:limit] + "...[已截断]"
|
||
|
||
|
||
def _tail_text(value: str, limit: int = 1000) -> str:
|
||
"""保留长诊断日志尾部,通常错误原因更靠近末尾。"""
|
||
if len(value) <= limit:
|
||
return value
|
||
marker = "[已截断]..."
|
||
return marker + value[-(limit - len(marker)) :]
|
||
|
||
|
||
def _string(payload: dict[str, Any], key: str, default: str) -> str:
|
||
"""安全读取字符串字段。"""
|
||
value = payload.get(key, default)
|
||
return str(value) if value is not None else default
|
||
|
||
|
||
def _float(payload: dict[str, Any], key: str, default: float) -> float:
|
||
"""安全读取浮点数字段。"""
|
||
try:
|
||
return float(payload.get(key, default))
|
||
except (TypeError, ValueError):
|
||
return default
|
||
|
||
|
||
def _optional_bool(value: Any) -> bool | None:
|
||
"""解析可选布尔值,字段缺失时保留 None。"""
|
||
if value is None:
|
||
return None
|
||
if isinstance(value, bool):
|
||
return value
|
||
if isinstance(value, str):
|
||
lowered = value.strip().lower()
|
||
if lowered in ("", "null", "none"):
|
||
return None
|
||
if lowered in ("true", "1", "yes", "y"):
|
||
return True
|
||
if lowered in ("false", "0", "no", "n"):
|
||
return False
|
||
return bool(value)
|
||
|
||
|
||
def _risk_level(value: Any) -> str:
|
||
"""解析单 action 风险等级,非法值降级为 medium。"""
|
||
text = str(value or "").strip().lower()
|
||
if text in ("low", "medium", "high"):
|
||
return text
|
||
return "medium"
|
||
|
||
|
||
def _dict(value: Any) -> dict[str, Any]:
|
||
"""确保返回 dict,非法值降级为空 dict。"""
|
||
return value if isinstance(value, dict) else {}
|
||
|
||
|
||
def _string_list(value: Any) -> list[str]:
|
||
"""确保返回字符串列表。"""
|
||
if isinstance(value, list):
|
||
return [str(item) for item in value]
|
||
if value in (None, ""):
|
||
return []
|
||
return [str(value)]
|