5151 SINE_MONO_S32 ,
5252 SINE_MONO_S32_44100 ,
5353 SINE_MONO_S32_8000 ,
54+ cleanup_device_str ,
5455)
5556
5657torch ._dynamo .config .capture_dynamic_output_shape_ops = True
@@ -63,6 +64,7 @@ class TestVideoDecoderOps:
6364 def test_seek_and_next (self , device ):
6465 decoder = create_from_file (str (NASA_VIDEO .path ))
6566 add_video_stream (decoder , device = device )
67+ device = cleanup_device_str (device )
6668 frame0 , _ , _ = get_next_frame (decoder )
6769 reference_frame0 = NASA_VIDEO .get_frame_data_by_index (0 )
6870 assert_frames_equal (frame0 , reference_frame0 .to (device ))
@@ -80,6 +82,7 @@ def test_seek_and_next(self, device):
8082 def test_seek_to_negative_pts (self , device ):
8183 decoder = create_from_file (str (NASA_VIDEO .path ))
8284 add_video_stream (decoder , device = device )
85+ device = cleanup_device_str (device )
8386 frame0 , _ , _ = get_next_frame (decoder )
8487 reference_frame0 = NASA_VIDEO .get_frame_data_by_index (0 )
8588 assert_frames_equal (frame0 , reference_frame0 .to (device ))
@@ -92,6 +95,7 @@ def test_seek_to_negative_pts(self, device):
9295 def test_get_frame_at_pts (self , device ):
9396 decoder = create_from_file (str (NASA_VIDEO .path ))
9497 add_video_stream (decoder , device = device )
98+ device = cleanup_device_str (device )
9599 # This frame has pts=6.006 and duration=0.033367, so it should be visible
96100 # at timestamps in the range [6.006, 6.039367) (not including the last timestamp).
97101 frame6 , _ , _ = get_frame_at_pts (decoder , 6.006 )
@@ -116,6 +120,7 @@ def test_get_frame_at_pts(self, device):
116120 def test_get_frame_at_index (self , device ):
117121 decoder = create_from_file (str (NASA_VIDEO .path ))
118122 add_video_stream (decoder , device = device )
123+ device = cleanup_device_str (device )
119124 frame0 , _ , _ = get_frame_at_index (decoder , frame_index = 0 )
120125 reference_frame0 = NASA_VIDEO .get_frame_data_by_index (0 )
121126 assert_frames_equal (frame0 , reference_frame0 .to (device ))
@@ -134,6 +139,7 @@ def test_get_frame_at_index(self, device):
134139 def test_get_frame_with_info_at_index (self , device ):
135140 decoder = create_from_file (str (NASA_VIDEO .path ))
136141 add_video_stream (decoder , device = device )
142+ device = cleanup_device_str (device )
137143 frame6 , pts , duration = get_frame_at_index (decoder , frame_index = 180 )
138144 reference_frame6 = NASA_VIDEO .get_frame_data_by_index (
139145 INDEX_OF_FRAME_AT_6_SECONDS
@@ -146,6 +152,7 @@ def test_get_frame_with_info_at_index(self, device):
146152 def test_get_frames_at_indices (self , device ):
147153 decoder = create_from_file (str (NASA_VIDEO .path ))
148154 add_video_stream (decoder , device = device )
155+ device = cleanup_device_str (device )
149156 frames0and180 , * _ = get_frames_at_indices (decoder , frame_indices = [0 , 180 ])
150157 reference_frame0 = NASA_VIDEO .get_frame_data_by_index (0 )
151158 reference_frame180 = NASA_VIDEO .get_frame_data_by_index (
@@ -158,6 +165,7 @@ def test_get_frames_at_indices(self, device):
158165 def test_get_frames_at_indices_unsorted_indices (self , device ):
159166 decoder = create_from_file (str (NASA_VIDEO .path ))
160167 _add_video_stream (decoder , device = device )
168+ device = cleanup_device_str (device )
161169
162170 frame_indices = [2 , 0 , 1 , 0 , 2 ]
163171
@@ -185,6 +193,7 @@ def test_get_frames_at_indices_unsorted_indices(self, device):
185193 def test_get_frames_at_indices_negative_indices (self , device ):
186194 decoder = create_from_file (str (NASA_VIDEO .path ))
187195 add_video_stream (decoder , device = device )
196+ device = cleanup_device_str (device )
188197 frames389and387and1 , * _ = get_frames_at_indices (
189198 decoder , frame_indices = [- 1 , - 3 , - 389 ]
190199 )
@@ -199,6 +208,7 @@ def test_get_frames_at_indices_negative_indices(self, device):
199208 def test_get_frames_at_indices_fail_on_invalid_negative_indices (self , device ):
200209 decoder = create_from_file (str (NASA_VIDEO .path ))
201210 add_video_stream (decoder , device = device )
211+ device = cleanup_device_str (device )
202212 with pytest .raises (
203213 IndexError ,
204214 match = "negative indices must have an absolute value less than the number of frames" ,
@@ -211,6 +221,7 @@ def test_get_frames_at_indices_fail_on_invalid_negative_indices(self, device):
211221 def test_get_frames_by_pts (self , device ):
212222 decoder = create_from_file (str (NASA_VIDEO .path ))
213223 _add_video_stream (decoder , device = device )
224+ device = cleanup_device_str (device )
214225
215226 # Note: 13.01 should give the last video frame for the NASA video
216227 timestamps = [2 , 0 , 1 , 0 + 1e-3 , 13.01 , 2 + 1e-3 ]
@@ -243,6 +254,7 @@ def test_pts_apis_against_index_ref(self, device):
243254 # we get the expected frame.
244255 decoder = create_from_file (str (NASA_VIDEO .path ))
245256 add_video_stream (decoder , device = device )
257+ device = cleanup_device_str (device )
246258
247259 metadata = get_json_metadata (decoder )
248260 metadata_dict = json .loads (metadata )
@@ -294,6 +306,7 @@ def test_pts_apis_against_index_ref(self, device):
294306 def test_get_frames_in_range (self , device ):
295307 decoder = create_from_file (str (NASA_VIDEO .path ))
296308 add_video_stream (decoder , device = device )
309+ device = cleanup_device_str (device )
297310
298311 # ensure that the degenerate case of a range of size 1 works
299312 ref_frame0 = NASA_VIDEO .get_frame_data_by_range (0 , 1 )
@@ -334,6 +347,7 @@ def test_get_frames_in_range(self, device):
334347 def test_throws_exception_at_eof (self , device ):
335348 decoder = create_from_file (str (NASA_VIDEO .path ))
336349 add_video_stream (decoder , device = device )
350+ device = cleanup_device_str (device )
337351
338352 seek_to_pts (decoder , 12.979633 )
339353 last_frame , _ , _ = get_next_frame (decoder )
@@ -362,6 +376,7 @@ def test_compile_seek_and_next(self, device):
362376 @torch .compile (fullgraph = True , backend = "eager" )
363377 def get_frame1_and_frame_time6 (decoder ):
364378 add_video_stream (decoder , device = device )
379+ device = cleanup_device_str (device )
365380 frame0 , _ , _ = get_next_frame (decoder )
366381 seek_to_pts (decoder , 6.0 )
367382 frame_time6 , _ , _ = get_next_frame (decoder )
@@ -405,6 +420,7 @@ def test_create_decoder(self, create_from, device):
405420 raise ValueError ("Oops, double check the parametrization of this test!" )
406421
407422 add_video_stream (decoder , device = device )
423+ device = cleanup_device_str (device )
408424 frame0 , _ , _ = get_next_frame (decoder )
409425 reference_frame0 = NASA_VIDEO .get_frame_data_by_index (0 )
410426 assert_frames_equal (frame0 , reference_frame0 .to (device ))
@@ -510,6 +526,7 @@ def test_seek_mode_custom_frame_mappings(self, device):
510526 decoder = create_from_file (
511527 str (NASA_VIDEO .path ), seek_mode = "custom_frame_mappings"
512528 )
529+ device = cleanup_device_str (device )
513530 add_video_stream (
514531 decoder ,
515532 device = device ,
@@ -1042,6 +1059,7 @@ def seek(self, offset: int, whence: int) -> int:
10421059 )
10431060 decoder = create_from_file_like (file_counter , "approximate" )
10441061 add_video_stream (decoder , device = device )
1062+ device = cleanup_device_str (device )
10451063
10461064 frame0 , * _ = get_next_frame (decoder )
10471065 reference_frame0 = NASA_VIDEO .get_frame_data_by_index (0 )
0 commit comments