Skip to content

Commit 74b4583

Browse files
authored
[Release 0.7 Cherry-Pick] Let get_frames_at and get_frames_played_at accept tensor indices … (#882)
1 parent a47f5ed commit 74b4583

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

src/torchcodec/decoders/_video_decoder.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,12 @@ def get_frames_at(self, indices: list[int]) -> FrameBatch:
226226
Returns:
227227
FrameBatch: The frames at the given indices.
228228
"""
229+
if isinstance(indices, torch.Tensor):
230+
# TODO we should avoid converting tensors to lists and just let the
231+
# core ops and C++ code natively accept tensors. See
232+
# https://github.com/pytorch/torchcodec/issues/879
233+
indices = indices.to(torch.int).tolist()
234+
229235
data, pts_seconds, duration_seconds = core.get_frames_at_indices(
230236
self._decoder, frame_indices=indices
231237
)
@@ -301,6 +307,12 @@ def get_frames_played_at(self, seconds: list[float]) -> FrameBatch:
301307
Returns:
302308
FrameBatch: The frames that are played at ``seconds``.
303309
"""
310+
if isinstance(seconds, torch.Tensor):
311+
# TODO we should avoid converting tensors to lists and just let the
312+
# core ops and C++ code natively accept tensors. See
313+
# https://github.com/pytorch/torchcodec/issues/879
314+
seconds = seconds.to(torch.float).tolist()
315+
304316
data, pts_seconds, duration_seconds = core.get_frames_by_pts(
305317
self._decoder, timestamps=seconds
306318
)

test/test_decoders.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,6 +1389,17 @@ def test_10bit_videos_cpu(self, asset):
13891389
# custom_frame_mappings=custom_frame_mappings,
13901390
# )
13911391

1392+
def test_get_frames_at_tensor_indices(self):
1393+
# Non-regression test for tensor support in get_frames_at() and
1394+
# get_frames_played_at()
1395+
decoder = VideoDecoder(NASA_VIDEO.path)
1396+
1397+
decoder.get_frames_at(torch.tensor([0, 10], dtype=torch.int))
1398+
decoder.get_frames_at(torch.tensor([0, 10], dtype=torch.float))
1399+
1400+
decoder.get_frames_played_at(torch.tensor([0, 1], dtype=torch.int))
1401+
decoder.get_frames_played_at(torch.tensor([0, 1], dtype=torch.float))
1402+
13921403

13931404
class TestAudioDecoder:
13941405
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32))

0 commit comments

Comments
 (0)