Skip to content

Commit 3f445bb

Browse files
authored
BETA CUDA interface: Add TODOs and more explicit initialization (#918)
1 parent 6d72f11 commit 3f445bb

File tree

1 file changed

+23
-17
lines changed

1 file changed

+23
-17
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -109,27 +109,29 @@ static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) {
109109
caps.nMaxMBCount);
110110

111111
// Decoder creation parameters, taken from DALI
112-
CUVIDDECODECREATEINFO decoder_info = {};
113-
decoder_info.bitDepthMinus8 = videoFormat->bit_depth_luma_minus8;
114-
decoder_info.ChromaFormat = videoFormat->chroma_format;
115-
decoder_info.CodecType = videoFormat->codec;
116-
decoder_info.ulHeight = videoFormat->coded_height;
117-
decoder_info.ulWidth = videoFormat->coded_width;
118-
decoder_info.ulMaxHeight = videoFormat->coded_height;
119-
decoder_info.ulMaxWidth = videoFormat->coded_width;
120-
decoder_info.ulTargetHeight =
112+
CUVIDDECODECREATEINFO decoderParams = {};
113+
decoderParams.bitDepthMinus8 = videoFormat->bit_depth_luma_minus8;
114+
decoderParams.ChromaFormat = videoFormat->chroma_format;
115+
decoderParams.OutputFormat = cudaVideoSurfaceFormat_NV12;
116+
decoderParams.ulCreationFlags = cudaVideoCreate_Default;
117+
decoderParams.CodecType = videoFormat->codec;
118+
decoderParams.ulHeight = videoFormat->coded_height;
119+
decoderParams.ulWidth = videoFormat->coded_width;
120+
decoderParams.ulMaxHeight = videoFormat->coded_height;
121+
decoderParams.ulMaxWidth = videoFormat->coded_width;
122+
decoderParams.ulTargetHeight =
121123
videoFormat->display_area.bottom - videoFormat->display_area.top;
122-
decoder_info.ulTargetWidth =
124+
decoderParams.ulTargetWidth =
123125
videoFormat->display_area.right - videoFormat->display_area.left;
124-
decoder_info.ulNumDecodeSurfaces = videoFormat->min_num_decode_surfaces;
125-
decoder_info.ulNumOutputSurfaces = 2;
126-
decoder_info.display_area.left = videoFormat->display_area.left;
127-
decoder_info.display_area.right = videoFormat->display_area.right;
128-
decoder_info.display_area.top = videoFormat->display_area.top;
129-
decoder_info.display_area.bottom = videoFormat->display_area.bottom;
126+
decoderParams.ulNumDecodeSurfaces = videoFormat->min_num_decode_surfaces;
127+
decoderParams.ulNumOutputSurfaces = 2;
128+
decoderParams.display_area.left = videoFormat->display_area.left;
129+
decoderParams.display_area.right = videoFormat->display_area.right;
130+
decoderParams.display_area.top = videoFormat->display_area.top;
131+
decoderParams.display_area.bottom = videoFormat->display_area.bottom;
130132

131133
CUvideodecoder* decoder = new CUvideodecoder();
132-
result = cuvidCreateDecoder(decoder, &decoder_info);
134+
result = cuvidCreateDecoder(decoder, &decoderParams);
133135
TORCH_CHECK(
134136
result == CUDA_SUCCESS, "Failed to create NVDEC decoder: ", result);
135137
return UniqueCUvideodecoder(decoder, CUvideoDecoderDeleter{});
@@ -360,6 +362,10 @@ int BetaCudaDeviceInterface::receiveFrame(UniqueAVFrame& avFrame) {
360362
CUVIDPARSERDISPINFO dispInfo = readyFrames_.front();
361363
readyFrames_.pop();
362364

365+
// TODONVDEC P1 we need to set the procParams.output_stream field to the
366+
// current CUDA stream and ensure proper synchronization. There's a related
367+
// NVDECTODO in CudaDeviceInterface.cpp where we do the necessary
368+
// synchronization for NPP.
363369
CUVIDPROCPARAMS procParams = {};
364370
procParams.progressive_frame = dispInfo.progressive_frame;
365371
procParams.top_field_first = dispInfo.top_field_first;

0 commit comments

Comments
 (0)