@@ -602,25 +602,34 @@ FrameOutput SingleStreamDecoder::getFrameAtIndexInternal(
602
602
}
603
603
604
604
FrameBatchOutput SingleStreamDecoder::getFramesAtIndices (
605
- const std::vector< int64_t > & frameIndices) {
605
+ const torch::Tensor & frameIndices) {
606
606
validateActiveStream (AVMEDIA_TYPE_VIDEO);
607
607
608
- auto indicesAreSorted =
609
- std::is_sorted (frameIndices.begin (), frameIndices.end ());
608
+ auto frameIndicesAccessor = frameIndices.accessor <int64_t , 1 >();
609
+
610
+ bool indicesAreSorted = true ;
611
+ for (int64_t i = 1 ; i < frameIndices.numel (); ++i) {
612
+ if (frameIndicesAccessor[i] < frameIndicesAccessor[i - 1 ]) {
613
+ indicesAreSorted = false ;
614
+ break ;
615
+ }
616
+ }
610
617
611
618
std::vector<size_t > argsort;
612
619
if (!indicesAreSorted) {
613
620
// if frameIndices is [13, 10, 12, 11]
614
621
// when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want
615
622
// to use to decode the frames
616
623
// and argsort is [ 1, 3, 2, 0]
617
- argsort.resize (frameIndices.size ());
624
+ argsort.resize (frameIndices.numel ());
618
625
for (size_t i = 0 ; i < argsort.size (); ++i) {
619
626
argsort[i] = i;
620
627
}
621
628
std::sort (
622
- argsort.begin (), argsort.end (), [&frameIndices](size_t a, size_t b) {
623
- return frameIndices[a] < frameIndices[b];
629
+ argsort.begin (),
630
+ argsort.end (),
631
+ [&frameIndicesAccessor](size_t a, size_t b) {
632
+ return frameIndicesAccessor[a] < frameIndicesAccessor[b];
624
633
});
625
634
}
626
635
@@ -629,12 +638,12 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
629
638
const auto & streamInfo = streamInfos_[activeStreamIndex_];
630
639
const auto & videoStreamOptions = streamInfo.videoStreamOptions ;
631
640
FrameBatchOutput frameBatchOutput (
632
- frameIndices.size (), videoStreamOptions, streamMetadata);
641
+ frameIndices.numel (), videoStreamOptions, streamMetadata);
633
642
634
643
auto previousIndexInVideo = -1 ;
635
- for (size_t f = 0 ; f < frameIndices.size (); ++f) {
644
+ for (int64_t f = 0 ; f < frameIndices.numel (); ++f) {
636
645
auto indexInOutput = indicesAreSorted ? f : argsort[f];
637
- auto indexInVideo = frameIndices [indexInOutput];
646
+ auto indexInVideo = frameIndicesAccessor [indexInOutput];
638
647
639
648
if ((f > 0 ) && (indexInVideo == previousIndexInVideo)) {
640
649
// Avoid decoding the same frame twice
@@ -776,7 +785,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
776
785
frameIndices[i] = secondsToIndexLowerBound (frameSeconds);
777
786
}
778
787
779
- return getFramesAtIndices (frameIndices);
788
+ // TODO: Support tensors natively instead of a vector to avoid a copy.
789
+ return getFramesAtIndices (torch::tensor (frameIndices));
780
790
}
781
791
782
792
FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange (
0 commit comments