Skip to content

Commit 1a5dd96

Browse files
committed
Migrate multimodal type adapter tests to use AutoProcessor for a MM scenario, and added Audio and Video tests
1 parent bef4906 commit 1a5dd96

File tree

1 file changed

+37
-6
lines changed

1 file changed

+37
-6
lines changed

tests/models/test_transformers_multimodal_type_adapter.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,24 @@
22

33
from PIL import Image as PILImage
44
from outlines_core import Index, Vocabulary
5-
from transformers import AutoTokenizer, LogitsProcessorList
5+
from transformers import (
6+
AutoProcessor,
7+
LogitsProcessorList,
8+
)
69

7-
from outlines.inputs import Chat, Image, Video
10+
from outlines.inputs import Audio, Chat, Image, Video
811
from outlines.models.transformers import TransformersMultiModalTypeAdapter
912
from outlines.backends.outlines_core import OutlinesCoreLogitsProcessor
1013

1114

12-
MODEL_NAME = "erwanf/gpt2-mini"
15+
MODEL_NAME = "trl-internal-testing/tiny-LlavaForConditionalGeneration"
1316

1417

1518
@pytest.fixture
1619
def adapter():
17-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
20+
processor = AutoProcessor.from_pretrained(MODEL_NAME)
21+
tokenizer = processor.tokenizer
1822
type_adapter = TransformersMultiModalTypeAdapter(tokenizer=tokenizer)
19-
chat_template = '{% for message in messages %}{{ message.role }}: {{ message.content }}{% endfor %}'
20-
type_adapter.tokenizer.chat_template = chat_template
2123

2224
return type_adapter
2325

@@ -39,6 +41,18 @@ def image():
3941
return image
4042

4143

44+
@pytest.fixture
45+
def video():
46+
# Simple mock video data
47+
return "mock_video_data"
48+
49+
50+
@pytest.fixture
51+
def audio():
52+
# Simple mock audio data
53+
return "mock_audio_data"
54+
55+
4256
def test_transformers_multimodal_type_adapter_format_input(adapter, image):
4357
with pytest.raises(TypeError):
4458
adapter.format_input("hello")
@@ -189,6 +203,23 @@ def test_transformers_multimodal_type_adapter_format_input_invalid_content_type(
189203
adapter.format_input(chat_prompt)
190204

191205

206+
def test_transformers_multimodal_type_adapter_format_asset_for_template(adapter, image, video, audio):
207+
# Test Image asset
208+
image_asset = Image(image)
209+
formatted_image = adapter._format_asset_for_template(image_asset)
210+
assert formatted_image == {"type": "image", "image": image_asset}
211+
212+
# Test Video asset
213+
video_asset = Video(video)
214+
formatted_video = adapter._format_asset_for_template(video_asset)
215+
assert formatted_video == {"type": "video", "video": video_asset}
216+
217+
# Test Audio asset
218+
audio_asset = Audio(audio)
219+
formatted_audio = adapter._format_asset_for_template(audio_asset)
220+
assert formatted_audio == {"type": "audio", "audio": audio_asset}
221+
222+
192223
def test_transformers_multimodal_type_adapter_format_asset_for_template_invalid_type(adapter):
193224
class MockUnsupportedAsset:
194225
pass

0 commit comments

Comments
 (0)