Skip to content

Commit ac7b387

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into nvdec_receive_send
2 parents f58816a + c2e202d commit ac7b387

File tree

8 files changed

+103
-41
lines changed

8 files changed

+103
-41
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,18 @@ static bool g_cpu = registerDeviceInterface(
1515

1616
} // namespace
1717

18+
CpuDeviceInterface::SwsFrameContext::SwsFrameContext(
19+
int inputWidth,
20+
int inputHeight,
21+
AVPixelFormat inputFormat,
22+
int outputWidth,
23+
int outputHeight)
24+
: inputWidth(inputWidth),
25+
inputHeight(inputHeight),
26+
inputFormat(inputFormat),
27+
outputWidth(outputWidth),
28+
outputHeight(outputHeight) {}
29+
1830
bool CpuDeviceInterface::SwsFrameContext::operator==(
1931
const CpuDeviceInterface::SwsFrameContext& other) const {
2032
return inputWidth == other.inputWidth && inputHeight == other.inputHeight &&
@@ -97,13 +109,12 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
97109
// And we sometimes re-create them because it's possible for frame
98110
// resolution to change mid-stream. Finally, we want to reuse the colorspace
99111
// conversion objects as much as possible for performance reasons.
100-
SwsFrameContext swsFrameContext;
101-
102-
swsFrameContext.inputWidth = avFrame->width;
103-
swsFrameContext.inputHeight = avFrame->height;
104-
swsFrameContext.inputFormat = frameFormat;
105-
swsFrameContext.outputWidth = expectedOutputWidth;
106-
swsFrameContext.outputHeight = expectedOutputHeight;
112+
SwsFrameContext swsFrameContext(
113+
avFrame->width,
114+
avFrame->height,
115+
frameFormat,
116+
expectedOutputWidth,
117+
expectedOutputHeight);
107118

108119
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
109120
expectedOutputHeight, expectedOutputWidth, torch::kCPU));
@@ -128,22 +139,20 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
128139
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
129140
// See comment above in swscale branch about the filterGraphContext_
130141
// creation. creation
131-
FiltersContext filtersContext;
132-
133-
filtersContext.inputWidth = avFrame->width;
134-
filtersContext.inputHeight = avFrame->height;
135-
filtersContext.inputFormat = frameFormat;
136-
filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio;
137-
filtersContext.outputWidth = expectedOutputWidth;
138-
filtersContext.outputHeight = expectedOutputHeight;
139-
filtersContext.outputFormat = AV_PIX_FMT_RGB24;
140-
filtersContext.timeBase = timeBase;
141-
142142
std::stringstream filters;
143143
filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight;
144144
filters << ":sws_flags=bilinear";
145145

146-
filtersContext.filtergraphStr = filters.str();
146+
FiltersContext filtersContext(
147+
avFrame->width,
148+
avFrame->height,
149+
frameFormat,
150+
avFrame->sample_aspect_ratio,
151+
expectedOutputWidth,
152+
expectedOutputHeight,
153+
AV_PIX_FMT_RGB24,
154+
filters.str(),
155+
timeBase);
147156

148157
if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
149158
filterGraphContext_ =

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,19 @@ class CpuDeviceInterface : public DeviceInterface {
4343
const UniqueAVFrame& avFrame);
4444

4545
struct SwsFrameContext {
46-
int inputWidth;
47-
int inputHeight;
48-
AVPixelFormat inputFormat;
49-
int outputWidth;
50-
int outputHeight;
46+
int inputWidth = 0;
47+
int inputHeight = 0;
48+
AVPixelFormat inputFormat = AV_PIX_FMT_NONE;
49+
int outputWidth = 0;
50+
int outputHeight = 0;
51+
52+
SwsFrameContext() = default;
53+
SwsFrameContext(
54+
int inputWidth,
55+
int inputHeight,
56+
AVPixelFormat inputFormat,
57+
int outputWidth,
58+
int outputHeight);
5159
bool operator==(const SwsFrameContext&) const;
5260
bool operator!=(const SwsFrameContext&) const;
5361
};

src/torchcodec/_core/DeviceInterface.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,6 @@ struct DeviceInterfaceKey {
3636
: deviceType(type), variant(var) {}
3737
};
3838

39-
// Note that all these device functions should only be called if the device is
40-
// not a CPU device. CPU device functions are already implemented in the
41-
// SingleStreamDecoder implementation.
42-
// These functions should only be called from within an if block like this:
43-
// if (device.type() != torch::kCPU) {
44-
// deviceFunction(device, ...);
45-
// }
46-
4739
class DeviceInterface {
4840
public:
4941
DeviceInterface(const torch::Device& device) : device_(device) {}

src/torchcodec/_core/FilterGraph.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,26 @@ extern "C" {
1313

1414
namespace facebook::torchcodec {
1515

16+
FiltersContext::FiltersContext(
17+
int inputWidth,
18+
int inputHeight,
19+
AVPixelFormat inputFormat,
20+
AVRational inputAspectRatio,
21+
int outputWidth,
22+
int outputHeight,
23+
AVPixelFormat outputFormat,
24+
const std::string& filtergraphStr,
25+
AVRational timeBase)
26+
: inputWidth(inputWidth),
27+
inputHeight(inputHeight),
28+
inputFormat(inputFormat),
29+
inputAspectRatio(inputAspectRatio),
30+
outputWidth(outputWidth),
31+
outputHeight(outputHeight),
32+
outputFormat(outputFormat),
33+
filtergraphStr(filtergraphStr),
34+
timeBase(timeBase) {}
35+
1636
bool operator==(const AVRational& lhs, const AVRational& rhs) {
1737
return lhs.num == rhs.num && lhs.den == rhs.den;
1838
}

src/torchcodec/_core/FilterGraph.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,24 @@ struct FiltersContext {
1919
int outputWidth = 0;
2020
int outputHeight = 0;
2121
AVPixelFormat outputFormat = AV_PIX_FMT_NONE;
22-
2322
std::string filtergraphStr;
2423
AVRational timeBase = {0, 0};
2524
UniqueAVBufferRef hwFramesCtx;
2625

26+
FiltersContext() = default;
27+
FiltersContext(FiltersContext&&) = default;
28+
FiltersContext& operator=(FiltersContext&&) = default;
29+
FiltersContext(
30+
int inputWidth,
31+
int inputHeight,
32+
AVPixelFormat inputFormat,
33+
AVRational inputAspectRatio,
34+
int outputWidth,
35+
int outputHeight,
36+
AVPixelFormat outputFormat,
37+
const std::string& filtergraphStr,
38+
AVRational timeBase);
39+
2740
bool operator==(const FiltersContext&) const;
2841
bool operator!=(const FiltersContext&) const;
2942
};

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -244,13 +244,6 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
244244
return;
245245
}
246246

247-
for (unsigned int i = 0; i < formatContext_->nb_streams; ++i) {
248-
// We want to scan and update the metadata of all streams.
249-
TORCH_CHECK(
250-
formatContext_->streams[i]->discard != AVDISCARD_ALL,
251-
"Did you add a stream before you called for a scan?");
252-
}
253-
254247
AutoAVPacket autoAVPacket;
255248
while (true) {
256249
ReferenceAVPacket packet(autoAVPacket);
@@ -1439,7 +1432,11 @@ FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput(
14391432
formatContext_->streams[activeStreamIndex_]->time_base);
14401433
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
14411434
convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput);
1442-
} else if (deviceInterface_) {
1435+
} else {
1436+
TORCH_CHECK(
1437+
deviceInterface_ != nullptr,
1438+
"No device interface available for video decoding. This ",
1439+
"shouldn't happen, please report.");
14431440
deviceInterface_->convertAVFrameToFrameOutput(
14441441
streamInfo.videoStreamOptions,
14451442
streamInfo.timeBase,

src/torchcodec/decoders/_video_decoder.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,12 @@ def get_frames_at(self, indices: list[int]) -> FrameBatch:
247247
Returns:
248248
FrameBatch: The frames at the given indices.
249249
"""
250+
if isinstance(indices, torch.Tensor):
251+
# TODO we should avoid converting tensors to lists and just let the
252+
# core ops and C++ code natively accept tensors. See
253+
# https://github.com/pytorch/torchcodec/issues/879
254+
indices = indices.to(torch.int).tolist()
255+
250256
data, pts_seconds, duration_seconds = core.get_frames_at_indices(
251257
self._decoder, frame_indices=indices
252258
)
@@ -322,6 +328,12 @@ def get_frames_played_at(self, seconds: list[float]) -> FrameBatch:
322328
Returns:
323329
FrameBatch: The frames that are played at ``seconds``.
324330
"""
331+
if isinstance(seconds, torch.Tensor):
332+
# TODO we should avoid converting tensors to lists and just let the
333+
# core ops and C++ code natively accept tensors. See
334+
# https://github.com/pytorch/torchcodec/issues/879
335+
seconds = seconds.to(torch.float).tolist()
336+
325337
data, pts_seconds, duration_seconds = core.get_frames_by_pts(
326338
self._decoder, timestamps=seconds
327339
)

test/test_decoders.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,6 +1407,17 @@ def test_custom_frame_mappings_init_fails_invalid_json(self, tmp_path, device):
14071407
custom_frame_mappings=custom_frame_mappings,
14081408
)
14091409

1410+
def test_get_frames_at_tensor_indices(self):
1411+
# Non-regression test for tensor support in get_frames_at() and
1412+
# get_frames_played_at()
1413+
decoder = VideoDecoder(NASA_VIDEO.path)
1414+
1415+
decoder.get_frames_at(torch.tensor([0, 10], dtype=torch.int))
1416+
decoder.get_frames_at(torch.tensor([0, 10], dtype=torch.float))
1417+
1418+
decoder.get_frames_played_at(torch.tensor([0, 1], dtype=torch.int))
1419+
decoder.get_frames_played_at(torch.tensor([0, 1], dtype=torch.float))
1420+
14101421

14111422
class TestAudioDecoder:
14121423
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32))

0 commit comments

Comments
 (0)