2026-05-25 21:33:13 +08:00

105 lines
2.9 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
LangGraph 条件边 - 让图学会做决定
"""
from langgraph.graph import StateGraph, START, END
from typing import TypedDict
# 1⃣ 定义状态
class GraphState(TypedDict):
message: str
user_input: str
# 2⃣ 定义节点
def receive_node(state: GraphState):
"""接收节点 - 分类用户输入"""
text = state['user_input'].strip().lower()
print(f"[接收] 用户说: {state['user_input']}")
# 简单分类
if any(word in text for word in ['你好', 'hello', 'hi', 'hey']):
state['message'] = "问候"
elif '?' in text or '' in text or '怎么' in text or '什么' in text:
state['message'] = "问题"
else:
state['message'] = "闲聊"
print(f"[接收] 分类为: {state['message']}")
return state
def respond_greeting(state: GraphState):
"""回应问候"""
state['message'] = "你好呀!有什么可以帮你的吗?"
print(f"[问候回应] {state['message']}")
return state
def answer_question(state: GraphState):
"""回答问题"""
state['message'] = "这是一个好问题!让我想想... (这里可以接入 LLM)"
print(f"[问题回答] {state['message']}")
return state
def chat_casual(state: GraphState):
"""闲聊"""
state['message'] = "哈哈,聊得开心!"
print(f"[闲聊] {state['message']}")
return state
# 3⃣ 条件路由函数 - 这是关键!
def route_input(state: GraphState):
"""根据分类决定下一步走哪个节点"""
category = state['message']
print(f"[路由] 决定走向: {category}")
if category == "问候":
return "respond_greeting"
elif category == "问题":
return "answer_question"
else:
return "chat_casual"
# 4⃣ 构建图
graph = StateGraph(GraphState)
# 添加所有节点
graph.add_node("receive", receive_node)
graph.add_node("respond_greeting", respond_greeting)
graph.add_node("answer_question", answer_question)
graph.add_node("chat_casual", chat_casual)
# 添加边
graph.add_edge(START, "receive")
# 条件边!根据路由函数返回值决定走向
graph.add_conditional_edges(
"receive", # 从接收节点出发
route_input, # 用这个函数做判断
{
"respond_greeting": "respond_greeting",
"answer_question": "answer_question",
"chat_casual": "chat_casual",
}
)
# 所有分支最后都汇聚到 END
graph.add_edge("respond_greeting", END)
graph.add_edge("answer_question", END)
graph.add_edge("chat_casual", END)
# 编译
app = graph.compile()
# 5⃣ 测试三种情况
test_cases = [
"你好!",
"什么是 LangGraph",
"今天天气不错",
]
print("=" * 50)
print("条件边演示 - 图会根据输入走不同路径")
print("=" * 50)
for i, text in enumerate(test_cases, 1):
print(f"\n--- 测试 {i}: {text} ---")
result = app.invoke({"user_input": text, "message": ""})
print(f" 最终输出: {result['message']}")