Skip to content

Commit 1177bb4

Browse files
authored
BETA CUDA interface: integrate with our existing tests and fix EOF hang (#921)
1 parent bcbb889 commit 1177bb4

File tree

4 files changed

+89
-21
lines changed

4 files changed

+89
-21
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,8 +424,8 @@ int BetaCudaDeviceInterface::frameReadyInDisplayOrder(
424424
int BetaCudaDeviceInterface::receiveFrame(UniqueAVFrame& avFrame) {
425425
if (readyFrames_.empty()) {
426426
// No frame found, instruct caller to try again later after sending more
427-
// packets.
428-
return AVERROR(EAGAIN);
427+
// packets, or to stop if EOF was already sent.
428+
return eofSent_ ? AVERROR_EOF : AVERROR(EAGAIN);
429429
}
430430
CUVIDPARSERDISPINFO dispInfo = readyFrames_.front();
431431
readyFrames_.pop();

test/test_decoders.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
SINE_MONO_S32_8000,
4646
TEST_SRC_2_720P,
4747
TEST_SRC_2_720P_H265,
48+
unsplit_device_str,
4849
)
4950

5051

@@ -178,6 +179,7 @@ def test_getitem_int(self, num_ffmpeg_threads, device, seek_mode):
178179
device=device,
179180
seek_mode=seek_mode,
180181
)
182+
device, _ = unsplit_device_str(device)
181183

182184
ref_frame0 = NASA_VIDEO.get_frame_data_by_index(0).to(device)
183185
ref_frame1 = NASA_VIDEO.get_frame_data_by_index(1).to(device)
@@ -223,6 +225,7 @@ def test_getitem_numpy_int(self):
223225
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
224226
def test_getitem_slice(self, device, seek_mode):
225227
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
228+
device, _ = unsplit_device_str(device)
226229

227230
# ensure that the degenerate case of a range of size 1 works
228231

@@ -400,6 +403,7 @@ def test_getitem_fails(self, device, seek_mode):
400403
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
401404
def test_iteration(self, device, seek_mode):
402405
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
406+
device, _ = unsplit_device_str(device)
403407

404408
ref_frame0 = NASA_VIDEO.get_frame_data_by_index(0).to(device)
405409
ref_frame1 = NASA_VIDEO.get_frame_data_by_index(1).to(device)
@@ -447,6 +451,7 @@ def test_iteration_slow(self):
447451
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
448452
def test_get_frame_at(self, device, seek_mode):
449453
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
454+
device, _ = unsplit_device_str(device)
450455

451456
ref_frame9 = NASA_VIDEO.get_frame_data_by_index(9).to(device)
452457
frame9 = decoder.get_frame_at(9)
@@ -510,6 +515,7 @@ def test_get_frame_at_fails(self, device, seek_mode):
510515
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
511516
def test_get_frames_at(self, device, seek_mode):
512517
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
518+
device, _ = unsplit_device_str(device)
513519

514520
# test positive and negative frame index
515521
frames = decoder.get_frames_at([35, 25, -1, -2])
@@ -585,6 +591,7 @@ def test_get_frame_at_av1(self, device):
585591
pytest.skip("AV1 decoding on CUDA is not supported internally")
586592

587593
decoder = VideoDecoder(AV1_VIDEO.path, device=device)
594+
device, _ = unsplit_device_str(device)
588595
ref_frame10 = AV1_VIDEO.get_frame_data_by_index(10)
589596
ref_frame_info10 = AV1_VIDEO.get_frame_info(10)
590597
decoded_frame10 = decoder.get_frame_at(10)
@@ -596,6 +603,7 @@ def test_get_frame_at_av1(self, device):
596603
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
597604
def test_get_frame_played_at(self, device, seek_mode):
598605
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
606+
device, _ = unsplit_device_str(device)
599607

600608
ref_frame_played_at_6 = NASA_VIDEO.get_frame_data_by_index(180).to(device)
601609
assert_frames_equal(
@@ -635,8 +643,8 @@ def test_get_frame_played_at_fails(self, device, seek_mode):
635643
@pytest.mark.parametrize("device", all_supported_devices())
636644
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
637645
def test_get_frames_played_at(self, device, seek_mode):
638-
639646
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
647+
device, _ = unsplit_device_str(device)
640648

641649
# Note: We know the frame at ~0.84s has index 25, the one at 1.16s has
642650
# index 35. We use those indices as reference to test against.
@@ -695,6 +703,7 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode):
695703
device=device,
696704
seek_mode=seek_mode,
697705
)
706+
device, _ = unsplit_device_str(device)
698707

699708
# test degenerate case where we only actually get 1 frame
700709
ref_frames9 = NASA_VIDEO.get_frame_data_by_range(
@@ -799,6 +808,7 @@ def test_get_frames_in_range_slice_indices_syntax(self, device, seek_mode):
799808
device=device,
800809
seek_mode=seek_mode,
801810
)
811+
device, _ = unsplit_device_str(device)
802812

803813
# high range ends get capped to num_frames
804814
frames387_389 = decoder.get_frames_in_range(start=387, stop=1000)
@@ -874,6 +884,7 @@ def test_get_frames_with_missing_num_frames_metadata(
874884
device=device,
875885
seek_mode=seek_mode,
876886
)
887+
device, _ = unsplit_device_str(device)
877888

878889
assert decoder.metadata.num_frames_from_header is None
879890
assert decoder.metadata.num_frames_from_content is None
@@ -942,6 +953,7 @@ def test_get_frames_by_pts_in_range(self, stream_index, device, seek_mode):
942953
device=device,
943954
seek_mode=seek_mode,
944955
)
956+
device, _ = unsplit_device_str(device)
945957

946958
# Note that we are comparing the results of VideoDecoder's method:
947959
# get_frames_played_in_range()
@@ -1134,6 +1146,7 @@ def test_get_key_frame_indices(self, device):
11341146
@pytest.mark.parametrize("device", all_supported_devices())
11351147
def test_compile(self, device):
11361148
decoder = VideoDecoder(NASA_VIDEO.path, device=device)
1149+
device, _ = unsplit_device_str(device)
11371150

11381151
@contextlib.contextmanager
11391152
def restore_capture_scalar_outputs():
@@ -1271,6 +1284,19 @@ def test_10bit_videos(self, device, asset):
12711284
# This just validates that we can decode 10-bit videos.
12721285
# TODO validate against the ref that the decoded frames are correct
12731286

1287+
if device == "cuda:0:beta":
1288+
# This fails on our BETA interface on asset 0 (only!) with:
1289+
#
1290+
# RuntimeError: Codec configuration not supported on this GPU.
1291+
# Codec: 4, chroma format: 1, bit depth: 10
1292+
#
1293+
# I don't remember but I suspect asset 0 is actually the one that
1294+
# fallsback to the CPU path on the default CUDA interface (that
1295+
# would make sense)
1296+
# We should investigate if and how we could make that fallback
1297+
# happen for the BETA interface.
1298+
pytest.skip("TODONVDEC P2 - investigate and unskip")
1299+
12741300
decoder = VideoDecoder(asset.path, device=device)
12751301
decoder.get_frame_at(10)
12761302

@@ -1316,6 +1342,7 @@ def test_custom_frame_mappings_json_and_bytes(
13161342
device=device,
13171343
custom_frame_mappings=custom_frame_mappings,
13181344
)
1345+
device, _ = unsplit_device_str(device)
13191346
frame_0 = decoder.get_frame_at(0)
13201347
frame_5 = decoder.get_frame_at(5)
13211348
assert_frames_equal(

test/test_ops.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
SINE_MONO_S32,
5656
SINE_MONO_S32_44100,
5757
SINE_MONO_S32_8000,
58+
unsplit_device_str,
5859
)
5960

6061
torch._dynamo.config.capture_dynamic_output_shape_ops = True
@@ -66,7 +67,8 @@ class TestVideoDecoderOps:
6667
@pytest.mark.parametrize("device", all_supported_devices())
6768
def test_seek_and_next(self, device):
6869
decoder = create_from_file(str(NASA_VIDEO.path))
69-
add_video_stream(decoder, device=device)
70+
device, device_variant = unsplit_device_str(device)
71+
add_video_stream(decoder, device=device, device_variant=device_variant)
7072
frame0, _, _ = get_next_frame(decoder)
7173
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
7274
assert_frames_equal(frame0, reference_frame0.to(device))
@@ -83,7 +85,8 @@ def test_seek_and_next(self, device):
8385
@pytest.mark.parametrize("device", all_supported_devices())
8486
def test_seek_to_negative_pts(self, device):
8587
decoder = create_from_file(str(NASA_VIDEO.path))
86-
add_video_stream(decoder, device=device)
88+
device, device_variant = unsplit_device_str(device)
89+
add_video_stream(decoder, device=device, device_variant=device_variant)
8790
frame0, _, _ = get_next_frame(decoder)
8891
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
8992
assert_frames_equal(frame0, reference_frame0.to(device))
@@ -95,7 +98,8 @@ def test_seek_to_negative_pts(self, device):
9598
@pytest.mark.parametrize("device", all_supported_devices())
9699
def test_get_frame_at_pts(self, device):
97100
decoder = create_from_file(str(NASA_VIDEO.path))
98-
add_video_stream(decoder, device=device)
101+
device, device_variant = unsplit_device_str(device)
102+
add_video_stream(decoder, device=device, device_variant=device_variant)
99103
# This frame has pts=6.006 and duration=0.033367, so it should be visible
100104
# at timestamps in the range [6.006, 6.039367) (not including the last timestamp).
101105
frame6, _, _ = get_frame_at_pts(decoder, 6.006)
@@ -119,7 +123,8 @@ def test_get_frame_at_pts(self, device):
119123
@pytest.mark.parametrize("device", all_supported_devices())
120124
def test_get_frame_at_index(self, device):
121125
decoder = create_from_file(str(NASA_VIDEO.path))
122-
add_video_stream(decoder, device=device)
126+
device, device_variant = unsplit_device_str(device)
127+
add_video_stream(decoder, device=device, device_variant=device_variant)
123128
frame0, _, _ = get_frame_at_index(decoder, frame_index=0)
124129
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
125130
assert_frames_equal(frame0, reference_frame0.to(device))
@@ -137,7 +142,8 @@ def test_get_frame_at_index(self, device):
137142
@pytest.mark.parametrize("device", all_supported_devices())
138143
def test_get_frame_with_info_at_index(self, device):
139144
decoder = create_from_file(str(NASA_VIDEO.path))
140-
add_video_stream(decoder, device=device)
145+
device, device_variant = unsplit_device_str(device)
146+
add_video_stream(decoder, device=device, device_variant=device_variant)
141147
frame6, pts, duration = get_frame_at_index(decoder, frame_index=180)
142148
reference_frame6 = NASA_VIDEO.get_frame_data_by_index(
143149
INDEX_OF_FRAME_AT_6_SECONDS
@@ -149,7 +155,8 @@ def test_get_frame_with_info_at_index(self, device):
149155
@pytest.mark.parametrize("device", all_supported_devices())
150156
def test_get_frames_at_indices(self, device):
151157
decoder = create_from_file(str(NASA_VIDEO.path))
152-
add_video_stream(decoder, device=device)
158+
device, device_variant = unsplit_device_str(device)
159+
add_video_stream(decoder, device=device, device_variant=device_variant)
153160
frames0and180, *_ = get_frames_at_indices(decoder, frame_indices=[0, 180])
154161
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
155162
reference_frame180 = NASA_VIDEO.get_frame_data_by_index(
@@ -161,7 +168,8 @@ def test_get_frames_at_indices(self, device):
161168
@pytest.mark.parametrize("device", all_supported_devices())
162169
def test_get_frames_at_indices_unsorted_indices(self, device):
163170
decoder = create_from_file(str(NASA_VIDEO.path))
164-
_add_video_stream(decoder, device=device)
171+
device, device_variant = unsplit_device_str(device)
172+
add_video_stream(decoder, device=device, device_variant=device_variant)
165173

166174
frame_indices = [2, 0, 1, 0, 2]
167175

@@ -188,7 +196,8 @@ def test_get_frames_at_indices_unsorted_indices(self, device):
188196
@pytest.mark.parametrize("device", all_supported_devices())
189197
def test_get_frames_at_indices_negative_indices(self, device):
190198
decoder = create_from_file(str(NASA_VIDEO.path))
191-
add_video_stream(decoder, device=device)
199+
device, device_variant = unsplit_device_str(device)
200+
add_video_stream(decoder, device=device, device_variant=device_variant)
192201
frames389and387and1, *_ = get_frames_at_indices(
193202
decoder, frame_indices=[-1, -3, -389]
194203
)
@@ -202,7 +211,8 @@ def test_get_frames_at_indices_negative_indices(self, device):
202211
@pytest.mark.parametrize("device", all_supported_devices())
203212
def test_get_frames_at_indices_fail_on_invalid_negative_indices(self, device):
204213
decoder = create_from_file(str(NASA_VIDEO.path))
205-
add_video_stream(decoder, device=device)
214+
device, device_variant = unsplit_device_str(device)
215+
add_video_stream(decoder, device=device, device_variant=device_variant)
206216
with pytest.raises(
207217
IndexError,
208218
match="negative indices must have an absolute value less than the number of frames",
@@ -214,7 +224,8 @@ def test_get_frames_at_indices_fail_on_invalid_negative_indices(self, device):
214224
@pytest.mark.parametrize("device", all_supported_devices())
215225
def test_get_frames_by_pts(self, device):
216226
decoder = create_from_file(str(NASA_VIDEO.path))
217-
_add_video_stream(decoder, device=device)
227+
device, device_variant = unsplit_device_str(device)
228+
add_video_stream(decoder, device=device, device_variant=device_variant)
218229

219230
# Note: 13.01 should give the last video frame for the NASA video
220231
timestamps = [2, 0, 1, 0 + 1e-3, 13.01, 2 + 1e-3]
@@ -246,7 +257,8 @@ def test_pts_apis_against_index_ref(self, device):
246257
# APIs exactly where those frames are supposed to start. We assert that
247258
# we get the expected frame.
248259
decoder = create_from_file(str(NASA_VIDEO.path))
249-
add_video_stream(decoder, device=device)
260+
device, device_variant = unsplit_device_str(device)
261+
add_video_stream(decoder, device=device, device_variant=device_variant)
250262

251263
metadata = get_json_metadata(decoder)
252264
metadata_dict = json.loads(metadata)
@@ -297,7 +309,8 @@ def test_pts_apis_against_index_ref(self, device):
297309
@pytest.mark.parametrize("device", all_supported_devices())
298310
def test_get_frames_in_range(self, device):
299311
decoder = create_from_file(str(NASA_VIDEO.path))
300-
add_video_stream(decoder, device=device)
312+
device, device_variant = unsplit_device_str(device)
313+
add_video_stream(decoder, device=device, device_variant=device_variant)
301314

302315
# ensure that the degenerate case of a range of size 1 works
303316
ref_frame0 = NASA_VIDEO.get_frame_data_by_range(0, 1)
@@ -337,7 +350,8 @@ def test_get_frames_in_range(self, device):
337350
@pytest.mark.parametrize("device", all_supported_devices())
338351
def test_throws_exception_at_eof(self, device):
339352
decoder = create_from_file(str(NASA_VIDEO.path))
340-
add_video_stream(decoder, device=device)
353+
device, device_variant = unsplit_device_str(device)
354+
add_video_stream(decoder, device=device, device_variant=device_variant)
341355

342356
seek_to_pts(decoder, 12.979633)
343357
last_frame, _, _ = get_next_frame(decoder)
@@ -352,7 +366,8 @@ def test_throws_exception_at_eof(self, device):
352366
@pytest.mark.parametrize("device", all_supported_devices())
353367
def test_throws_exception_if_seek_too_far(self, device):
354368
decoder = create_from_file(str(NASA_VIDEO.path))
355-
add_video_stream(decoder, device=device)
369+
device, device_variant = unsplit_device_str(device)
370+
add_video_stream(decoder, device=device, device_variant=device_variant)
356371
# pts=12.979633 is the last frame in the video.
357372
seek_to_pts(decoder, 12.979633 + 1.0e-4)
358373
with pytest.raises(IndexError, match="no more frames"):
@@ -363,9 +378,11 @@ def test_compile_seek_and_next(self, device):
363378
# TODO_OPEN_ISSUE Scott (T180277797): Get this to work with the inductor stack. Right now
364379
# compilation fails because it can't handle tensors of size unknown at
365380
# compile-time.
381+
device, device_variant = unsplit_device_str(device)
382+
366383
@torch.compile(fullgraph=True, backend="eager")
367384
def get_frame1_and_frame_time6(decoder):
368-
add_video_stream(decoder, device=device)
385+
add_video_stream(decoder, device=device, device_variant=device_variant)
369386
frame0, _, _ = get_next_frame(decoder)
370387
seek_to_pts(decoder, 6.0)
371388
frame_time6, _, _ = get_next_frame(decoder)
@@ -408,7 +425,8 @@ def test_create_decoder(self, create_from, device):
408425
else:
409426
raise ValueError("Oops, double check the parametrization of this test!")
410427

411-
add_video_stream(decoder, device=device)
428+
device, device_variant = unsplit_device_str(device)
429+
add_video_stream(decoder, device=device, device_variant=device_variant)
412430
frame0, _, _ = get_next_frame(decoder)
413431
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
414432
assert_frames_equal(frame0, reference_frame0.to(device))
@@ -536,9 +554,11 @@ def test_seek_mode_custom_frame_mappings(self, device):
536554
decoder = create_from_file(
537555
str(NASA_VIDEO.path), seek_mode="custom_frame_mappings"
538556
)
557+
device, device_variant = unsplit_device_str(device)
539558
add_video_stream(
540559
decoder,
541560
device=device,
561+
device_variant=device_variant,
542562
stream_index=stream_index,
543563
custom_frame_mappings=NASA_VIDEO.get_custom_frame_mappings(
544564
stream_index=stream_index
@@ -1077,7 +1097,8 @@ def seek(self, offset: int, whence: int) -> int:
10771097
open(NASA_VIDEO.path, mode="rb", buffering=buffering)
10781098
)
10791099
decoder = create_from_file_like(file_counter, "approximate")
1080-
add_video_stream(decoder, device=device)
1100+
device, device_variant = unsplit_device_str(device)
1101+
add_video_stream(decoder, device=device, device_variant=device_variant)
10811102

10821103
frame0, *_ = get_next_frame(decoder)
10831104
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)

test/utils.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,27 @@ def needs_cuda(test_item):
2727

2828

2929
def all_supported_devices():
30-
return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))
30+
return (
31+
"cpu",
32+
pytest.param("cuda", marks=pytest.mark.needs_cuda),
33+
pytest.param("cuda:0:beta", marks=pytest.mark.needs_cuda),
34+
)
35+
36+
37+
def unsplit_device_str(device_str: str) -> str:
38+
# helper meant to be used as
39+
# device, device_variant = unsplit_device_str(device)
40+
# when `device` comes from all_supported_devices() and may be "cuda:0:beta".
41+
# It is used:
42+
# - before calling `.to(device)` where device can't be "cuda:0:beta"
43+
# - before calling add_video_stream(device=device, device_variant=device_variant)
44+
#
45+
# TODONVDEC P2: Find a less clunky way to test the BETA CUDA interface. It
46+
# will ultimately depend on how we want to publicly expose it.
47+
if device_str == "cuda:0:beta":
48+
return "cuda", "beta"
49+
else:
50+
return device_str, "default"
3151

3252

3353
def get_ffmpeg_major_version():

0 commit comments

Comments
 (0)