Skip to content

Commit a032cb7

Browse files
committed
Don't pass pre-allocated GPU tensor to CPU decoding
1 parent 8e55bd4 commit a032cb7

File tree

3 files changed

+18
-2
lines changed

3 files changed

+18
-2
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,15 +169,18 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
169169
outputDims_.width,
170170
outputDims_.height);
171171

172+
172173
outputTensor = preAllocatedOutputTensor.value_or(
173174
allocateEmptyHWCTensor(outputDims_, torch::kCPU));
174175

176+
175177
if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
176178
createSwsContext(swsFrameContext, avFrame->colorspace);
177179
prevSwsFrameContext_ = swsFrameContext;
178180
}
179181
int resultHeight =
180182
convertAVFrameToTensorUsingSwScale(avFrame, outputTensor);
183+
181184
// If this check failed, it would mean that the frame wasn't reshaped to
182185
// the expected height.
183186
// TODO: Can we do the same check for width?

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,18 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
231231

232232
FrameOutput cpuFrameOutput;
233233
cpuInterface->convertAVFrameToFrameOutput(
234-
avFrame, cpuFrameOutput, preAllocatedOutputTensor);
234+
avFrame, cpuFrameOutput);
235+
236+
// TODO: explain that the pre-allocated tensor is on the GPU, but we need
237+
// to do the decoding on the CPU, and we can't pass the pre-allocated tensor
238+
// to do it. BUT WHY did it work before?
239+
if (preAllocatedOutputTensor.has_value()) {
240+
preAllocatedOutputTensor.value().copy_(cpuFrameOutput.data);
241+
frameOutput.data = preAllocatedOutputTensor.value();
242+
} else {
243+
frameOutput.data = cpuFrameOutput.data.to(device_);
244+
}
235245

236-
frameOutput.data = cpuFrameOutput.data.to(device_);
237246
return;
238247
}
239248

test/test_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,13 @@ def test_get_frame_with_info_at_index(self, device):
146146

147147
@pytest.mark.parametrize("device", all_supported_devices())
148148
def test_get_frames_at_indices(self, device):
149+
print("test_get_frames_at_indices")
149150
decoder = create_from_file(str(NASA_VIDEO.path))
151+
print("decoder created")
150152
add_video_stream(decoder, device=device)
153+
print("stream added")
151154
frames0and180, *_ = get_frames_at_indices(decoder, frame_indices=[0, 180])
155+
print("frames retrieved")
152156
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
153157
reference_frame180 = NASA_VIDEO.get_frame_data_by_index(
154158
INDEX_OF_FRAME_AT_6_SECONDS

0 commit comments

Comments
 (0)