|
5 | 5 | #include <mutex>
|
6 | 6 |
|
7 | 7 | #include "src/torchcodec/_core/Cache.h"
|
| 8 | +#include "src/torchcodec/_core/CpuDeviceInterface.h" |
8 | 9 | #include "src/torchcodec/_core/CudaDeviceInterface.h"
|
9 | 10 | #include "src/torchcodec/_core/FFMPEGCommon.h"
|
10 | 11 |
|
@@ -230,7 +231,7 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12(
|
230 | 231 | reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data);
|
231 | 232 | AVPixelFormat actualFormat = hwFramesCtx->sw_format;
|
232 | 233 |
|
233 |
| - // NV12 conversion is implemented directly with NPP, no need for filters. |
| 234 | + // If the frame is already in NV12 format, we don't need to do anything. |
234 | 235 | if (actualFormat == AV_PIX_FMT_NV12) {
|
235 | 236 | return std::move(avFrame);
|
236 | 237 | }
|
@@ -310,35 +311,64 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
|
310 | 311 | UniqueAVFrame& avFrame,
|
311 | 312 | FrameOutput& frameOutput,
|
312 | 313 | std::optional<torch::Tensor> preAllocatedOutputTensor) {
|
| 314 | + if (preAllocatedOutputTensor.has_value()) { |
| 315 | + auto shape = preAllocatedOutputTensor.value().sizes(); |
| 316 | + TORCH_CHECK( |
| 317 | + (shape.size() == 3) && (shape[0] == outputDims_.height) && |
| 318 | + (shape[1] == outputDims_.width) && (shape[2] == 3), |
| 319 | + "Expected tensor of shape ", |
| 320 | + outputDims_.height, |
| 321 | + "x", |
| 322 | + outputDims_.width, |
| 323 | + "x3, got ", |
| 324 | + shape); |
| 325 | + } |
| 326 | + |
| 327 | + // All of our CUDA decoding assumes NV12 format. We handle non-NV12 formats by |
| 328 | + // converting them to NV12. |
313 | 329 | avFrame = maybeConvertAVFrameToNV12(avFrame);
|
314 | 330 |
|
315 |
| - // The filtered frame might be on CPU if CPU fallback has happenned on filter |
316 |
| - // graph level. For example, that's how we handle color format conversion |
317 |
| - // on FFmpeg 4.4 where scale_cuda did not have this supported implemented yet. |
318 | 331 | if (avFrame->format != AV_PIX_FMT_CUDA) {
|
319 | 332 | // The frame's format is AV_PIX_FMT_CUDA if and only if its content is on
|
320 |
| - // the GPU. In this branch, the frame is on the CPU: this is what NVDEC |
321 |
| - // gives us if it wasn't able to decode a frame, for whatever reason. |
322 |
| - // Typically that happens if the video's encoder isn't supported by NVDEC. |
323 |
| - // Below, we choose to convert the frame's color-space using the CPU |
324 |
| - // codepath, and send it back to the GPU at the very end. |
| 333 | + // the GPU. In this branch, the frame is on the CPU. There are two possible |
| 334 | + // reasons: |
| 335 | + // |
| 336 | + // 1. During maybeConvertAVFrameToNV12(), we had a non-NV12 format frame |
| 337 | + // and we're on FFmpeg 4.4 or earlier. In such cases, we had to use CPU |
| 338 | + // filters and we just converted the frame to RGB24. |
| 339 | + // 2. This is what NVDEC gave us if it wasn't able to decode a frame, for |
| 340 | + // whatever reason. Typically that happens if the video's encoder isn't |
| 341 | + // supported by NVDEC. |
325 | 342 | //
|
326 |
| - // TODO: A possibly better solution would be to send the frame to the GPU |
327 |
| - // first, and do the color conversion there. |
| 343 | + // In both cases, we have a frame on the CPU, and we need a CPU device to |
| 344 | + // handle it. We send the frame back to the CUDA device when we're done. |
328 | 345 | //
|
329 |
| - // TODO: If we're going to keep this around, we should probably cache it? |
330 |
| - auto cpuInterface = createDeviceInterface(torch::Device(torch::kCPU)); |
| 346 | + // TODO: Perhaps we should cache cpuInterface? |
| 347 | + auto cpuInterface = std::make_unique<CpuDeviceInterface>(torch::kCPU); |
331 | 348 | TORCH_CHECK(
|
332 | 349 | cpuInterface != nullptr, "Failed to create CPU device interface");
|
333 | 350 | cpuInterface->initialize(
|
334 | 351 | nullptr, VideoStreamOptions(), {}, timeBase_, outputDims_);
|
335 | 352 |
|
| 353 | + enum AVPixelFormat frameFormat = |
| 354 | + static_cast<enum AVPixelFormat>(avFrame->format); |
| 355 | + |
336 | 356 | FrameOutput cpuFrameOutput;
|
337 |
| - cpuInterface->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput); |
338 | 357 |
|
339 |
| - // TODO: explain that the pre-allocated tensor is on the GPU, but we need |
340 |
| - // to do the decoding on the CPU, and we can't pass the pre-allocated tensor |
341 |
| - // to do it. BUT WHY did it work before? |
| 358 | + if (frameFormat == AV_PIX_FMT_RGB24 && |
| 359 | + avFrame->width == outputDims_.width && |
| 360 | + avFrame->height == outputDims_.height) { |
| 361 | + // Reason 1 above. The frame is already in the format and dimensions that |
| 362 | + // we need, we just need to convert it to a tensor. |
| 363 | + cpuFrameOutput.data = cpuInterface->toTensor(avFrame); |
| 364 | + } else { |
| 365 | + // Reason 2 above. We need to do a full conversion. |
| 366 | + cpuInterface->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput); |
| 367 | + } |
| 368 | + |
| 369 | + // Finally, we need to send the frame back to the GPU. Note that the |
| 370 | + // pre-allocated tensor is on the GPU, so we can't send that to the CPU |
| 371 | + // device interface. We copy it over here. |
342 | 372 | if (preAllocatedOutputTensor.has_value()) {
|
343 | 373 | preAllocatedOutputTensor.value().copy_(cpuFrameOutput.data);
|
344 | 374 | frameOutput.data = preAllocatedOutputTensor.value();
|
@@ -372,16 +402,6 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
|
372 | 402 | torch::Tensor& dst = frameOutput.data;
|
373 | 403 | if (preAllocatedOutputTensor.has_value()) {
|
374 | 404 | dst = preAllocatedOutputTensor.value();
|
375 |
| - auto shape = dst.sizes(); |
376 |
| - TORCH_CHECK( |
377 |
| - (shape.size() == 3) && (shape[0] == outputDims_.height) && |
378 |
| - (shape[1] == outputDims_.width) && (shape[2] == 3), |
379 |
| - "Expected tensor of shape ", |
380 |
| - outputDims_.height, |
381 |
| - "x", |
382 |
| - outputDims_.width, |
383 |
| - "x3, got ", |
384 |
| - shape); |
385 | 405 | } else {
|
386 | 406 | dst = allocateEmptyHWCTensor(outputDims_, device_);
|
387 | 407 | }
|
|
0 commit comments