Skip to content

Commit 8b1e254

Browse files
vbardamdrxy
andauthored
fix(core): add on_tool_error to _AstreamEventsCallbackHandler (#30709)
Fixes #30708 --------- Co-authored-by: Mason Daugherty <[email protected]>
1 parent ced9fc2 commit 8b1e254

File tree

2 files changed

+59
-8
lines changed

2 files changed

+59
-8
lines changed

libs/core/langchain_core/runnables/schema.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ class EventData(TypedDict, total=False):
2323
won't be known until the *END* of the Runnable when it has finished streaming
2424
its inputs.
2525
"""
26+
error: NotRequired[BaseException]
27+
"""The error that occurred during the execution of the Runnable.
28+
29+
This field is only available if the Runnable raised an exception.
30+
31+
.. versionadded:: 1.0.0
32+
"""
2633
output: Any
2734
"""The output of the Runnable that generated the event.
2835

libs/core/langchain_core/tracers/event_stream.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,28 @@ async def on_chain_end(
610610
run_type,
611611
)
612612

613+
def _get_tool_run_info_with_inputs(self, run_id: UUID) -> tuple[RunInfo, Any]:
614+
"""Get run info for a tool and extract inputs, with validation.
615+
616+
Args:
617+
run_id: The run ID of the tool.
618+
619+
Returns:
620+
A tuple of (run_info, inputs).
621+
622+
Raises:
623+
AssertionError: If the run ID is a tool call and does not have inputs.
624+
"""
625+
run_info = self.run_map.pop(run_id)
626+
if "inputs" not in run_info:
627+
msg = (
628+
f"Run ID {run_id} is a tool call and is expected to have "
629+
f"inputs associated with it."
630+
)
631+
raise AssertionError(msg)
632+
inputs = run_info["inputs"]
633+
return run_info, inputs
634+
613635
@override
614636
async def on_tool_start(
615637
self,
@@ -652,21 +674,43 @@ async def on_tool_start(
652674
"tool",
653675
)
654676

677+
@override
678+
async def on_tool_error(
679+
self,
680+
error: BaseException,
681+
*,
682+
run_id: UUID,
683+
parent_run_id: Optional[UUID] = None,
684+
tags: Optional[list[str]] = None,
685+
**kwargs: Any,
686+
) -> None:
687+
"""Run when tool errors."""
688+
run_info, inputs = self._get_tool_run_info_with_inputs(run_id)
689+
690+
self._send(
691+
{
692+
"event": "on_tool_error",
693+
"data": {
694+
"error": error,
695+
"input": inputs,
696+
},
697+
"run_id": str(run_id),
698+
"name": run_info["name"],
699+
"tags": run_info["tags"],
700+
"metadata": run_info["metadata"],
701+
"parent_ids": self._get_parent_ids(run_id),
702+
},
703+
"tool",
704+
)
705+
655706
@override
656707
async def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
657708
"""End a trace for a tool run.
658709
659710
Raises:
660711
AssertionError: If the run ID is a tool call and does not have inputs
661712
"""
662-
run_info = self.run_map.pop(run_id)
663-
if "inputs" not in run_info:
664-
msg = (
665-
f"Run ID {run_id} is a tool call and is expected to have "
666-
f"inputs associated with it."
667-
)
668-
raise AssertionError(msg)
669-
inputs = run_info["inputs"]
713+
run_info, inputs = self._get_tool_run_info_with_inputs(run_id)
670714

671715
self._send(
672716
{

0 commit comments

Comments
 (0)