"""OpenAI-compatible HTTP LLM client. The client targets providers exposing a `/chat/completions` endpoint with OpenAI-style request and response shapes. It intentionally uses only the Python standard library so the runtime can stay dependency-light. """ from __future__ import annotations import json import urllib.request from collections.abc import Callable from typing import Any from pam_deploy_graph.constants import ( ALLOWED_ACTIONS, DEFAULT_PARAMS, GLOBAL_ACTION_SEQUENCE, IP_ACTION_SEQUENCE, REQUIRED_PARAMS, SENSITIVE_KEYS, ) from pam_deploy_graph.models import ExecutionStrategy, LlmDeployPlan, LlmIntentResult, LlmParamResult from .prompts import INTENT_PROMPT, PARAM_PROMPT, PLAN_PROMPT, SYSTEM_PROMPT JsonTransport = Callable[[str, dict[str, str], dict[str, Any], float], dict[str, Any]] class OpenAICompatibleLlmClient: def __init__( self, *, base_url: str, api_key: str, model: str, timeout_sec: float = 30, temperature: float = 0, transport: JsonTransport | None = None, ) -> None: if not base_url: raise ValueError("LLM base_url is required") if not api_key: raise ValueError("LLM api_key is required") if not model: raise ValueError("LLM model is required") self.base_url = base_url.rstrip("/") self.api_key = api_key self.model = model self.timeout_sec = timeout_sec self.temperature = temperature self.transport = transport or _default_transport def understand_request(self, text: str) -> LlmIntentResult: payload = self._complete_json(INTENT_PROMPT, {"user_text": text}) return LlmIntentResult( intent=_string(payload, "intent", "deploy"), # type: ignore[arg-type] mode_preference=_string(payload, "mode_preference", "未指定"), # type: ignore[arg-type] strategy_preference=_string(payload, "strategy_preference", "未指定"), # type: ignore[arg-type] confidence=_float(payload, "confidence", 0.0), reasons=_string_list(payload.get("reasons")), needs_clarification=bool(payload.get("needs_clarification", False)), clarification_questions=_string_list(payload.get("clarification_questions")), ) def extract_params(self, text: str, base_params: dict[str, Any] | None = None) -> LlmParamResult: original_base = dict(base_params or {}) safe_base = _redact_sensitive(original_base) payload = self._complete_json( PARAM_PROMPT, { "user_text": text, "base_params": safe_base, "required_params": list(REQUIRED_PARAMS), "default_params": DEFAULT_PARAMS, }, ) extracted = _dict(payload.get("extracted_params")) merged = original_base.copy() for key, value in extracted.items(): if key in SENSITIVE_KEYS and value == "***": continue merged[key] = value control = _dict(payload.get("extracted_control")) missing = [key for key in REQUIRED_PARAMS if not merged.get(key)] sensitive = [key for key in ("CLIENT_SECRET", "CLIENT_ID") if merged.get(key)] llm_ambiguous = _string_list(payload.get("ambiguous_fields")) return LlmParamResult( extracted_params=merged, extracted_control=control, missing_required_params=missing, ambiguous_fields=llm_ambiguous, sensitive_fields_present=sensitive, ) def generate_plan( self, *, params: dict[str, Any], intent: str, strategy: ExecutionStrategy, ) -> LlmDeployPlan: payload = self._complete_json( PLAN_PROMPT, { "params": _redact_sensitive(params), "intent": intent, "execution_strategy": strategy, "allowed_actions": list(ALLOWED_ACTIONS), "global_action_sequence": list(GLOBAL_ACTION_SEQUENCE), "ip_action_sequence": list(IP_ACTION_SEQUENCE), }, ) planned_actions = _string_list(payload.get("planned_actions")) or list(GLOBAL_ACTION_SEQUENCE) return LlmDeployPlan( summary=_string(payload, "summary", "PAM deployment plan"), risk_notes=_string_list(payload.get("risk_notes")), planned_actions=planned_actions, requires_confirmation=bool(payload.get("requires_confirmation", True)), execution_strategy=_string(payload, "execution_strategy", strategy), # type: ignore[arg-type] ) def _complete_json(self, instruction: str, input_payload: dict[str, Any]) -> dict[str, Any]: request_payload = { "model": self.model, "temperature": self.temperature, "response_format": {"type": "json_object"}, "messages": [ {"role": "system", "content": SYSTEM_PROMPT}, { "role": "user", "content": instruction + "\n\n输入 JSON:\n" + json.dumps(input_payload, ensure_ascii=False, sort_keys=True), }, ], } response = self.transport( _chat_completions_url(self.base_url), { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", }, request_payload, self.timeout_sec, ) content = _message_content(response) parsed = _loads_json_object(content) if not isinstance(parsed, dict): raise ValueError("LLM response must be a JSON object") return parsed def _default_transport( url: str, headers: dict[str, str], payload: dict[str, Any], timeout_sec: float, ) -> dict[str, Any]: request = urllib.request.Request( url, data=json.dumps(payload).encode("utf-8"), headers=headers, method="POST", ) with urllib.request.urlopen(request, timeout=timeout_sec) as response: raw = response.read().decode("utf-8") decoded = json.loads(raw) if not isinstance(decoded, dict): raise ValueError("LLM HTTP response must be a JSON object") return decoded def _chat_completions_url(base_url: str) -> str: clean = base_url.rstrip("/") if clean.endswith("/chat/completions"): return clean return f"{clean}/chat/completions" def _message_content(response: dict[str, Any]) -> Any: try: content = response["choices"][0]["message"]["content"] except (KeyError, IndexError, TypeError) as exc: raise ValueError("LLM response does not contain choices[0].message.content") from exc if isinstance(content, list): parts: list[str] = [] for item in content: if isinstance(item, dict) and item.get("type") == "text": parts.append(str(item.get("text", ""))) elif isinstance(item, str): parts.append(item) return "".join(parts) return content def _loads_json_object(content: Any) -> Any: if isinstance(content, dict): return content if not isinstance(content, str): raise ValueError("LLM message content must be JSON text") return json.loads(content) def _redact_sensitive(value: Any) -> Any: if isinstance(value, dict): redacted: dict[str, Any] = {} for key, item in value.items(): if str(key) in SENSITIVE_KEYS: redacted[str(key)] = "***" else: redacted[str(key)] = _redact_sensitive(item) return redacted if isinstance(value, list): return [_redact_sensitive(item) for item in value] return value def _string(payload: dict[str, Any], key: str, default: str) -> str: value = payload.get(key, default) return str(value) if value is not None else default def _float(payload: dict[str, Any], key: str, default: float) -> float: try: return float(payload.get(key, default)) except (TypeError, ValueError): return default def _dict(value: Any) -> dict[str, Any]: return value if isinstance(value, dict) else {} def _string_list(value: Any) -> list[str]: if isinstance(value, list): return [str(item) for item in value] if value in (None, ""): return [] return [str(value)]