Skip to content

Commit 5ce1280

Browse files
authored
Fix for OCI Generative AI tool_call_id not found because AIMessage.tool_calls are ignored (#43)
* Fix OCI tool calling with GenericProvider * Update unit test to verify tool call
1 parent 74e8754 commit 5ce1280

File tree

2 files changed

+163
-4
lines changed

2 files changed

+163
-4
lines changed

libs/oci/langchain_oci/chat_models/oci_generative_ai.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -712,9 +712,9 @@ def messages_to_oci_params(
712712
)
713713
else:
714714
oci_message = self.oci_chat_message[role](content=tool_content)
715-
elif isinstance(message, AIMessage) and message.additional_kwargs.get(
716-
"tool_calls"
717-
):
715+
elif isinstance(message, AIMessage) and (
716+
message.tool_calls or
717+
message.additional_kwargs.get("tool_calls")):
718718
# Process content and tool calls for assistant messages
719719
content = self._process_message_content(message.content)
720720
tool_calls = []

libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py

Lines changed: 160 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
from unittest.mock import MagicMock
77

88
import pytest
9-
from langchain_core.messages import HumanMessage
109
from pytest import MonkeyPatch
1110

11+
from langchain_core.messages import HumanMessage, AIMessage
1212
from langchain_oci.chat_models.oci_generative_ai import ChatOCIGenAI
1313

1414

@@ -575,6 +575,165 @@ def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
575575
assert response["parsed"].conditions == "Sunny"
576576

577577

578+
@pytest.mark.requires("oci")
579+
def test_ai_message_tool_calls_direct_field(monkeypatch: MonkeyPatch) -> None:
580+
"""Test AIMessage with tool_calls in the direct tool_calls field."""
581+
582+
oci_gen_ai_client = MagicMock()
583+
llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client)
584+
585+
# Track if the tool_calls processing branch is executed
586+
tool_calls_processed = False
587+
588+
def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
589+
nonlocal tool_calls_processed
590+
# Check if the request contains tool_calls in the message
591+
request = args[0]
592+
if hasattr(request, 'chat_request') and hasattr(request.chat_request, 'messages'):
593+
for msg in request.chat_request.messages:
594+
if hasattr(msg, 'tool_calls') and msg.tool_calls:
595+
tool_calls_processed = True
596+
break
597+
return MockResponseDict(
598+
{
599+
"status": 200,
600+
"data": MockResponseDict(
601+
{
602+
"chat_response": MockResponseDict(
603+
{
604+
"api_format": "GENERIC",
605+
"choices": [
606+
MockResponseDict(
607+
{
608+
"message": MockResponseDict(
609+
{
610+
"role": "ASSISTANT",
611+
"name": None,
612+
"content": [
613+
MockResponseDict(
614+
{
615+
"text": (
616+
"I'll help you."
617+
),
618+
"type": "TEXT",
619+
}
620+
)
621+
],
622+
"tool_calls": [],
623+
}
624+
),
625+
"finish_reason": "completed",
626+
}
627+
)
628+
],
629+
"time_created": "2025-08-14T10:00:01.100000+00:00",
630+
}
631+
),
632+
"model_id": "meta.llama-3.3-70b-instruct",
633+
"model_version": "1.0.0",
634+
}
635+
),
636+
"request_id": "1234567890",
637+
"headers": MockResponseDict({"content-length": "123"}),
638+
}
639+
)
640+
641+
monkeypatch.setattr(llm.client, "chat", mocked_response)
642+
643+
# Create AIMessage with tool_calls in the direct tool_calls field
644+
ai_message = AIMessage(
645+
content="I need to call a function",
646+
tool_calls=[
647+
{
648+
"id": "call_123",
649+
"name": "get_weather",
650+
"args": {"location": "San Francisco"},
651+
}
652+
]
653+
)
654+
655+
messages = [ai_message]
656+
657+
# This should not raise an error and should process the tool_calls correctly
658+
response = llm.invoke(messages)
659+
assert response.content == "I'll help you."
660+
661+
662+
@pytest.mark.requires("oci")
663+
def test_ai_message_tool_calls_additional_kwargs(monkeypatch: MonkeyPatch) -> None:
664+
"""Test AIMessage with tool_calls in additional_kwargs field."""
665+
666+
oci_gen_ai_client = MagicMock()
667+
llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client)
668+
669+
def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
670+
return MockResponseDict(
671+
{
672+
"status": 200,
673+
"data": MockResponseDict(
674+
{
675+
"chat_response": MockResponseDict(
676+
{
677+
"api_format": "GENERIC",
678+
"choices": [
679+
MockResponseDict(
680+
{
681+
"message": MockResponseDict(
682+
{
683+
"role": "ASSISTANT",
684+
"name": None,
685+
"content": [
686+
MockResponseDict(
687+
{
688+
"text": (
689+
"I'll help you."
690+
),
691+
"type": "TEXT",
692+
}
693+
)
694+
],
695+
"tool_calls": [],
696+
}
697+
),
698+
"finish_reason": "completed",
699+
}
700+
)
701+
],
702+
"time_created": "2025-08-14T10:00:01.100000+00:00",
703+
}
704+
),
705+
"model_id": "meta.llama-3.3-70b-instruct",
706+
"model_version": "1.0.0",
707+
}
708+
),
709+
"request_id": "1234567890",
710+
"headers": MockResponseDict({"content-length": "123"}),
711+
}
712+
)
713+
714+
monkeypatch.setattr(llm.client, "chat", mocked_response)
715+
716+
# Create AIMessage with tool_calls in additional_kwargs
717+
ai_message = AIMessage(
718+
content="I need to call a function",
719+
additional_kwargs={
720+
"tool_calls": [
721+
{
722+
"id": "call_456",
723+
"name": "get_weather",
724+
"args": {"location": "New York"},
725+
}
726+
]
727+
}
728+
)
729+
730+
messages = [ai_message]
731+
732+
# This should not raise an error and should process the tool_calls correctly
733+
response = llm.invoke(messages)
734+
assert response.content == "I'll help you."
735+
736+
578737
def test_get_provider():
579738
"""Test determining the provider based on the model_id."""
580739
model_provider_map = {

0 commit comments

Comments
 (0)