Skip to content

Commit dda2649

Browse files
committed
CUDA and CPU refactoring regarding NV12.
1 parent 7f88e60 commit dda2649

File tree

3 files changed

+49
-46
lines changed

3 files changed

+49
-46
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -154,23 +154,6 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
154154
enum AVPixelFormat frameFormat =
155155
static_cast<enum AVPixelFormat>(avFrame->format);
156156

157-
// This is an early-return optimization: if the format is already what we
158-
// need, and the dimensions are also what we need, we don't need to call
159-
// swscale or filtergraph. We can just convert the AVFrame to a tensor.
160-
if (frameFormat == AV_PIX_FMT_RGB24 && avFrame->width == outputDims_.width &&
161-
avFrame->height == outputDims_.height) {
162-
outputTensor = toTensor(avFrame);
163-
if (preAllocatedOutputTensor.has_value()) {
164-
// We have already validated that preAllocatedOutputTensor and
165-
// outputTensor have the same shape.
166-
preAllocatedOutputTensor.value().copy_(outputTensor);
167-
frameOutput.data = preAllocatedOutputTensor.value();
168-
} else {
169-
frameOutput.data = outputTensor;
170-
}
171-
return;
172-
}
173-
174157
if (colorConversionLibrary_ == ColorConversionLibrary::SWSCALE) {
175158
// We need to compare the current frame context with our previous frame
176159
// context. If they are different, then we need to re-create our colorspace

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@ class CpuDeviceInterface : public DeviceInterface {
3636
std::optional<torch::Tensor> preAllocatedOutputTensor =
3737
std::nullopt) override;
3838

39+
torch::Tensor toTensor(const UniqueAVFrame& avFrame);
40+
3941
private:
4042
int convertAVFrameToTensorUsingSwScale(
4143
const UniqueAVFrame& avFrame,
4244
torch::Tensor& outputTensor);
4345

44-
torch::Tensor toTensor(const UniqueAVFrame& avFrame);
45-
4646
struct SwsFrameContext {
4747
int inputWidth = 0;
4848
int inputHeight = 0;

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <mutex>
66

77
#include "src/torchcodec/_core/Cache.h"
8+
#include "src/torchcodec/_core/CpuDeviceInterface.h"
89
#include "src/torchcodec/_core/CudaDeviceInterface.h"
910
#include "src/torchcodec/_core/FFMPEGCommon.h"
1011

@@ -230,7 +231,7 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12(
230231
reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data);
231232
AVPixelFormat actualFormat = hwFramesCtx->sw_format;
232233

233-
// NV12 conversion is implemented directly with NPP, no need for filters.
234+
// If the frame is already in NV12 format, we don't need to do anything.
234235
if (actualFormat == AV_PIX_FMT_NV12) {
235236
return std::move(avFrame);
236237
}
@@ -310,35 +311,64 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
310311
UniqueAVFrame& avFrame,
311312
FrameOutput& frameOutput,
312313
std::optional<torch::Tensor> preAllocatedOutputTensor) {
314+
if (preAllocatedOutputTensor.has_value()) {
315+
auto shape = preAllocatedOutputTensor.value().sizes();
316+
TORCH_CHECK(
317+
(shape.size() == 3) && (shape[0] == outputDims_.height) &&
318+
(shape[1] == outputDims_.width) && (shape[2] == 3),
319+
"Expected tensor of shape ",
320+
outputDims_.height,
321+
"x",
322+
outputDims_.width,
323+
"x3, got ",
324+
shape);
325+
}
326+
327+
// All of our CUDA decoding assumes NV12 format. We handle non-NV12 formats by
328+
// converting them to NV12.
313329
avFrame = maybeConvertAVFrameToNV12(avFrame);
314330

315-
// The filtered frame might be on CPU if CPU fallback has happenned on filter
316-
// graph level. For example, that's how we handle color format conversion
317-
// on FFmpeg 4.4 where scale_cuda did not have this supported implemented yet.
318331
if (avFrame->format != AV_PIX_FMT_CUDA) {
319332
// The frame's format is AV_PIX_FMT_CUDA if and only if its content is on
320-
// the GPU. In this branch, the frame is on the CPU: this is what NVDEC
321-
// gives us if it wasn't able to decode a frame, for whatever reason.
322-
// Typically that happens if the video's encoder isn't supported by NVDEC.
323-
// Below, we choose to convert the frame's color-space using the CPU
324-
// codepath, and send it back to the GPU at the very end.
333+
// the GPU. In this branch, the frame is on the CPU. There are two possible
334+
// reasons:
335+
//
336+
// 1. During maybeConvertAVFrameToNV12(), we had a non-NV12 format frame
337+
// and we're on FFmpeg 4.4 or earlier. In such cases, we had to use CPU
338+
// filters and we just converted the frame to RGB24.
339+
// 2. This is what NVDEC gave us if it wasn't able to decode a frame, for
340+
// whatever reason. Typically that happens if the video's encoder isn't
341+
// supported by NVDEC.
325342
//
326-
// TODO: A possibly better solution would be to send the frame to the GPU
327-
// first, and do the color conversion there.
343+
// In both cases, we have a frame on the CPU, and we need a CPU device to
344+
// handle it. We send the frame back to the CUDA device when we're done.
328345
//
329-
// TODO: If we're going to keep this around, we should probably cache it?
330-
auto cpuInterface = createDeviceInterface(torch::Device(torch::kCPU));
346+
// TODO: Perhaps we should cache cpuInterface?
347+
auto cpuInterface = std::make_unique<CpuDeviceInterface>(torch::kCPU);
331348
TORCH_CHECK(
332349
cpuInterface != nullptr, "Failed to create CPU device interface");
333350
cpuInterface->initialize(
334351
nullptr, VideoStreamOptions(), {}, timeBase_, outputDims_);
335352

353+
enum AVPixelFormat frameFormat =
354+
static_cast<enum AVPixelFormat>(avFrame->format);
355+
336356
FrameOutput cpuFrameOutput;
337-
cpuInterface->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput);
338357

339-
// TODO: explain that the pre-allocated tensor is on the GPU, but we need
340-
// to do the decoding on the CPU, and we can't pass the pre-allocated tensor
341-
// to do it. BUT WHY did it work before?
358+
if (frameFormat == AV_PIX_FMT_RGB24 &&
359+
avFrame->width == outputDims_.width &&
360+
avFrame->height == outputDims_.height) {
361+
// Reason 1 above. The frame is already in the format and dimensions that
362+
// we need, we just need to convert it to a tensor.
363+
cpuFrameOutput.data = cpuInterface->toTensor(avFrame);
364+
} else {
365+
// Reason 2 above. We need to do a full conversion.
366+
cpuInterface->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput);
367+
}
368+
369+
// Finally, we need to send the frame back to the GPU. Note that the
370+
// pre-allocated tensor is on the GPU, so we can't send that to the CPU
371+
// device interface. We copy it over here.
342372
if (preAllocatedOutputTensor.has_value()) {
343373
preAllocatedOutputTensor.value().copy_(cpuFrameOutput.data);
344374
frameOutput.data = preAllocatedOutputTensor.value();
@@ -372,16 +402,6 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
372402
torch::Tensor& dst = frameOutput.data;
373403
if (preAllocatedOutputTensor.has_value()) {
374404
dst = preAllocatedOutputTensor.value();
375-
auto shape = dst.sizes();
376-
TORCH_CHECK(
377-
(shape.size() == 3) && (shape[0] == outputDims_.height) &&
378-
(shape[1] == outputDims_.width) && (shape[2] == 3),
379-
"Expected tensor of shape ",
380-
outputDims_.height,
381-
"x",
382-
outputDims_.width,
383-
"x3, got ",
384-
shape);
385405
} else {
386406
dst = allocateEmptyHWCTensor(outputDims_, device_);
387407
}

0 commit comments

Comments
 (0)