Skip to content

Commit afcdb5f

Browse files
jmaillefaudvbarda
andauthored
support ToolNode as input to create_supervisor (#135)
`create_react_agent` in `langgraph-prebuilt` allows to pass a `ToolNode` for more precise management of the ToolExecutions (for example to raise an error instead of returning a `ToolMessage`). It would be nice if `langgraph-supervisor` could also accept a `ToolNode`. --------- Co-authored-by: vbarda <[email protected]>
1 parent 6367beb commit afcdb5f

File tree

1 file changed

+78
-33
lines changed

1 file changed

+78
-33
lines changed

langgraph_supervisor/supervisor.py

Lines changed: 78 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
import inspect
2-
from typing import Any, Callable, Literal, Optional, Type, Union, cast, get_args
2+
from typing import Any, Callable, Literal, Optional, Sequence, Type, Union, cast, get_args
33

44
from langchain_core.language_models import BaseChatModel, LanguageModelLike
55
from langchain_core.runnables import RunnableConfig
66
from langchain_core.tools import BaseTool
77
from langgraph.graph import END, START, StateGraph
8+
from langgraph.prebuilt import ToolNode
89
from langgraph.prebuilt.chat_agent_executor import (
910
AgentState,
1011
Prompt,
1112
StateSchemaType,
1213
StructuredResponseSchema,
14+
_should_bind_tools,
1315
create_react_agent,
1416
)
1517
from langgraph.pregel import Pregel
@@ -93,7 +95,7 @@ async def acall_agent(state: dict, config: RunnableConfig) -> dict:
9395
return RunnableCallable(call_agent, acall_agent)
9496

9597

96-
def _get_handoff_destinations(tools: list[BaseTool | Callable]) -> list[str]:
98+
def _get_handoff_destinations(tools: Sequence[BaseTool | Callable]) -> list[str]:
9799
"""Extract handoff destinations from provided tools.
98100
Args:
99101
tools: List of tools to inspect.
@@ -109,11 +111,70 @@ def _get_handoff_destinations(tools: list[BaseTool | Callable]) -> list[str]:
109111
]
110112

111113

114+
def _prepare_tool_node(
115+
tools: list[BaseTool | Callable] | ToolNode | None,
116+
handoff_tool_prefix: Optional[str],
117+
add_handoff_messages: bool,
118+
agent_names: set[str],
119+
) -> ToolNode:
120+
"""Prepare the ToolNode to use in supervisor agent."""
121+
if isinstance(tools, ToolNode):
122+
input_tool_node = tools
123+
tool_classes = list(tools.tools_by_name.values())
124+
elif tools:
125+
input_tool_node = ToolNode(tools)
126+
# get the tool functions wrapped in a tool class from the ToolNode
127+
tool_classes = list(input_tool_node.tools_by_name.values())
128+
else:
129+
input_tool_node = None
130+
tool_classes = []
131+
132+
handoff_destinations = _get_handoff_destinations(tool_classes)
133+
if handoff_destinations:
134+
if missing_handoff_destinations := set(agent_names) - set(handoff_destinations):
135+
raise ValueError(
136+
"When providing custom handoff tools, you must provide them for all subagents. "
137+
f"Missing handoff tools for agents '{missing_handoff_destinations}'."
138+
)
139+
140+
# Handoff tools should be already provided here
141+
tool_node = cast(ToolNode, input_tool_node)
142+
else:
143+
handoff_tools = [
144+
create_handoff_tool(
145+
agent_name=agent_name,
146+
name=(
147+
None
148+
if handoff_tool_prefix is None
149+
else f"{handoff_tool_prefix}{_normalize_agent_name(agent_name)}"
150+
),
151+
add_handoff_messages=add_handoff_messages,
152+
)
153+
for agent_name in agent_names
154+
]
155+
all_tools = tool_classes + list(handoff_tools)
156+
157+
# re-wrap the combined tools in a ToolNode
158+
# if the original input was a ToolNode, apply the same params
159+
if input_tool_node is not None:
160+
tool_node = ToolNode(
161+
all_tools,
162+
name=input_tool_node.name,
163+
tags=list(input_tool_node.tags) if input_tool_node.tags else None,
164+
handle_tool_errors=input_tool_node.handle_tool_errors,
165+
messages_key=input_tool_node.messages_key,
166+
)
167+
else:
168+
tool_node = ToolNode(all_tools)
169+
170+
return tool_node
171+
172+
112173
def create_supervisor(
113174
agents: list[Pregel],
114175
*,
115176
model: LanguageModelLike,
116-
tools: list[BaseTool | Callable] | None = None,
177+
tools: list[BaseTool | Callable] | ToolNode | None = None,
117178
prompt: Prompt | None = None,
118179
response_format: Optional[
119180
Union[StructuredResponseSchema, tuple[str, StructuredResponseSchema]]
@@ -256,45 +317,29 @@ def web_search(query: str) -> str:
256317

257318
agent_names.add(agent.name)
258319

259-
handoff_destinations = _get_handoff_destinations(tools or [])
260-
if handoff_destinations:
261-
if missing_handoff_destinations := set(agent_names) - set(handoff_destinations):
262-
raise ValueError(
263-
"When providing custom handoff tools, you must provide them for all subagents. "
264-
f"Missing handoff tools for agents '{missing_handoff_destinations}'."
265-
)
320+
tool_node = _prepare_tool_node(
321+
tools,
322+
handoff_tool_prefix,
323+
add_handoff_messages,
324+
agent_names,
325+
)
326+
all_tools = list(tool_node.tools_by_name.values())
266327

267-
# Handoff tools should be already provided here
268-
all_tools = tools or []
269-
else:
270-
handoff_tools = [
271-
create_handoff_tool(
272-
agent_name=agent.name,
273-
name=(
274-
None
275-
if handoff_tool_prefix is None
276-
else f"{handoff_tool_prefix}{_normalize_agent_name(agent.name)}"
277-
),
278-
add_handoff_messages=add_handoff_messages,
328+
if _should_bind_tools(model, all_tools):
329+
if _supports_disable_parallel_tool_calls(model):
330+
model = cast(BaseChatModel, model).bind_tools(
331+
all_tools, parallel_tool_calls=parallel_tool_calls
279332
)
280-
for agent in agents
281-
]
282-
all_tools = (tools or []) + list(handoff_tools)
283-
284-
if _supports_disable_parallel_tool_calls(model):
285-
model = cast(BaseChatModel, model).bind_tools(
286-
all_tools, parallel_tool_calls=parallel_tool_calls
287-
)
288-
else:
289-
model = cast(BaseChatModel, model).bind_tools(all_tools)
333+
else:
334+
model = cast(BaseChatModel, model).bind_tools(all_tools)
290335

291336
if include_agent_name:
292337
model = with_agent_name(model, include_agent_name)
293338

294339
supervisor_agent = create_react_agent(
295340
name=supervisor_name,
296341
model=model,
297-
tools=all_tools,
342+
tools=tool_node,
298343
prompt=prompt,
299344
state_schema=state_schema,
300345
response_format=response_format,

0 commit comments

Comments
 (0)