diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 33c20ed2e68c6..2b86c36f151bc 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -428,9 +428,35 @@ def invoke( config: Optional[RunnableConfig] = None, *, stop: Optional[list[str]] = None, + output_version: Optional[str] = None, **kwargs: Any, ) -> AIMessage: + """Invoke the model. + + Args: + input: The model input. See ``LanguageModelInput`` for valid options. + config: The ``RunnableConfig`` to use for this model run. + stop: Stop word(s) to use during generation. + output_version: Override the model's ``output_version`` for this invocation. + If None, uses the called model's configured ``output_version``. + **kwargs: Additional keyword arguments. + + Returns: + The model's response message. + + """ config = ensure_config(config) + + effective_output_version = ( + output_version if output_version is not None else self.output_version + ) + kwargs["_output_version"] = effective_output_version or "v0" + + # Whether the user explicitly set an output_version for either model or call + kwargs["_output_version_explicit"] = ( + output_version is not None or self.output_version is not None + ) + return cast( "AIMessage", cast( @@ -455,9 +481,35 @@ async def ainvoke( config: Optional[RunnableConfig] = None, *, stop: Optional[list[str]] = None, + output_version: Optional[str] = None, **kwargs: Any, ) -> AIMessage: + """Asynchronously invoke the model. + + Args: + input: The model input. See ``LanguageModelInput`` for valid options. + config: The ``RunnableConfig`` to use for this model run. + stop: Stop word(s) to use during generation. + output_version: Override the model's ``output_version`` for this invocation. + If None, uses the called model's configured ``output_version``. + **kwargs: Additional keyword arguments. + + Returns: + The model's response message. + + """ config = ensure_config(config) + + effective_output_version = ( + output_version if output_version is not None else self.output_version + ) + kwargs["_output_version"] = effective_output_version or "v0" + + # Whether the user explicitly set an output_version for either model or call + kwargs["_output_version_explicit"] = ( + output_version is not None or self.output_version is not None + ) + llm_result = await self.agenerate_prompt( [self._convert_input(input)], stop=stop, @@ -514,13 +566,44 @@ def stream( config: Optional[RunnableConfig] = None, *, stop: Optional[list[str]] = None, + output_version: Optional[str] = None, **kwargs: Any, ) -> Iterator[AIMessageChunk]: + """Stream responses from the chat model. + + Args: + input: The model input. See ``LanguageModelInput`` for valid options. + config: The ``RunnableConfig`` to use for this model run. + stop: Stop word(s) to use during generation. + output_version: Override the model's ``output_version`` for this invocation. + If None, uses the called model's configured ``output_version``. + **kwargs: Additional keyword arguments. + + Returns: + Iterator of message chunks. + + """ + effective_output_version = ( + output_version if output_version is not None else self.output_version + ) + kwargs["_output_version"] = effective_output_version or "v0" + + # Whether the user explicitly set an output_version for either model or call + kwargs["_output_version_explicit"] = ( + output_version is not None or self.output_version is not None + ) + if not self._should_stream(async_api=False, **{**kwargs, "stream": True}): # model doesn't implement streaming, so use default implementation yield cast( "AIMessageChunk", - self.invoke(input, config=config, stop=stop, **kwargs), + self.invoke( + input, + config=config, + stop=stop, + output_version=effective_output_version, + **kwargs, + ), ) else: config = ensure_config(config) @@ -566,11 +649,27 @@ def stream( input_messages = _normalize_messages(messages) run_id = "-".join((LC_ID_PREFIX, str(run_manager.run_id))) yielded = False - for chunk in self._stream(input_messages, stop=stop, **kwargs): + + filtered_kwargs = { + k: v + for k, v in kwargs.items() + if k not in ("_output_version", "_output_version_explicit") + } + for chunk in self._stream( + input_messages, + stop=stop, + output_version=kwargs["_output_version"], + **filtered_kwargs, + ): if chunk.message.id is None: chunk.message.id = run_id - chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) - if self.output_version == "v1": + response_metadata = _gen_info_and_msg_metadata(chunk) + output_version = kwargs["_output_version"] + # Add output_version to response_metadata only if was explicitly set + if kwargs.get("_output_version_explicit", False): + response_metadata["output_version"] = output_version + chunk.message.response_metadata = response_metadata + if output_version == "v1": # Overwrite .content with .content_blocks chunk.message = _update_message_content_to_blocks( chunk.message, "v1" @@ -630,13 +729,44 @@ async def astream( config: Optional[RunnableConfig] = None, *, stop: Optional[list[str]] = None, + output_version: Optional[str] = None, **kwargs: Any, ) -> AsyncIterator[AIMessageChunk]: + """Asynchronously stream responses from the model. + + Args: + input: The model input. See ``LanguageModelInput`` for valid options. + config: The ``RunnableConfig`` to use for this model run. + stop: Stop word(s) to use during generation. + output_version: Override the model's ``output_version`` for this invocation. + If None, uses the called model's configured ``output_version``. + **kwargs: Additional keyword arguments. + + Returns: + Async Iterator of message chunks. + + """ + effective_output_version = ( + output_version if output_version is not None else self.output_version + ) + kwargs["_output_version"] = effective_output_version or "v0" + + # Whether the user explicitly set an output_version for either model or call + kwargs["_output_version_explicit"] = ( + output_version is not None or self.output_version is not None + ) + if not self._should_stream(async_api=True, **{**kwargs, "stream": True}): # No async or sync stream is implemented, so fall back to ainvoke yield cast( "AIMessageChunk", - await self.ainvoke(input, config=config, stop=stop, **kwargs), + await self.ainvoke( + input, + config=config, + stop=stop, + output_version=effective_output_version, + **kwargs, + ), ) return @@ -684,15 +814,27 @@ async def astream( input_messages = _normalize_messages(messages) run_id = "-".join((LC_ID_PREFIX, str(run_manager.run_id))) yielded = False + + filtered_kwargs = { + k: v + for k, v in kwargs.items() + if k not in ("_output_version", "_output_version_explicit") + } async for chunk in self._astream( input_messages, stop=stop, - **kwargs, + output_version=kwargs["_output_version"], + **filtered_kwargs, ): if chunk.message.id is None: chunk.message.id = run_id - chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) - if self.output_version == "v1": + response_metadata = _gen_info_and_msg_metadata(chunk) + output_version = kwargs["_output_version"] + # Add output_version to response_metadata only if was explicitly set + if kwargs.get("_output_version_explicit", False): + response_metadata["output_version"] = output_version + chunk.message.response_metadata = response_metadata + if output_version == "v1": # Overwrite .content with .content_blocks chunk.message = _update_message_content_to_blocks( chunk.message, "v1" @@ -724,7 +866,10 @@ async def astream( generations_with_error_metadata = _generate_response_from_error(e) chat_generation_chunk = merge_chat_generation_chunks(chunks) if chat_generation_chunk: - generations = [[chat_generation_chunk], generations_with_error_metadata] + generations = [ + [chat_generation_chunk], + generations_with_error_metadata, + ] else: generations = [generations_with_error_metadata] await run_manager.on_llm_error( @@ -1164,6 +1309,9 @@ def _generate_with_cache( if self.rate_limiter: self.rate_limiter.acquire(blocking=True) + output_version = kwargs.pop("_output_version", self.output_version) + output_version_explicit = kwargs.pop("_output_version_explicit", False) + # If stream is not explicitly set, check if implicitly requested by # astream_events() or astream_log(). Bail out if _stream not implemented if self._should_stream( @@ -1176,16 +1324,28 @@ def _generate_with_cache( f"{LC_ID_PREFIX}-{run_manager.run_id}" if run_manager else None ) yielded = False - for chunk in self._stream(messages, stop=stop, **kwargs): - chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) - if self.output_version == "v1": - # Overwrite .content with .content_blocks - chunk.message = _update_message_content_to_blocks( - chunk.message, "v1" - ) + + filtered_kwargs = { + k: v + for k, v in kwargs.items() + if k not in ("_output_version", "_output_version_explicit") + } + for chunk in self._stream( + messages, stop=stop, output_version=output_version, **filtered_kwargs + ): + response_metadata = _gen_info_and_msg_metadata(chunk) + # Add output_version to response_metadata only if it was explicitly set + if output_version_explicit: + response_metadata["output_version"] = output_version + chunk.message.response_metadata = response_metadata if run_manager: if chunk.message.id is None: chunk.message.id = run_id + if output_version == "v1": + # Overwrite .content with .content_blocks + chunk.message = _update_message_content_to_blocks( + chunk.message, "v1" + ) run_manager.on_llm_new_token( cast("str", chunk.message.content), chunk=chunk ) @@ -1210,12 +1370,32 @@ def _generate_with_cache( run_manager.on_llm_new_token("", chunk=chunk) chunks.append(chunk) result = generate_from_stream(iter(chunks)) - elif inspect.signature(self._generate).parameters.get("run_manager"): - result = self._generate( - messages, stop=stop, run_manager=run_manager, **kwargs - ) else: - result = self._generate(messages, stop=stop, **kwargs) + filtered_kwargs = { + k: v + for k, v in kwargs.items() + if k not in ("_output_version", "_output_version_explicit") + } + if inspect.signature(self._generate).parameters.get("run_manager"): + result = self._generate( + messages, + stop=stop, + run_manager=run_manager, + **filtered_kwargs, + ) + else: + result = self._generate( + messages, + stop=stop, + **filtered_kwargs, + ) + + if output_version == "v1": + # Overwrite .content with .content_blocks + for generation in result.generations: + generation.message = _update_message_content_to_blocks( + generation.message, "v1" + ) if self.output_version == "v1": # Overwrite .content with .content_blocks @@ -1228,9 +1408,11 @@ def _generate_with_cache( for idx, generation in enumerate(result.generations): if run_manager and generation.message.id is None: generation.message.id = f"{LC_ID_PREFIX}-{run_manager.run_id}-{idx}" - generation.message.response_metadata = _gen_info_and_msg_metadata( - generation - ) + response_metadata = _gen_info_and_msg_metadata(generation) + # Add output_version to response_metadata only if it was explicitly set + if output_version_explicit: + response_metadata["output_version"] = output_version + generation.message.response_metadata = response_metadata if len(result.generations) == 1 and result.llm_output is not None: result.generations[0].message.response_metadata = { **result.llm_output, @@ -1272,6 +1454,9 @@ async def _agenerate_with_cache( if self.rate_limiter: await self.rate_limiter.aacquire(blocking=True) + output_version = kwargs.pop("_output_version", self.output_version) + output_version_explicit = kwargs.pop("_output_version_explicit", False) + # If stream is not explicitly set, check if implicitly requested by # astream_events() or astream_log(). Bail out if _astream not implemented if self._should_stream( @@ -1284,16 +1469,28 @@ async def _agenerate_with_cache( f"{LC_ID_PREFIX}-{run_manager.run_id}" if run_manager else None ) yielded = False - async for chunk in self._astream(messages, stop=stop, **kwargs): - chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) - if self.output_version == "v1": - # Overwrite .content with .content_blocks - chunk.message = _update_message_content_to_blocks( - chunk.message, "v1" - ) + + filtered_kwargs = { + k: v + for k, v in kwargs.items() + if k not in ("_output_version", "_output_version_explicit") + } + async for chunk in self._astream( + messages, stop=stop, output_version=output_version, **filtered_kwargs + ): + response_metadata = _gen_info_and_msg_metadata(chunk) + # Add output_version to response_metadata only if it was explicitly set + if output_version_explicit: + response_metadata["output_version"] = output_version + chunk.message.response_metadata = response_metadata if run_manager: if chunk.message.id is None: chunk.message.id = run_id + if output_version == "v1": + # Overwrite .content with .content_blocks + chunk.message = _update_message_content_to_blocks( + chunk.message, "v1" + ) await run_manager.on_llm_new_token( cast("str", chunk.message.content), chunk=chunk ) @@ -1319,11 +1516,32 @@ async def _agenerate_with_cache( chunks.append(chunk) result = generate_from_stream(iter(chunks)) elif inspect.signature(self._agenerate).parameters.get("run_manager"): + filtered_kwargs = { + k: v + for k, v in kwargs.items() + if k not in ("_output_version", "_output_version_explicit") + } result = await self._agenerate( - messages, stop=stop, run_manager=run_manager, **kwargs + messages, + stop=stop, + run_manager=run_manager, + **filtered_kwargs, ) else: - result = await self._agenerate(messages, stop=stop, **kwargs) + # Filter out internal parameters before passing to implementation + filtered_kwargs = { + k: v + for k, v in kwargs.items() + if k not in ("_output_version", "_output_version_explicit") + } + result = await self._agenerate(messages, stop=stop, **filtered_kwargs) + + if output_version == "v1": + # Overwrite .content with .content_blocks + for generation in result.generations: + generation.message = _update_message_content_to_blocks( + generation.message, "v1" + ) if self.output_version == "v1": # Overwrite .content with .content_blocks @@ -1336,9 +1554,11 @@ async def _agenerate_with_cache( for idx, generation in enumerate(result.generations): if run_manager and generation.message.id is None: generation.message.id = f"{LC_ID_PREFIX}-{run_manager.run_id}-{idx}" - generation.message.response_metadata = _gen_info_and_msg_metadata( - generation - ) + response_metadata = _gen_info_and_msg_metadata(generation) + # Add output_version to response_metadata only if it was explicitly set + if output_version_explicit: + response_metadata["output_version"] = output_version + generation.message.response_metadata = response_metadata if len(result.generations) == 1 and result.llm_output is not None: result.generations[0].message.response_metadata = { **result.llm_output, @@ -1354,6 +1574,8 @@ def _generate( messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + output_version: str = "v0", **kwargs: Any, ) -> ChatResult: """Generate the result. @@ -1367,12 +1589,15 @@ def _generate( Returns: The chat result. """ + # Concrete implementations should override this method and use the same params async def _agenerate( self, messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + *, + output_version: str = "v0", **kwargs: Any, ) -> ChatResult: """Generate the result. @@ -1392,6 +1617,7 @@ async def _agenerate( messages, stop, run_manager.get_sync() if run_manager else None, + output_version=output_version, **kwargs, ) @@ -1400,6 +1626,8 @@ def _stream( messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + output_version: str = "v0", **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: """Stream the output of the model. @@ -1420,6 +1648,8 @@ async def _astream( messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + *, + output_version: str = "v0", **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: """Stream the output of the model. @@ -1439,6 +1669,7 @@ async def _astream( messages, stop, run_manager.get_sync() if run_manager else None, + output_version=output_version, **kwargs, ) done = object() @@ -1808,6 +2039,9 @@ def _generate( messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + # For backward compatibility + output_version: str = "v0", # noqa: ARG002 **kwargs: Any, ) -> ChatResult: output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs) @@ -1830,6 +2064,8 @@ async def _agenerate( messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + *, + output_version: str = "v0", **kwargs: Any, ) -> ChatResult: return await run_in_executor( @@ -1838,6 +2074,7 @@ async def _agenerate( messages, stop=stop, run_manager=run_manager.get_sync() if run_manager else None, + output_version=output_version, **kwargs, ) diff --git a/libs/core/langchain_core/language_models/fake_chat_models.py b/libs/core/langchain_core/language_models/fake_chat_models.py index 89aa8fb3b5433..1bdbb3e4ae32e 100644 --- a/libs/core/langchain_core/language_models/fake_chat_models.py +++ b/libs/core/langchain_core/language_models/fake_chat_models.py @@ -12,6 +12,9 @@ AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) +from langchain_core.language_models._utils import ( + _update_message_content_to_blocks, +) from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult @@ -248,10 +251,32 @@ def _generate( messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + output_version: str = "v0", **kwargs: Any, ) -> ChatResult: message = next(self.messages) message_ = AIMessage(content=message) if isinstance(message, str) else message + + if output_version == "v1": + message_ = _update_message_content_to_blocks(message_, "v1") + + # Only set in response metadata if output_version is explicitly provided + # (If output_version is "v0" and self.output_version is None, it's the default) + output_version_explicit = not ( + output_version == "v0" and getattr(self, "output_version", None) is None + ) + if output_version_explicit: + if hasattr(message_, "response_metadata"): + message_.response_metadata = {"output_version": output_version} + else: + message_ = AIMessage( + content=message_.content, + additional_kwargs=message_.additional_kwargs, + response_metadata={"output_version": output_version}, + id=message_.id, + ) + generation = ChatGeneration(message=message_) return ChatResult(generations=[generation]) @@ -260,10 +285,16 @@ def _stream( messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + output_version: str = "v0", **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: chat_result = self._generate( - messages, stop=stop, run_manager=run_manager, **kwargs + messages, + stop=stop, + run_manager=run_manager, + output_version="v0", # Always call with v0 to get original string content + **kwargs, ) if not isinstance(chat_result, ChatResult): msg = ( @@ -302,6 +333,21 @@ def _stream( and not message.additional_kwargs ): chunk.message.chunk_position = "last" + + if output_version == "v1": + chunk.message = _update_message_content_to_blocks( + chunk.message, "v1" + ) + + output_version_explicit = not ( + output_version == "v0" + and getattr(self, "output_version", None) is None + ) + if output_version_explicit: + chunk.message.response_metadata = {"output_version": output_version} + else: + chunk.message.response_metadata = {} + if run_manager: run_manager.on_llm_new_token(token, chunk=chunk) yield chunk @@ -321,7 +367,7 @@ def _stream( id=message.id, content="", additional_kwargs={ - "function_call": {fkey: fvalue_chunk} + "function_call": {fkey: fvalue_chunk}, }, ) ) @@ -336,7 +382,9 @@ def _stream( message=AIMessageChunk( id=message.id, content="", - additional_kwargs={"function_call": {fkey: fvalue}}, + additional_kwargs={ + "function_call": {fkey: fvalue}, + }, ) ) if run_manager: @@ -348,7 +396,9 @@ def _stream( else: chunk = ChatGenerationChunk( message=AIMessageChunk( - id=message.id, content="", additional_kwargs={key: value} + id=message.id, + content="", + additional_kwargs={key: value}, ) ) if run_manager: @@ -358,6 +408,14 @@ def _stream( ) yield chunk + # Add a final chunk with chunk_position="last" after all additional_kwargs + final_chunk = ChatGenerationChunk( + message=AIMessageChunk(id=message.id, content="", chunk_position="last") + ) + if run_manager: + run_manager.on_llm_new_token("", chunk=final_chunk) + yield final_chunk + @property def _llm_type(self) -> str: return "generic-fake-chat-model" diff --git a/libs/core/langchain_core/utils/_merge.py b/libs/core/langchain_core/utils/_merge.py index 7b8465e8d0256..93b99d24b0992 100644 --- a/libs/core/langchain_core/utils/_merge.py +++ b/libs/core/langchain_core/utils/_merge.py @@ -44,6 +44,16 @@ def merge_dicts(left: dict[str, Any], *others: dict[str, Any]) -> dict[str, Any] ) raise TypeError(msg) elif isinstance(merged[right_k], str): + if right_k == "output_version": + if merged[right_k] == right_v: + continue + msg = ( + "Unable to merge. Two different values seen for " + f"'output_version': {merged[right_k]} and {right_v}. " + "'output_version' should have the same value across " + "all chunks in a generation." + ) + raise ValueError(msg) # TODO: Add below special handling for 'type' key in 0.3 and remove # merge_lists 'type' logic. # @@ -58,8 +68,7 @@ def merge_dicts(left: dict[str, Any], *others: dict[str, Any]) -> dict[str, Any] # "all dicts." # ) if (right_k == "index" and merged[right_k].startswith("lc_")) or ( - right_k in ("id", "output_version", "model_provider") - and merged[right_k] == right_v + right_k in ("id", "model_provider") and merged[right_k] == right_v ): continue merged[right_k] += right_v diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index 8ef536f726011..ac300d5f5a9b9 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -7,7 +7,10 @@ import pytest from typing_extensions import override -from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain_core.language_models import ( BaseChatModel, FakeListChatModel, @@ -185,6 +188,8 @@ def _generate( messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + output_version: str = "v0", **kwargs: Any, ) -> ChatResult: """Top Level call.""" @@ -218,6 +223,8 @@ def _generate( messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + output_version: str = "v0", **kwargs: Any, ) -> ChatResult: """Top Level call.""" @@ -229,6 +236,8 @@ def _stream( messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + output_version: str = "v0", **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: """Stream the output of the model.""" @@ -244,19 +253,21 @@ def _llm_type(self) -> str: model = ModelWithSyncStream() chunks = list(model.stream("anything")) assert chunks == [ + _any_id_ai_message_chunk(content="a"), _any_id_ai_message_chunk( - content="a", + content="b", + chunk_position="last", ), - _any_id_ai_message_chunk(content="b", chunk_position="last"), ] assert len({chunk.id for chunk in chunks}) == 1 assert type(model)._astream == BaseChatModel._astream astream_chunks = [chunk async for chunk in model.astream("anything")] assert astream_chunks == [ + _any_id_ai_message_chunk(content="a"), _any_id_ai_message_chunk( - content="a", + content="b", + chunk_position="last", ), - _any_id_ai_message_chunk(content="b", chunk_position="last"), ] assert len({chunk.id for chunk in astream_chunks}) == 1 @@ -270,6 +281,8 @@ def _generate( messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + output_version: str = "v0", **kwargs: Any, ) -> ChatResult: """Top Level call.""" @@ -280,7 +293,9 @@ async def _astream( self, messages: list[BaseMessage], stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, # type: ignore[override] + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + *, + output_version: Optional[str] = "v0", **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: """Stream the output of the model.""" @@ -296,10 +311,11 @@ def _llm_type(self) -> str: model = ModelWithAsyncStream() chunks = [chunk async for chunk in model.astream("anything")] assert chunks == [ + _any_id_ai_message_chunk(content="a"), _any_id_ai_message_chunk( - content="a", + content="b", + chunk_position="last", ), - _any_id_ai_message_chunk(content="b", chunk_position="last"), ] assert len({chunk.id for chunk in chunks}) == 1 @@ -351,6 +367,8 @@ def _generate( messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + output_version: str = "v0", **kwargs: Any, ) -> ChatResult: return ChatResult(generations=[ChatGeneration(message=AIMessage("invoke"))]) @@ -367,6 +385,8 @@ def _stream( messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + output_version: str = "v0", **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: yield ChatGenerationChunk(message=AIMessageChunk(content="stream")) diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_output_version.py b/libs/core/tests/unit_tests/language_models/chat_models/test_output_version.py new file mode 100644 index 0000000000000..6d234600e7202 --- /dev/null +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_output_version.py @@ -0,0 +1,366 @@ +"""Test output_version functionality in BaseChatModel.""" + +from collections.abc import AsyncIterator, Iterator +from typing import Any, Optional +from unittest.mock import patch + +import pytest +from pydantic import ConfigDict +from typing_extensions import override + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import BaseChatModel +from langchain_core.language_models.fake_chat_models import GenericFakeChatModel +from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, HumanMessage +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult + + +class OutputVersionTrackingChatModel(GenericFakeChatModel): + """Chat model that tracks output_version parameter for testing.""" + + model_config = ConfigDict(extra="allow") + last_output_version: Optional[str] = None + + @override + def _generate( + self, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + output_version: str = "v0", + **kwargs: Any, + ) -> ChatResult: + """Store the output_version that was passed.""" + self.last_output_version = output_version + message = AIMessage(content="test response") + generation = ChatGeneration(message=message) + return ChatResult(generations=[generation]) + + @override + def _stream( + self, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + output_version: str = "v0", + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + """Store the output_version that was passed.""" + self.last_output_version = output_version + yield ChatGenerationChunk(message=AIMessageChunk(content="test")) + yield ChatGenerationChunk(message=AIMessageChunk(content=" stream")) + + @override + async def _astream( + self, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, # type: ignore[override] + *, + output_version: str = "v0", + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + """Store the output_version that was passed.""" + self.last_output_version = output_version + yield ChatGenerationChunk(message=AIMessageChunk(content="async")) + yield ChatGenerationChunk(message=AIMessageChunk(content=" stream")) + + +@pytest.fixture +def messages() -> list[BaseMessage]: + return [HumanMessage("Hello")] + + +class TestOutputVersionPassing: + """Test that output_version parameter is correctly passed to model methods.""" + + @pytest.mark.parametrize( + ("method_name", "default_version", "provided_version", "expected_version"), + [ + # Test invoke - output_version is no longer passed to _generate methods + ("invoke", "v1", None, "v0"), # Always defaults to v0 in _generate + ("invoke", "v0", "v1", "v0"), # Always defaults to v0 in _generate + # Test stream - output_version is still passed to _stream methods + ("stream", "v1", None, "v1"), # Uses default when not provided + ("stream", "v1", "v2", "v2"), # Uses provided version + ], + ) + def test_sync_methods_output_version( + self, + messages: list[BaseMessage], + method_name: str, + default_version: str, + provided_version: Optional[str], + expected_version: str, + ) -> None: + """Test sync methods handle output_version correctly.""" + model = OutputVersionTrackingChatModel( + messages=iter(["test response"]), output_version=default_version + ) + method = getattr(model, method_name) + + if provided_version is not None: + if method_name == "stream": + list(method(messages, output_version=provided_version)) + else: + method(messages, output_version=provided_version) + elif method_name == "stream": + list(method(messages)) + else: + method(messages) + + assert model.last_output_version == expected_version + + @pytest.mark.parametrize( + ("method_name", "default_version", "provided_version", "expected_version"), + [ + # Test ainvoke - output_version is no longer passed to _generate methods + ("ainvoke", "v1", None, "v0"), # Always defaults to v0 in _generate + ("ainvoke", "v0", "v1", "v0"), # Always defaults to v0 in _generate + # Test astream - output_version is still passed to _stream methods + ("astream", "v1", None, "v1"), # Uses default when not provided + ("astream", "v1", "v0", "v0"), # Uses provided version + ], + ) + async def test_async_methods_output_version( + self, + messages: list[BaseMessage], + method_name: str, + default_version: str, + provided_version: Optional[str], + expected_version: str, + ) -> None: + """Test async methods handle output_version correctly.""" + model = OutputVersionTrackingChatModel( + messages=iter(["test response"]), output_version=default_version + ) + method = getattr(model, method_name) + + if provided_version is not None: + if method_name == "astream": + async for _ in method(messages, output_version=provided_version): + pass + else: + await method(messages, output_version=provided_version) + elif method_name == "astream": + async for _ in method(messages): + pass + else: + await method(messages) + + assert model.last_output_version == expected_version + + +class TestStreamFallback: + """Test stream fallback behavior with output_version.""" + + def test_stream_fallback_to_invoke_passes_output_version( + self, + messages: list[BaseMessage], + ) -> None: + """Test `stream()` fallback passes `output_version` correctly.""" + + class NoStreamModel(BaseChatModel): + model_config = ConfigDict(extra="allow") + last_output_version: Optional[str] = None + + @override + def _generate( + self, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + output_version: str = "v0", + **kwargs: Any, + ) -> ChatResult: + self.last_output_version = output_version + message = AIMessage(content="test response") + generation = ChatGeneration(message=message) + return ChatResult(generations=[generation]) + + @property + def _llm_type(self) -> str: + return "no-stream-model" + + model = NoStreamModel(output_version="v1") + # Stream should fallback to invoke but output_version is no longer + # passed to _generate + list(model.stream(messages, output_version="v2")) + assert model.last_output_version == "v0" # _generate always gets v0 default + + async def test_astream_fallback_to_ainvoke_passes_output_version( + self, + messages: list[BaseMessage], + ) -> None: + """Test `astream()` fallback passes `output_version` correctly.""" + + class NoStreamModel(BaseChatModel): + model_config = ConfigDict(extra="allow") + last_output_version: Optional[str] = None + + @override + def _generate( + self, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + output_version: str = "v0", + **kwargs: Any, + ) -> ChatResult: + self.last_output_version = output_version + message = AIMessage(content="test response") + generation = ChatGeneration(message=message) + return ChatResult(generations=[generation]) + + @property + def _llm_type(self) -> str: + return "no-stream-model" + + model = NoStreamModel(output_version="v1") + # astream should fallback to ainvoke but output_version is no longer + # passed to _generate + async for _ in model.astream(messages, output_version="v2"): + pass + assert model.last_output_version == "v0" # _generate always gets v0 default + + +class TestOutputVersionInMessages: + """Test output_version is added to message response_metadata.""" + + def test_output_version_added_to_message_response_metadata( + self, + messages: list[BaseMessage], + ) -> None: + """Test that output_version is added to message response_metadata.""" + model = OutputVersionTrackingChatModel( + messages=iter(["test response"]), output_version="v1" + ) + result = model.invoke(messages, output_version="v2") + assert result.response_metadata["output_version"] == "v2" + + def test_output_version_added_to_stream_message_response_metadata( + self, + messages: list[BaseMessage], + ) -> None: + """Test that output_version is added to streamed message response_metadata.""" + model = OutputVersionTrackingChatModel( + messages=iter(["test response"]), output_version="v1" + ) + chunks = list(model.stream(messages, output_version="v2")) + + # Check that content chunks (not the "last" chunk) have the output_version + content_chunks = [chunk for chunk in chunks if chunk.content] + assert len(content_chunks) >= 1 # Should have at least one content chunk + + for chunk in content_chunks: + assert "output_version" in chunk.response_metadata + assert chunk.response_metadata["output_version"] == "v2" + + async def test_output_version_added_to_astream_message_response_metadata( + self, + messages: list[BaseMessage], + ) -> None: + """Test output_version added to async streamed response_metadata.""" + model = OutputVersionTrackingChatModel( + messages=iter(["test response"]), output_version="v1" + ) + chunks = [chunk async for chunk in model.astream(messages, output_version="v2")] + + # Check that content chunks (not the "last" chunk) have the output_version + content_chunks = [chunk for chunk in chunks if chunk.content] + assert len(content_chunks) >= 1 # Should have at least one content chunk + + for chunk in content_chunks: + assert "output_version" in chunk.response_metadata + assert chunk.response_metadata["output_version"] == "v2" + + def test_no_output_version_anywhere_no_metadata( + self, + messages: list[BaseMessage], + ) -> None: + """Test that when no output_version is set, no metadata is added.""" + from itertools import cycle + + # Test invoke + model = GenericFakeChatModel(messages=cycle([AIMessage(content="hello")])) + result = model.invoke(messages) + assert result.response_metadata == {} + assert isinstance(result.content, str) # Should be v0 behavior + + # Test stream + model_stream = GenericFakeChatModel( + messages=cycle([AIMessage(content="hello")]) + ) + chunks = list(model_stream.stream(messages)) + content_chunks = [chunk for chunk in chunks if chunk.content] + for chunk in content_chunks: + assert chunk.response_metadata == {} + + +class TestOutputVersionMerging: + """Test output_version handling in merge operations.""" + + def test_output_version_consistency_in_merge(self) -> None: + """Test that merge_dicts raises error for inconsistent output_version.""" + from langchain_core.utils._merge import merge_dicts + + left_dict = {"output_version": "v1"} + right_dict = {"output_version": "v2"} + + with pytest.raises(ValueError, match="Unable to merge.*output_version"): + merge_dicts(left_dict, right_dict) + + def test_output_version_merge_same_value(self) -> None: + """Test that merge_dicts works fine when output_version values are same.""" + from langchain_core.utils._merge import merge_dicts + + left_dict = {"output_version": "v1", "other": "data1"} + right_dict = {"output_version": "v1", "more": "data2"} + + result = merge_dicts(left_dict, right_dict) + assert result["output_version"] == "v1" + assert result["other"] == "data1" + assert result["more"] == "data2" + + +class TestBackwardsCompatibility: + """Test backwards compatibility features.""" + + def test_backwards_compatibility_with_v0_default( + self, + messages: list[BaseMessage], + ) -> None: + """Test that models default to v0 for backward compatibility.""" + model = OutputVersionTrackingChatModel( + messages=iter(["test response"]) + ) # Don't specify output_version + model.invoke(messages) + # The default should be v0 for backward compatibility + assert model.last_output_version == "v0" + + def test_output_version_preserved_through_chain_calls( + self, + messages: list[BaseMessage], + ) -> None: + """Test that output_version is preserved through internal method calls.""" + model = OutputVersionTrackingChatModel( + messages=iter(["test response"]), output_version="v1" + ) + + # Test both with explicit and implicit (None) output_version + with patch.object( + model, "_generate_with_cache", wraps=model._generate_with_cache + ) as mock_cache: + model.invoke(messages, output_version="v2") + # Verify the internal call received the right output_version + mock_cache.assert_called_once() + call_kwargs = mock_cache.call_args[1] + assert call_kwargs.get("_output_version") == "v2" + + # Verify the model implementation received the default output_version + # (not the explicit one) + assert model.last_output_version == "v0" diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_rate_limiting.py b/libs/core/tests/unit_tests/language_models/chat_models/test_rate_limiting.py index 580e6f5837e6b..347f8c7763729 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_rate_limiting.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_rate_limiting.py @@ -217,7 +217,8 @@ def test_rate_limit_skips_cache() -> None: '[{"lc": 1, "type": "constructor", "id": ["langchain", "schema", ' '"messages", "HumanMessage"], "kwargs": {"content": "foo", ' '"type": "human"}}]', - "[('_type', 'generic-fake-chat-model'), ('stop', None)]", + "[('_output_version', 'v0'), ('_output_version_explicit', False), ('_type'," + " 'generic-fake-chat-model'), ('stop', None)]", ) ] diff --git a/libs/core/tests/unit_tests/stubs.py b/libs/core/tests/unit_tests/stubs.py index 5cd45afb41f48..bcdf60a6d3ca1 100644 --- a/libs/core/tests/unit_tests/stubs.py +++ b/libs/core/tests/unit_tests/stubs.py @@ -29,6 +29,7 @@ def _any_id_document(**kwargs: Any) -> Document: def _any_id_ai_message(**kwargs: Any) -> AIMessage: """Create ai message with an any id field.""" + # Don't automatically add output_version - it should only be present when explicit message = AIMessage(**kwargs) message.id = AnyStr() return message @@ -36,6 +37,7 @@ def _any_id_ai_message(**kwargs: Any) -> AIMessage: def _any_id_ai_message_chunk(**kwargs: Any) -> AIMessageChunk: """Create ai message with an any id field.""" + # Don't automatically add output_version - it should only be present when explicit message = AIMessageChunk(**kwargs) message.id = AnyStr() return message diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index ee1b0d6289685..4115116a0ac1a 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -1599,8 +1599,11 @@ def _stream( run_manager: Optional[CallbackManagerForLLMRun] = None, *, stream_usage: Optional[bool] = None, + output_version: Optional[str] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: + # Note: output_version accepted for interface consistency; format conversion + # handled by core if stream_usage is None: stream_usage = self.stream_usage kwargs["stream"] = True @@ -1635,8 +1638,11 @@ async def _astream( run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, *, stream_usage: Optional[bool] = None, + output_version: Optional[str] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: + # Note: output_version accepted for interface consistency; format conversion + # handled by core if stream_usage is None: stream_usage = self.stream_usage kwargs["stream"] = True @@ -1719,13 +1725,18 @@ def _generate( messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + output_version: Optional[str] = None, **kwargs: Any, ) -> ChatResult: + # Note: output_version accepted for interface consistency; format conversion + # handled by core if self.streaming: stream_iter = self._stream( messages, stop=stop, run_manager=run_manager, + output_version=output_version, **kwargs, ) return generate_from_stream(stream_iter) @@ -1741,13 +1752,18 @@ async def _agenerate( messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + *, + output_version: Optional[str] = None, **kwargs: Any, ) -> ChatResult: + # Note: output_version accepted for interface consistency; format conversion + # handled by core if self.streaming: stream_iter = self._astream( messages, stop=stop, run_manager=run_manager, + output_version=output_version, **kwargs, ) return await agenerate_from_stream(stream_iter) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 8116630830b6c..9e1a489f28ea4 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -964,18 +964,26 @@ def _stream_responses( messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + output_version: Optional[str] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: kwargs["stream"] = True + effective_output_version = ( + output_version + if output_version is not None + else (self.output_version or "v0") + ) payload = self._get_request_payload(messages, stop=stop, **kwargs) + api_payload = self._prepare_api_payload(payload) if self.include_response_headers: raw_context_manager = self.root_client.with_raw_response.responses.create( - **payload + **api_payload ) context_manager = raw_context_manager.parse() headers = {"headers": dict(raw_context_manager.headers)} else: - context_manager = self.root_client.responses.create(**payload) + context_manager = self.root_client.responses.create(**api_payload) headers = {} original_schema_obj = kwargs.get("response_format") @@ -1000,7 +1008,7 @@ def _stream_responses( schema=original_schema_obj, metadata=metadata, has_reasoning=has_reasoning, - output_version=self.output_version, + output_version=effective_output_version, ) if generation_chunk: if run_manager: @@ -1017,20 +1025,30 @@ async def _astream_responses( messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + *, + output_version: Optional[str] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: kwargs["stream"] = True + effective_output_version = ( + output_version + if output_version is not None + else (self.output_version or "v0") + ) payload = self._get_request_payload(messages, stop=stop, **kwargs) + api_payload = self._prepare_api_payload(payload) if self.include_response_headers: raw_context_manager = ( await self.root_async_client.with_raw_response.responses.create( - **payload + **api_payload ) ) context_manager = raw_context_manager.parse() headers = {"headers": dict(raw_context_manager.headers)} else: - context_manager = await self.root_async_client.responses.create(**payload) + context_manager = await self.root_async_client.responses.create( + **api_payload + ) headers = {} original_schema_obj = kwargs.get("response_format") @@ -1055,7 +1073,7 @@ async def _astream_responses( schema=original_schema_obj, metadata=metadata, has_reasoning=has_reasoning, - output_version=self.output_version, + output_version=effective_output_version, ) if generation_chunk: if run_manager: @@ -1093,9 +1111,12 @@ def _stream( run_manager: Optional[CallbackManagerForLLMRun] = None, *, stream_usage: Optional[bool] = None, + output_version: Optional[str] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: kwargs["stream"] = True + # Note: output_version accepted for interface consistency; format conversion + # handled by core stream_usage = self._should_stream_usage(stream_usage, **kwargs) if stream_usage: kwargs["stream_options"] = {"include_usage": stream_usage} @@ -1109,16 +1130,20 @@ def _stream( "Cannot currently include response headers when response_format is " "specified." ) - payload.pop("stream") - response_stream = self.root_client.beta.chat.completions.stream(**payload) + api_payload = self._prepare_api_payload(payload) + api_payload.pop("stream") + response_stream = self.root_client.beta.chat.completions.stream( + **api_payload + ) context_manager = response_stream else: + api_payload = self._prepare_api_payload(payload) if self.include_response_headers: - raw_response = self.client.with_raw_response.create(**payload) + raw_response = self.client.with_raw_response.create(**api_payload) response = raw_response.parse() base_generation_info = {"headers": dict(raw_response.headers)} else: - response = self.client.create(**payload) + response = self.client.create(**api_payload) context_manager = response try: with context_manager as response: @@ -1161,11 +1186,23 @@ def _generate( messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + output_version: Optional[str] = None, **kwargs: Any, ) -> ChatResult: + effective_output_version = ( + output_version + if output_version is not None + else (self.output_version or "v0") + ) + if self.streaming: stream_iter = self._stream( - messages, stop=stop, run_manager=run_manager, **kwargs + messages, + stop=stop, + run_manager=run_manager, + output_version=effective_output_version, + **kwargs, ) return generate_from_stream(stream_iter) payload = self._get_request_payload(messages, stop=stop, **kwargs) @@ -1184,14 +1221,15 @@ def _generate( except openai.BadRequestError as e: _handle_openai_bad_request(e) elif self._use_responses_api(payload): + api_payload = self._prepare_api_payload(payload) original_schema_obj = kwargs.get("response_format") if original_schema_obj and _is_pydantic_class(original_schema_obj): raw_response = self.root_client.responses.with_raw_response.parse( - **payload + **api_payload ) else: raw_response = self.root_client.responses.with_raw_response.create( - **payload + **api_payload ) response = raw_response.parse() if self.include_response_headers: @@ -1200,10 +1238,11 @@ def _generate( response, schema=original_schema_obj, metadata=generation_info, - output_version=self.output_version, + output_version=effective_output_version, ) else: - raw_response = self.client.with_raw_response.create(**payload) + api_payload = self._prepare_api_payload(payload) + raw_response = self.client.with_raw_response.create(**api_payload) response = raw_response.parse() except Exception as e: if raw_response is not None and hasattr(raw_response, "http_response"): @@ -1264,6 +1303,12 @@ def _get_request_payload( ] return payload + def _prepare_api_payload(self, payload: dict) -> dict: + """Remove LangChain-specific parameters before making OpenAI API calls.""" + api_payload = payload.copy() + api_payload.pop("output_version", None) + return api_payload + def _create_chat_result( self, response: Union[dict, openai.BaseModel], @@ -1337,9 +1382,12 @@ async def _astream( run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, *, stream_usage: Optional[bool] = None, + output_version: Optional[str] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: kwargs["stream"] = True + # Note: output_version accepted for interface consistency; format conversion + # handled by core stream_usage = self._should_stream_usage(stream_usage, **kwargs) if stream_usage: kwargs["stream_options"] = {"include_usage": stream_usage} @@ -1353,20 +1401,22 @@ async def _astream( "Cannot currently include response headers when response_format is " "specified." ) - payload.pop("stream") + api_payload = self._prepare_api_payload(payload) + api_payload.pop("stream") response_stream = self.root_async_client.beta.chat.completions.stream( - **payload + **api_payload ) context_manager = response_stream else: + api_payload = self._prepare_api_payload(payload) if self.include_response_headers: raw_response = await self.async_client.with_raw_response.create( - **payload + **api_payload ) response = raw_response.parse() base_generation_info = {"headers": dict(raw_response.headers)} else: - response = await self.async_client.create(**payload) + response = await self.async_client.create(**api_payload) context_manager = response try: async with context_manager as response: @@ -1409,11 +1459,23 @@ async def _agenerate( messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + *, + output_version: Optional[str] = None, **kwargs: Any, ) -> ChatResult: + effective_output_version = ( + output_version + if output_version is not None + else (self.output_version or "v0") + ) + if self.streaming: stream_iter = self._astream( - messages, stop=stop, run_manager=run_manager, **kwargs + messages, + stop=stop, + run_manager=run_manager, + output_version=effective_output_version, + **kwargs, ) return await agenerate_from_stream(stream_iter) payload = self._get_request_payload(messages, stop=stop, **kwargs) @@ -1421,26 +1483,28 @@ async def _agenerate( raw_response = None try: if "response_format" in payload: - payload.pop("stream") + api_payload = self._prepare_api_payload(payload) + api_payload.pop("stream") try: raw_response = await self.root_async_client.chat.completions.with_raw_response.parse( # noqa: E501 - **payload + **api_payload ) response = raw_response.parse() except openai.BadRequestError as e: _handle_openai_bad_request(e) elif self._use_responses_api(payload): + api_payload = self._prepare_api_payload(payload) original_schema_obj = kwargs.get("response_format") if original_schema_obj and _is_pydantic_class(original_schema_obj): raw_response = ( await self.root_async_client.responses.with_raw_response.parse( - **payload + **api_payload ) ) else: raw_response = ( await self.root_async_client.responses.with_raw_response.create( - **payload + **api_payload ) ) response = raw_response.parse() @@ -1450,11 +1514,12 @@ async def _agenerate( response, schema=original_schema_obj, metadata=generation_info, - output_version=self.output_version, + output_version=effective_output_version, ) else: + api_payload = self._prepare_api_payload(payload) raw_response = await self.async_client.with_raw_response.create( - **payload + **api_payload ) response = raw_response.parse() except Exception as e: @@ -4052,6 +4117,9 @@ def _construct_lc_result_from_responses_api( ) if output_version == "v0": message = _convert_to_v03_ai_message(message) + elif output_version == "v1": + # Use content_blocks property which handles v1 conversion via block_translators + message = message.model_copy(update={"content": message.content_blocks}) return ChatResult(generations=[ChatGeneration(message=message)]) @@ -4289,6 +4357,12 @@ def _advance(output_idx: int, sub_idx: Optional[int] = None) -> None: AIMessageChunk, _convert_to_v03_ai_message(message, has_reasoning=has_reasoning), ) + elif output_version == "v1": + # Use content_blocks property which handles v1 conversion via block_translators + message = cast( + AIMessageChunk, + message.model_copy(update={"content": message.content_blocks}), + ) return ( current_index,