Skip to content

Commit f58816a

Browse files
committed
ops tests
1 parent 29edb40 commit f58816a

File tree

3 files changed

+28
-7
lines changed

3 files changed

+28
-7
lines changed

test/test_decoders.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,10 @@
4343
SINE_MONO_S32,
4444
SINE_MONO_S32_44100,
4545
SINE_MONO_S32_8000,
46+
cleanup_device_str,
4647
)
4748

4849

49-
def cleanup_device_str(device: str) -> str:
50-
# Remove any custom cuda device suffixes like ":custom_nvdec"
51-
if device.startswith("cuda:"):
52-
return device.split(":")[0] + ":" + device.split(":")[1]
53-
return device
54-
55-
5650
class TestDecoder:
5751
@pytest.mark.parametrize(
5852
"Decoder, asset",

test/test_ops.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
SINE_MONO_S32,
5252
SINE_MONO_S32_44100,
5353
SINE_MONO_S32_8000,
54+
cleanup_device_str,
5455
)
5556

5657
torch._dynamo.config.capture_dynamic_output_shape_ops = True
@@ -63,6 +64,7 @@ class TestVideoDecoderOps:
6364
def test_seek_and_next(self, device):
6465
decoder = create_from_file(str(NASA_VIDEO.path))
6566
add_video_stream(decoder, device=device)
67+
device = cleanup_device_str(device)
6668
frame0, _, _ = get_next_frame(decoder)
6769
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
6870
assert_frames_equal(frame0, reference_frame0.to(device))
@@ -80,6 +82,7 @@ def test_seek_and_next(self, device):
8082
def test_seek_to_negative_pts(self, device):
8183
decoder = create_from_file(str(NASA_VIDEO.path))
8284
add_video_stream(decoder, device=device)
85+
device = cleanup_device_str(device)
8386
frame0, _, _ = get_next_frame(decoder)
8487
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
8588
assert_frames_equal(frame0, reference_frame0.to(device))
@@ -92,6 +95,7 @@ def test_seek_to_negative_pts(self, device):
9295
def test_get_frame_at_pts(self, device):
9396
decoder = create_from_file(str(NASA_VIDEO.path))
9497
add_video_stream(decoder, device=device)
98+
device = cleanup_device_str(device)
9599
# This frame has pts=6.006 and duration=0.033367, so it should be visible
96100
# at timestamps in the range [6.006, 6.039367) (not including the last timestamp).
97101
frame6, _, _ = get_frame_at_pts(decoder, 6.006)
@@ -116,6 +120,7 @@ def test_get_frame_at_pts(self, device):
116120
def test_get_frame_at_index(self, device):
117121
decoder = create_from_file(str(NASA_VIDEO.path))
118122
add_video_stream(decoder, device=device)
123+
device = cleanup_device_str(device)
119124
frame0, _, _ = get_frame_at_index(decoder, frame_index=0)
120125
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
121126
assert_frames_equal(frame0, reference_frame0.to(device))
@@ -134,6 +139,7 @@ def test_get_frame_at_index(self, device):
134139
def test_get_frame_with_info_at_index(self, device):
135140
decoder = create_from_file(str(NASA_VIDEO.path))
136141
add_video_stream(decoder, device=device)
142+
device = cleanup_device_str(device)
137143
frame6, pts, duration = get_frame_at_index(decoder, frame_index=180)
138144
reference_frame6 = NASA_VIDEO.get_frame_data_by_index(
139145
INDEX_OF_FRAME_AT_6_SECONDS
@@ -146,6 +152,7 @@ def test_get_frame_with_info_at_index(self, device):
146152
def test_get_frames_at_indices(self, device):
147153
decoder = create_from_file(str(NASA_VIDEO.path))
148154
add_video_stream(decoder, device=device)
155+
device = cleanup_device_str(device)
149156
frames0and180, *_ = get_frames_at_indices(decoder, frame_indices=[0, 180])
150157
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
151158
reference_frame180 = NASA_VIDEO.get_frame_data_by_index(
@@ -158,6 +165,7 @@ def test_get_frames_at_indices(self, device):
158165
def test_get_frames_at_indices_unsorted_indices(self, device):
159166
decoder = create_from_file(str(NASA_VIDEO.path))
160167
_add_video_stream(decoder, device=device)
168+
device = cleanup_device_str(device)
161169

162170
frame_indices = [2, 0, 1, 0, 2]
163171

@@ -185,6 +193,7 @@ def test_get_frames_at_indices_unsorted_indices(self, device):
185193
def test_get_frames_at_indices_negative_indices(self, device):
186194
decoder = create_from_file(str(NASA_VIDEO.path))
187195
add_video_stream(decoder, device=device)
196+
device = cleanup_device_str(device)
188197
frames389and387and1, *_ = get_frames_at_indices(
189198
decoder, frame_indices=[-1, -3, -389]
190199
)
@@ -199,6 +208,7 @@ def test_get_frames_at_indices_negative_indices(self, device):
199208
def test_get_frames_at_indices_fail_on_invalid_negative_indices(self, device):
200209
decoder = create_from_file(str(NASA_VIDEO.path))
201210
add_video_stream(decoder, device=device)
211+
device = cleanup_device_str(device)
202212
with pytest.raises(
203213
IndexError,
204214
match="negative indices must have an absolute value less than the number of frames",
@@ -211,6 +221,7 @@ def test_get_frames_at_indices_fail_on_invalid_negative_indices(self, device):
211221
def test_get_frames_by_pts(self, device):
212222
decoder = create_from_file(str(NASA_VIDEO.path))
213223
_add_video_stream(decoder, device=device)
224+
device = cleanup_device_str(device)
214225

215226
# Note: 13.01 should give the last video frame for the NASA video
216227
timestamps = [2, 0, 1, 0 + 1e-3, 13.01, 2 + 1e-3]
@@ -243,6 +254,7 @@ def test_pts_apis_against_index_ref(self, device):
243254
# we get the expected frame.
244255
decoder = create_from_file(str(NASA_VIDEO.path))
245256
add_video_stream(decoder, device=device)
257+
device = cleanup_device_str(device)
246258

247259
metadata = get_json_metadata(decoder)
248260
metadata_dict = json.loads(metadata)
@@ -294,6 +306,7 @@ def test_pts_apis_against_index_ref(self, device):
294306
def test_get_frames_in_range(self, device):
295307
decoder = create_from_file(str(NASA_VIDEO.path))
296308
add_video_stream(decoder, device=device)
309+
device = cleanup_device_str(device)
297310

298311
# ensure that the degenerate case of a range of size 1 works
299312
ref_frame0 = NASA_VIDEO.get_frame_data_by_range(0, 1)
@@ -334,6 +347,7 @@ def test_get_frames_in_range(self, device):
334347
def test_throws_exception_at_eof(self, device):
335348
decoder = create_from_file(str(NASA_VIDEO.path))
336349
add_video_stream(decoder, device=device)
350+
device = cleanup_device_str(device)
337351

338352
seek_to_pts(decoder, 12.979633)
339353
last_frame, _, _ = get_next_frame(decoder)
@@ -362,6 +376,7 @@ def test_compile_seek_and_next(self, device):
362376
@torch.compile(fullgraph=True, backend="eager")
363377
def get_frame1_and_frame_time6(decoder):
364378
add_video_stream(decoder, device=device)
379+
device = cleanup_device_str(device)
365380
frame0, _, _ = get_next_frame(decoder)
366381
seek_to_pts(decoder, 6.0)
367382
frame_time6, _, _ = get_next_frame(decoder)
@@ -405,6 +420,7 @@ def test_create_decoder(self, create_from, device):
405420
raise ValueError("Oops, double check the parametrization of this test!")
406421

407422
add_video_stream(decoder, device=device)
423+
device = cleanup_device_str(device)
408424
frame0, _, _ = get_next_frame(decoder)
409425
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
410426
assert_frames_equal(frame0, reference_frame0.to(device))
@@ -510,6 +526,7 @@ def test_seek_mode_custom_frame_mappings(self, device):
510526
decoder = create_from_file(
511527
str(NASA_VIDEO.path), seek_mode="custom_frame_mappings"
512528
)
529+
device = cleanup_device_str(device)
513530
add_video_stream(
514531
decoder,
515532
device=device,
@@ -1042,6 +1059,7 @@ def seek(self, offset: int, whence: int) -> int:
10421059
)
10431060
decoder = create_from_file_like(file_counter, "approximate")
10441061
add_video_stream(decoder, device=device)
1062+
device = cleanup_device_str(device)
10451063

10461064
frame0, *_ = get_next_frame(decoder)
10471065
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)

test/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@ def all_supported_devices():
3030
return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda), pytest.param("cuda:0:custom_nvdec", marks=pytest.mark.needs_cuda))
3131

3232

33+
def cleanup_device_str(device: str) -> str:
34+
# Remove any custom cuda device suffixes like ":custom_nvdec"
35+
# To be called before calling `.to(device)` on a tensor.
36+
# TODO THIS IS AWFUL.
37+
if device.startswith("cuda:"):
38+
return device.split(":")[0] + ":" + device.split(":")[1]
39+
return device
40+
41+
3342
def get_ffmpeg_major_version():
3443
ffmpeg_version = get_ffmpeg_library_versions()["ffmpeg_version"]
3544
# When building FFmpeg from source there can be a `n` prefix in the version

0 commit comments

Comments
 (0)