2
2
3
3
from PIL import Image as PILImage
4
4
from outlines_core import Index , Vocabulary
5
- from transformers import AutoTokenizer , LogitsProcessorList
5
+ from transformers import (
6
+ AutoProcessor ,
7
+ LogitsProcessorList ,
8
+ )
6
9
7
- from outlines .inputs import Chat , Image , Video
10
+ from outlines .inputs import Audio , Chat , Image , Video
8
11
from outlines .models .transformers import TransformersMultiModalTypeAdapter
9
12
from outlines .backends .outlines_core import OutlinesCoreLogitsProcessor
10
13
11
14
12
- MODEL_NAME = "erwanf/gpt2-mini "
15
+ MODEL_NAME = "trl-internal-testing/tiny-LlavaForConditionalGeneration "
13
16
14
17
15
18
@pytest .fixture
16
19
def adapter ():
17
- tokenizer = AutoTokenizer .from_pretrained (MODEL_NAME )
20
+ processor = AutoProcessor .from_pretrained (MODEL_NAME )
21
+ tokenizer = processor .tokenizer
18
22
type_adapter = TransformersMultiModalTypeAdapter (tokenizer = tokenizer )
19
- chat_template = '{% for message in messages %}{{ message.role }}: {{ message.content }}{% endfor %}'
20
- type_adapter .tokenizer .chat_template = chat_template
21
23
22
24
return type_adapter
23
25
@@ -39,6 +41,18 @@ def image():
39
41
return image
40
42
41
43
44
+ @pytest .fixture
45
+ def video ():
46
+ # Simple mock video data
47
+ return "mock_video_data"
48
+
49
+
50
+ @pytest .fixture
51
+ def audio ():
52
+ # Simple mock audio data
53
+ return "mock_audio_data"
54
+
55
+
42
56
def test_transformers_multimodal_type_adapter_format_input (adapter , image ):
43
57
with pytest .raises (TypeError ):
44
58
adapter .format_input ("hello" )
@@ -189,6 +203,23 @@ def test_transformers_multimodal_type_adapter_format_input_invalid_content_type(
189
203
adapter .format_input (chat_prompt )
190
204
191
205
206
+ def test_transformers_multimodal_type_adapter_format_asset_for_template (adapter , image , video , audio ):
207
+ # Test Image asset
208
+ image_asset = Image (image )
209
+ formatted_image = adapter ._format_asset_for_template (image_asset )
210
+ assert formatted_image == {"type" : "image" , "image" : image_asset }
211
+
212
+ # Test Video asset
213
+ video_asset = Video (video )
214
+ formatted_video = adapter ._format_asset_for_template (video_asset )
215
+ assert formatted_video == {"type" : "video" , "video" : video_asset }
216
+
217
+ # Test Audio asset
218
+ audio_asset = Audio (audio )
219
+ formatted_audio = adapter ._format_asset_for_template (audio_asset )
220
+ assert formatted_audio == {"type" : "audio" , "audio" : audio_asset }
221
+
222
+
192
223
def test_transformers_multimodal_type_adapter_format_asset_for_template_invalid_type (adapter ):
193
224
class MockUnsupportedAsset :
194
225
pass
0 commit comments