"""OpenAI-compatible HTTP LLM client。 该 client 面向暴露 `/chat/completions` 的模型服务,并使用 OpenAI 风格的 请求/响应结构。实现只依赖 Python 标准库,便于控制运行时依赖体积。 """ from __future__ import annotations import json import logging import time from pathlib import Path import urllib.request from collections.abc import Callable, Iterable, Iterator from typing import Any from pam_deploy_graph.constants import ( ALLOWED_ACTIONS, DEFAULT_PARAMS, GLOBAL_ACTION_SEQUENCE, IP_ACTION_SEQUENCE, REQUIRED_PARAMS, SENSITIVE_KEYS, ) from pam_deploy_graph.logging_utils import json_for_log, redact_for_log from pam_deploy_graph.models import ExecutionStrategy, LlmDeployPlan, LlmIntentResult, LlmParamResult, LlmSingleActionProposal from pam_deploy_graph.models import ActionResult, LlmActionAnalysis from .prompts import ( ACTION_ANALYSIS_PROMPT, CHAT_PROMPT, INTENT_PROMPT, LOG_ANALYSIS_PROMPT, PARAM_PROMPT, PLAN_PROMPT, SINGLE_ACTION_PROMPT, SYSTEM_PROMPT, ) from .text_filter import filter_thinking_chunks, strip_thinking_text JsonTransport = Callable[[str, dict[str, str], dict[str, Any], float], dict[str, Any]] StreamTransport = Callable[[str, dict[str, str], dict[str, Any], float], Iterable[str]] logger = logging.getLogger(__name__) class OpenAICompatibleLlmClient: """通过 OpenAI-compatible HTTP 接口获取结构化 LLM 输出。""" def __init__( self, *, base_url: str, api_key: str, model: str, action_analysis_prompt: str | None = None, timeout_sec: float = 30, temperature: float = 0, transport: JsonTransport | None = None, stream_transport: StreamTransport | None = None, ) -> None: """保存连接参数、模型参数和可替换的 HTTP transport。""" if not base_url: raise ValueError("必须配置 LLM base_url") if not model: raise ValueError("必须配置 LLM model") self.base_url = base_url.rstrip("/") self.api_key = api_key self.model = model self.action_analysis_prompt = action_analysis_prompt or ACTION_ANALYSIS_PROMPT self.timeout_sec = timeout_sec self.temperature = temperature self.transport = transport or _default_transport self.stream_transport = stream_transport or _default_stream_transport logger.info( "OpenAI-compatible LLM client 初始化 base_url=%s endpoint=%s model=%s has_api_key=%s timeout=%s temperature=%s custom_transport=%s custom_stream_transport=%s", self.base_url, _chat_completions_url(self.base_url), self.model, bool(self.api_key), self.timeout_sec, self.temperature, transport is not None, stream_transport is not None, ) def understand_request(self, text: str) -> LlmIntentResult: """调用 LLM 识别用户意图。""" payload = self._complete_json("understand_request", INTENT_PROMPT, {"user_text": text}) return LlmIntentResult( intent=_string(payload, "intent", "deploy"), # type: ignore[arg-type] mode_preference=_string(payload, "mode_preference", "未指定"), # type: ignore[arg-type] strategy_preference=_string(payload, "strategy_preference", "未指定"), # type: ignore[arg-type] confidence=_float(payload, "confidence", 0.0), reasons=_string_list(payload.get("reasons")), needs_clarification=bool(payload.get("needs_clarification", False)), clarification_questions=_string_list(payload.get("clarification_questions")), ) def extract_params(self, text: str, base_params: dict[str, Any] | None = None) -> LlmParamResult: """调用 LLM 抽取参数,并避免把敏感值发送进 prompt。""" original_base = dict(base_params or {}) safe_base = _redact_sensitive(original_base) payload = self._complete_json( "extract_params", PARAM_PROMPT, { "user_text": text, "base_params": safe_base, "required_params": list(REQUIRED_PARAMS), "default_params": DEFAULT_PARAMS, }, ) extracted = _dict(payload.get("extracted_params")) merged = original_base.copy() for key, value in extracted.items(): if key in SENSITIVE_KEYS and value == "***": continue merged[key] = value control = _dict(payload.get("extracted_control")) missing = [key for key in REQUIRED_PARAMS if not merged.get(key)] sensitive = [key for key in ("CLIENT_SECRET", "CLIENT_ID") if merged.get(key)] llm_ambiguous = _string_list(payload.get("ambiguous_fields")) return LlmParamResult( extracted_params=merged, extracted_control=control, missing_required_params=missing, ambiguous_fields=llm_ambiguous, sensitive_fields_present=sensitive, ) def generate_plan( self, *, params: dict[str, Any], intent: str, strategy: ExecutionStrategy, ) -> LlmDeployPlan: """调用 LLM 生成部署计划。""" payload = self._complete_json( "generate_plan", PLAN_PROMPT, { "params": _redact_sensitive(params), "intent": intent, "execution_strategy": strategy, "allowed_actions": list(ALLOWED_ACTIONS), "global_action_sequence": list(GLOBAL_ACTION_SEQUENCE), "ip_action_sequence": list(IP_ACTION_SEQUENCE), }, ) planned_actions = _string_list(payload.get("planned_actions")) or list(GLOBAL_ACTION_SEQUENCE) return LlmDeployPlan( summary=_string(payload, "summary", "PAM 部署计划"), risk_notes=_string_list(payload.get("risk_notes")), planned_actions=planned_actions, requires_confirmation=bool(payload.get("requires_confirmation", True)), execution_strategy=_string(payload, "execution_strategy", strategy), # type: ignore[arg-type] ) def analyze_action_result( self, *, action: str, result: ActionResult, ) -> LlmActionAnalysis: """调用 LLM 分析 action 结果,返回结构化诊断建议。""" payload = self._complete_json( "analyze_action_result", self.action_analysis_prompt, { "action": action, "result": _action_review_result_payload(action, result), }, ) return LlmActionAnalysis( action=_string(payload, "action", action), has_anomaly=bool(payload.get("has_anomaly", False)), severity=_string(payload, "severity", "info"), # type: ignore[arg-type] possible_reason=_string(payload, "possible_reason", ""), suggested_action=_string(payload, "suggested_action", ""), requires_confirmation=bool(payload.get("requires_confirmation", False)), should_continue=bool(payload.get("should_continue", True)), progress_complete=_optional_bool(payload.get("progress_complete")), notes=_string_list(payload.get("notes")), ) def chat(self, text: str, context: dict[str, Any] | None = None) -> str: """调用 LLM 做普通对话,不要求 JSON 响应。""" return self._complete_text( "chat", CHAT_PROMPT, { "user_text": text, "context": _redact_sensitive(context or {}), }, ) def chat_stream(self, text: str, context: dict[str, Any] | None = None) -> Iterable[str]: """调用 LLM 做普通流式对话,不要求 JSON 响应。""" return self._complete_text_stream( "chat", CHAT_PROMPT, { "user_text": text, "context": _redact_sensitive(context or {}), }, ) def analyze_log(self, log_text: str, question: str | None = None, source_path: str = "") -> str: """调用 LLM 分析日志尾部摘要。""" return self._complete_text( "analyze_log", LOG_ANALYSIS_PROMPT, { "source_path": source_path, "question": question or "请分析日志中的异常、可能原因和下一步建议。", "log_tail": redact_for_log(log_text, max_text_len=64000), }, ) def propose_action( self, text: str, allowed_actions: list[str], params: dict[str, Any], state_summary: dict[str, Any] | None = None, ) -> LlmSingleActionProposal: """调用 LLM 把自然语言解析为单 action 调用建议。""" payload = self._complete_json( "propose_action", SINGLE_ACTION_PROMPT, { "user_text": text, "allowed_actions": allowed_actions, "params": _redact_sensitive(params), "state_summary": _redact_sensitive(state_summary or {}), }, ) action = _string(payload, "action", "") if action not in allowed_actions: action = "" return LlmSingleActionProposal( action=action, ip=_string(payload, "ip", ""), kwargs=_dict(payload.get("kwargs")), reason=_string(payload, "reason", ""), risk_level=_risk_level(payload.get("risk_level")), requires_confirmation=True, ) def _complete_json(self, operation: str, instruction: str, input_payload: dict[str, Any]) -> dict[str, Any]: """发送 chat/completions 请求,并解析 JSON 对象响应。""" started_at = time.perf_counter() endpoint = _chat_completions_url(self.base_url) request_payload = { "model": self.model, "temperature": self.temperature, "response_format": {"type": "json_object"}, "messages": [ {"role": "system", "content": SYSTEM_PROMPT}, { "role": "user", "content": instruction + "\n\n输入 JSON:\n" + json.dumps(input_payload, ensure_ascii=False, sort_keys=True), }, ], } headers = {"Content-Type": "application/json"} if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" logger.info( "LLM 请求开始 operation=%s endpoint=%s model=%s timeout=%s has_api_key=%s input=%s", operation, endpoint, self.model, self.timeout_sec, bool(self.api_key), json_for_log(input_payload, max_text_len=1600), ) try: response = self.transport( endpoint, headers, request_payload, self.timeout_sec, ) content = _message_content(response) logger.info( "LLM 原始响应 operation=%s duration_ms=%s content=%s", operation, int((time.perf_counter() - started_at) * 1000), redact_for_log(content, max_text_len=1600), ) parsed = _loads_json_object(content) if not isinstance(parsed, dict): raise ValueError("LLM 响应必须是 JSON object") except Exception: logger.exception( "LLM 请求失败 operation=%s endpoint=%s duration_ms=%s input=%s", operation, endpoint, int((time.perf_counter() - started_at) * 1000), json_for_log(input_payload, max_text_len=1600), ) raise logger.info( "LLM 请求完成 operation=%s duration_ms=%s response_keys=%s response=%s", operation, int((time.perf_counter() - started_at) * 1000), sorted(parsed.keys()), json_for_log(parsed, max_text_len=1600), ) return parsed def _complete_text(self, operation: str, instruction: str, input_payload: dict[str, Any]) -> str: """发送 chat/completions 请求,并返回普通文本响应。""" started_at = time.perf_counter() endpoint = _chat_completions_url(self.base_url) request_payload = { "model": self.model, "temperature": self.temperature, "messages": [ {"role": "system", "content": instruction}, { "role": "user", "content": "输入 JSON:\n" + json.dumps(input_payload, ensure_ascii=False, sort_keys=True), }, ], } headers = {"Content-Type": "application/json"} if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" logger.info( "LLM 文本请求开始 operation=%s endpoint=%s model=%s timeout=%s has_api_key=%s input=%s", operation, endpoint, self.model, self.timeout_sec, bool(self.api_key), json_for_log(input_payload, max_text_len=1600), ) try: response = self.transport(endpoint, headers, request_payload, self.timeout_sec) content = strip_thinking_text(str(_message_content(response))) except Exception: logger.exception( "LLM 文本请求失败 operation=%s endpoint=%s duration_ms=%s input=%s", operation, endpoint, int((time.perf_counter() - started_at) * 1000), json_for_log(input_payload, max_text_len=1600), ) raise logger.info( "LLM 文本请求完成 operation=%s duration_ms=%s content=%s", operation, int((time.perf_counter() - started_at) * 1000), redact_for_log(content, max_text_len=1600), ) return content def _complete_text_stream(self, operation: str, instruction: str, input_payload: dict[str, Any]) -> Iterable[str]: """发送 stream chat/completions 请求,并返回过滤后的普通文本分片。""" started_at = time.perf_counter() endpoint = _chat_completions_url(self.base_url) request_payload = { "model": self.model, "temperature": self.temperature, "stream": True, "messages": [ {"role": "system", "content": instruction}, { "role": "user", "content": "输入 JSON:\n" + json.dumps(input_payload, ensure_ascii=False, sort_keys=True), }, ], } headers = {"Content-Type": "application/json"} if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" logger.info( "LLM 流式文本请求开始 operation=%s endpoint=%s model=%s timeout=%s has_api_key=%s input=%s", operation, endpoint, self.model, self.timeout_sec, bool(self.api_key), json_for_log(input_payload, max_text_len=1600), ) try: raw_chunks = self.stream_transport(endpoint, headers, request_payload, self.timeout_sec) for chunk in filter_thinking_chunks(raw_chunks): if chunk: yield chunk except Exception: logger.exception( "LLM 流式文本请求失败 operation=%s endpoint=%s duration_ms=%s input=%s", operation, endpoint, int((time.perf_counter() - started_at) * 1000), json_for_log(input_payload, max_text_len=1600), ) raise logger.info( "LLM 流式文本请求完成 operation=%s duration_ms=%s", operation, int((time.perf_counter() - started_at) * 1000), ) def _default_transport( url: str, headers: dict[str, str], payload: dict[str, Any], timeout_sec: float, ) -> dict[str, Any]: """使用标准库 urllib 发送 JSON POST 请求。""" request = urllib.request.Request( url, data=json.dumps(payload).encode("utf-8"), headers=headers, method="POST", ) with urllib.request.urlopen(request, timeout=timeout_sec) as response: raw = response.read().decode("utf-8") decoded = json.loads(raw) if not isinstance(decoded, dict): raise ValueError("LLM HTTP 响应必须是 JSON object") return decoded def _default_stream_transport( url: str, headers: dict[str, str], payload: dict[str, Any], timeout_sec: float, ) -> Iterator[str]: """使用标准库 urllib 发送 OpenAI-compatible SSE 流式请求。""" request = urllib.request.Request( url, data=json.dumps(payload).encode("utf-8"), headers=headers, method="POST", ) with urllib.request.urlopen(request, timeout=timeout_sec) as response: for raw_line in response: line = raw_line.decode("utf-8", errors="replace").strip() if not line or line.startswith(":"): continue if line.startswith("event:") or line.startswith("id:"): continue if not line.startswith("data:"): raise ValueError("LLM 流式响应不是 SSE data 格式") data = line[len("data:") :].strip() if data == "[DONE]": break try: decoded = json.loads(data) except json.JSONDecodeError: logger.debug("忽略无法解析的 LLM stream data: %s", redact_for_log(data, max_text_len=300)) continue chunk = _stream_delta_content(decoded) if chunk: yield chunk def load_prompt_text(path: str | None) -> str: """读取自定义提示词文件。""" if not path: return ACTION_ANALYSIS_PROMPT prompt_path = Path(path) return prompt_path.read_text(encoding="utf-8").strip() or ACTION_ANALYSIS_PROMPT def _chat_completions_url(base_url: str) -> str: """把 base_url 规范化为 chat/completions endpoint。""" clean = base_url.rstrip("/") if clean.endswith("/chat/completions"): return clean return f"{clean}/chat/completions" def _message_content(response: dict[str, Any]) -> Any: """从 OpenAI-compatible 响应中提取 message.content。""" try: content = response["choices"][0]["message"]["content"] except (KeyError, IndexError, TypeError) as exc: raise ValueError("LLM 响应缺少 choices[0].message.content") from exc if isinstance(content, list): parts: list[str] = [] for item in content: if isinstance(item, dict) and item.get("type") == "text": parts.append(str(item.get("text", ""))) elif isinstance(item, str): parts.append(item) return "".join(parts) return content def _stream_delta_content(response: dict[str, Any]) -> str: """从 OpenAI-compatible stream chunk 中提取 delta.content。""" try: choice = response["choices"][0] except (KeyError, IndexError, TypeError): return "" delta = choice.get("delta") if isinstance(choice, dict) else None if isinstance(delta, dict) and "content" in delta: return str(_content_parts_to_text(delta.get("content"))) message = choice.get("message") if isinstance(choice, dict) else None if isinstance(message, dict) and "content" in message: return str(_content_parts_to_text(message.get("content"))) text = choice.get("text") if isinstance(choice, dict) else None return str(text) if text is not None else "" def _content_parts_to_text(content: Any) -> str: """把 OpenAI content parts 或字符串转换为纯文本。""" if isinstance(content, list): parts: list[str] = [] for item in content: if isinstance(item, dict) and item.get("type") == "text": parts.append(str(item.get("text", ""))) elif isinstance(item, str): parts.append(item) return "".join(parts) return "" if content is None else str(content) def _loads_json_object(content: Any) -> Any: """把 message.content 解析为 JSON 对象。""" if isinstance(content, dict): return content if not isinstance(content, str): raise ValueError("LLM message content 必须是 JSON 文本") return json.loads(content) def _redact_sensitive(value: Any) -> Any: """递归脱敏 prompt 输入中的敏感字段。""" if isinstance(value, dict): redacted: dict[str, Any] = {} for key, item in value.items(): if str(key) in SENSITIVE_KEYS: redacted[str(key)] = "***" else: redacted[str(key)] = _redact_sensitive(item) return redacted if isinstance(value, list): return [_redact_sensitive(item) for item in value] return value def _action_review_result_payload(action: str, result: ActionResult) -> dict[str, Any]: """构造 action 审核输入,避免把正常脚本日志当作错误喂给 LLM。""" payload: dict[str, Any] = { "backend": result.backend, "ok": result.ok, "exit_code": result.exit_code, "tool_name": result.tool_name, "values": _redact_sensitive(result.values), "error_summary": result.error_summary, } if _needs_diagnostic_log(action, result): diagnostic = _diagnostic_log_text(result) if diagnostic: payload["diagnostic_log"] = diagnostic return payload def _needs_diagnostic_log(action: str, result: ActionResult) -> bool: """仅在失败或业务异常时把少量诊断日志交给 LLM。""" if not result.ok or result.error_summary or result.values.get("PENDING_AGENT_CONFIRMATION"): return True if action == "verify-ip": success = result.values.get("SUCCESS") if success is not None and str(success).lower() not in ("true", "1", "yes"): return True return False def _diagnostic_log_text(result: ActionResult) -> str: """优先使用错误摘要;必要时取 stderr/stdout/raw_output 的尾部作为诊断上下文。""" if result.error_summary: return _truncate_text(result.error_summary) for text in (result.stderr, result.stdout, result.raw_output): stripped = text.strip() if stripped: return _tail_text(stripped) return "" def _truncate_text(value: str, limit: int = 1000) -> str: """截断发送给 LLM 的长文本,避免传入完整日志。""" if len(value) <= limit: return value return value[:limit] + "...[已截断]" def _tail_text(value: str, limit: int = 1000) -> str: """保留长诊断日志尾部,通常错误原因更靠近末尾。""" if len(value) <= limit: return value marker = "[已截断]..." return marker + value[-(limit - len(marker)) :] def _string(payload: dict[str, Any], key: str, default: str) -> str: """安全读取字符串字段。""" value = payload.get(key, default) return str(value) if value is not None else default def _float(payload: dict[str, Any], key: str, default: float) -> float: """安全读取浮点数字段。""" try: return float(payload.get(key, default)) except (TypeError, ValueError): return default def _optional_bool(value: Any) -> bool | None: """解析可选布尔值,字段缺失时保留 None。""" if value is None: return None if isinstance(value, bool): return value if isinstance(value, str): lowered = value.strip().lower() if lowered in ("", "null", "none"): return None if lowered in ("true", "1", "yes", "y"): return True if lowered in ("false", "0", "no", "n"): return False return bool(value) def _risk_level(value: Any) -> str: """解析单 action 风险等级,非法值降级为 medium。""" text = str(value or "").strip().lower() if text in ("low", "medium", "high"): return text return "medium" def _dict(value: Any) -> dict[str, Any]: """确保返回 dict,非法值降级为空 dict。""" return value if isinstance(value, dict) else {} def _string_list(value: Any) -> list[str]: """确保返回字符串列表。""" if isinstance(value, list): return [str(item) for item in value] if value in (None, ""): return [] return [str(value)]