55
55
SINE_MONO_S32 ,
56
56
SINE_MONO_S32_44100 ,
57
57
SINE_MONO_S32_8000 ,
58
+ unsplit_device_str ,
58
59
)
59
60
60
61
torch ._dynamo .config .capture_dynamic_output_shape_ops = True
@@ -66,7 +67,8 @@ class TestVideoDecoderOps:
66
67
@pytest .mark .parametrize ("device" , all_supported_devices ())
67
68
def test_seek_and_next (self , device ):
68
69
decoder = create_from_file (str (NASA_VIDEO .path ))
69
- add_video_stream (decoder , device = device )
70
+ device , device_variant = unsplit_device_str (device )
71
+ add_video_stream (decoder , device = device , device_variant = device_variant )
70
72
frame0 , _ , _ = get_next_frame (decoder )
71
73
reference_frame0 = NASA_VIDEO .get_frame_data_by_index (0 )
72
74
assert_frames_equal (frame0 , reference_frame0 .to (device ))
@@ -83,7 +85,8 @@ def test_seek_and_next(self, device):
83
85
@pytest .mark .parametrize ("device" , all_supported_devices ())
84
86
def test_seek_to_negative_pts (self , device ):
85
87
decoder = create_from_file (str (NASA_VIDEO .path ))
86
- add_video_stream (decoder , device = device )
88
+ device , device_variant = unsplit_device_str (device )
89
+ add_video_stream (decoder , device = device , device_variant = device_variant )
87
90
frame0 , _ , _ = get_next_frame (decoder )
88
91
reference_frame0 = NASA_VIDEO .get_frame_data_by_index (0 )
89
92
assert_frames_equal (frame0 , reference_frame0 .to (device ))
@@ -95,7 +98,8 @@ def test_seek_to_negative_pts(self, device):
95
98
@pytest .mark .parametrize ("device" , all_supported_devices ())
96
99
def test_get_frame_at_pts (self , device ):
97
100
decoder = create_from_file (str (NASA_VIDEO .path ))
98
- add_video_stream (decoder , device = device )
101
+ device , device_variant = unsplit_device_str (device )
102
+ add_video_stream (decoder , device = device , device_variant = device_variant )
99
103
# This frame has pts=6.006 and duration=0.033367, so it should be visible
100
104
# at timestamps in the range [6.006, 6.039367) (not including the last timestamp).
101
105
frame6 , _ , _ = get_frame_at_pts (decoder , 6.006 )
@@ -119,7 +123,8 @@ def test_get_frame_at_pts(self, device):
119
123
@pytest .mark .parametrize ("device" , all_supported_devices ())
120
124
def test_get_frame_at_index (self , device ):
121
125
decoder = create_from_file (str (NASA_VIDEO .path ))
122
- add_video_stream (decoder , device = device )
126
+ device , device_variant = unsplit_device_str (device )
127
+ add_video_stream (decoder , device = device , device_variant = device_variant )
123
128
frame0 , _ , _ = get_frame_at_index (decoder , frame_index = 0 )
124
129
reference_frame0 = NASA_VIDEO .get_frame_data_by_index (0 )
125
130
assert_frames_equal (frame0 , reference_frame0 .to (device ))
@@ -137,7 +142,8 @@ def test_get_frame_at_index(self, device):
137
142
@pytest .mark .parametrize ("device" , all_supported_devices ())
138
143
def test_get_frame_with_info_at_index (self , device ):
139
144
decoder = create_from_file (str (NASA_VIDEO .path ))
140
- add_video_stream (decoder , device = device )
145
+ device , device_variant = unsplit_device_str (device )
146
+ add_video_stream (decoder , device = device , device_variant = device_variant )
141
147
frame6 , pts , duration = get_frame_at_index (decoder , frame_index = 180 )
142
148
reference_frame6 = NASA_VIDEO .get_frame_data_by_index (
143
149
INDEX_OF_FRAME_AT_6_SECONDS
@@ -149,7 +155,8 @@ def test_get_frame_with_info_at_index(self, device):
149
155
@pytest .mark .parametrize ("device" , all_supported_devices ())
150
156
def test_get_frames_at_indices (self , device ):
151
157
decoder = create_from_file (str (NASA_VIDEO .path ))
152
- add_video_stream (decoder , device = device )
158
+ device , device_variant = unsplit_device_str (device )
159
+ add_video_stream (decoder , device = device , device_variant = device_variant )
153
160
frames0and180 , * _ = get_frames_at_indices (decoder , frame_indices = [0 , 180 ])
154
161
reference_frame0 = NASA_VIDEO .get_frame_data_by_index (0 )
155
162
reference_frame180 = NASA_VIDEO .get_frame_data_by_index (
@@ -161,7 +168,8 @@ def test_get_frames_at_indices(self, device):
161
168
@pytest .mark .parametrize ("device" , all_supported_devices ())
162
169
def test_get_frames_at_indices_unsorted_indices (self , device ):
163
170
decoder = create_from_file (str (NASA_VIDEO .path ))
164
- _add_video_stream (decoder , device = device )
171
+ device , device_variant = unsplit_device_str (device )
172
+ add_video_stream (decoder , device = device , device_variant = device_variant )
165
173
166
174
frame_indices = [2 , 0 , 1 , 0 , 2 ]
167
175
@@ -188,7 +196,8 @@ def test_get_frames_at_indices_unsorted_indices(self, device):
188
196
@pytest .mark .parametrize ("device" , all_supported_devices ())
189
197
def test_get_frames_at_indices_negative_indices (self , device ):
190
198
decoder = create_from_file (str (NASA_VIDEO .path ))
191
- add_video_stream (decoder , device = device )
199
+ device , device_variant = unsplit_device_str (device )
200
+ add_video_stream (decoder , device = device , device_variant = device_variant )
192
201
frames389and387and1 , * _ = get_frames_at_indices (
193
202
decoder , frame_indices = [- 1 , - 3 , - 389 ]
194
203
)
@@ -202,7 +211,8 @@ def test_get_frames_at_indices_negative_indices(self, device):
202
211
@pytest .mark .parametrize ("device" , all_supported_devices ())
203
212
def test_get_frames_at_indices_fail_on_invalid_negative_indices (self , device ):
204
213
decoder = create_from_file (str (NASA_VIDEO .path ))
205
- add_video_stream (decoder , device = device )
214
+ device , device_variant = unsplit_device_str (device )
215
+ add_video_stream (decoder , device = device , device_variant = device_variant )
206
216
with pytest .raises (
207
217
IndexError ,
208
218
match = "negative indices must have an absolute value less than the number of frames" ,
@@ -214,7 +224,8 @@ def test_get_frames_at_indices_fail_on_invalid_negative_indices(self, device):
214
224
@pytest .mark .parametrize ("device" , all_supported_devices ())
215
225
def test_get_frames_by_pts (self , device ):
216
226
decoder = create_from_file (str (NASA_VIDEO .path ))
217
- _add_video_stream (decoder , device = device )
227
+ device , device_variant = unsplit_device_str (device )
228
+ add_video_stream (decoder , device = device , device_variant = device_variant )
218
229
219
230
# Note: 13.01 should give the last video frame for the NASA video
220
231
timestamps = [2 , 0 , 1 , 0 + 1e-3 , 13.01 , 2 + 1e-3 ]
@@ -246,7 +257,8 @@ def test_pts_apis_against_index_ref(self, device):
246
257
# APIs exactly where those frames are supposed to start. We assert that
247
258
# we get the expected frame.
248
259
decoder = create_from_file (str (NASA_VIDEO .path ))
249
- add_video_stream (decoder , device = device )
260
+ device , device_variant = unsplit_device_str (device )
261
+ add_video_stream (decoder , device = device , device_variant = device_variant )
250
262
251
263
metadata = get_json_metadata (decoder )
252
264
metadata_dict = json .loads (metadata )
@@ -297,7 +309,8 @@ def test_pts_apis_against_index_ref(self, device):
297
309
@pytest .mark .parametrize ("device" , all_supported_devices ())
298
310
def test_get_frames_in_range (self , device ):
299
311
decoder = create_from_file (str (NASA_VIDEO .path ))
300
- add_video_stream (decoder , device = device )
312
+ device , device_variant = unsplit_device_str (device )
313
+ add_video_stream (decoder , device = device , device_variant = device_variant )
301
314
302
315
# ensure that the degenerate case of a range of size 1 works
303
316
ref_frame0 = NASA_VIDEO .get_frame_data_by_range (0 , 1 )
@@ -337,7 +350,8 @@ def test_get_frames_in_range(self, device):
337
350
@pytest .mark .parametrize ("device" , all_supported_devices ())
338
351
def test_throws_exception_at_eof (self , device ):
339
352
decoder = create_from_file (str (NASA_VIDEO .path ))
340
- add_video_stream (decoder , device = device )
353
+ device , device_variant = unsplit_device_str (device )
354
+ add_video_stream (decoder , device = device , device_variant = device_variant )
341
355
342
356
seek_to_pts (decoder , 12.979633 )
343
357
last_frame , _ , _ = get_next_frame (decoder )
@@ -352,7 +366,8 @@ def test_throws_exception_at_eof(self, device):
352
366
@pytest .mark .parametrize ("device" , all_supported_devices ())
353
367
def test_throws_exception_if_seek_too_far (self , device ):
354
368
decoder = create_from_file (str (NASA_VIDEO .path ))
355
- add_video_stream (decoder , device = device )
369
+ device , device_variant = unsplit_device_str (device )
370
+ add_video_stream (decoder , device = device , device_variant = device_variant )
356
371
# pts=12.979633 is the last frame in the video.
357
372
seek_to_pts (decoder , 12.979633 + 1.0e-4 )
358
373
with pytest .raises (IndexError , match = "no more frames" ):
@@ -363,9 +378,11 @@ def test_compile_seek_and_next(self, device):
363
378
# TODO_OPEN_ISSUE Scott (T180277797): Get this to work with the inductor stack. Right now
364
379
# compilation fails because it can't handle tensors of size unknown at
365
380
# compile-time.
381
+ device , device_variant = unsplit_device_str (device )
382
+
366
383
@torch .compile (fullgraph = True , backend = "eager" )
367
384
def get_frame1_and_frame_time6 (decoder ):
368
- add_video_stream (decoder , device = device )
385
+ add_video_stream (decoder , device = device , device_variant = device_variant )
369
386
frame0 , _ , _ = get_next_frame (decoder )
370
387
seek_to_pts (decoder , 6.0 )
371
388
frame_time6 , _ , _ = get_next_frame (decoder )
@@ -408,7 +425,8 @@ def test_create_decoder(self, create_from, device):
408
425
else :
409
426
raise ValueError ("Oops, double check the parametrization of this test!" )
410
427
411
- add_video_stream (decoder , device = device )
428
+ device , device_variant = unsplit_device_str (device )
429
+ add_video_stream (decoder , device = device , device_variant = device_variant )
412
430
frame0 , _ , _ = get_next_frame (decoder )
413
431
reference_frame0 = NASA_VIDEO .get_frame_data_by_index (0 )
414
432
assert_frames_equal (frame0 , reference_frame0 .to (device ))
@@ -536,9 +554,11 @@ def test_seek_mode_custom_frame_mappings(self, device):
536
554
decoder = create_from_file (
537
555
str (NASA_VIDEO .path ), seek_mode = "custom_frame_mappings"
538
556
)
557
+ device , device_variant = unsplit_device_str (device )
539
558
add_video_stream (
540
559
decoder ,
541
560
device = device ,
561
+ device_variant = device_variant ,
542
562
stream_index = stream_index ,
543
563
custom_frame_mappings = NASA_VIDEO .get_custom_frame_mappings (
544
564
stream_index = stream_index
@@ -1077,7 +1097,8 @@ def seek(self, offset: int, whence: int) -> int:
1077
1097
open (NASA_VIDEO .path , mode = "rb" , buffering = buffering )
1078
1098
)
1079
1099
decoder = create_from_file_like (file_counter , "approximate" )
1080
- add_video_stream (decoder , device = device )
1100
+ device , device_variant = unsplit_device_str (device )
1101
+ add_video_stream (decoder , device = device , device_variant = device_variant )
1081
1102
1082
1103
frame0 , * _ = get_next_frame (decoder )
1083
1104
reference_frame0 = NASA_VIDEO .get_frame_data_by_index (0 )
0 commit comments