287 lines
12 KiB
Python
287 lines
12 KiB
Python
"""LLM 结构化输出的确定性规则 fallback。
|
||
|
||
该类不是对真实模型的替代,只用于本地开发和测试时提供稳定输出。
|
||
真实 LLM client 需要实现相同方法。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import re
|
||
from typing import Any
|
||
|
||
from pam_deploy_graph.constants import GLOBAL_ACTION_SEQUENCE, IP_ACTION_SEQUENCE, REQUIRED_PARAMS
|
||
from pam_deploy_graph.models import (
|
||
ActionResult,
|
||
ExecutionStrategy,
|
||
LlmActionAnalysis,
|
||
LlmDeployPlan,
|
||
LlmIntentResult,
|
||
LlmModeDecision,
|
||
LlmParamResult,
|
||
)
|
||
|
||
KEY_ALIASES = {
|
||
"home_base_url": "HOME_BASE_URL",
|
||
"HOME_BASE_URL": "HOME_BASE_URL",
|
||
"client_id": "CLIENT_ID",
|
||
"CLIENT_ID": "CLIENT_ID",
|
||
"client_secret": "CLIENT_SECRET",
|
||
"CLIENT_SECRET": "CLIENT_SECRET",
|
||
"airportCode": "AIRPORT_CODE",
|
||
"AIRPORT_CODE": "AIRPORT_CODE",
|
||
"applicationName": "APP_NAME",
|
||
"APP_NAME": "APP_NAME",
|
||
"moduleName": "MODULE_NAME",
|
||
"MODULE_NAME": "MODULE_NAME",
|
||
"versionNumber": "VERSION_NUMBER",
|
||
"VERSION_NUMBER": "VERSION_NUMBER",
|
||
"zipFilePath": "ZIP_FILE_PATH",
|
||
"ZIP_FILE_PATH": "ZIP_FILE_PATH",
|
||
"actionType": "ACTION_TYPE",
|
||
"ACTION_TYPE": "ACTION_TYPE",
|
||
"timeOut": "TIMEOUT",
|
||
"TIMEOUT": "TIMEOUT",
|
||
"logName": "LOG_NAME",
|
||
"LOG_NAME": "LOG_NAME",
|
||
}
|
||
|
||
|
||
class RuleBasedLlmClient:
|
||
"""基于规则的轻量 LLM client fallback。"""
|
||
|
||
def understand_request(self, text: str) -> LlmIntentResult:
|
||
"""用关键词规则识别用户意图和执行策略偏好。"""
|
||
lowered = text.lower()
|
||
reasons: list[str] = []
|
||
intent = "deploy"
|
||
|
||
if any(word in lowered for word in ("用法", "怎么用", "生成脚本", "给我脚本", "usage")):
|
||
intent = "show_usage"
|
||
reasons.append("用户在询问脚本用法或脚本生成")
|
||
elif any(word in lowered for word in ("预演", "计划", "不执行", "不要动环境", "dry-run", "preview")):
|
||
intent = "preview"
|
||
reasons.append("用户要求只预演或不触碰环境")
|
||
elif any(word in lowered for word in ("在线ip", "在线 ip", "查询ip", "查询 ip", "node", "工作站")):
|
||
intent = "query_node_ips"
|
||
reasons.append("用户要求查询 Node 或在线工作站")
|
||
elif any(word in lowered for word in ("回滚", "rollback")):
|
||
intent = "rollback"
|
||
reasons.append("用户要求回滚")
|
||
else:
|
||
reasons.append("默认识别为部署请求")
|
||
|
||
mode_preference = "未指定"
|
||
strategy_preference = "未指定"
|
||
if any(word in lowered for word in ("mcp", "在线执行", "直接在线")):
|
||
mode_preference = "MCP"
|
||
strategy_preference = "hybrid_node_mcp"
|
||
reasons.append("用户倾向 MCP;PAM_HOME 仍需脚本 action")
|
||
if any(word in lowered for word in ("脚本", "离线", "script", "shell", "powershell")):
|
||
mode_preference = "API脚本"
|
||
strategy_preference = "script_only"
|
||
reasons.append("用户倾向脚本或离线执行")
|
||
if intent == "preview":
|
||
strategy_preference = strategy_preference if strategy_preference != "未指定" else "hybrid_node_mcp"
|
||
|
||
return LlmIntentResult(
|
||
intent=intent, # type: ignore[arg-type]
|
||
mode_preference=mode_preference, # type: ignore[arg-type]
|
||
strategy_preference=strategy_preference, # type: ignore[arg-type]
|
||
confidence=0.72 if intent != "deploy" else 0.6,
|
||
reasons=reasons,
|
||
)
|
||
|
||
def extract_params(self, text: str, base_params: dict[str, Any] | None = None) -> LlmParamResult:
|
||
"""从 key=value、中文短语和 IP 地址中抽取参数。"""
|
||
params = dict(base_params or {})
|
||
params.update(self._extract_key_values(text))
|
||
params.update(self._extract_chinese_patterns(text))
|
||
|
||
control: dict[str, Any] = {}
|
||
ips = re.findall(r"\b(?:\d{1,3}\.){3}\d{1,3}\b", text)
|
||
if ips:
|
||
control["user_specified_ips"] = ips
|
||
|
||
missing = [key for key in REQUIRED_PARAMS if not params.get(key)]
|
||
sensitive = [key for key in ("CLIENT_SECRET", "CLIENT_ID") if params.get(key)]
|
||
return LlmParamResult(
|
||
extracted_params=params,
|
||
extracted_control=control,
|
||
missing_required_params=missing,
|
||
sensitive_fields_present=sensitive,
|
||
)
|
||
|
||
def generate_plan(
|
||
self,
|
||
*,
|
||
params: dict[str, Any],
|
||
intent: str,
|
||
strategy: ExecutionStrategy,
|
||
skill_policy: dict[str, Any],
|
||
tool_summaries: list[dict[str, Any]],
|
||
) -> LlmDeployPlan:
|
||
"""生成确定性的部署计划和风险提示。"""
|
||
if strategy == "hybrid_node_mcp":
|
||
strategy_text = "PAM_HOME 使用脚本 action,PAM_NODE 使用 MCP"
|
||
elif strategy == "script_only":
|
||
strategy_text = "全部 action 使用脚本 action"
|
||
else:
|
||
strategy_text = "全部 action 使用 fake runner"
|
||
|
||
summary = (
|
||
f"计划处理 {params.get('AIRPORT_CODE', '-')}/"
|
||
f"{params.get('APP_NAME', '-')}/"
|
||
f"{params.get('MODULE_NAME', '-')}/"
|
||
f"{params.get('VERSION_NUMBER', '-')},执行策略为 {strategy_text}。"
|
||
)
|
||
risk_notes = [
|
||
"真实部署前必须确认参数。",
|
||
"发布版本、创建下载任务、升级和回滚属于高风险动作。",
|
||
"回滚只能在用户确认后执行。",
|
||
]
|
||
if strategy == "hybrid_node_mcp":
|
||
risk_notes.append("PAM_HOME 当前没有 MCP 能力,HOME 阶段仍会调用脚本 action。")
|
||
|
||
if intent == "query_node_ips":
|
||
planned_actions = ["get-token", "get-node-url", "get-online-ips"]
|
||
elif intent == "rollback":
|
||
planned_actions = ["rollback-ip", "verify-ip", "download-log"]
|
||
elif intent == "deploy":
|
||
planned_actions = [*GLOBAL_ACTION_SEQUENCE, *IP_ACTION_SEQUENCE]
|
||
else:
|
||
planned_actions = list(GLOBAL_ACTION_SEQUENCE)
|
||
|
||
return LlmDeployPlan(
|
||
summary=summary,
|
||
risk_notes=risk_notes,
|
||
planned_actions=planned_actions,
|
||
requires_confirmation=intent in ("deploy", "query_node_ips", "rollback"),
|
||
execution_strategy=strategy,
|
||
)
|
||
|
||
def decide_execution_mode(
|
||
self,
|
||
*,
|
||
text: str,
|
||
params: dict[str, Any],
|
||
intent: str,
|
||
strategy: ExecutionStrategy,
|
||
allowed_modes: list[str],
|
||
tool_summaries: list[dict[str, Any]],
|
||
) -> LlmModeDecision:
|
||
"""根据关键词规则决定进入固定 runtime 或 agentic skill。"""
|
||
lowered = text.lower()
|
||
requested_agentic = any(
|
||
word in lowered for word in ("自主编排", "按 skill", "自动选择工具", "自动决策", "toolcall", "agentic")
|
||
)
|
||
diagnostic_intent = any(word in lowered for word in ("诊断", "排查", "分析异常", "帮我看看", "explore"))
|
||
high_risk_intent = intent in ("deploy", "rollback") or any(word in lowered for word in ("批量", "升级", "回滚"))
|
||
|
||
mode = "fixed_runtime"
|
||
reason = "标准部署和高风险动作默认走固定 runtime。"
|
||
risk_level = "high" if high_risk_intent else "medium"
|
||
requires_confirmation = True
|
||
|
||
if requested_agentic or diagnostic_intent:
|
||
mode = "agentic_skill"
|
||
reason = "用户明确要求按 skill 自主编排,或任务更偏探索/诊断。"
|
||
risk_level = "medium"
|
||
|
||
if mode not in allowed_modes:
|
||
mode = allowed_modes[0] if allowed_modes else "fixed_runtime"
|
||
reason = "原始模式不在 skill 允许集合内,已回退到允许模式。"
|
||
|
||
return LlmModeDecision(
|
||
mode=mode, # type: ignore[arg-type]
|
||
reason=reason,
|
||
risk_level=risk_level, # type: ignore[arg-type]
|
||
requires_confirmation=requires_confirmation,
|
||
)
|
||
|
||
def analyze_action_result(
|
||
self,
|
||
*,
|
||
action: str,
|
||
result: ActionResult,
|
||
state_summary: dict[str, Any],
|
||
) -> LlmActionAnalysis:
|
||
"""用本地规则分析 action 结果,作为真实 LLM 不可用时的兜底。"""
|
||
notes: list[str] = []
|
||
has_anomaly = not result.ok
|
||
severity = "info"
|
||
possible_reason = ""
|
||
suggested_action = "继续观察。"
|
||
requires_confirmation = False
|
||
should_continue = True
|
||
|
||
if not result.ok:
|
||
severity = "medium"
|
||
possible_reason = result.error_summary or "action 返回失败状态。"
|
||
suggested_action = "查看 action stderr/raw_output,确认参数、网络和目标服务状态。"
|
||
notes.append("硬规则检测到 action 执行失败。")
|
||
should_continue = False
|
||
|
||
if action == "verify-ip":
|
||
success = result.values.get("SUCCESS")
|
||
if success is not None and str(success).lower() not in ("true", "1", "yes"):
|
||
has_anomaly = True
|
||
severity = "high"
|
||
possible_reason = result.values.get("MESSAGE", "") or "工作站健康检查未通过。"
|
||
suggested_action = "先下载日志并人工确认是否执行回滚。"
|
||
requires_confirmation = True
|
||
notes.append("verify-ip SUCCESS 非成功值。")
|
||
should_continue = False
|
||
|
||
if action == "rollback-ip" and not result.ok:
|
||
severity = "high"
|
||
suggested_action = "保持待确认状态,人工排查回滚失败原因后重试或转人工处理。"
|
||
requires_confirmation = True
|
||
notes.append("rollback-ip 失败需要人工处理。")
|
||
should_continue = False
|
||
|
||
if result.values.get("PENDING_AGENT_CONFIRMATION"):
|
||
has_anomaly = True
|
||
severity = "high"
|
||
possible_reason = str(result.values["PENDING_AGENT_CONFIRMATION"])
|
||
suggested_action = "暂停自动流程,等待人工确认。"
|
||
requires_confirmation = True
|
||
notes.append("action 返回待人工确认标记。")
|
||
should_continue = False
|
||
|
||
return LlmActionAnalysis(
|
||
action=action,
|
||
has_anomaly=has_anomaly,
|
||
severity=severity, # type: ignore[arg-type]
|
||
possible_reason=possible_reason,
|
||
suggested_action=suggested_action,
|
||
requires_confirmation=requires_confirmation,
|
||
should_continue=should_continue,
|
||
notes=notes,
|
||
)
|
||
|
||
def _extract_key_values(self, text: str) -> dict[str, str]:
|
||
"""抽取 KEY=VALUE 形式的参数。"""
|
||
params: dict[str, str] = {}
|
||
for match in re.finditer(r"([A-Za-z_][A-Za-z0-9_]*)\s*=\s*([^\s,;]+)", text):
|
||
raw_key, value = match.groups()
|
||
key = KEY_ALIASES.get(raw_key)
|
||
if key:
|
||
params[key] = value.strip()
|
||
return params
|
||
|
||
def _extract_chinese_patterns(self, text: str) -> dict[str, str]:
|
||
"""抽取常见中文描述中的部署参数。"""
|
||
patterns = {
|
||
"AIRPORT_CODE": r"(?:机场|三字码)\s*[::]?\s*([A-Z]{3})",
|
||
"APP_NAME": r"(?:应用|应用名)\s*[::]?\s*([A-Za-z0-9_.-]+)",
|
||
"MODULE_NAME": r"(?:模块|模块名)\s*[::]?\s*([A-Za-z0-9_.-]+)",
|
||
"VERSION_NUMBER": r"(?:版本|版本号)\s*[::]?\s*([A-Za-z0-9_.-]+)",
|
||
"ZIP_FILE_PATH": r"(?:包|软件包|zip)\s*[::]?\s*([A-Za-z]:[\\/][^\s,;]+|/[^\s,;]+)",
|
||
}
|
||
params: dict[str, str] = {}
|
||
for key, pattern in patterns.items():
|
||
match = re.search(pattern, text)
|
||
if match:
|
||
params[key] = match.group(1)
|
||
return params
|