105 lines
2.9 KiB
Python
105 lines
2.9 KiB
Python
"""
|
||
LangGraph 条件边 - 让图学会做决定
|
||
"""
|
||
from langgraph.graph import StateGraph, START, END
|
||
from typing import TypedDict
|
||
|
||
# 1️⃣ 定义状态
|
||
class GraphState(TypedDict):
|
||
message: str
|
||
user_input: str
|
||
|
||
# 2️⃣ 定义节点
|
||
def receive_node(state: GraphState):
|
||
"""接收节点 - 分类用户输入"""
|
||
text = state['user_input'].strip().lower()
|
||
print(f"[接收] 用户说: {state['user_input']}")
|
||
|
||
# 简单分类
|
||
if any(word in text for word in ['你好', 'hello', 'hi', 'hey']):
|
||
state['message'] = "问候"
|
||
elif '?' in text or '?' in text or '怎么' in text or '什么' in text:
|
||
state['message'] = "问题"
|
||
else:
|
||
state['message'] = "闲聊"
|
||
|
||
print(f"[接收] 分类为: {state['message']}")
|
||
return state
|
||
|
||
def respond_greeting(state: GraphState):
|
||
"""回应问候"""
|
||
state['message'] = "你好呀!有什么可以帮你的吗?"
|
||
print(f"[问候回应] {state['message']}")
|
||
return state
|
||
|
||
def answer_question(state: GraphState):
|
||
"""回答问题"""
|
||
state['message'] = "这是一个好问题!让我想想... (这里可以接入 LLM)"
|
||
print(f"[问题回答] {state['message']}")
|
||
return state
|
||
|
||
def chat_casual(state: GraphState):
|
||
"""闲聊"""
|
||
state['message'] = "哈哈,聊得开心!"
|
||
print(f"[闲聊] {state['message']}")
|
||
return state
|
||
|
||
# 3️⃣ 条件路由函数 - 这是关键!
|
||
def route_input(state: GraphState):
|
||
"""根据分类决定下一步走哪个节点"""
|
||
category = state['message']
|
||
print(f"[路由] 决定走向: {category}")
|
||
|
||
if category == "问候":
|
||
return "respond_greeting"
|
||
elif category == "问题":
|
||
return "answer_question"
|
||
else:
|
||
return "chat_casual"
|
||
|
||
# 4️⃣ 构建图
|
||
graph = StateGraph(GraphState)
|
||
|
||
# 添加所有节点
|
||
graph.add_node("receive", receive_node)
|
||
graph.add_node("respond_greeting", respond_greeting)
|
||
graph.add_node("answer_question", answer_question)
|
||
graph.add_node("chat_casual", chat_casual)
|
||
|
||
# 添加边
|
||
graph.add_edge(START, "receive")
|
||
|
||
# 条件边!根据路由函数返回值决定走向
|
||
graph.add_conditional_edges(
|
||
"receive", # 从接收节点出发
|
||
route_input, # 用这个函数做判断
|
||
{
|
||
"respond_greeting": "respond_greeting",
|
||
"answer_question": "answer_question",
|
||
"chat_casual": "chat_casual",
|
||
}
|
||
)
|
||
|
||
# 所有分支最后都汇聚到 END
|
||
graph.add_edge("respond_greeting", END)
|
||
graph.add_edge("answer_question", END)
|
||
graph.add_edge("chat_casual", END)
|
||
|
||
# 编译
|
||
app = graph.compile()
|
||
|
||
# 5️⃣ 测试三种情况
|
||
test_cases = [
|
||
"你好!",
|
||
"什么是 LangGraph?",
|
||
"今天天气不错",
|
||
]
|
||
|
||
print("=" * 50)
|
||
print("条件边演示 - 图会根据输入走不同路径")
|
||
print("=" * 50)
|
||
|
||
for i, text in enumerate(test_cases, 1):
|
||
print(f"\n--- 测试 {i}: {text} ---")
|
||
result = app.invoke({"user_input": text, "message": ""})
|
||
print(f" 最终输出: {result['message']}") |