|
6 | 6 | from unittest.mock import MagicMock |
7 | 7 |
|
8 | 8 | import pytest |
9 | | -from langchain_core.messages import HumanMessage |
10 | 9 | from pytest import MonkeyPatch |
11 | 10 |
|
| 11 | +from langchain_core.messages import HumanMessage, AIMessage |
12 | 12 | from langchain_oci.chat_models.oci_generative_ai import ChatOCIGenAI |
13 | 13 |
|
14 | 14 |
|
@@ -575,6 +575,165 @@ def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def] |
575 | 575 | assert response["parsed"].conditions == "Sunny" |
576 | 576 |
|
577 | 577 |
|
| 578 | +@pytest.mark.requires("oci") |
| 579 | +def test_ai_message_tool_calls_direct_field(monkeypatch: MonkeyPatch) -> None: |
| 580 | + """Test AIMessage with tool_calls in the direct tool_calls field.""" |
| 581 | + |
| 582 | + oci_gen_ai_client = MagicMock() |
| 583 | + llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client) |
| 584 | + |
| 585 | + # Track if the tool_calls processing branch is executed |
| 586 | + tool_calls_processed = False |
| 587 | + |
| 588 | + def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def] |
| 589 | + nonlocal tool_calls_processed |
| 590 | + # Check if the request contains tool_calls in the message |
| 591 | + request = args[0] |
| 592 | + if hasattr(request, 'chat_request') and hasattr(request.chat_request, 'messages'): |
| 593 | + for msg in request.chat_request.messages: |
| 594 | + if hasattr(msg, 'tool_calls') and msg.tool_calls: |
| 595 | + tool_calls_processed = True |
| 596 | + break |
| 597 | + return MockResponseDict( |
| 598 | + { |
| 599 | + "status": 200, |
| 600 | + "data": MockResponseDict( |
| 601 | + { |
| 602 | + "chat_response": MockResponseDict( |
| 603 | + { |
| 604 | + "api_format": "GENERIC", |
| 605 | + "choices": [ |
| 606 | + MockResponseDict( |
| 607 | + { |
| 608 | + "message": MockResponseDict( |
| 609 | + { |
| 610 | + "role": "ASSISTANT", |
| 611 | + "name": None, |
| 612 | + "content": [ |
| 613 | + MockResponseDict( |
| 614 | + { |
| 615 | + "text": ( |
| 616 | + "I'll help you." |
| 617 | + ), |
| 618 | + "type": "TEXT", |
| 619 | + } |
| 620 | + ) |
| 621 | + ], |
| 622 | + "tool_calls": [], |
| 623 | + } |
| 624 | + ), |
| 625 | + "finish_reason": "completed", |
| 626 | + } |
| 627 | + ) |
| 628 | + ], |
| 629 | + "time_created": "2025-08-14T10:00:01.100000+00:00", |
| 630 | + } |
| 631 | + ), |
| 632 | + "model_id": "meta.llama-3.3-70b-instruct", |
| 633 | + "model_version": "1.0.0", |
| 634 | + } |
| 635 | + ), |
| 636 | + "request_id": "1234567890", |
| 637 | + "headers": MockResponseDict({"content-length": "123"}), |
| 638 | + } |
| 639 | + ) |
| 640 | + |
| 641 | + monkeypatch.setattr(llm.client, "chat", mocked_response) |
| 642 | + |
| 643 | + # Create AIMessage with tool_calls in the direct tool_calls field |
| 644 | + ai_message = AIMessage( |
| 645 | + content="I need to call a function", |
| 646 | + tool_calls=[ |
| 647 | + { |
| 648 | + "id": "call_123", |
| 649 | + "name": "get_weather", |
| 650 | + "args": {"location": "San Francisco"}, |
| 651 | + } |
| 652 | + ] |
| 653 | + ) |
| 654 | + |
| 655 | + messages = [ai_message] |
| 656 | + |
| 657 | + # This should not raise an error and should process the tool_calls correctly |
| 658 | + response = llm.invoke(messages) |
| 659 | + assert response.content == "I'll help you." |
| 660 | + |
| 661 | + |
| 662 | +@pytest.mark.requires("oci") |
| 663 | +def test_ai_message_tool_calls_additional_kwargs(monkeypatch: MonkeyPatch) -> None: |
| 664 | + """Test AIMessage with tool_calls in additional_kwargs field.""" |
| 665 | + |
| 666 | + oci_gen_ai_client = MagicMock() |
| 667 | + llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client) |
| 668 | + |
| 669 | + def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def] |
| 670 | + return MockResponseDict( |
| 671 | + { |
| 672 | + "status": 200, |
| 673 | + "data": MockResponseDict( |
| 674 | + { |
| 675 | + "chat_response": MockResponseDict( |
| 676 | + { |
| 677 | + "api_format": "GENERIC", |
| 678 | + "choices": [ |
| 679 | + MockResponseDict( |
| 680 | + { |
| 681 | + "message": MockResponseDict( |
| 682 | + { |
| 683 | + "role": "ASSISTANT", |
| 684 | + "name": None, |
| 685 | + "content": [ |
| 686 | + MockResponseDict( |
| 687 | + { |
| 688 | + "text": ( |
| 689 | + "I'll help you." |
| 690 | + ), |
| 691 | + "type": "TEXT", |
| 692 | + } |
| 693 | + ) |
| 694 | + ], |
| 695 | + "tool_calls": [], |
| 696 | + } |
| 697 | + ), |
| 698 | + "finish_reason": "completed", |
| 699 | + } |
| 700 | + ) |
| 701 | + ], |
| 702 | + "time_created": "2025-08-14T10:00:01.100000+00:00", |
| 703 | + } |
| 704 | + ), |
| 705 | + "model_id": "meta.llama-3.3-70b-instruct", |
| 706 | + "model_version": "1.0.0", |
| 707 | + } |
| 708 | + ), |
| 709 | + "request_id": "1234567890", |
| 710 | + "headers": MockResponseDict({"content-length": "123"}), |
| 711 | + } |
| 712 | + ) |
| 713 | + |
| 714 | + monkeypatch.setattr(llm.client, "chat", mocked_response) |
| 715 | + |
| 716 | + # Create AIMessage with tool_calls in additional_kwargs |
| 717 | + ai_message = AIMessage( |
| 718 | + content="I need to call a function", |
| 719 | + additional_kwargs={ |
| 720 | + "tool_calls": [ |
| 721 | + { |
| 722 | + "id": "call_456", |
| 723 | + "name": "get_weather", |
| 724 | + "args": {"location": "New York"}, |
| 725 | + } |
| 726 | + ] |
| 727 | + } |
| 728 | + ) |
| 729 | + |
| 730 | + messages = [ai_message] |
| 731 | + |
| 732 | + # This should not raise an error and should process the tool_calls correctly |
| 733 | + response = llm.invoke(messages) |
| 734 | + assert response.content == "I'll help you." |
| 735 | + |
| 736 | + |
578 | 737 | def test_get_provider(): |
579 | 738 | """Test determining the provider based on the model_id.""" |
580 | 739 | model_provider_map = { |
|
0 commit comments