Skip to content

Commit fc5468e

Browse files
committed
Test to ensure transforms are not used with non-CPU
1 parent dda2649 commit fc5468e

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,10 @@ void SingleStreamDecoder::addVideoStream(
466466
std::vector<Transform*>& transforms,
467467
const VideoStreamOptions& videoStreamOptions,
468468
std::optional<FrameMappings> customFrameMappings) {
469+
TORCH_CHECK(
470+
transforms.empty() || videoStreamOptions.device == torch::kCPU,
471+
"Transforms are only supported for CPU devices.");
472+
469473
addStream(
470474
streamIndex,
471475
AVMEDIA_TYPE_VIDEO,

test/test_ops.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,15 @@ def test_color_conversion_library_with_scaling(
614614
assert_frames_equal(filtergraph_frame0, swscale_frame0)
615615
assert filtergraph_frame0.shape == (3, target_height, target_width)
616616

617+
@needs_cuda
618+
def test_scaling_on_cuda_fails(self):
619+
decoder = create_from_file(str(NASA_VIDEO.path))
620+
with pytest.raises(
621+
RuntimeError,
622+
match="Transforms are only supported for CPU devices.",
623+
):
624+
add_video_stream(decoder, device="cuda", width=100, height=100)
625+
617626
@pytest.mark.parametrize("dimension_order", ("NHWC", "NCHW"))
618627
@pytest.mark.parametrize("color_conversion_library", ("filtergraph", "swscale"))
619628
def test_color_conversion_library_with_dimension_order(

0 commit comments

Comments
 (0)