@@ -137,6 +137,24 @@ static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) {
137
137
return UniqueCUvideodecoder (decoder, CUvideoDecoderDeleter{});
138
138
}
139
139
140
+ cudaVideoCodec validateCodecSupport (AVCodecID codecId) {
141
+ switch (codecId) {
142
+ case AV_CODEC_ID_H264:
143
+ return cudaVideoCodec_H264;
144
+ case AV_CODEC_ID_HEVC:
145
+ return cudaVideoCodec_HEVC;
146
+ // TODONVDEC P0: support more codecs
147
+ // case AV_CODEC_ID_AV1: return cudaVideoCodec_AV1;
148
+ // case AV_CODEC_ID_MPEG4: return cudaVideoCodec_MPEG4;
149
+ // case AV_CODEC_ID_VP8: return cudaVideoCodec_VP8;
150
+ // case AV_CODEC_ID_VP9: return cudaVideoCodec_VP9;
151
+ // case AV_CODEC_ID_MJPEG: return cudaVideoCodec_JPEG;
152
+ default : {
153
+ TORCH_CHECK (false , " Unsupported codec type: " , avcodec_get_name (codecId));
154
+ }
155
+ }
156
+ }
157
+
140
158
} // namespace
141
159
142
160
BetaCudaDeviceInterface::BetaCudaDeviceInterface (const torch::Device& device)
@@ -162,36 +180,100 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
162
180
}
163
181
}
164
182
165
- void BetaCudaDeviceInterface::initialize (const AVStream* avStream) {
183
+ void BetaCudaDeviceInterface::initialize (
184
+ const AVStream* avStream,
185
+ const UniqueDecodingAVFormatContext& avFormatCtx) {
166
186
torch::Tensor dummyTensorForCudaInitialization = torch::empty (
167
187
{1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device_));
168
188
169
- TORCH_CHECK (avStream != nullptr , " AVStream cannot be null" );
170
- timeBase_ = avStream->time_base ;
171
-
172
189
auto cudaDevice = torch::Device (torch::kCUDA );
173
190
defaultCudaInterface_ =
174
191
std::unique_ptr<DeviceInterface>(createDeviceInterface (cudaDevice));
175
192
AVCodecContext dummyCodecContext = {};
176
- defaultCudaInterface_->initialize (avStream);
193
+ defaultCudaInterface_->initialize (avStream, avFormatCtx );
177
194
defaultCudaInterface_->registerHardwareDeviceWithCodec (&dummyCodecContext);
178
195
179
- const AVCodecParameters* codecpar = avStream->codecpar ;
180
- TORCH_CHECK (codecpar != nullptr , " CodecParameters cannot be null" );
196
+ TORCH_CHECK (avStream != nullptr , " AVStream cannot be null" );
197
+ timeBase_ = avStream->time_base ;
198
+
199
+ const AVCodecParameters* codecPar = avStream->codecpar ;
200
+ TORCH_CHECK (codecPar != nullptr , " CodecParameters cannot be null" );
201
+
202
+ initializeBSF (codecPar, avFormatCtx);
203
+
204
+ // Create parser. Default values that aren't obvious are taken from DALI.
205
+ CUVIDPARSERPARAMS parserParams = {};
206
+ parserParams.CodecType = validateCodecSupport (codecPar->codec_id );
207
+ parserParams.ulMaxNumDecodeSurfaces = 8 ;
208
+ parserParams.ulMaxDisplayDelay = 0 ;
209
+ // Callback setup, all are triggered by the parser within a call
210
+ // to cuvidParseVideoData
211
+ parserParams.pUserData = this ;
212
+ parserParams.pfnSequenceCallback = pfnSequenceCallback;
213
+ parserParams.pfnDecodePicture = pfnDecodePictureCallback;
214
+ parserParams.pfnDisplayPicture = pfnDisplayPictureCallback;
181
215
216
+ CUresult result = cuvidCreateVideoParser (&videoParser_, &parserParams);
182
217
TORCH_CHECK (
183
- // TODONVDEC P0 support more
184
- avStream->codecpar ->codec_id == AV_CODEC_ID_H264,
185
- " Can only do H264 for now" );
218
+ result == CUDA_SUCCESS, " Failed to create video parser: " , result);
219
+ }
186
220
221
+ void BetaCudaDeviceInterface::initializeBSF (
222
+ const AVCodecParameters* codecPar,
223
+ const UniqueDecodingAVFormatContext& avFormatCtx) {
187
224
// Setup bit stream filters (BSF):
188
225
// https://ffmpeg.org/doxygen/7.0/group__lavc__bsf.html
189
- // This is only needed for some formats, like H264 or HEVC. TODONVDEC P1: For
190
- // now we apply BSF unconditionally, but it should be optional and dependent
191
- // on codec and container.
192
- const AVBitStreamFilter* avBSF = av_bsf_get_by_name (" h264_mp4toannexb" );
226
+ // This is only needed for some formats, like H264 or HEVC.
227
+
228
+ TORCH_CHECK (codecPar != nullptr , " codecPar cannot be null" );
229
+ TORCH_CHECK (avFormatCtx != nullptr , " AVFormatContext cannot be null" );
230
+ TORCH_CHECK (
231
+ avFormatCtx->iformat != nullptr ,
232
+ " AVFormatContext->iformat cannot be null" );
233
+ std::string filterName;
234
+
235
+ // Matching logic is taken from DALI
236
+ switch (codecPar->codec_id ) {
237
+ case AV_CODEC_ID_H264: {
238
+ const std::string formatName = avFormatCtx->iformat ->long_name
239
+ ? avFormatCtx->iformat ->long_name
240
+ : " " ;
241
+
242
+ if (formatName == " QuickTime / MOV" ||
243
+ formatName == " FLV (Flash Video)" ||
244
+ formatName == " Matroska / WebM" || formatName == " raw H.264 video" ) {
245
+ filterName = " h264_mp4toannexb" ;
246
+ }
247
+ break ;
248
+ }
249
+
250
+ case AV_CODEC_ID_HEVC: {
251
+ const std::string formatName = avFormatCtx->iformat ->long_name
252
+ ? avFormatCtx->iformat ->long_name
253
+ : " " ;
254
+
255
+ if (formatName == " QuickTime / MOV" ||
256
+ formatName == " FLV (Flash Video)" ||
257
+ formatName == " Matroska / WebM" || formatName == " raw HEVC video" ) {
258
+ filterName = " hevc_mp4toannexb" ;
259
+ }
260
+ break ;
261
+ }
262
+
263
+ default :
264
+ // No bitstream filter needed for other codecs
265
+ // TODONVDEC P1 MPEG4 will need one!
266
+ break ;
267
+ }
268
+
269
+ if (filterName.empty ()) {
270
+ // Only initialize BSF if we actually need one
271
+ return ;
272
+ }
273
+
274
+ const AVBitStreamFilter* avBSF = av_bsf_get_by_name (filterName.c_str ());
193
275
TORCH_CHECK (
194
- avBSF != nullptr , " Failed to find h264_mp4toannexb bitstream filter" );
276
+ avBSF != nullptr , " Failed to find bitstream filter: " , filterName );
195
277
196
278
AVBSFContext* avBSFContext = nullptr ;
197
279
int retVal = av_bsf_alloc (avBSF, &avBSFContext);
@@ -202,7 +284,7 @@ void BetaCudaDeviceInterface::initialize(const AVStream* avStream) {
202
284
203
285
bitstreamFilter_.reset (avBSFContext);
204
286
205
- retVal = avcodec_parameters_copy (bitstreamFilter_->par_in , codecpar );
287
+ retVal = avcodec_parameters_copy (bitstreamFilter_->par_in , codecPar );
206
288
TORCH_CHECK (
207
289
retVal >= AVSUCCESS,
208
290
" Failed to copy codec parameters: " ,
@@ -213,22 +295,6 @@ void BetaCudaDeviceInterface::initialize(const AVStream* avStream) {
213
295
retVal == AVSUCCESS,
214
296
" Failed to initialize bitstream filter: " ,
215
297
getFFMPEGErrorStringFromErrorCode (retVal));
216
-
217
- // Create parser. Default values that aren't obvious are taken from DALI.
218
- CUVIDPARSERPARAMS parserParams = {};
219
- parserParams.CodecType = cudaVideoCodec_H264;
220
- parserParams.ulMaxNumDecodeSurfaces = 8 ;
221
- parserParams.ulMaxDisplayDelay = 0 ;
222
- // Callback setup, all are triggered by the parser within a call
223
- // to cuvidParseVideoData
224
- parserParams.pUserData = this ;
225
- parserParams.pfnSequenceCallback = pfnSequenceCallback;
226
- parserParams.pfnDecodePicture = pfnDecodePictureCallback;
227
- parserParams.pfnDisplayPicture = pfnDisplayPictureCallback;
228
-
229
- CUresult result = cuvidCreateVideoParser (&videoParser_, &parserParams);
230
- TORCH_CHECK (
231
- result == CUDA_SUCCESS, " Failed to create video parser: " , result);
232
298
}
233
299
234
300
// This callback is called by the parser within cuvidParseVideoData when there
0 commit comments