Skip to content

Commit b1af8ce

Browse files
authored
BETA CUDA interface: H265 support (#919)
1 parent 3f445bb commit b1af8ce

11 files changed

+161
-48
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 98 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,24 @@ static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) {
137137
return UniqueCUvideodecoder(decoder, CUvideoDecoderDeleter{});
138138
}
139139

140+
cudaVideoCodec validateCodecSupport(AVCodecID codecId) {
141+
switch (codecId) {
142+
case AV_CODEC_ID_H264:
143+
return cudaVideoCodec_H264;
144+
case AV_CODEC_ID_HEVC:
145+
return cudaVideoCodec_HEVC;
146+
// TODONVDEC P0: support more codecs
147+
// case AV_CODEC_ID_AV1: return cudaVideoCodec_AV1;
148+
// case AV_CODEC_ID_MPEG4: return cudaVideoCodec_MPEG4;
149+
// case AV_CODEC_ID_VP8: return cudaVideoCodec_VP8;
150+
// case AV_CODEC_ID_VP9: return cudaVideoCodec_VP9;
151+
// case AV_CODEC_ID_MJPEG: return cudaVideoCodec_JPEG;
152+
default: {
153+
TORCH_CHECK(false, "Unsupported codec type: ", avcodec_get_name(codecId));
154+
}
155+
}
156+
}
157+
140158
} // namespace
141159

142160
BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device)
@@ -162,36 +180,100 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
162180
}
163181
}
164182

165-
void BetaCudaDeviceInterface::initialize(const AVStream* avStream) {
183+
void BetaCudaDeviceInterface::initialize(
184+
const AVStream* avStream,
185+
const UniqueDecodingAVFormatContext& avFormatCtx) {
166186
torch::Tensor dummyTensorForCudaInitialization = torch::empty(
167187
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_));
168188

169-
TORCH_CHECK(avStream != nullptr, "AVStream cannot be null");
170-
timeBase_ = avStream->time_base;
171-
172189
auto cudaDevice = torch::Device(torch::kCUDA);
173190
defaultCudaInterface_ =
174191
std::unique_ptr<DeviceInterface>(createDeviceInterface(cudaDevice));
175192
AVCodecContext dummyCodecContext = {};
176-
defaultCudaInterface_->initialize(avStream);
193+
defaultCudaInterface_->initialize(avStream, avFormatCtx);
177194
defaultCudaInterface_->registerHardwareDeviceWithCodec(&dummyCodecContext);
178195

179-
const AVCodecParameters* codecpar = avStream->codecpar;
180-
TORCH_CHECK(codecpar != nullptr, "CodecParameters cannot be null");
196+
TORCH_CHECK(avStream != nullptr, "AVStream cannot be null");
197+
timeBase_ = avStream->time_base;
198+
199+
const AVCodecParameters* codecPar = avStream->codecpar;
200+
TORCH_CHECK(codecPar != nullptr, "CodecParameters cannot be null");
201+
202+
initializeBSF(codecPar, avFormatCtx);
203+
204+
// Create parser. Default values that aren't obvious are taken from DALI.
205+
CUVIDPARSERPARAMS parserParams = {};
206+
parserParams.CodecType = validateCodecSupport(codecPar->codec_id);
207+
parserParams.ulMaxNumDecodeSurfaces = 8;
208+
parserParams.ulMaxDisplayDelay = 0;
209+
// Callback setup, all are triggered by the parser within a call
210+
// to cuvidParseVideoData
211+
parserParams.pUserData = this;
212+
parserParams.pfnSequenceCallback = pfnSequenceCallback;
213+
parserParams.pfnDecodePicture = pfnDecodePictureCallback;
214+
parserParams.pfnDisplayPicture = pfnDisplayPictureCallback;
181215

216+
CUresult result = cuvidCreateVideoParser(&videoParser_, &parserParams);
182217
TORCH_CHECK(
183-
// TODONVDEC P0 support more
184-
avStream->codecpar->codec_id == AV_CODEC_ID_H264,
185-
"Can only do H264 for now");
218+
result == CUDA_SUCCESS, "Failed to create video parser: ", result);
219+
}
186220

221+
void BetaCudaDeviceInterface::initializeBSF(
222+
const AVCodecParameters* codecPar,
223+
const UniqueDecodingAVFormatContext& avFormatCtx) {
187224
// Setup bit stream filters (BSF):
188225
// https://ffmpeg.org/doxygen/7.0/group__lavc__bsf.html
189-
// This is only needed for some formats, like H264 or HEVC. TODONVDEC P1: For
190-
// now we apply BSF unconditionally, but it should be optional and dependent
191-
// on codec and container.
192-
const AVBitStreamFilter* avBSF = av_bsf_get_by_name("h264_mp4toannexb");
226+
// This is only needed for some formats, like H264 or HEVC.
227+
228+
TORCH_CHECK(codecPar != nullptr, "codecPar cannot be null");
229+
TORCH_CHECK(avFormatCtx != nullptr, "AVFormatContext cannot be null");
230+
TORCH_CHECK(
231+
avFormatCtx->iformat != nullptr,
232+
"AVFormatContext->iformat cannot be null");
233+
std::string filterName;
234+
235+
// Matching logic is taken from DALI
236+
switch (codecPar->codec_id) {
237+
case AV_CODEC_ID_H264: {
238+
const std::string formatName = avFormatCtx->iformat->long_name
239+
? avFormatCtx->iformat->long_name
240+
: "";
241+
242+
if (formatName == "QuickTime / MOV" ||
243+
formatName == "FLV (Flash Video)" ||
244+
formatName == "Matroska / WebM" || formatName == "raw H.264 video") {
245+
filterName = "h264_mp4toannexb";
246+
}
247+
break;
248+
}
249+
250+
case AV_CODEC_ID_HEVC: {
251+
const std::string formatName = avFormatCtx->iformat->long_name
252+
? avFormatCtx->iformat->long_name
253+
: "";
254+
255+
if (formatName == "QuickTime / MOV" ||
256+
formatName == "FLV (Flash Video)" ||
257+
formatName == "Matroska / WebM" || formatName == "raw HEVC video") {
258+
filterName = "hevc_mp4toannexb";
259+
}
260+
break;
261+
}
262+
263+
default:
264+
// No bitstream filter needed for other codecs
265+
// TODONVDEC P1 MPEG4 will need one!
266+
break;
267+
}
268+
269+
if (filterName.empty()) {
270+
// Only initialize BSF if we actually need one
271+
return;
272+
}
273+
274+
const AVBitStreamFilter* avBSF = av_bsf_get_by_name(filterName.c_str());
193275
TORCH_CHECK(
194-
avBSF != nullptr, "Failed to find h264_mp4toannexb bitstream filter");
276+
avBSF != nullptr, "Failed to find bitstream filter: ", filterName);
195277

196278
AVBSFContext* avBSFContext = nullptr;
197279
int retVal = av_bsf_alloc(avBSF, &avBSFContext);
@@ -202,7 +284,7 @@ void BetaCudaDeviceInterface::initialize(const AVStream* avStream) {
202284

203285
bitstreamFilter_.reset(avBSFContext);
204286

205-
retVal = avcodec_parameters_copy(bitstreamFilter_->par_in, codecpar);
287+
retVal = avcodec_parameters_copy(bitstreamFilter_->par_in, codecPar);
206288
TORCH_CHECK(
207289
retVal >= AVSUCCESS,
208290
"Failed to copy codec parameters: ",
@@ -213,22 +295,6 @@ void BetaCudaDeviceInterface::initialize(const AVStream* avStream) {
213295
retVal == AVSUCCESS,
214296
"Failed to initialize bitstream filter: ",
215297
getFFMPEGErrorStringFromErrorCode(retVal));
216-
217-
// Create parser. Default values that aren't obvious are taken from DALI.
218-
CUVIDPARSERPARAMS parserParams = {};
219-
parserParams.CodecType = cudaVideoCodec_H264;
220-
parserParams.ulMaxNumDecodeSurfaces = 8;
221-
parserParams.ulMaxDisplayDelay = 0;
222-
// Callback setup, all are triggered by the parser within a call
223-
// to cuvidParseVideoData
224-
parserParams.pUserData = this;
225-
parserParams.pfnSequenceCallback = pfnSequenceCallback;
226-
parserParams.pfnDecodePicture = pfnDecodePictureCallback;
227-
parserParams.pfnDisplayPicture = pfnDisplayPictureCallback;
228-
229-
CUresult result = cuvidCreateVideoParser(&videoParser_, &parserParams);
230-
TORCH_CHECK(
231-
result == CUDA_SUCCESS, "Failed to create video parser: ", result);
232298
}
233299

234300
// This callback is called by the parser within cuvidParseVideoData when there

src/torchcodec/_core/BetaCudaDeviceInterface.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ class BetaCudaDeviceInterface : public DeviceInterface {
3737
explicit BetaCudaDeviceInterface(const torch::Device& device);
3838
virtual ~BetaCudaDeviceInterface();
3939

40-
void initialize(const AVStream* avStream) override;
40+
void initialize(
41+
const AVStream* avStream,
42+
const UniqueDecodingAVFormatContext& avFormatCtx) override;
4143

4244
void convertAVFrameToFrameOutput(
4345
UniqueAVFrame& avFrame,
@@ -61,6 +63,9 @@ class BetaCudaDeviceInterface : public DeviceInterface {
6163
private:
6264
// Apply bitstream filter, modifies packet in-place
6365
void applyBSF(ReferenceAVPacket& packet);
66+
void initializeBSF(
67+
const AVCodecParameters* codecPar,
68+
const UniqueDecodingAVFormatContext& avFormatCtx);
6469

6570
UniqueAVFrame convertCudaFrameToAVFrame(
6671
CUdeviceptr framePtr,

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
4646
device_.type() == torch::kCPU, "Unsupported device: ", device_.str());
4747
}
4848

49-
void CpuDeviceInterface::initialize(const AVStream* avStream) {
49+
void CpuDeviceInterface::initialize(
50+
const AVStream* avStream,
51+
[[maybe_unused]] const UniqueDecodingAVFormatContext& avFormatCtx) {
5052
TORCH_CHECK(avStream != nullptr, "avStream is null");
5153
timeBase_ = avStream->time_base;
5254
}

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ class CpuDeviceInterface : public DeviceInterface {
2323
return std::nullopt;
2424
}
2525

26-
virtual void initialize(const AVStream* avStream) override;
26+
virtual void initialize(
27+
const AVStream* avStream,
28+
const UniqueDecodingAVFormatContext& avFormatCtx) override;
2729

2830
virtual void initializeVideo(
2931
const VideoStreamOptions& videoStreamOptions,

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,14 +203,16 @@ CudaDeviceInterface::~CudaDeviceInterface() {
203203
}
204204
}
205205

206-
void CudaDeviceInterface::initialize(const AVStream* avStream) {
206+
void CudaDeviceInterface::initialize(
207+
const AVStream* avStream,
208+
const UniqueDecodingAVFormatContext& avFormatCtx) {
207209
TORCH_CHECK(avStream != nullptr, "avStream is null");
208210
timeBase_ = avStream->time_base;
209211

210212
cpuInterface_ = createDeviceInterface(torch::kCPU);
211213
TORCH_CHECK(
212214
cpuInterface_ != nullptr, "Failed to create CPU device interface");
213-
cpuInterface_->initialize(avStream);
215+
cpuInterface_->initialize(avStream, avFormatCtx);
214216
cpuInterface_->initializeVideo(
215217
VideoStreamOptions(),
216218
{},

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ class CudaDeviceInterface : public DeviceInterface {
2020

2121
std::optional<const AVCodec*> findCodec(const AVCodecID& codecId) override;
2222

23-
void initialize(const AVStream* avStream) override;
23+
void initialize(
24+
const AVStream* avStream,
25+
const UniqueDecodingAVFormatContext& avFormatCtx) override;
2426

2527
void initializeVideo(
2628
const VideoStreamOptions& videoStreamOptions,

src/torchcodec/_core/DeviceInterface.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ class DeviceInterface {
5252
};
5353

5454
// Initialize the device with parameters generic to all kinds of decoding.
55-
virtual void initialize(const AVStream* avStream) = 0;
55+
virtual void initialize(
56+
const AVStream* avStream,
57+
const UniqueDecodingAVFormatContext& avFormatCtx) = 0;
5658

5759
// Initialize the device with parameters specific to video decoding. There is
5860
// a default empty implementation.

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ void SingleStreamDecoder::addStream(
439439
TORCH_CHECK(
440440
deviceInterface_ != nullptr,
441441
"Failed to create device interface. This should never happen, please report.");
442-
deviceInterface_->initialize(streamInfo.stream);
442+
deviceInterface_->initialize(streamInfo.stream, formatContext_);
443443

444444
// TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within
445445
// addStream() which is supposed to be generic

test/resources/testsrc2_h265.mp4

890 KB
Binary file not shown.

test/test_decoders.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
SINE_MONO_S32_44100,
4545
SINE_MONO_S32_8000,
4646
TEST_SRC_2_720P,
47+
TEST_SRC_2_720P_H265,
4748
)
4849

4950

@@ -1415,7 +1416,9 @@ def test_get_frames_at_tensor_indices(self):
14151416
# assert_tensor_close_on_at_least or something like that.
14161417

14171418
@needs_cuda
1418-
@pytest.mark.parametrize("asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE))
1419+
@pytest.mark.parametrize(
1420+
"asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE, TEST_SRC_2_720P_H265)
1421+
)
14191422
@pytest.mark.parametrize("contiguous_indices", (True, False))
14201423
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
14211424
def test_beta_cuda_interface_get_frame_at(
@@ -1445,7 +1448,9 @@ def test_beta_cuda_interface_get_frame_at(
14451448
assert beta_frame.duration_seconds == ref_frame.duration_seconds
14461449

14471450
@needs_cuda
1448-
@pytest.mark.parametrize("asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE))
1451+
@pytest.mark.parametrize(
1452+
"asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE, TEST_SRC_2_720P_H265)
1453+
)
14491454
@pytest.mark.parametrize("contiguous_indices", (True, False))
14501455
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
14511456
def test_beta_cuda_interface_get_frames_at(
@@ -1476,7 +1481,9 @@ def test_beta_cuda_interface_get_frames_at(
14761481
)
14771482

14781483
@needs_cuda
1479-
@pytest.mark.parametrize("asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE))
1484+
@pytest.mark.parametrize(
1485+
"asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE, TEST_SRC_2_720P_H265)
1486+
)
14801487
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
14811488
def test_beta_cuda_interface_get_frame_played_at(self, asset, seek_mode):
14821489
ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode)
@@ -1498,7 +1505,9 @@ def test_beta_cuda_interface_get_frame_played_at(self, asset, seek_mode):
14981505
assert beta_frame.duration_seconds == ref_frame.duration_seconds
14991506

15001507
@needs_cuda
1501-
@pytest.mark.parametrize("asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE))
1508+
@pytest.mark.parametrize(
1509+
"asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE, TEST_SRC_2_720P_H265)
1510+
)
15021511
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
15031512
def test_beta_cuda_interface_get_frames_played_at(self, asset, seek_mode):
15041513
ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode)
@@ -1521,7 +1530,9 @@ def test_beta_cuda_interface_get_frames_played_at(self, asset, seek_mode):
15211530
)
15221531

15231532
@needs_cuda
1524-
@pytest.mark.parametrize("asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE))
1533+
@pytest.mark.parametrize(
1534+
"asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE, TEST_SRC_2_720P_H265)
1535+
)
15251536
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
15261537
def test_beta_cuda_interface_backwards(self, asset, seek_mode):
15271538

@@ -1541,12 +1552,24 @@ def test_beta_cuda_interface_backwards(self, asset, seek_mode):
15411552
assert beta_frame.pts_seconds == ref_frame.pts_seconds
15421553
assert beta_frame.duration_seconds == ref_frame.duration_seconds
15431554

1555+
@needs_cuda
1556+
def test_beta_cuda_interface_small_h265(self):
1557+
# TODONVDEC P2 investigate why/how the default interface can decode this
1558+
# video.
1559+
1560+
# This is fine on the default interface - why?
1561+
VideoDecoder(H265_VIDEO.path, device="cuda").get_frame_at(0)
1562+
# But it fails on the beta interface due to input validation checks, which we took from DALI!
1563+
with pytest.raises(
1564+
RuntimeError,
1565+
match="Video is too small in at least one dimension. Provided: 128x128 vs supported:144x144",
1566+
):
1567+
VideoDecoder(H265_VIDEO.path, device="cuda:0:beta").get_frame_at(0)
1568+
15441569
@needs_cuda
15451570
def test_beta_cuda_interface_error(self):
1546-
with pytest.raises(RuntimeError, match="Can only do H264 for now"):
1571+
with pytest.raises(RuntimeError, match="Unsupported codec type: av1"):
15471572
VideoDecoder(AV1_VIDEO.path, device="cuda:0:beta")
1548-
with pytest.raises(RuntimeError, match="Can only do H264 for now"):
1549-
VideoDecoder(H265_VIDEO.path, device="cuda:0:beta")
15501573
with pytest.raises(RuntimeError, match="Unsupported device"):
15511574
VideoDecoder(NASA_VIDEO.path, device="cuda:0:bad_variant")
15521575

0 commit comments

Comments
 (0)