Skip to content

Commit 7ca81c7

Browse files
author
pytorchbot
committed
2025-09-09 nightly release (ee77f57)
1 parent 0f840e5 commit 7ca81c7

File tree

4 files changed

+28
-6
lines changed

4 files changed

+28
-6
lines changed

src/torchcodec/decoders/_video_decoder.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,12 @@ def get_frames_at(self, indices: list[int]) -> FrameBatch:
247247
Returns:
248248
FrameBatch: The frames at the given indices.
249249
"""
250+
if isinstance(indices, torch.Tensor):
251+
# TODO we should avoid converting tensors to lists and just let the
252+
# core ops and C++ code natively accept tensors. See
253+
# https://github.com/pytorch/torchcodec/issues/879
254+
indices = indices.to(torch.int).tolist()
255+
250256
data, pts_seconds, duration_seconds = core.get_frames_at_indices(
251257
self._decoder, frame_indices=indices
252258
)
@@ -322,6 +328,12 @@ def get_frames_played_at(self, seconds: list[float]) -> FrameBatch:
322328
Returns:
323329
FrameBatch: The frames that are played at ``seconds``.
324330
"""
331+
if isinstance(seconds, torch.Tensor):
332+
# TODO we should avoid converting tensors to lists and just let the
333+
# core ops and C++ code natively accept tensors. See
334+
# https://github.com/pytorch/torchcodec/issues/879
335+
seconds = seconds.to(torch.float).tolist()
336+
325337
data, pts_seconds, duration_seconds = core.get_frames_by_pts(
326338
self._decoder, timestamps=seconds
327339
)

test/resources/nasa_13013.mp4.stream3.frame000289.pt

Lines changed: 0 additions & 1 deletion
This file was deleted.

test/test_decoders.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def test_getitem_int(self, num_ffmpeg_threads, device, seek_mode):
180180
ref_frame0 = NASA_VIDEO.get_frame_data_by_index(0).to(device)
181181
ref_frame1 = NASA_VIDEO.get_frame_data_by_index(1).to(device)
182182
ref_frame180 = NASA_VIDEO.get_frame_data_by_index(180).to(device)
183-
ref_frame_last = NASA_VIDEO.get_frame_data_by_index(289).to(device)
183+
ref_frame_last = NASA_VIDEO.get_frame_data_by_index(389).to(device)
184184

185185
assert_frames_equal(ref_frame0, decoder[0])
186186
assert_frames_equal(ref_frame1, decoder[1])
@@ -193,7 +193,7 @@ def test_getitem_numpy_int(self):
193193
ref_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
194194
ref_frame1 = NASA_VIDEO.get_frame_data_by_index(1)
195195
ref_frame180 = NASA_VIDEO.get_frame_data_by_index(180)
196-
ref_frame_last = NASA_VIDEO.get_frame_data_by_index(289)
196+
ref_frame_last = NASA_VIDEO.get_frame_data_by_index(389)
197197

198198
# test against numpy.int64
199199
assert_frames_equal(ref_frame0, decoder[numpy.int64(0)])
@@ -404,7 +404,7 @@ def test_iteration(self, device, seek_mode):
404404
ref_frame9 = NASA_VIDEO.get_frame_data_by_index(9).to(device)
405405
ref_frame35 = NASA_VIDEO.get_frame_data_by_index(35).to(device)
406406
ref_frame180 = NASA_VIDEO.get_frame_data_by_index(180).to(device)
407-
ref_frame_last = NASA_VIDEO.get_frame_data_by_index(289).to(device)
407+
ref_frame_last = NASA_VIDEO.get_frame_data_by_index(389).to(device)
408408

409409
# Access an arbitrary frame to make sure that the later iteration
410410
# still works as expected. The underlying C++ decoder object is
@@ -1390,6 +1390,17 @@ def test_custom_frame_mappings_init_fails_invalid_json(self, tmp_path, device):
13901390
custom_frame_mappings=custom_frame_mappings,
13911391
)
13921392

1393+
def test_get_frames_at_tensor_indices(self):
1394+
# Non-regression test for tensor support in get_frames_at() and
1395+
# get_frames_played_at()
1396+
decoder = VideoDecoder(NASA_VIDEO.path)
1397+
1398+
decoder.get_frames_at(torch.tensor([0, 10], dtype=torch.int))
1399+
decoder.get_frames_at(torch.tensor([0, 10], dtype=torch.float))
1400+
1401+
decoder.get_frames_played_at(torch.tensor([0, 1], dtype=torch.int))
1402+
decoder.get_frames_played_at(torch.tensor([0, 1], dtype=torch.float))
1403+
13931404

13941405
class TestAudioDecoder:
13951406
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32))

test/test_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def test_throws_exception_at_eof(self, device):
337337

338338
seek_to_pts(decoder, 12.979633)
339339
last_frame, _, _ = get_next_frame(decoder)
340-
reference_last_frame = NASA_VIDEO.get_frame_data_by_index(289)
340+
reference_last_frame = NASA_VIDEO.get_frame_data_by_index(389)
341341
assert_frames_equal(last_frame, reference_last_frame.to(device))
342342
with pytest.raises(IndexError, match="no more frames"):
343343
get_next_frame(decoder)
@@ -1059,7 +1059,7 @@ def seek(self, offset: int, whence: int) -> int:
10591059
seek_to_pts(decoder, 12.979633)
10601060

10611061
frame_last, *_ = get_next_frame(decoder)
1062-
reference_frame_last = NASA_VIDEO.get_frame_data_by_index(289)
1062+
reference_frame_last = NASA_VIDEO.get_frame_data_by_index(389)
10631063
assert_frames_equal(frame_last, reference_frame_last.to(device))
10641064

10651065
assert file_counter.num_seeks > initialization_seeks

0 commit comments

Comments
 (0)