402 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
LangGraph 第 5 步:真实 LLM 智能体
ReAct 模式:思考(Reason) - 行动(Act) - 观察(Observe)
配置见 config.py
【Java 开发者 Python 速查】
- 没有分号 ; 结尾
- 用缩进4空格表示代码块不用 {}
- def 定义函数(类似 Java 的方法)
- class 定义类
- import 导入模块(类似 Java 的 import
- 变量不用声明类型(但可以用 :str 做类型提示)
- True/False/None类似 Java 的 true/false/null
- list 类似 ArrayListdict 类似 HashMap
- for item in list: 遍历(类似 Java 的 for-each
"""
# ============================================================
# 导入模块
# ============================================================
# Java: import java.util.List;
# Python: import 模块名
import sys # 命令行参数(类似 Java 的 main 的 String[] args
import requests # HTTP 请求库(类似 Java 的 OkHttp/RestTemplate
from langgraph.graph import StateGraph, START, END # 从模块导入特定类
from typing import TypedDict # 类型提示工具
# ============================================================
# 1. 加载配置
# ============================================================
# Java: 从配置文件读取
# Python: 直接 import 另一个 .py 文件,像导入类一样
from config import API_KEY, BASE_URL, MODEL, MAX_ITERATIONS, TEMPERATURE
# if 条件判断
# Java: if (condition) { ... }
# Python: if condition:
# 缩进表示代码块4个空格
if API_KEY == "sk-...":
print("请先在 config.py 中配置 API_KEY")
exit(1) # 退出程序(类似 Java 的 System.exit(1)
# ============================================================
# 2. 定义状态State
# ============================================================
"""
TypedDict 是 Python 的类型提示工具。
Java 对比:
// Java
class AgentState {
String question;
List<String> thoughts;
int iteration;
}
// Python
class AgentState(TypedDict):
question: str # String
thoughts: list # List<String>
iteration: int # int
注意: Python 的类型提示只是给编辑器看的,运行时不强制检查。
TypedDict 本质上还是 dict字典只是加了类型说明。
"""
class AgentState(TypedDict):
question: str # String - 用户的问题
thoughts: list # List<String> - 思考历史
current_thought: str # String - 当前这轮的思考
action: str # String - 要采取的行动
action_param: str # String - 行动的参数
observation: str # String - 行动后的观察结果
final_answer: str # String - 最终答案
iteration: int # int - 当前迭代轮次
max_iterations: int # int - 最大迭代次数
# ============================================================
# 3. 定义工具Tools
# ============================================================
"""
工具就是 Python 函数,智能体可以调用它们。
Java 对比:
// Java: 需要定义接口和实现
public interface Tool {
String execute(String param);
}
// Python: 直接就是函数
def calculator(expression: str) -> str:
...
Python 函数定义:
def 函数名(参数: 类型) -> 返回类型:
函数体(缩进)
return 返回值
"""
def calculator(expression: str) -> str:
"""
计算器工具
expression: str - 参数类型提示String
-> str - 返回类型提示String
"""
try:
# eval() 把字符串当 Python 表达式执行
# Java: 需要用脚本引擎Python 直接 eval
# {"__builtins__": {}} 是安全限制,防止执行危险代码
result = eval(expression, {"__builtins__": {}}, {})
# f-string: Python 的字符串格式化
# Java: String.format("结果: %s = %s", expression, result)
# Python: f"结果: {expression} = {result}"
return f"计算结果: {expression} = {result}"
except Exception as e:
# try-catch 语法
# Java: catch (Exception e) { ... }
# Python: except Exception as e:
return f"计算错误: {e}"
def search_knowledge(query: str) -> str:
"""
知识库搜索工具(模拟)
Python dict字典对比 Java Map:
# Java: Map<String, String> map = new HashMap<>();
# map.put("key", "value");
# Python: dict = {"key": "value"}
"""
knowledge = {
"langgraph": "LangGraph 是 LangChain 团队开发的框架...",
"python": "Python 是一种高级编程语言...",
"langchain": "LangChain 是构建 LLM 应用的框架...",
"ai": "人工智能 (AI) 是计算机科学的一个分支...",
}
# for 循环遍历字典
# Java: for (Map.Entry<String, String> entry : knowledge.entrySet())
# Python: for key, value in knowledge.items():
for key, value in knowledge.items():
if key in query.lower(): # .lower() 转小写(类似 Java 的 toLowerCase()
return f"搜索到: {value}"
return f"未找到关于 '{query}' 的精确信息"
# 工具注册表dict 映射 工具名 -> 函数
# Java: Map<String, Function<String, String>> tools = new HashMap<>();
# Python: tools = {"name": function}
tools = {
"calculator": calculator,
"search": search_knowledge,
}
# ============================================================
# 4. LLM 调用函数
# ============================================================
def call_llm(messages, max_tokens=300):
"""
调用 LLM API
max_tokens=300 是默认参数
Java: 需要方法重载
Python: 直接在参数列表中给默认值
"""
url = f"{BASE_URL}/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {API_KEY}"
}
body = {
"model": MODEL,
"messages": messages,
"max_tokens": max_tokens,
"temperature": TEMPERATURE,
}
# requests.post() 发送 HTTP POST 请求
# Java: 需要 OkHttp/RestTemplate 几行代码
# Python: 一行搞定
response = requests.post(url, headers=headers, json=body, timeout=30)
response.raise_for_status() # 检查 HTTP 状态码
data = response.json() # 解析 JSON类似 Java 的 Jackson/Gson
# 访问嵌套 JSON
# Java: data.get("choices").get(0).get("message").get("content")
# Python: data["choices"][0]["message"]["content"]
return data["choices"][0]["message"]["content"]
# ============================================================
# 5. 定义节点Nodes
# ============================================================
"""
LangGraph 节点:接收 State返回更新后的 State。
为什么每个节点都要 state = state.copy()
- Python 字典是可变对象mutable
- 直接修改会影响原始数据
- .copy() 创建浅拷贝(类似 Java 的 new HashMap<>(original)
"""
def think_node(state: AgentState):
"""思考节点 - 让 LLM 决定下一步"""
state = state.copy() # 创建副本
state['iteration'] += 1 # 迭代次数 +1类似 Java 的 state.iteration++
# 多行字符串(三引号)
# Java: "line1\n" + "line2\n"
# Python: """line1\nline2"""
system_prompt = f"""你是一个智能助手,可以使用以下工具:
1. calculator - 数学计算,参数是数学表达式如 "2+3*4"
2. search - 搜索知识,参数是搜索关键词
当前是第 {state['iteration']}/{state['max_iterations']} 轮。
请严格按照以下格式回复:
[思考] 你的思考过程
[行动] 工具名称|参数
例如:
[思考] 我需要计算这个数学题
[行动] calculator|2+3*4
如果可以直接回答,请这样回复:
[思考] 我已经知道答案了
[回答] 你的最终答案"""
# 构建消息列表
# Java: List<Map<String, String>> messages = new ArrayList<>();
# messages.add(Map.of("role", "system", "content", prompt));
# Python: messages = [{"role": "system", "content": prompt}]
messages = [{"role": "system", "content": system_prompt}]
messages.append({"role": "user", "content": state['question']})
# .get() 安全访问
# Java: String obs = state.get("observation");
# Python: obs = state.get('observation') # 不存在返回 None
if state.get('observation'):
messages.append({"role": "assistant", "content": f"[观察] {state['observation']}"})
# 调用 LLM
thought_text = call_llm(messages)
# 保存思考历史
state['current_thought'] = thought_text
state['thoughts'] = state.get('thoughts', []) + [thought_text]
# .get('thoughts', []) - 如果不存在返回空列表 []
# 打印调试信息
print(f"\n{'='*50}") # '='*50 生成 50 个等号(类似 Java 的 String.repeat(50)
print(f"[思考] 第 {state['iteration']} 轮:")
print(thought_text)
# 解析 LLM 的输出
# .split('\n') 按换行符分割字符串
# Java: String[] lines = thought_text.split("\n");
# Python: lines = thought_text.split('\n')
for line in thought_text.split('\n'):
if '[行动]' in line: # 'in' 检查子字符串(类似 Java 的 contains()
# .replace() 替换字符串
# .strip() 去前后空格(类似 Java 的 trim()
# .split('|') 按 | 分割
parts = line.replace('[行动]', '').strip().split('|')
if len(parts) == 2: # len() 获取长度(类似 Java 的 .length()
state['action'] = parts[0].strip()
state['action_param'] = parts[1].strip()
return state
if '[回答]' in line:
state['final_answer'] = line.replace('[回答]', '').strip()
return state
state['action'] = ""
return state
def act_node(state: AgentState):
"""行动节点 - 执行工具"""
state = state.copy()
# .get() 带默认值
# Java: String action = state.getOrDefault("action", "");
# Python: action = state.get('action', '')
action = state.get('action', '')
param = state.get('action_param', '')
print(f"\n[行动] 执行 {action}({param})")
# 检查键是否存在
# Java: if (tools.containsKey(action))
# Python: if action in tools
if action in tools:
result = tools[action](param) # 调用函数
state['observation'] = result
print(f"[观察] {result}")
else:
state['observation'] = f"未知工具: {action}"
print(f"[观察] 未知工具: {action}")
return state
def answer_node(state: AgentState):
"""回答节点 - 生成最终答案"""
state = state.copy()
if state.get('final_answer'):
print(f"\n[回答] {state['final_answer']}")
return state
# 列表推导式(类似 Java 的 Stream
# Java: String.join("\n", thoughts)
# Python: "\n".join(thoughts)
messages = [
{"role": "system", "content": "请根据以下信息给出简洁的最终答案"},
{"role": "user", "content": f"问题: {state['question']}\n\n思考过程:\n" + "\n".join(state.get('thoughts', []))}
]
state['final_answer'] = call_llm(messages, max_tokens=200)
print(f"\n[回答] {state['final_answer']}")
return state
# ============================================================
# 6. 路由函数Router
# ============================================================
"""
路由函数:接收 State返回字符串下一步节点名
Java 对比:
// Java: String route(AgentState state) { ... return "act"; }
// Python: def route(state: AgentState) -> str: ... return "act"
"""
def route(state: AgentState):
"""决定下一步去哪里"""
if state.get('final_answer'):
return "answer"
if state['iteration'] >= state['max_iterations']:
return "answer"
if state.get('action'):
return "act"
return "answer"
# ============================================================
# 7. 构建图Build Graph
# ============================================================
graph = StateGraph(AgentState)
graph.add_node("think", think_node)
graph.add_node("act", act_node)
graph.add_node("answer", answer_node)
graph.add_edge(START, "think")
graph.add_conditional_edges("think", route, {"act": "act", "answer": "answer"})
graph.add_edge("act", "think")
graph.add_edge("answer", END)
app = graph.compile()
# ============================================================
# 8. 运行Run
# ============================================================
print("=" * 50)
print("LangGraph ReAct 智能体")
print("=" * 50)
print(f"API: {BASE_URL}")
print(f"模型: {MODEL}")
print(f"最大迭代: {MAX_ITERATIONS}")
print("\n图结构:")
print(" START -> think -> [有行动?] -> act -> think (循环)")
print(" |")
print(" +-> [无行动/达到限制] -> answer -> END")
# 命令行参数
# Java: public static void main(String[] args)
# Python: sys.argv[0] 是脚本名sys.argv[1:] 是参数
if len(sys.argv) > 1:
questions = [" ".join(sys.argv[1:])] # 把所有参数拼成一个字符串
else:
questions = [
"计算一下 123 * 456",
"什么是 LangGraph",
]
# for 循环
# Java: for (String q : questions) { ... }
# Python: for q in questions:
for q in questions:
print(f"\n{'#'*50}")
print(f"问题: {q}")
print(f"{'#'*50}")
# app.invoke() 运行图
result = app.invoke({
"question": q,
"thoughts": [],
"iteration": 0,
"max_iterations": MAX_ITERATIONS,
"action": "",
"observation": "",
"final_answer": "",
"action_param": "",
})
print(f"\n{'='*50}")
print(f"最终答案: {result['final_answer']}")
print(f"{'='*50}")