Skip to content

Commit b3e2e2c

Browse files
authored
Added proper tensor support for get_frames_at()
Differential Revision: D83506846 Pull Request resolved: #915
1 parent 37e361a commit b3e2e2c

File tree

9 files changed

+53
-31
lines changed

9 files changed

+53
-31
lines changed

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -602,25 +602,34 @@ FrameOutput SingleStreamDecoder::getFrameAtIndexInternal(
602602
}
603603

604604
FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
605-
const std::vector<int64_t>& frameIndices) {
605+
const torch::Tensor& frameIndices) {
606606
validateActiveStream(AVMEDIA_TYPE_VIDEO);
607607

608-
auto indicesAreSorted =
609-
std::is_sorted(frameIndices.begin(), frameIndices.end());
608+
auto frameIndicesAccessor = frameIndices.accessor<int64_t, 1>();
609+
610+
bool indicesAreSorted = true;
611+
for (int64_t i = 1; i < frameIndices.numel(); ++i) {
612+
if (frameIndicesAccessor[i] < frameIndicesAccessor[i - 1]) {
613+
indicesAreSorted = false;
614+
break;
615+
}
616+
}
610617

611618
std::vector<size_t> argsort;
612619
if (!indicesAreSorted) {
613620
// if frameIndices is [13, 10, 12, 11]
614621
// when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want
615622
// to use to decode the frames
616623
// and argsort is [ 1, 3, 2, 0]
617-
argsort.resize(frameIndices.size());
624+
argsort.resize(frameIndices.numel());
618625
for (size_t i = 0; i < argsort.size(); ++i) {
619626
argsort[i] = i;
620627
}
621628
std::sort(
622-
argsort.begin(), argsort.end(), [&frameIndices](size_t a, size_t b) {
623-
return frameIndices[a] < frameIndices[b];
629+
argsort.begin(),
630+
argsort.end(),
631+
[&frameIndicesAccessor](size_t a, size_t b) {
632+
return frameIndicesAccessor[a] < frameIndicesAccessor[b];
624633
});
625634
}
626635

@@ -629,12 +638,12 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
629638
const auto& streamInfo = streamInfos_[activeStreamIndex_];
630639
const auto& videoStreamOptions = streamInfo.videoStreamOptions;
631640
FrameBatchOutput frameBatchOutput(
632-
frameIndices.size(), videoStreamOptions, streamMetadata);
641+
frameIndices.numel(), videoStreamOptions, streamMetadata);
633642

634643
auto previousIndexInVideo = -1;
635-
for (size_t f = 0; f < frameIndices.size(); ++f) {
644+
for (int64_t f = 0; f < frameIndices.numel(); ++f) {
636645
auto indexInOutput = indicesAreSorted ? f : argsort[f];
637-
auto indexInVideo = frameIndices[indexInOutput];
646+
auto indexInVideo = frameIndicesAccessor[indexInOutput];
638647

639648
if ((f > 0) && (indexInVideo == previousIndexInVideo)) {
640649
// Avoid decoding the same frame twice
@@ -776,7 +785,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
776785
frameIndices[i] = secondsToIndexLowerBound(frameSeconds);
777786
}
778787

779-
return getFramesAtIndices(frameIndices);
788+
// TODO: Support tensors natively instead of a vector to avoid a copy.
789+
return getFramesAtIndices(torch::tensor(frameIndices));
780790
}
781791

782792
FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(

src/torchcodec/_core/SingleStreamDecoder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class SingleStreamDecoder {
106106

107107
// Returns frames at the given indices for a given stream as a single stacked
108108
// Tensor.
109-
FrameBatchOutput getFramesAtIndices(const std::vector<int64_t>& frameIndices);
109+
FrameBatchOutput getFramesAtIndices(const torch::Tensor& frameIndices);
110110

111111
// Returns frames within a given range. The range is defined by [start, stop).
112112
// The values retrieved from the range are: [start, start+step,

src/torchcodec/_core/custom_ops.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
5555
m.def(
5656
"get_frame_at_index(Tensor(a!) decoder, *, int frame_index) -> (Tensor, Tensor, Tensor)");
5757
m.def(
58-
"get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices) -> (Tensor, Tensor, Tensor)");
58+
"get_frames_at_indices(Tensor(a!) decoder, *, Tensor frame_indices) -> (Tensor, Tensor, Tensor)");
5959
m.def(
6060
"get_frames_in_range(Tensor(a!) decoder, *, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
6161
m.def(
@@ -378,11 +378,9 @@ OpsFrameOutput get_frame_at_index(at::Tensor& decoder, int64_t frame_index) {
378378
// Return the frames at given indices for a given stream
379379
OpsFrameBatchOutput get_frames_at_indices(
380380
at::Tensor& decoder,
381-
at::IntArrayRef frame_indices) {
381+
const at::Tensor& frame_indices) {
382382
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
383-
std::vector<int64_t> frameIndicesVec(
384-
frame_indices.begin(), frame_indices.end());
385-
auto result = videoDecoder->getFramesAtIndices(frameIndicesVec);
383+
auto result = videoDecoder->getFramesAtIndices(frame_indices);
386384
return makeOpsFrameBatchOutput(result);
387385
}
388386

src/torchcodec/_core/ops.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ def load_torchcodec_shared_libraries():
114114
get_next_frame = torch.ops.torchcodec_ns.get_next_frame.default
115115
get_frame_at_pts = torch.ops.torchcodec_ns.get_frame_at_pts.default
116116
get_frame_at_index = torch.ops.torchcodec_ns.get_frame_at_index.default
117-
get_frames_at_indices = torch.ops.torchcodec_ns.get_frames_at_indices.default
117+
_get_frames_at_indices_tensor_input = (
118+
torch.ops.torchcodec_ns.get_frames_at_indices.default
119+
)
118120
get_frames_by_pts = torch.ops.torchcodec_ns.get_frames_by_pts.default
119121
get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default
120122
get_frames_by_pts_in_range = torch.ops.torchcodec_ns.get_frames_by_pts_in_range.default
@@ -198,6 +200,18 @@ def encode_audio_to_file_like(
198200
)
199201

200202

203+
def get_frames_at_indices(
204+
decoder: torch.Tensor, *, frame_indices: Union[torch.Tensor, list[int]]
205+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
206+
if isinstance(frame_indices, torch.Tensor):
207+
# Ensure indices is the correct dtype (int64)
208+
frame_indices = frame_indices.to(torch.int64)
209+
else:
210+
# Convert list to tensor for dispatch
211+
frame_indices = torch.tensor(frame_indices)
212+
return _get_frames_at_indices_tensor_input(decoder, frame_indices=frame_indices)
213+
214+
201215
# ==============================
202216
# Abstract impl for the operators. Needed by torch.compile.
203217
# ==============================
@@ -371,9 +385,7 @@ def get_frame_at_index_abstract(
371385

372386
@register_fake("torchcodec_ns::get_frames_at_indices")
373387
def get_frames_at_indices_abstract(
374-
decoder: torch.Tensor,
375-
*,
376-
frame_indices: List[int],
388+
decoder: torch.Tensor, *, frame_indices: Union[torch.Tensor, List[int]]
377389
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
378390
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
379391
return (

src/torchcodec/_samplers/video_clip_sampler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,8 @@ def _get_clips_for_index_based_sampling(
227227
clip_start_idx + i * index_based_sampler_args.video_frame_dilation
228228
for i in range(index_based_sampler_args.frames_per_clip)
229229
]
230+
# Need torch.stack to convert List[Tensor[int]] into 1D Tensor[int]
231+
batch_indexes = torch.stack(batch_indexes)
230232
frames, *_ = get_frames_at_indices(
231233
video_decoder,
232234
frame_indices=batch_indexes,

src/torchcodec/decoders/_video_decoder.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -240,24 +240,20 @@ def get_frame_at(self, index: int) -> Frame:
240240
duration_seconds=duration_seconds.item(),
241241
)
242242

243-
def get_frames_at(self, indices: list[int]) -> FrameBatch:
243+
def get_frames_at(self, indices: Union[torch.Tensor, list[int]]) -> FrameBatch:
244244
"""Return frames at the given indices.
245245
246246
Args:
247-
indices (list of int): The indices of the frames to retrieve.
247+
indices (torch.Tensor or list of int): The indices of the frames to retrieve.
248248
249249
Returns:
250250
FrameBatch: The frames at the given indices.
251251
"""
252-
if isinstance(indices, torch.Tensor):
253-
# TODO we should avoid converting tensors to lists and just let the
254-
# core ops and C++ code natively accept tensors. See
255-
# https://github.com/pytorch/torchcodec/issues/879
256-
indices = indices.to(torch.int).tolist()
257252

258253
data, pts_seconds, duration_seconds = core.get_frames_at_indices(
259254
self._decoder, frame_indices=indices
260255
)
256+
261257
return FrameBatch(
262258
data=data,
263259
pts_seconds=pts_seconds,

test/VideoDecoderTest.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,8 @@ TEST_P(SingleStreamDecoderTest, DecodesFramesInABatchInNCHW) {
222222
*ourDecoder->getContainerMetadata().bestVideoStreamIndex;
223223
ourDecoder->addVideoStream(bestVideoStreamIndex);
224224
// Frame with index 180 corresponds to timestamp 6.006.
225-
auto output = ourDecoder->getFramesAtIndices({0, 180});
225+
auto frameIndices = torch::tensor({0, 180});
226+
auto output = ourDecoder->getFramesAtIndices(frameIndices);
226227
auto tensor = output.data;
227228
EXPECT_EQ(tensor.sizes(), std::vector<long>({2, 3, 270, 480}));
228229

@@ -246,7 +247,8 @@ TEST_P(SingleStreamDecoderTest, DecodesFramesInABatchInNHWC) {
246247
videoStreamOptions.dimensionOrder = "NHWC";
247248
ourDecoder->addVideoStream(bestVideoStreamIndex, videoStreamOptions);
248249
// Frame with index 180 corresponds to timestamp 6.006.
249-
auto output = ourDecoder->getFramesAtIndices({0, 180});
250+
auto frameIndices = torch::tensor({0, 180});
251+
auto output = ourDecoder->getFramesAtIndices(frameIndices);
250252
auto tensor = output.data;
251253
EXPECT_EQ(tensor.sizes(), std::vector<long>({2, 270, 480, 3}));
252254

test/test_decoders.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,9 @@ def test_get_frames_at_fails(self, device, seek_mode):
569569
with pytest.raises(IndexError, match="Invalid frame index=390"):
570570
decoder.get_frames_at([390])
571571

572-
with pytest.raises(RuntimeError, match="Expected a value of type"):
572+
with pytest.raises(
573+
RuntimeError, match="expected scalar type Long but found Float"
574+
):
573575
decoder.get_frames_at([0.3])
574576

575577
@pytest.mark.parametrize("device", all_supported_devices())

test/test_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1209,7 +1209,7 @@ def seek(self, offset: int, whence: int) -> int:
12091209
torch.manual_seed(0)
12101210
indices = torch.randint(
12111211
0, len(NASA_VIDEO.frames[NASA_VIDEO.default_stream_index]), size=(50,)
1212-
).tolist()
1212+
)
12131213

12141214
frames_file_like, *_ = get_frames_at_indices(
12151215
decoder_file_like, frame_indices=indices

0 commit comments

Comments
 (0)