diff --git a/docs/features/models/transformers_multimodal.md b/docs/features/models/transformers_multimodal.md index 24d119de3..3d34a0f78 100644 --- a/docs/features/models/transformers_multimodal.md +++ b/docs/features/models/transformers_multimodal.md @@ -10,18 +10,18 @@ The Outlines `TransformersMultiModal` model inherits from `Transformers` and sha To load the model, you can use the `from_transformers` function. It takes 2 arguments: -- `model`: a `transformers` model (created with `AutoModelForCausalLM` for instance) +- `model`: a `transformers` model (created with `AutoModelForImageTextToText` for instance) - `tokenizer_or_processor`: a `transformers` processor (created with `AutoProcessor` for instance, it must be an instance of `ProcessorMixin`) For instance: ```python import outlines -from transformers import AutoModelForCausalLM, AutoProcessor +from transformers import AutoModelForImageTextToText, AutoProcessor # Create the transformers model and processor -hf_model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct") -hf_processor = AutoProcessor.from_pretrained("microsoft/Phi-3-mini-4k-instruct") +hf_model = AutoModelForImageTextToText.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct") +hf_processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct") # Create the Outlines model model = outlines.from_transformers(hf_model, hf_processor) @@ -76,11 +76,186 @@ result = model( print(result) # '{"specie": "cat", "color": "white", "weight": 4}' print(Animal.model_validate_json(result)) # specie=cat, color=white, weight=4 ``` +!!! Warning + + Make sure your prompt contains the tags expected by your processor to correctly inject the assets in the prompt. For some vision multimodal models for instance, you need to add as many `` tags in your prompt as there are image assets included in your model input. `Chat` method, instead, does not require this step. + -The `TransformersMultiModal` model supports batch generation. To use it, invoke the `batch` method with a list of lists. You will receive as a result a list of completions. +### Chat +The `Chat` interface offers a more convenient way to work with multimodal inputs. You don't need to manually add asset tags like ``. The model's HF processor handles the chat templating and asset placement for you automatically. +To do so, call the model with a `Chat` instance using a multimodal chat format. Assets must be pre-processed as `outlines.inputs.{Image, Audio, Video}` format, and only `image`, `video`, and `audio` types are supported. For instance: +```python +import outlines +from outlines.inputs import Chat, Image +from transformers import AutoModelForImageTextToText, AutoProcessor +from PIL import Image as PILImage +from io import BytesIO +from urllib.request import urlopen +import torch + +model_kwargs = { + "torch_dtype": torch.bfloat16, + "attn_implementation": "flash_attention_2", + "device_map": "auto", + } + +def get_image_from_url(image_url): + img_byte_stream = BytesIO(urlopen(image_url).read()) + image = PILImage.open(img_byte_stream).convert("RGB") + image.format = "PNG" + return image + +# Create the model +model = outlines.from_transformers( + AutoModelForImageTextToText.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", **model_kwargs), + AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", **model_kwargs) +) + +IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg" + +# Create the chat mutimodal input +prompt = Chat([ + { + "role": "user", + "content": [ + {"type": "image", "image": Image(get_image_from_url(IMAGE_URL))}, + {"type": "text", "text": "Describe the image in few words."} + ], + } +]) + +# Call the model to generate a response +response = model(prompt, max_new_tokens=50) +print(response) # 'A Siamese cat with blue eyes is sitting on a cat tree, looking alert and curious.' +``` + +Or using a list containing text and assets: + +```python +import outlines +from outlines.inputs import Chat, Image +from transformers import AutoModelForImageTextToText, AutoProcessor +from PIL import Image as PILImage +from io import BytesIO +import requests +import torch + + +TEST_MODEL = "Qwen/Qwen2.5-VL-7B-Instruct" + +# Function to get an image +def get_image(url): + headers = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' + } + r = requests.get(url, headers=headers) + image = PILImage.open(BytesIO(r.content)).convert("RGB") + image.format = "PNG" + return image + +model_kwargs = { + "torch_dtype": torch.bfloat16, + # "attn_implementation": "flash_attention_2", + "device_map": "auto", + } + +# Create a model +model = outlines.from_transformers( + AutoModelForImageTextToText.from_pretrained(TEST_MODEL, **model_kwargs), + AutoProcessor.from_pretrained(TEST_MODEL, **model_kwargs), +) + +# Create the chat input +prompt = Chat([ + {"role": "user", "content": "You are a helpful assistant that helps me described pictures."}, + {"role": "assistant", "content": "I'd be happy to help you describe pictures! Please go ahead and share an image"}, + { + "role": "user", + "content": ["Describe briefly the image", Image(get_image("https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg"))] + }, +]) + +# Call the model to generate a response +response = model(prompt, max_new_tokens=50) +print(response) # 'The image shows a light-colored cat with a white chest...' +``` + + +### Batching +The `TransformersMultiModal` model supports batching through the `batch` method. To use it, provide a list of prompts (using the formats described above) to the `batch` method. You will receive as a result a list of completions. + +An example using the Chat format: + +```python +import outlines +from outlines.inputs import Chat, Image +from transformers import AutoModelForImageTextToText, AutoProcessor +from PIL import Image as PILImage +from io import BytesIO +from urllib.request import urlopen +import torch +from pydantic import BaseModel + +model_kwargs = { + "torch_dtype": torch.bfloat16, + "attn_implementation": "flash_attention_2", + "device_map": "auto", + } + +class Animal(BaseModel): + animal: str + color: str + +def get_image_from_url(image_url): + img_byte_stream = BytesIO(urlopen(image_url).read()) + image = PILImage.open(img_byte_stream).convert("RGB") + image.format = "PNG" + return image + +# Create the model +model = outlines.from_transformers( + AutoModelForImageTextToText.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", **model_kwargs), + AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", **model_kwargs) +) + +IMAGE_URL_1 = "https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg" +IMAGE_URL_2 = "https://upload.wikimedia.org/wikipedia/commons/a/af/Golden_retriever_eating_pigs_foot.jpg" + +# Create the chat mutimodal messages +messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe the image in few words."}, + {"type": "image", "image": Image(get_image_from_url(IMAGE_URL_1))}, + ], + }, +] + +messages_2 = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe the image in few words."}, + {"type": "image", "image": Image(get_image_from_url(IMAGE_URL_2))}, + ], + }, +] + +prompts = [Chat(messages), Chat(messages_2)] + +# Call the model to generate a response +responses = model.batch(prompts, output_type=Animal, max_new_tokens=100) +print(responses) # ['{ "animal": "cat", "color": "white and gray" }', '{ "animal": "dog", "color": "white" }'] +print([Animal.model_validate_json(i) for i in responses]) # [Animal(animal='cat', color='white and gray'), Animal(animal='dog', color='white')] +``` + + +An example using a list of lists with tag assets: + ```python from io import BytesIO from urllib.request import urlopen @@ -119,7 +294,3 @@ result = model.batch( ) print(result) # ['The image shows a cat', 'The image shows an astronaut'] ``` - -!!! Warning - - Make sure your prompt contains the tags expected by your processor to correctly inject the assets in the prompt. For some vision multimodal models for instance, you need to add as many `` tags in your prompt as there are image assets included in your model input. diff --git a/outlines/inputs.py b/outlines/inputs.py index 50a6a6741..54b17dd49 100644 --- a/outlines/inputs.py +++ b/outlines/inputs.py @@ -80,8 +80,11 @@ class Chat: Each message contained in the messages list must be a dict with 'role' and 'content' keys. The role can be 'user', 'assistant', or 'system'. The content - can be a string or a list containing a str and assets (images, videos, - audios, etc.) in the case of multimodal models. + supports either: + - a text string, + - a list containing text and assets (e.g., ["Describe...", Image(...)]), + - only for HuggingFace transformers models, a list of dict items with explicit types (e.g., + [{"type": "text", "text": "Describe..."}, {"type": "image", "image": Image(...)}]) Examples -------- @@ -95,7 +98,7 @@ class Chat: chat_prompt.add_user_message(["Describe the image below", Image(image)]) # Add as an assistant message the response from the model. - chat_prompt.add_assistant_message("The is a black cat sitting on a couch.") + chat_prompt.add_assistant_message("There is a black cat sitting on a couch.") ``` Parameters diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py index 964dbce60..4d6387a40 100644 --- a/outlines/models/transformers.py +++ b/outlines/models/transformers.py @@ -430,61 +430,116 @@ def format_input(self, model_input): + "model or a `Chat` instance." ) - @format_input.register(dict) - def format_dict_input(self, model_input: dict) -> dict: - warnings.warn(""" - Providing the input as a dict is deprecated. Support for this will - be removed in the v1.2.0 release of Outlines. Use a list containing - a text prompt and assets (`Image`, `Audio` or `Video` instances) - instead. - For instance: - ```python - from outlines import Image - model = from_transformers(mymodel, myprocessor) - response = model([ - "A beautiful image of a cat", - Image(my_image), - ]) - ``` - """, - DeprecationWarning, - stacklevel=2, + @format_input.register(Chat) + def format_chat_input(self, model_input: Chat) -> dict: + conversation = [] + assets = [] + + # process each message, convert if needed to standardized multimodal chat template format + # and collect assets for HF processor + for message in model_input.messages: + processed_message, message_assets = self._prepare_message( + message["role"], message["content"] + ) + conversation.append(processed_message) + assets.extend(message_assets) + + formatted_prompt = self.tokenizer.apply_chat_template( + conversation, + tokenize=False, + add_generation_prompt=True ) - if "text" not in model_input: + # use the formatted prompt and the assets to format the input + return self.format_list_input([formatted_prompt, *assets]) + + def _prepare_message(self, role: str, content: str | list) -> tuple[dict, list]: + """Create a message.""" + if isinstance(content, str): + return {"role": role, "content": content}, [] + + elif isinstance(content, list): + if all(isinstance(item, dict) for item in content): # HF multimodal chat template + return {"role": role, "content": content}, self._extract_assets_from_content(content) + else: # list of string + assets + prompt = content[0] + assets = content[1:] + assets_dict = [self._format_asset_for_template(asset) for asset in assets] + + return {"role": role, "content": [ + {"type": "text", "text": prompt}, + *assets_dict + ]}, assets + else: raise ValueError( - "The input must contain the 'text' key along with the other " - + "keys required by your processor." + f"Invalid content type: {type(content)}. " + + "The content must be a string or a list containing text and assets " + + "or a list of dict items with explicit types." ) - return model_input - @format_input.register(Chat) - def format_chat_input(self, model_input: Chat) -> dict: - # we need to separate the images from the messages - # to apply the chat template to the messages without images - messages = model_input.messages - images = [] - messages_without_images = [] - for message in messages: - if isinstance(message["content"], list): - images.extend(message["content"][1:]) - messages_without_images.append({ - "role": message["role"], - "content": message["content"][0], - }) + def _extract_assets_from_content(self, content: list) -> list: + """Process a list of dict items.""" + assets = [] + + for item in content: + if len(item) > 2: + raise ValueError( + f"Found item with multiple keys: {item}. " + + "Each item in the content list must be a dictionary with a 'type' key and a single asset key. " + + "To include multiple assets, use separate dictionary items. " + + "For example: [{{'type': 'image', 'image': image1}}, {{'type': 'image', 'image': image2}}]. " + ) + + if "type" not in item: + raise ValueError( + "Each item in the content list must be a dictionary with a 'type' key. " + + "Valid types are 'text', 'image', 'video', or 'audio'. " + + "For instance {{'type': 'text', 'text': 'your message'}}. " + + f"Found item without 'type' key: {item}" + ) + if item["type"] == "text": + continue + elif item["type"] in ["image", "video", "audio"]: + asset_key = item["type"] + if asset_key not in item: + raise ValueError( + f"Item with type '{asset_key}' must contain a '{asset_key}' key. " + + f"Found item: {item}" + ) + if isinstance(item[asset_key], (Image, Video, Audio)): + assets.append(item[asset_key]) + else: + raise ValueError( + "Assets must be of type `Image`, `Video` or `Audio`. " + + f"Unsupported asset type: {type(item[asset_key])}" + ) else: - messages_without_images.append(message) - formatted_prompt = self.tokenizer.apply_chat_template( - messages_without_images, - tokenize=False - ) - # use the formatted prompt and the images to format the input - return self.format_list_input([formatted_prompt, *images]) + raise ValueError( + "Content must be 'text', 'image', 'video' or 'audio'. " + + f"Unsupported content type: {item['type']}") + return assets + + def _format_asset_for_template(self, asset: Image | Video | Audio) -> dict: + """Process an asset.""" + if isinstance(asset, Image): + return {"type": "image", "image": asset} + elif isinstance(asset, Video): + return {"type": "video", "video": asset} + elif isinstance(asset, Audio): + return {"type": "audio", "audio": asset} + else: + raise ValueError( + "Assets must be of type `Image`, `Video` or `Audio`. " + + f"Unsupported asset type: {type(asset)}" + ) @format_input.register(list) def format_list_input(self, model_input: list) -> dict: prompt = model_input[0] assets = model_input[1:] + if not assets: # handle empty assets case + return {"text": prompt} + asset_types = set(type(asset) for asset in assets) if len(asset_types) > 1: raise ValueError( diff --git a/tests/models/test_transformers_multimodal.py b/tests/models/test_transformers_multimodal.py index 41ff058e9..ce968e1f6 100644 --- a/tests/models/test_transformers_multimodal.py +++ b/tests/models/test_transformers_multimodal.py @@ -43,8 +43,6 @@ def model(): LlavaForConditionalGeneration.from_pretrained(TEST_MODEL), AutoProcessor.from_pretrained(TEST_MODEL), ) - chat_template = '{% for message in messages %}{{ message.role }}: {{ message.content }}{% endfor %}' - model.type_adapter.tokenizer.chat_template = chat_template return model @@ -100,7 +98,7 @@ def test_transformers_multimodal_chat(model, image): { "role": "user", "content": [ - "Describe this image in one sentence:", + "Describe this image in one sentence:", Image(image), ], }, @@ -109,6 +107,21 @@ def test_transformers_multimodal_chat(model, image): ) assert isinstance(result, str) + result = model( + Chat(messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image in one sentence:"}, + {"type": "image", "image": Image(image)}, + ], + }, + ]), + max_new_tokens=2, + ) + assert isinstance(result, str) + def test_transformers_inference_kwargs(model, image): result = model( @@ -221,7 +234,7 @@ def test_transformers_multimodal_batch(model, image): { "role": "user", "content": [ - "Describe this image in one sentence:", + "Describe this image in one sentence:", Image(image), ], }, @@ -231,7 +244,7 @@ def test_transformers_multimodal_batch(model, image): { "role": "user", "content": [ - "Describe this image in one sentence:", + "Describe this image in one sentence:", Image(image), ], }, @@ -242,15 +255,30 @@ def test_transformers_multimodal_batch(model, image): assert isinstance(result, list) assert len(result) == 2 - -def test_transformers_multimodal_deprecated_input_type(model, image): - with pytest.warns(DeprecationWarning): - result = model.generate( - { - "text": "Describe this image in one sentence:", - "image": image, - }, - None, - max_new_tokens=2, - ) - assert isinstance(result, str) + result = model.batch( + [ + Chat(messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image in one sentence:"}, + {"type": "image", "image": Image(image)}, + ], + }, + ]), + Chat(messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image in one sentence:"}, + {"type": "image", "image": Image(image)}, + ], + }, + ]), + ], + max_new_tokens=2, + ) + assert isinstance(result, list) + assert len(result) == 2 diff --git a/tests/models/test_transformers_multimodal_type_adapter.py b/tests/models/test_transformers_multimodal_type_adapter.py index 10df1b0e9..5518ef553 100644 --- a/tests/models/test_transformers_multimodal_type_adapter.py +++ b/tests/models/test_transformers_multimodal_type_adapter.py @@ -2,22 +2,24 @@ from PIL import Image as PILImage from outlines_core import Index, Vocabulary -from transformers import AutoTokenizer, LogitsProcessorList +from transformers import ( + AutoProcessor, + LogitsProcessorList, +) -from outlines.inputs import Chat, Image, Video +from outlines.inputs import Audio, Chat, Image, Video from outlines.models.transformers import TransformersMultiModalTypeAdapter from outlines.backends.outlines_core import OutlinesCoreLogitsProcessor -MODEL_NAME = "erwanf/gpt2-mini" +MODEL_NAME = "trl-internal-testing/tiny-LlavaForConditionalGeneration" @pytest.fixture def adapter(): - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + processor = AutoProcessor.from_pretrained(MODEL_NAME) + tokenizer = processor.tokenizer type_adapter = TransformersMultiModalTypeAdapter(tokenizer=tokenizer) - chat_template = '{% for message in messages %}{{ message.role }}: {{ message.content }}{% endfor %}' - type_adapter.tokenizer.chat_template = chat_template return type_adapter @@ -39,13 +41,24 @@ def image(): return image +@pytest.fixture +def video(): + # Simple mock video data + return "mock_video_data" + + +@pytest.fixture +def audio(): + # Simple mock audio data + return "mock_audio_data" + + def test_transformers_multimodal_type_adapter_format_input(adapter, image): with pytest.raises(TypeError): adapter.format_input("hello") - with pytest.raises(ValueError): - with pytest.deprecated_call(): - adapter.format_input({"foo": "bar"}) + with pytest.raises(TypeError): + adapter.format_input({"foo": "bar"}) with pytest.raises(ValueError, match="All assets must be of the same type"): adapter.format_input(["foo", Image(image), Video("")]) @@ -73,6 +86,51 @@ class MockAsset: assert len(result["images"]) == 1 assert result["images"][0] == image_asset.image + chat_prompt = Chat(messages=[ + {"role": "system", "content": "foo"}, + {"role": "user", "content": [{"type": "text", "text": "bar"}, {"type": "image", "image": image_asset}]}, + ]) + result = adapter.format_input(chat_prompt) + assert isinstance(result, dict) + assert isinstance(result["text"], str) + assert isinstance(result["images"], list) + assert len(result["images"]) == 1 + assert result["images"][0] == image_asset.image + + + + +def test_transformers_multimodal_type_adapter_format_input_empty_assets(adapter): + result = adapter.format_input(["Just text prompt"]) + assert result == {"text": "Just text prompt"} + + +def test_transformers_multimodal_type_adapter_format_input_chat_invalid_asset_type(adapter, image): + class MockAsset: + pass + + chat_prompt = Chat(messages=[ + {"role": "user", "content": [ + {"type": "text", "text": "Hello"}, + {"type": "image", "image": MockAsset()} # Wrong type + ]} + ]) + + with pytest.raises(ValueError, match="Assets must be of type"): + adapter.format_input(chat_prompt) + + +def test_transformers_multimodal_type_adapter_format_input_chat_unsupported_content_type(adapter): + chat_prompt = Chat(messages=[ + {"role": "user", "content": [ + {"type": "text", "text": "Hello"}, + {"type": "unsupported", "data": "some_data"} # Unsupported type + ]} + ]) + + with pytest.raises(ValueError, match="Content must be 'text'"): + adapter.format_input(chat_prompt) + def test_transformers_multimodal_type_adapter_format_output_type( adapter, logits_processor @@ -85,3 +143,126 @@ def test_transformers_multimodal_type_adapter_format_output_type( formatted = adapter.format_output_type(None) assert formatted is None + + +def test_transformers_multimodal_type_adapter_format_input_chat_missing_asset_key(adapter, image): + image_asset = Image(image) + + # Test missing 'image' key when type is 'image' + chat_prompt = Chat(messages=[ + {"role": "user", "content": [ + {"type": "text", "text": "What's in this image?"}, + {"type": "image", "txt": image_asset} # Wrong key: 'txt' instead of 'image' + ]} + ]) + + with pytest.raises(ValueError, match="Item with type 'image' must contain a 'image' key"): + adapter.format_input(chat_prompt) + + # Test missing 'video' key when type is 'video' + video_asset = Video("dummy_video") + chat_prompt = Chat(messages=[ + {"role": "user", "content": [ + {"type": "text", "text": "What's in this video?"}, + {"type": "video", "vid": video_asset} # Wrong key: 'vid' instead of 'video' + ]} + ]) + + with pytest.raises(ValueError, match="Item with type 'video' must contain a 'video' key"): + adapter.format_input(chat_prompt) + + +def test_transformers_multimodal_type_adapter_format_input_chat_missing_type_key(adapter, image): + image_asset = Image(image) + + chat_prompt = Chat(messages=[ + {"role": "user", "content": [ + {"text": "What's in this image?"}, # Missing 'type' key + {"type": "image", "image": image_asset} + ]} + ]) + + with pytest.raises(ValueError, match="Each item in the content list must be a dictionary with a 'type' key"): + adapter.format_input(chat_prompt) + + +def test_transformers_multimodal_type_adapter_format_input_invalid_content_type(adapter): + chat_prompt = Chat(messages=[ + {"role": "user", "content": 42} # Invalid content type (integer) + ]) + + with pytest.raises(ValueError, match="Invalid content type"): + adapter.format_input(chat_prompt) + + # Test with another invalid type + chat_prompt = Chat(messages=[ + {"role": "user", "content": {"invalid": "dict"}} # Invalid content type (dict not in list) + ]) + + with pytest.raises(ValueError, match="Invalid content type"): + adapter.format_input(chat_prompt) + + +def test_transformers_multimodal_type_adapter_format_asset_for_template(adapter, image, video, audio): + # Test Image asset + image_asset = Image(image) + formatted_image = adapter._format_asset_for_template(image_asset) + assert formatted_image == {"type": "image", "image": image_asset} + + # Test Video asset + video_asset = Video(video) + formatted_video = adapter._format_asset_for_template(video_asset) + assert formatted_video == {"type": "video", "video": video_asset} + + # Test Audio asset + audio_asset = Audio(audio) + formatted_audio = adapter._format_asset_for_template(audio_asset) + assert formatted_audio == {"type": "audio", "audio": audio_asset} + + +def test_transformers_multimodal_type_adapter_format_asset_for_template_invalid_type(adapter): + class MockUnsupportedAsset: + pass + + # This test requires accessing the private method directly since the error + # would normally be caught earlier in the validation chain + unsupported_asset = MockUnsupportedAsset() + + with pytest.raises(ValueError, match="Assets must be of type `Image`, `Video` or `Audio`"): + adapter._format_asset_for_template(unsupported_asset) + + +def test_transformers_multimodal_type_adapter_multiple_assets_in_single_item(adapter, image): + image_asset = Image(image) + video_asset = Video("dummy_video") + + chat_prompt = Chat(messages=[ + {"role": "user", "content": [ + {"type": "text", "text": "What's in this?"}, + {"type": "image", "image": image_asset, "video": video_asset} # Multiple asset types + ]} + ]) + + with pytest.raises(ValueError, match="Found item with multiple keys:"): + adapter.format_input(chat_prompt) + + + +def test_transformers_multimodal_type_adapter_correct_multiple_assets_usage(adapter, image): + image_asset1 = Image(image) + image_asset2 = Image(image) + + # Correct way: separate dictionary items for each asset + chat_prompt = Chat(messages=[ + {"role": "user", "content": [ + {"type": "text", "text": "What's in these images?"}, + {"type": "image", "image": image_asset1}, + {"type": "image", "image": image_asset2} + ]} + ]) + + result = adapter.format_input(chat_prompt) + assert isinstance(result, dict) + assert "text" in result + assert "images" in result + assert len(result["images"]) == 2 diff --git a/tests/test_inputs.py b/tests/test_inputs.py index d020a48c9..e3b93c91f 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -141,6 +141,13 @@ def test_chat_add_system_message(image_input): assert chat.messages[0]["role"] == "system" assert chat.messages[0]["content"] == ["prompt", image_input] + # Add a list of dict items with explicit types + chat = Chat(messages=[]) + chat.add_system_message([{"type": "text", "text": "prompt"}, {"type": "image", "image": image_input}]) + assert len(chat.messages) == 1 + assert chat.messages[0]["role"] == "system" + assert chat.messages[0]["content"] == [{"type": "text", "text": "prompt"}, {"type": "image", "image": image_input}] + def test_add_user_message_string(image_input): # Add a string @@ -157,6 +164,13 @@ def test_add_user_message_string(image_input): assert chat.messages[0]["role"] == "user" assert chat.messages[0]["content"] == ["prompt", image_input] + # Add a list of dict items with explicit types + chat = Chat(messages=[]) + chat.add_user_message([{"type": "text", "text": "prompt"}, {"type": "image", "image": image_input}]) + assert len(chat.messages) == 1 + assert chat.messages[0]["role"] == "user" + assert chat.messages[0]["content"] == [{"type": "text", "text": "prompt"}, {"type": "image", "image": image_input}] + def test_add_assistant_message_string(image_input): # Add a string @@ -172,3 +186,10 @@ def test_add_assistant_message_string(image_input): assert len(chat.messages) == 1 assert chat.messages[0]["role"] == "assistant" assert chat.messages[0]["content"] == ["prompt", image_input] + + # Add a list of dict items with explicit types + chat = Chat(messages=[]) + chat.add_assistant_message([{"type": "text", "text": "prompt"}, {"type": "image", "image": image_input}]) + assert len(chat.messages) == 1 + assert chat.messages[0]["role"] == "assistant" + assert chat.messages[0]["content"] == [{"type": "text", "text": "prompt"}, {"type": "image", "image": image_input}]