@@ -180,7 +180,7 @@ def test_getitem_int(self, num_ffmpeg_threads, device, seek_mode):
180
180
ref_frame0 = NASA_VIDEO .get_frame_data_by_index (0 ).to (device )
181
181
ref_frame1 = NASA_VIDEO .get_frame_data_by_index (1 ).to (device )
182
182
ref_frame180 = NASA_VIDEO .get_frame_data_by_index (180 ).to (device )
183
- ref_frame_last = NASA_VIDEO .get_frame_data_by_index (289 ).to (device )
183
+ ref_frame_last = NASA_VIDEO .get_frame_data_by_index (389 ).to (device )
184
184
185
185
assert_frames_equal (ref_frame0 , decoder [0 ])
186
186
assert_frames_equal (ref_frame1 , decoder [1 ])
@@ -193,7 +193,7 @@ def test_getitem_numpy_int(self):
193
193
ref_frame0 = NASA_VIDEO .get_frame_data_by_index (0 )
194
194
ref_frame1 = NASA_VIDEO .get_frame_data_by_index (1 )
195
195
ref_frame180 = NASA_VIDEO .get_frame_data_by_index (180 )
196
- ref_frame_last = NASA_VIDEO .get_frame_data_by_index (289 )
196
+ ref_frame_last = NASA_VIDEO .get_frame_data_by_index (389 )
197
197
198
198
# test against numpy.int64
199
199
assert_frames_equal (ref_frame0 , decoder [numpy .int64 (0 )])
@@ -404,7 +404,7 @@ def test_iteration(self, device, seek_mode):
404
404
ref_frame9 = NASA_VIDEO .get_frame_data_by_index (9 ).to (device )
405
405
ref_frame35 = NASA_VIDEO .get_frame_data_by_index (35 ).to (device )
406
406
ref_frame180 = NASA_VIDEO .get_frame_data_by_index (180 ).to (device )
407
- ref_frame_last = NASA_VIDEO .get_frame_data_by_index (289 ).to (device )
407
+ ref_frame_last = NASA_VIDEO .get_frame_data_by_index (389 ).to (device )
408
408
409
409
# Access an arbitrary frame to make sure that the later iteration
410
410
# still works as expected. The underlying C++ decoder object is
@@ -1390,6 +1390,17 @@ def test_custom_frame_mappings_init_fails_invalid_json(self, tmp_path, device):
1390
1390
custom_frame_mappings = custom_frame_mappings ,
1391
1391
)
1392
1392
1393
+ def test_get_frames_at_tensor_indices (self ):
1394
+ # Non-regression test for tensor support in get_frames_at() and
1395
+ # get_frames_played_at()
1396
+ decoder = VideoDecoder (NASA_VIDEO .path )
1397
+
1398
+ decoder .get_frames_at (torch .tensor ([0 , 10 ], dtype = torch .int ))
1399
+ decoder .get_frames_at (torch .tensor ([0 , 10 ], dtype = torch .float ))
1400
+
1401
+ decoder .get_frames_played_at (torch .tensor ([0 , 1 ], dtype = torch .int ))
1402
+ decoder .get_frames_played_at (torch .tensor ([0 , 1 ], dtype = torch .float ))
1403
+
1393
1404
1394
1405
class TestAudioDecoder :
1395
1406
@pytest .mark .parametrize ("asset" , (NASA_AUDIO , NASA_AUDIO_MP3 , SINE_MONO_S32 ))
0 commit comments