402 lines
14 KiB
Python
402 lines
14 KiB
Python
"""
|
||
LangGraph 第 5 步:真实 LLM 智能体
|
||
ReAct 模式:思考(Reason) - 行动(Act) - 观察(Observe)
|
||
配置见 config.py
|
||
|
||
【Java 开发者 Python 速查】
|
||
- 没有分号 ; 结尾
|
||
- 用缩进(4空格)表示代码块,不用 {}
|
||
- def 定义函数(类似 Java 的方法)
|
||
- class 定义类
|
||
- import 导入模块(类似 Java 的 import)
|
||
- 变量不用声明类型(但可以用 :str 做类型提示)
|
||
- True/False/None(类似 Java 的 true/false/null)
|
||
- list 类似 ArrayList,dict 类似 HashMap
|
||
- for item in list: 遍历(类似 Java 的 for-each)
|
||
"""
|
||
|
||
# ============================================================
|
||
# 导入模块
|
||
# ============================================================
|
||
# Java: import java.util.List;
|
||
# Python: import 模块名
|
||
import sys # 命令行参数(类似 Java 的 main 的 String[] args)
|
||
import requests # HTTP 请求库(类似 Java 的 OkHttp/RestTemplate)
|
||
from langgraph.graph import StateGraph, START, END # 从模块导入特定类
|
||
from typing import TypedDict # 类型提示工具
|
||
|
||
# ============================================================
|
||
# 1. 加载配置
|
||
# ============================================================
|
||
# Java: 从配置文件读取
|
||
# Python: 直接 import 另一个 .py 文件,像导入类一样
|
||
from config import API_KEY, BASE_URL, MODEL, MAX_ITERATIONS, TEMPERATURE
|
||
|
||
# if 条件判断
|
||
# Java: if (condition) { ... }
|
||
# Python: if condition:
|
||
# 缩进表示代码块(4个空格)
|
||
if API_KEY == "sk-...":
|
||
print("请先在 config.py 中配置 API_KEY")
|
||
exit(1) # 退出程序(类似 Java 的 System.exit(1))
|
||
|
||
# ============================================================
|
||
# 2. 定义状态(State)
|
||
# ============================================================
|
||
"""
|
||
TypedDict 是 Python 的类型提示工具。
|
||
|
||
Java 对比:
|
||
// Java
|
||
class AgentState {
|
||
String question;
|
||
List<String> thoughts;
|
||
int iteration;
|
||
}
|
||
|
||
// Python
|
||
class AgentState(TypedDict):
|
||
question: str # String
|
||
thoughts: list # List<String>
|
||
iteration: int # int
|
||
|
||
注意: Python 的类型提示只是给编辑器看的,运行时不强制检查。
|
||
TypedDict 本质上还是 dict(字典),只是加了类型说明。
|
||
"""
|
||
class AgentState(TypedDict):
|
||
question: str # String - 用户的问题
|
||
thoughts: list # List<String> - 思考历史
|
||
current_thought: str # String - 当前这轮的思考
|
||
action: str # String - 要采取的行动
|
||
action_param: str # String - 行动的参数
|
||
observation: str # String - 行动后的观察结果
|
||
final_answer: str # String - 最终答案
|
||
iteration: int # int - 当前迭代轮次
|
||
max_iterations: int # int - 最大迭代次数
|
||
|
||
# ============================================================
|
||
# 3. 定义工具(Tools)
|
||
# ============================================================
|
||
"""
|
||
工具就是 Python 函数,智能体可以调用它们。
|
||
|
||
Java 对比:
|
||
// Java: 需要定义接口和实现
|
||
public interface Tool {
|
||
String execute(String param);
|
||
}
|
||
|
||
// Python: 直接就是函数
|
||
def calculator(expression: str) -> str:
|
||
...
|
||
|
||
Python 函数定义:
|
||
def 函数名(参数: 类型) -> 返回类型:
|
||
函数体(缩进)
|
||
return 返回值
|
||
"""
|
||
|
||
def calculator(expression: str) -> str:
|
||
"""
|
||
计算器工具
|
||
|
||
expression: str - 参数类型提示(String)
|
||
-> str - 返回类型提示(String)
|
||
"""
|
||
try:
|
||
# eval() 把字符串当 Python 表达式执行
|
||
# Java: 需要用脚本引擎,Python 直接 eval
|
||
# {"__builtins__": {}} 是安全限制,防止执行危险代码
|
||
result = eval(expression, {"__builtins__": {}}, {})
|
||
|
||
# f-string: Python 的字符串格式化
|
||
# Java: String.format("结果: %s = %s", expression, result)
|
||
# Python: f"结果: {expression} = {result}"
|
||
return f"计算结果: {expression} = {result}"
|
||
except Exception as e:
|
||
# try-catch 语法
|
||
# Java: catch (Exception e) { ... }
|
||
# Python: except Exception as e:
|
||
return f"计算错误: {e}"
|
||
|
||
def search_knowledge(query: str) -> str:
|
||
"""
|
||
知识库搜索工具(模拟)
|
||
|
||
Python dict(字典)对比 Java Map:
|
||
# Java: Map<String, String> map = new HashMap<>();
|
||
# map.put("key", "value");
|
||
|
||
# Python: dict = {"key": "value"}
|
||
"""
|
||
knowledge = {
|
||
"langgraph": "LangGraph 是 LangChain 团队开发的框架...",
|
||
"python": "Python 是一种高级编程语言...",
|
||
"langchain": "LangChain 是构建 LLM 应用的框架...",
|
||
"ai": "人工智能 (AI) 是计算机科学的一个分支...",
|
||
}
|
||
|
||
# for 循环遍历字典
|
||
# Java: for (Map.Entry<String, String> entry : knowledge.entrySet())
|
||
# Python: for key, value in knowledge.items():
|
||
for key, value in knowledge.items():
|
||
if key in query.lower(): # .lower() 转小写(类似 Java 的 toLowerCase())
|
||
return f"搜索到: {value}"
|
||
return f"未找到关于 '{query}' 的精确信息"
|
||
|
||
# 工具注册表:dict 映射 工具名 -> 函数
|
||
# Java: Map<String, Function<String, String>> tools = new HashMap<>();
|
||
# Python: tools = {"name": function}
|
||
tools = {
|
||
"calculator": calculator,
|
||
"search": search_knowledge,
|
||
}
|
||
|
||
# ============================================================
|
||
# 4. LLM 调用函数
|
||
# ============================================================
|
||
def call_llm(messages, max_tokens=300):
|
||
"""
|
||
调用 LLM API
|
||
|
||
max_tokens=300 是默认参数
|
||
Java: 需要方法重载
|
||
Python: 直接在参数列表中给默认值
|
||
"""
|
||
url = f"{BASE_URL}/chat/completions"
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {API_KEY}"
|
||
}
|
||
body = {
|
||
"model": MODEL,
|
||
"messages": messages,
|
||
"max_tokens": max_tokens,
|
||
"temperature": TEMPERATURE,
|
||
}
|
||
|
||
# requests.post() 发送 HTTP POST 请求
|
||
# Java: 需要 OkHttp/RestTemplate 几行代码
|
||
# Python: 一行搞定
|
||
response = requests.post(url, headers=headers, json=body, timeout=30)
|
||
response.raise_for_status() # 检查 HTTP 状态码
|
||
data = response.json() # 解析 JSON(类似 Java 的 Jackson/Gson)
|
||
|
||
# 访问嵌套 JSON
|
||
# Java: data.get("choices").get(0).get("message").get("content")
|
||
# Python: data["choices"][0]["message"]["content"]
|
||
return data["choices"][0]["message"]["content"]
|
||
|
||
# ============================================================
|
||
# 5. 定义节点(Nodes)
|
||
# ============================================================
|
||
"""
|
||
LangGraph 节点:接收 State,返回更新后的 State。
|
||
|
||
为什么每个节点都要 state = state.copy()?
|
||
- Python 字典是可变对象(mutable)
|
||
- 直接修改会影响原始数据
|
||
- .copy() 创建浅拷贝(类似 Java 的 new HashMap<>(original))
|
||
"""
|
||
|
||
def think_node(state: AgentState):
|
||
"""思考节点 - 让 LLM 决定下一步"""
|
||
state = state.copy() # 创建副本
|
||
state['iteration'] += 1 # 迭代次数 +1(类似 Java 的 state.iteration++)
|
||
|
||
# 多行字符串(三引号)
|
||
# Java: "line1\n" + "line2\n"
|
||
# Python: """line1\nline2"""
|
||
system_prompt = f"""你是一个智能助手,可以使用以下工具:
|
||
1. calculator - 数学计算,参数是数学表达式如 "2+3*4"
|
||
2. search - 搜索知识,参数是搜索关键词
|
||
|
||
当前是第 {state['iteration']}/{state['max_iterations']} 轮。
|
||
|
||
请严格按照以下格式回复:
|
||
[思考] 你的思考过程
|
||
[行动] 工具名称|参数
|
||
例如:
|
||
[思考] 我需要计算这个数学题
|
||
[行动] calculator|2+3*4
|
||
|
||
如果可以直接回答,请这样回复:
|
||
[思考] 我已经知道答案了
|
||
[回答] 你的最终答案"""
|
||
|
||
# 构建消息列表
|
||
# Java: List<Map<String, String>> messages = new ArrayList<>();
|
||
# messages.add(Map.of("role", "system", "content", prompt));
|
||
# Python: messages = [{"role": "system", "content": prompt}]
|
||
messages = [{"role": "system", "content": system_prompt}]
|
||
messages.append({"role": "user", "content": state['question']})
|
||
|
||
# .get() 安全访问
|
||
# Java: String obs = state.get("observation");
|
||
# Python: obs = state.get('observation') # 不存在返回 None
|
||
if state.get('observation'):
|
||
messages.append({"role": "assistant", "content": f"[观察] {state['observation']}"})
|
||
|
||
# 调用 LLM
|
||
thought_text = call_llm(messages)
|
||
|
||
# 保存思考历史
|
||
state['current_thought'] = thought_text
|
||
state['thoughts'] = state.get('thoughts', []) + [thought_text]
|
||
# .get('thoughts', []) - 如果不存在返回空列表 []
|
||
|
||
# 打印调试信息
|
||
print(f"\n{'='*50}") # '='*50 生成 50 个等号(类似 Java 的 String.repeat(50))
|
||
print(f"[思考] 第 {state['iteration']} 轮:")
|
||
print(thought_text)
|
||
|
||
# 解析 LLM 的输出
|
||
# .split('\n') 按换行符分割字符串
|
||
# Java: String[] lines = thought_text.split("\n");
|
||
# Python: lines = thought_text.split('\n')
|
||
for line in thought_text.split('\n'):
|
||
if '[行动]' in line: # 'in' 检查子字符串(类似 Java 的 contains())
|
||
# .replace() 替换字符串
|
||
# .strip() 去前后空格(类似 Java 的 trim())
|
||
# .split('|') 按 | 分割
|
||
parts = line.replace('[行动]', '').strip().split('|')
|
||
if len(parts) == 2: # len() 获取长度(类似 Java 的 .length())
|
||
state['action'] = parts[0].strip()
|
||
state['action_param'] = parts[1].strip()
|
||
return state
|
||
if '[回答]' in line:
|
||
state['final_answer'] = line.replace('[回答]', '').strip()
|
||
return state
|
||
|
||
state['action'] = ""
|
||
return state
|
||
|
||
def act_node(state: AgentState):
|
||
"""行动节点 - 执行工具"""
|
||
state = state.copy()
|
||
|
||
# .get() 带默认值
|
||
# Java: String action = state.getOrDefault("action", "");
|
||
# Python: action = state.get('action', '')
|
||
action = state.get('action', '')
|
||
param = state.get('action_param', '')
|
||
|
||
print(f"\n[行动] 执行 {action}({param})")
|
||
|
||
# 检查键是否存在
|
||
# Java: if (tools.containsKey(action))
|
||
# Python: if action in tools
|
||
if action in tools:
|
||
result = tools[action](param) # 调用函数
|
||
state['observation'] = result
|
||
print(f"[观察] {result}")
|
||
else:
|
||
state['observation'] = f"未知工具: {action}"
|
||
print(f"[观察] 未知工具: {action}")
|
||
|
||
return state
|
||
|
||
def answer_node(state: AgentState):
|
||
"""回答节点 - 生成最终答案"""
|
||
state = state.copy()
|
||
|
||
if state.get('final_answer'):
|
||
print(f"\n[回答] {state['final_answer']}")
|
||
return state
|
||
|
||
# 列表推导式(类似 Java 的 Stream)
|
||
# Java: String.join("\n", thoughts)
|
||
# Python: "\n".join(thoughts)
|
||
messages = [
|
||
{"role": "system", "content": "请根据以下信息给出简洁的最终答案"},
|
||
{"role": "user", "content": f"问题: {state['question']}\n\n思考过程:\n" + "\n".join(state.get('thoughts', []))}
|
||
]
|
||
|
||
state['final_answer'] = call_llm(messages, max_tokens=200)
|
||
print(f"\n[回答] {state['final_answer']}")
|
||
return state
|
||
|
||
# ============================================================
|
||
# 6. 路由函数(Router)
|
||
# ============================================================
|
||
"""
|
||
路由函数:接收 State,返回字符串(下一步节点名)。
|
||
|
||
Java 对比:
|
||
// Java: String route(AgentState state) { ... return "act"; }
|
||
// Python: def route(state: AgentState) -> str: ... return "act"
|
||
"""
|
||
def route(state: AgentState):
|
||
"""决定下一步去哪里"""
|
||
if state.get('final_answer'):
|
||
return "answer"
|
||
if state['iteration'] >= state['max_iterations']:
|
||
return "answer"
|
||
if state.get('action'):
|
||
return "act"
|
||
return "answer"
|
||
|
||
# ============================================================
|
||
# 7. 构建图(Build Graph)
|
||
# ============================================================
|
||
graph = StateGraph(AgentState)
|
||
|
||
graph.add_node("think", think_node)
|
||
graph.add_node("act", act_node)
|
||
graph.add_node("answer", answer_node)
|
||
|
||
graph.add_edge(START, "think")
|
||
graph.add_conditional_edges("think", route, {"act": "act", "answer": "answer"})
|
||
graph.add_edge("act", "think")
|
||
graph.add_edge("answer", END)
|
||
|
||
app = graph.compile()
|
||
|
||
# ============================================================
|
||
# 8. 运行(Run)
|
||
# ============================================================
|
||
print("=" * 50)
|
||
print("LangGraph ReAct 智能体")
|
||
print("=" * 50)
|
||
print(f"API: {BASE_URL}")
|
||
print(f"模型: {MODEL}")
|
||
print(f"最大迭代: {MAX_ITERATIONS}")
|
||
print("\n图结构:")
|
||
print(" START -> think -> [有行动?] -> act -> think (循环)")
|
||
print(" |")
|
||
print(" +-> [无行动/达到限制] -> answer -> END")
|
||
|
||
# 命令行参数
|
||
# Java: public static void main(String[] args)
|
||
# Python: sys.argv[0] 是脚本名,sys.argv[1:] 是参数
|
||
if len(sys.argv) > 1:
|
||
questions = [" ".join(sys.argv[1:])] # 把所有参数拼成一个字符串
|
||
else:
|
||
questions = [
|
||
"计算一下 123 * 456",
|
||
"什么是 LangGraph?",
|
||
]
|
||
|
||
# for 循环
|
||
# Java: for (String q : questions) { ... }
|
||
# Python: for q in questions:
|
||
for q in questions:
|
||
print(f"\n{'#'*50}")
|
||
print(f"问题: {q}")
|
||
print(f"{'#'*50}")
|
||
|
||
# app.invoke() 运行图
|
||
result = app.invoke({
|
||
"question": q,
|
||
"thoughts": [],
|
||
"iteration": 0,
|
||
"max_iterations": MAX_ITERATIONS,
|
||
"action": "",
|
||
"observation": "",
|
||
"final_answer": "",
|
||
"action_param": "",
|
||
})
|
||
|
||
print(f"\n{'='*50}")
|
||
print(f"最终答案: {result['final_answer']}")
|
||
print(f"{'='*50}") |