11import inspect
2- from typing import Any , Callable , Literal , Optional , Type , Union , cast , get_args
2+ from typing import Any , Callable , Literal , Optional , Sequence , Type , Union , cast , get_args
33
44from langchain_core .language_models import BaseChatModel , LanguageModelLike
55from langchain_core .runnables import RunnableConfig
66from langchain_core .tools import BaseTool
77from langgraph .graph import END , START , StateGraph
8+ from langgraph .prebuilt import ToolNode
89from langgraph .prebuilt .chat_agent_executor import (
910 AgentState ,
1011 Prompt ,
1112 StateSchemaType ,
1213 StructuredResponseSchema ,
14+ _should_bind_tools ,
1315 create_react_agent ,
1416)
1517from langgraph .pregel import Pregel
@@ -93,7 +95,7 @@ async def acall_agent(state: dict, config: RunnableConfig) -> dict:
9395 return RunnableCallable (call_agent , acall_agent )
9496
9597
96- def _get_handoff_destinations (tools : list [BaseTool | Callable ]) -> list [str ]:
98+ def _get_handoff_destinations (tools : Sequence [BaseTool | Callable ]) -> list [str ]:
9799 """Extract handoff destinations from provided tools.
98100 Args:
99101 tools: List of tools to inspect.
@@ -109,11 +111,70 @@ def _get_handoff_destinations(tools: list[BaseTool | Callable]) -> list[str]:
109111 ]
110112
111113
114+ def _prepare_tool_node (
115+ tools : list [BaseTool | Callable ] | ToolNode | None ,
116+ handoff_tool_prefix : Optional [str ],
117+ add_handoff_messages : bool ,
118+ agent_names : set [str ],
119+ ) -> ToolNode :
120+ """Prepare the ToolNode to use in supervisor agent."""
121+ if isinstance (tools , ToolNode ):
122+ input_tool_node = tools
123+ tool_classes = list (tools .tools_by_name .values ())
124+ elif tools :
125+ input_tool_node = ToolNode (tools )
126+ # get the tool functions wrapped in a tool class from the ToolNode
127+ tool_classes = list (input_tool_node .tools_by_name .values ())
128+ else :
129+ input_tool_node = None
130+ tool_classes = []
131+
132+ handoff_destinations = _get_handoff_destinations (tool_classes )
133+ if handoff_destinations :
134+ if missing_handoff_destinations := set (agent_names ) - set (handoff_destinations ):
135+ raise ValueError (
136+ "When providing custom handoff tools, you must provide them for all subagents. "
137+ f"Missing handoff tools for agents '{ missing_handoff_destinations } '."
138+ )
139+
140+ # Handoff tools should be already provided here
141+ tool_node = cast (ToolNode , input_tool_node )
142+ else :
143+ handoff_tools = [
144+ create_handoff_tool (
145+ agent_name = agent_name ,
146+ name = (
147+ None
148+ if handoff_tool_prefix is None
149+ else f"{ handoff_tool_prefix } { _normalize_agent_name (agent_name )} "
150+ ),
151+ add_handoff_messages = add_handoff_messages ,
152+ )
153+ for agent_name in agent_names
154+ ]
155+ all_tools = tool_classes + list (handoff_tools )
156+
157+ # re-wrap the combined tools in a ToolNode
158+ # if the original input was a ToolNode, apply the same params
159+ if input_tool_node is not None :
160+ tool_node = ToolNode (
161+ all_tools ,
162+ name = input_tool_node .name ,
163+ tags = list (input_tool_node .tags ) if input_tool_node .tags else None ,
164+ handle_tool_errors = input_tool_node .handle_tool_errors ,
165+ messages_key = input_tool_node .messages_key ,
166+ )
167+ else :
168+ tool_node = ToolNode (all_tools )
169+
170+ return tool_node
171+
172+
112173def create_supervisor (
113174 agents : list [Pregel ],
114175 * ,
115176 model : LanguageModelLike ,
116- tools : list [BaseTool | Callable ] | None = None ,
177+ tools : list [BaseTool | Callable ] | ToolNode | None = None ,
117178 prompt : Prompt | None = None ,
118179 response_format : Optional [
119180 Union [StructuredResponseSchema , tuple [str , StructuredResponseSchema ]]
@@ -256,45 +317,29 @@ def web_search(query: str) -> str:
256317
257318 agent_names .add (agent .name )
258319
259- handoff_destinations = _get_handoff_destinations ( tools or [])
260- if handoff_destinations :
261- if missing_handoff_destinations := set ( agent_names ) - set ( handoff_destinations ):
262- raise ValueError (
263- "When providing custom handoff tools, you must provide them for all subagents. "
264- f"Missing handoff tools for agents ' { missing_handoff_destinations } '."
265- )
320+ tool_node = _prepare_tool_node (
321+ tools ,
322+ handoff_tool_prefix ,
323+ add_handoff_messages ,
324+ agent_names ,
325+ )
326+ all_tools = list ( tool_node . tools_by_name . values () )
266327
267- # Handoff tools should be already provided here
268- all_tools = tools or []
269- else :
270- handoff_tools = [
271- create_handoff_tool (
272- agent_name = agent .name ,
273- name = (
274- None
275- if handoff_tool_prefix is None
276- else f"{ handoff_tool_prefix } { _normalize_agent_name (agent .name )} "
277- ),
278- add_handoff_messages = add_handoff_messages ,
328+ if _should_bind_tools (model , all_tools ):
329+ if _supports_disable_parallel_tool_calls (model ):
330+ model = cast (BaseChatModel , model ).bind_tools (
331+ all_tools , parallel_tool_calls = parallel_tool_calls
279332 )
280- for agent in agents
281- ]
282- all_tools = (tools or []) + list (handoff_tools )
283-
284- if _supports_disable_parallel_tool_calls (model ):
285- model = cast (BaseChatModel , model ).bind_tools (
286- all_tools , parallel_tool_calls = parallel_tool_calls
287- )
288- else :
289- model = cast (BaseChatModel , model ).bind_tools (all_tools )
333+ else :
334+ model = cast (BaseChatModel , model ).bind_tools (all_tools )
290335
291336 if include_agent_name :
292337 model = with_agent_name (model , include_agent_name )
293338
294339 supervisor_agent = create_react_agent (
295340 name = supervisor_name ,
296341 model = model ,
297- tools = all_tools ,
342+ tools = tool_node ,
298343 prompt = prompt ,
299344 state_schema = state_schema ,
300345 response_format = response_format ,
0 commit comments