agent_deply/pam_deploy_graph/llm/openai_compatible.py
dark 87c48a74a5 新增 <think>...</think> 过滤器,支持完整标签、跨流式 chunk 标签、未闭合 <think>。
OpenAICompatibleLlmClient 新增 chat_stream(),使用 OpenAI-compatible /chat/completions 的 stream: true。
chat 普通对话现在优先流式分段输出;流式不可用或服务端不返回 SSE 时,会提示并自动 fallback 到非流式 chat()。
普通 chat 和 log analyze 都会过滤 think 内容,并且日志只记录过滤后的摘要。
更新了 chat/log 分析提示词,明确禁止输出 think/内部思考。
同步 README、打包 README、run.sh --help。
补充了过滤器、OpenAI 流式、CLI fallback、日志分析过滤等测试。
2026-06-05 12:32:58 +08:00

661 lines
25 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""OpenAI-compatible HTTP LLM client。
该 client 面向暴露 `/chat/completions` 的模型服务,并使用 OpenAI 风格的
请求/响应结构。实现只依赖 Python 标准库,便于控制运行时依赖体积。
"""
from __future__ import annotations
import json
import logging
import time
from pathlib import Path
import urllib.request
from collections.abc import Callable, Iterable, Iterator
from typing import Any
from pam_deploy_graph.constants import (
ALLOWED_ACTIONS,
DEFAULT_PARAMS,
GLOBAL_ACTION_SEQUENCE,
IP_ACTION_SEQUENCE,
REQUIRED_PARAMS,
SENSITIVE_KEYS,
)
from pam_deploy_graph.logging_utils import json_for_log, redact_for_log
from pam_deploy_graph.models import ExecutionStrategy, LlmDeployPlan, LlmIntentResult, LlmParamResult, LlmSingleActionProposal
from pam_deploy_graph.models import ActionResult, LlmActionAnalysis
from .prompts import (
ACTION_ANALYSIS_PROMPT,
CHAT_PROMPT,
INTENT_PROMPT,
LOG_ANALYSIS_PROMPT,
PARAM_PROMPT,
PLAN_PROMPT,
SINGLE_ACTION_PROMPT,
SYSTEM_PROMPT,
)
from .text_filter import filter_thinking_chunks, strip_thinking_text
JsonTransport = Callable[[str, dict[str, str], dict[str, Any], float], dict[str, Any]]
StreamTransport = Callable[[str, dict[str, str], dict[str, Any], float], Iterable[str]]
logger = logging.getLogger(__name__)
class OpenAICompatibleLlmClient:
"""通过 OpenAI-compatible HTTP 接口获取结构化 LLM 输出。"""
def __init__(
self,
*,
base_url: str,
api_key: str,
model: str,
action_analysis_prompt: str | None = None,
timeout_sec: float = 30,
temperature: float = 0,
transport: JsonTransport | None = None,
stream_transport: StreamTransport | None = None,
) -> None:
"""保存连接参数、模型参数和可替换的 HTTP transport。"""
if not base_url:
raise ValueError("必须配置 LLM base_url")
if not model:
raise ValueError("必须配置 LLM model")
self.base_url = base_url.rstrip("/")
self.api_key = api_key
self.model = model
self.action_analysis_prompt = action_analysis_prompt or ACTION_ANALYSIS_PROMPT
self.timeout_sec = timeout_sec
self.temperature = temperature
self.transport = transport or _default_transport
self.stream_transport = stream_transport or _default_stream_transport
logger.info(
"OpenAI-compatible LLM client 初始化 base_url=%s endpoint=%s model=%s has_api_key=%s timeout=%s temperature=%s custom_transport=%s custom_stream_transport=%s",
self.base_url,
_chat_completions_url(self.base_url),
self.model,
bool(self.api_key),
self.timeout_sec,
self.temperature,
transport is not None,
stream_transport is not None,
)
def understand_request(self, text: str) -> LlmIntentResult:
"""调用 LLM 识别用户意图。"""
payload = self._complete_json("understand_request", INTENT_PROMPT, {"user_text": text})
return LlmIntentResult(
intent=_string(payload, "intent", "deploy"), # type: ignore[arg-type]
mode_preference=_string(payload, "mode_preference", "未指定"), # type: ignore[arg-type]
strategy_preference=_string(payload, "strategy_preference", "未指定"), # type: ignore[arg-type]
confidence=_float(payload, "confidence", 0.0),
reasons=_string_list(payload.get("reasons")),
needs_clarification=bool(payload.get("needs_clarification", False)),
clarification_questions=_string_list(payload.get("clarification_questions")),
)
def extract_params(self, text: str, base_params: dict[str, Any] | None = None) -> LlmParamResult:
"""调用 LLM 抽取参数,并避免把敏感值发送进 prompt。"""
original_base = dict(base_params or {})
safe_base = _redact_sensitive(original_base)
payload = self._complete_json(
"extract_params",
PARAM_PROMPT,
{
"user_text": text,
"base_params": safe_base,
"required_params": list(REQUIRED_PARAMS),
"default_params": DEFAULT_PARAMS,
},
)
extracted = _dict(payload.get("extracted_params"))
merged = original_base.copy()
for key, value in extracted.items():
if key in SENSITIVE_KEYS and value == "***":
continue
merged[key] = value
control = _dict(payload.get("extracted_control"))
missing = [key for key in REQUIRED_PARAMS if not merged.get(key)]
sensitive = [key for key in ("CLIENT_SECRET", "CLIENT_ID") if merged.get(key)]
llm_ambiguous = _string_list(payload.get("ambiguous_fields"))
return LlmParamResult(
extracted_params=merged,
extracted_control=control,
missing_required_params=missing,
ambiguous_fields=llm_ambiguous,
sensitive_fields_present=sensitive,
)
def generate_plan(
self,
*,
params: dict[str, Any],
intent: str,
strategy: ExecutionStrategy,
) -> LlmDeployPlan:
"""调用 LLM 生成部署计划。"""
payload = self._complete_json(
"generate_plan",
PLAN_PROMPT,
{
"params": _redact_sensitive(params),
"intent": intent,
"execution_strategy": strategy,
"allowed_actions": list(ALLOWED_ACTIONS),
"global_action_sequence": list(GLOBAL_ACTION_SEQUENCE),
"ip_action_sequence": list(IP_ACTION_SEQUENCE),
},
)
planned_actions = _string_list(payload.get("planned_actions")) or list(GLOBAL_ACTION_SEQUENCE)
return LlmDeployPlan(
summary=_string(payload, "summary", "PAM 部署计划"),
risk_notes=_string_list(payload.get("risk_notes")),
planned_actions=planned_actions,
requires_confirmation=bool(payload.get("requires_confirmation", True)),
execution_strategy=_string(payload, "execution_strategy", strategy), # type: ignore[arg-type]
)
def analyze_action_result(
self,
*,
action: str,
result: ActionResult,
) -> LlmActionAnalysis:
"""调用 LLM 分析 action 结果,返回结构化诊断建议。"""
payload = self._complete_json(
"analyze_action_result",
self.action_analysis_prompt,
{
"action": action,
"result": _action_review_result_payload(action, result),
},
)
return LlmActionAnalysis(
action=_string(payload, "action", action),
has_anomaly=bool(payload.get("has_anomaly", False)),
severity=_string(payload, "severity", "info"), # type: ignore[arg-type]
possible_reason=_string(payload, "possible_reason", ""),
suggested_action=_string(payload, "suggested_action", ""),
requires_confirmation=bool(payload.get("requires_confirmation", False)),
should_continue=bool(payload.get("should_continue", True)),
progress_complete=_optional_bool(payload.get("progress_complete")),
notes=_string_list(payload.get("notes")),
)
def chat(self, text: str, context: dict[str, Any] | None = None) -> str:
"""调用 LLM 做普通对话,不要求 JSON 响应。"""
return self._complete_text(
"chat",
CHAT_PROMPT,
{
"user_text": text,
"context": _redact_sensitive(context or {}),
},
)
def chat_stream(self, text: str, context: dict[str, Any] | None = None) -> Iterable[str]:
"""调用 LLM 做普通流式对话,不要求 JSON 响应。"""
return self._complete_text_stream(
"chat",
CHAT_PROMPT,
{
"user_text": text,
"context": _redact_sensitive(context or {}),
},
)
def analyze_log(self, log_text: str, question: str | None = None, source_path: str = "") -> str:
"""调用 LLM 分析日志尾部摘要。"""
return self._complete_text(
"analyze_log",
LOG_ANALYSIS_PROMPT,
{
"source_path": source_path,
"question": question or "请分析日志中的异常、可能原因和下一步建议。",
"log_tail": redact_for_log(log_text, max_text_len=64000),
},
)
def propose_action(
self,
text: str,
allowed_actions: list[str],
params: dict[str, Any],
state_summary: dict[str, Any] | None = None,
) -> LlmSingleActionProposal:
"""调用 LLM 把自然语言解析为单 action 调用建议。"""
payload = self._complete_json(
"propose_action",
SINGLE_ACTION_PROMPT,
{
"user_text": text,
"allowed_actions": allowed_actions,
"params": _redact_sensitive(params),
"state_summary": _redact_sensitive(state_summary or {}),
},
)
action = _string(payload, "action", "")
if action not in allowed_actions:
action = ""
return LlmSingleActionProposal(
action=action,
ip=_string(payload, "ip", ""),
kwargs=_dict(payload.get("kwargs")),
reason=_string(payload, "reason", ""),
risk_level=_risk_level(payload.get("risk_level")),
requires_confirmation=True,
)
def _complete_json(self, operation: str, instruction: str, input_payload: dict[str, Any]) -> dict[str, Any]:
"""发送 chat/completions 请求,并解析 JSON 对象响应。"""
started_at = time.perf_counter()
endpoint = _chat_completions_url(self.base_url)
request_payload = {
"model": self.model,
"temperature": self.temperature,
"response_format": {"type": "json_object"},
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": instruction
+ "\n\n输入 JSON:\n"
+ json.dumps(input_payload, ensure_ascii=False, sort_keys=True),
},
],
}
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
logger.info(
"LLM 请求开始 operation=%s endpoint=%s model=%s timeout=%s has_api_key=%s input=%s",
operation,
endpoint,
self.model,
self.timeout_sec,
bool(self.api_key),
json_for_log(input_payload, max_text_len=1600),
)
try:
response = self.transport(
endpoint,
headers,
request_payload,
self.timeout_sec,
)
content = _message_content(response)
logger.info(
"LLM 原始响应 operation=%s duration_ms=%s content=%s",
operation,
int((time.perf_counter() - started_at) * 1000),
redact_for_log(content, max_text_len=1600),
)
parsed = _loads_json_object(content)
if not isinstance(parsed, dict):
raise ValueError("LLM 响应必须是 JSON object")
except Exception:
logger.exception(
"LLM 请求失败 operation=%s endpoint=%s duration_ms=%s input=%s",
operation,
endpoint,
int((time.perf_counter() - started_at) * 1000),
json_for_log(input_payload, max_text_len=1600),
)
raise
logger.info(
"LLM 请求完成 operation=%s duration_ms=%s response_keys=%s response=%s",
operation,
int((time.perf_counter() - started_at) * 1000),
sorted(parsed.keys()),
json_for_log(parsed, max_text_len=1600),
)
return parsed
def _complete_text(self, operation: str, instruction: str, input_payload: dict[str, Any]) -> str:
"""发送 chat/completions 请求,并返回普通文本响应。"""
started_at = time.perf_counter()
endpoint = _chat_completions_url(self.base_url)
request_payload = {
"model": self.model,
"temperature": self.temperature,
"messages": [
{"role": "system", "content": instruction},
{
"role": "user",
"content": "输入 JSON:\n" + json.dumps(input_payload, ensure_ascii=False, sort_keys=True),
},
],
}
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
logger.info(
"LLM 文本请求开始 operation=%s endpoint=%s model=%s timeout=%s has_api_key=%s input=%s",
operation,
endpoint,
self.model,
self.timeout_sec,
bool(self.api_key),
json_for_log(input_payload, max_text_len=1600),
)
try:
response = self.transport(endpoint, headers, request_payload, self.timeout_sec)
content = strip_thinking_text(str(_message_content(response)))
except Exception:
logger.exception(
"LLM 文本请求失败 operation=%s endpoint=%s duration_ms=%s input=%s",
operation,
endpoint,
int((time.perf_counter() - started_at) * 1000),
json_for_log(input_payload, max_text_len=1600),
)
raise
logger.info(
"LLM 文本请求完成 operation=%s duration_ms=%s content=%s",
operation,
int((time.perf_counter() - started_at) * 1000),
redact_for_log(content, max_text_len=1600),
)
return content
def _complete_text_stream(self, operation: str, instruction: str, input_payload: dict[str, Any]) -> Iterable[str]:
"""发送 stream chat/completions 请求,并返回过滤后的普通文本分片。"""
started_at = time.perf_counter()
endpoint = _chat_completions_url(self.base_url)
request_payload = {
"model": self.model,
"temperature": self.temperature,
"stream": True,
"messages": [
{"role": "system", "content": instruction},
{
"role": "user",
"content": "输入 JSON:\n" + json.dumps(input_payload, ensure_ascii=False, sort_keys=True),
},
],
}
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
logger.info(
"LLM 流式文本请求开始 operation=%s endpoint=%s model=%s timeout=%s has_api_key=%s input=%s",
operation,
endpoint,
self.model,
self.timeout_sec,
bool(self.api_key),
json_for_log(input_payload, max_text_len=1600),
)
try:
raw_chunks = self.stream_transport(endpoint, headers, request_payload, self.timeout_sec)
for chunk in filter_thinking_chunks(raw_chunks):
if chunk:
yield chunk
except Exception:
logger.exception(
"LLM 流式文本请求失败 operation=%s endpoint=%s duration_ms=%s input=%s",
operation,
endpoint,
int((time.perf_counter() - started_at) * 1000),
json_for_log(input_payload, max_text_len=1600),
)
raise
logger.info(
"LLM 流式文本请求完成 operation=%s duration_ms=%s",
operation,
int((time.perf_counter() - started_at) * 1000),
)
def _default_transport(
url: str,
headers: dict[str, str],
payload: dict[str, Any],
timeout_sec: float,
) -> dict[str, Any]:
"""使用标准库 urllib 发送 JSON POST 请求。"""
request = urllib.request.Request(
url,
data=json.dumps(payload).encode("utf-8"),
headers=headers,
method="POST",
)
with urllib.request.urlopen(request, timeout=timeout_sec) as response:
raw = response.read().decode("utf-8")
decoded = json.loads(raw)
if not isinstance(decoded, dict):
raise ValueError("LLM HTTP 响应必须是 JSON object")
return decoded
def _default_stream_transport(
url: str,
headers: dict[str, str],
payload: dict[str, Any],
timeout_sec: float,
) -> Iterator[str]:
"""使用标准库 urllib 发送 OpenAI-compatible SSE 流式请求。"""
request = urllib.request.Request(
url,
data=json.dumps(payload).encode("utf-8"),
headers=headers,
method="POST",
)
with urllib.request.urlopen(request, timeout=timeout_sec) as response:
for raw_line in response:
line = raw_line.decode("utf-8", errors="replace").strip()
if not line or line.startswith(":"):
continue
if line.startswith("event:") or line.startswith("id:"):
continue
if not line.startswith("data:"):
raise ValueError("LLM 流式响应不是 SSE data 格式")
data = line[len("data:") :].strip()
if data == "[DONE]":
break
try:
decoded = json.loads(data)
except json.JSONDecodeError:
logger.debug("忽略无法解析的 LLM stream data: %s", redact_for_log(data, max_text_len=300))
continue
chunk = _stream_delta_content(decoded)
if chunk:
yield chunk
def load_prompt_text(path: str | None) -> str:
"""读取自定义提示词文件。"""
if not path:
return ACTION_ANALYSIS_PROMPT
prompt_path = Path(path)
return prompt_path.read_text(encoding="utf-8").strip() or ACTION_ANALYSIS_PROMPT
def _chat_completions_url(base_url: str) -> str:
"""把 base_url 规范化为 chat/completions endpoint。"""
clean = base_url.rstrip("/")
if clean.endswith("/chat/completions"):
return clean
return f"{clean}/chat/completions"
def _message_content(response: dict[str, Any]) -> Any:
"""从 OpenAI-compatible 响应中提取 message.content。"""
try:
content = response["choices"][0]["message"]["content"]
except (KeyError, IndexError, TypeError) as exc:
raise ValueError("LLM 响应缺少 choices[0].message.content") from exc
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
parts.append(str(item.get("text", "")))
elif isinstance(item, str):
parts.append(item)
return "".join(parts)
return content
def _stream_delta_content(response: dict[str, Any]) -> str:
"""从 OpenAI-compatible stream chunk 中提取 delta.content。"""
try:
choice = response["choices"][0]
except (KeyError, IndexError, TypeError):
return ""
delta = choice.get("delta") if isinstance(choice, dict) else None
if isinstance(delta, dict) and "content" in delta:
return str(_content_parts_to_text(delta.get("content")))
message = choice.get("message") if isinstance(choice, dict) else None
if isinstance(message, dict) and "content" in message:
return str(_content_parts_to_text(message.get("content")))
text = choice.get("text") if isinstance(choice, dict) else None
return str(text) if text is not None else ""
def _content_parts_to_text(content: Any) -> str:
"""把 OpenAI content parts 或字符串转换为纯文本。"""
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
parts.append(str(item.get("text", "")))
elif isinstance(item, str):
parts.append(item)
return "".join(parts)
return "" if content is None else str(content)
def _loads_json_object(content: Any) -> Any:
"""把 message.content 解析为 JSON 对象。"""
if isinstance(content, dict):
return content
if not isinstance(content, str):
raise ValueError("LLM message content 必须是 JSON 文本")
return json.loads(content)
def _redact_sensitive(value: Any) -> Any:
"""递归脱敏 prompt 输入中的敏感字段。"""
if isinstance(value, dict):
redacted: dict[str, Any] = {}
for key, item in value.items():
if str(key) in SENSITIVE_KEYS:
redacted[str(key)] = "***"
else:
redacted[str(key)] = _redact_sensitive(item)
return redacted
if isinstance(value, list):
return [_redact_sensitive(item) for item in value]
return value
def _action_review_result_payload(action: str, result: ActionResult) -> dict[str, Any]:
"""构造 action 审核输入,避免把正常脚本日志当作错误喂给 LLM。"""
payload: dict[str, Any] = {
"backend": result.backend,
"ok": result.ok,
"exit_code": result.exit_code,
"tool_name": result.tool_name,
"values": _redact_sensitive(result.values),
"error_summary": result.error_summary,
}
if _needs_diagnostic_log(action, result):
diagnostic = _diagnostic_log_text(result)
if diagnostic:
payload["diagnostic_log"] = diagnostic
return payload
def _needs_diagnostic_log(action: str, result: ActionResult) -> bool:
"""仅在失败或业务异常时把少量诊断日志交给 LLM。"""
if not result.ok or result.error_summary or result.values.get("PENDING_AGENT_CONFIRMATION"):
return True
if action == "verify-ip":
success = result.values.get("SUCCESS")
if success is not None and str(success).lower() not in ("true", "1", "yes"):
return True
return False
def _diagnostic_log_text(result: ActionResult) -> str:
"""优先使用错误摘要;必要时取 stderr/stdout/raw_output 的尾部作为诊断上下文。"""
if result.error_summary:
return _truncate_text(result.error_summary)
for text in (result.stderr, result.stdout, result.raw_output):
stripped = text.strip()
if stripped:
return _tail_text(stripped)
return ""
def _truncate_text(value: str, limit: int = 1000) -> str:
"""截断发送给 LLM 的长文本,避免传入完整日志。"""
if len(value) <= limit:
return value
return value[:limit] + "...[已截断]"
def _tail_text(value: str, limit: int = 1000) -> str:
"""保留长诊断日志尾部,通常错误原因更靠近末尾。"""
if len(value) <= limit:
return value
marker = "[已截断]..."
return marker + value[-(limit - len(marker)) :]
def _string(payload: dict[str, Any], key: str, default: str) -> str:
"""安全读取字符串字段。"""
value = payload.get(key, default)
return str(value) if value is not None else default
def _float(payload: dict[str, Any], key: str, default: float) -> float:
"""安全读取浮点数字段。"""
try:
return float(payload.get(key, default))
except (TypeError, ValueError):
return default
def _optional_bool(value: Any) -> bool | None:
"""解析可选布尔值,字段缺失时保留 None。"""
if value is None:
return None
if isinstance(value, bool):
return value
if isinstance(value, str):
lowered = value.strip().lower()
if lowered in ("", "null", "none"):
return None
if lowered in ("true", "1", "yes", "y"):
return True
if lowered in ("false", "0", "no", "n"):
return False
return bool(value)
def _risk_level(value: Any) -> str:
"""解析单 action 风险等级,非法值降级为 medium。"""
text = str(value or "").strip().lower()
if text in ("low", "medium", "high"):
return text
return "medium"
def _dict(value: Any) -> dict[str, Any]:
"""确保返回 dict非法值降级为空 dict。"""
return value if isinstance(value, dict) else {}
def _string_list(value: Any) -> list[str]:
"""确保返回字符串列表。"""
if isinstance(value, list):
return [str(item) for item in value]
if value in (None, ""):
return []
return [str(value)]