Skip to content

Commit 48e3ea3

Browse files
committed
Better comments; refactor toTensor
1 parent fc5468e commit 48e3ea3

File tree

5 files changed

+27
-31
lines changed

5 files changed

+27
-31
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,9 @@ void CpuDeviceInterface::initialize(
5656
timeBase_ = timeBase;
5757
outputDims_ = outputDims;
5858

59-
// TODO: rationalize comment below with new stuff.
60-
// By default, we want to use swscale for color conversion because it is
61-
// faster. However, it has width requirements, so we may need to fall back
62-
// to filtergraph. We also need to respect what was requested from the
63-
// options; we respect the options unconditionally, so it's possible for
64-
// swscale's width requirements to be violated. We don't expose the ability to
65-
// choose color conversion library publicly; we only use this ability
66-
// internally.
59+
// We want to use swscale for color conversion if possible because it is
60+
// faster than filtergraph. The following are the conditions we need to meet
61+
// to use it.
6762

6863
// We can only use swscale when we have a single resize transform. Note that
6964
// this means swscale will not support the case of having several,
@@ -76,12 +71,14 @@ void CpuDeviceInterface::initialize(
7671
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
7772
bool isWidthSwScaleCompatible = (outputDims_.width % 32) == 0;
7873

74+
// Note that we do not expose this capability in the public API, only through
75+
// the core API.
7976
bool userRequestedSwScale = videoStreamOptions_.colorConversionLibrary ==
8077
ColorConversionLibrary::SWSCALE;
8178

8279
// Note that we treat the transform limitation differently from the width
8380
// limitation. That is, we consider the transforms being compatible with
84-
// sws_scale as a hard requirement. If the transforms are not compatiable,
81+
// swscale as a hard requirement. If the transforms are not compatiable,
8582
// then we will end up not applying the transforms, and that is wrong.
8683
//
8784
// The width requirement, however, is a soft requirement. Even if we don't
@@ -94,7 +91,7 @@ void CpuDeviceInterface::initialize(
9491
colorConversionLibrary_ = ColorConversionLibrary::SWSCALE;
9592

9693
// We established above that if the transforms are swscale compatible and
97-
// non-empty, then they must have only one transforms, and that transform is
94+
// non-empty, then they must have only one transform, and that transform is
9895
// ResizeTransform.
9996
if (!transforms.empty()) {
10097
auto resize = dynamic_cast<ResizeTransform*>(transforms[0].get());
@@ -207,7 +204,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
207204
std::make_unique<FilterGraph>(filtersContext, videoStreamOptions_);
208205
prevFiltersContext_ = std::move(filtersContext);
209206
}
210-
outputTensor = toTensor(filterGraphContext_->convert(avFrame));
207+
outputTensor = rgbAVFrameToTensor(filterGraphContext_->convert(avFrame));
211208

212209
// Similarly to above, if this check fails it means the frame wasn't
213210
// reshaped to its expected dimensions by filtergraph.
@@ -256,21 +253,6 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
256253
return resultHeight;
257254
}
258255

259-
torch::Tensor CpuDeviceInterface::toTensor(const UniqueAVFrame& avFrame) {
260-
TORCH_CHECK_EQ(avFrame->format, AV_PIX_FMT_RGB24);
261-
262-
int height = avFrame->height;
263-
int width = avFrame->width;
264-
std::vector<int64_t> shape = {height, width, 3};
265-
std::vector<int64_t> strides = {avFrame->linesize[0], 3, 1};
266-
AVFrame* avFrameClone = av_frame_clone(avFrame.get());
267-
auto deleter = [avFrameClone](void*) {
268-
UniqueAVFrame avFrameToDelete(avFrameClone);
269-
};
270-
return torch::from_blob(
271-
avFrameClone->data[0], shape, strides, deleter, {torch::kUInt8});
272-
}
273-
274256
void CpuDeviceInterface::createSwsContext(
275257
const SwsFrameContext& swsFrameContext,
276258
const enum AVColorSpace colorspace) {

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ class CpuDeviceInterface : public DeviceInterface {
3636
std::optional<torch::Tensor> preAllocatedOutputTensor =
3737
std::nullopt) override;
3838

39-
torch::Tensor toTensor(const UniqueAVFrame& avFrame);
40-
4139
private:
4240
int convertAVFrameToTensorUsingSwScale(
4341
const UniqueAVFrame& avFrame,

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include <mutex>
66

77
#include "src/torchcodec/_core/Cache.h"
8-
#include "src/torchcodec/_core/CpuDeviceInterface.h"
98
#include "src/torchcodec/_core/CudaDeviceInterface.h"
109
#include "src/torchcodec/_core/FFMPEGCommon.h"
1110

@@ -344,7 +343,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
344343
// handle it. We send the frame back to the CUDA device when we're done.
345344
//
346345
// TODO: Perhaps we should cache cpuInterface?
347-
auto cpuInterface = std::make_unique<CpuDeviceInterface>(torch::kCPU);
346+
auto cpuInterface = createDeviceInterface(torch::kCPU);
348347
TORCH_CHECK(
349348
cpuInterface != nullptr, "Failed to create CPU device interface");
350349
cpuInterface->initialize(
@@ -360,7 +359,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
360359
avFrame->height == outputDims_.height) {
361360
// Reason 1 above. The frame is already in the format and dimensions that
362361
// we need, we just need to convert it to a tensor.
363-
cpuFrameOutput.data = cpuInterface->toTensor(avFrame);
362+
cpuFrameOutput.data = rgbAVFrameToTensor(avFrame);
364363
} else {
365364
// Reason 2 above. We need to do a full conversion.
366365
cpuInterface->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput);

src/torchcodec/_core/DeviceInterface.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,19 @@ std::unique_ptr<DeviceInterface> createDeviceInterface(
7676
return std::unique_ptr<DeviceInterface>(deviceMap[deviceType](device));
7777
}
7878

79+
torch::Tensor rgbAVFrameToTensor(const UniqueAVFrame& avFrame) {
80+
TORCH_CHECK_EQ(avFrame->format, AV_PIX_FMT_RGB24);
81+
82+
int height = avFrame->height;
83+
int width = avFrame->width;
84+
std::vector<int64_t> shape = {height, width, 3};
85+
std::vector<int64_t> strides = {avFrame->linesize[0], 3, 1};
86+
AVFrame* avFrameClone = av_frame_clone(avFrame.get());
87+
auto deleter = [avFrameClone](void*) {
88+
UniqueAVFrame avFrameToDelete(avFrameClone);
89+
};
90+
return torch::from_blob(
91+
avFrameClone->data[0], shape, strides, deleter, {torch::kUInt8});
92+
}
93+
7994
} // namespace facebook::torchcodec

src/torchcodec/_core/DeviceInterface.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,6 @@ torch::Device createTorchDevice(const std::string device);
6060
std::unique_ptr<DeviceInterface> createDeviceInterface(
6161
const torch::Device& device);
6262

63+
torch::Tensor rgbAVFrameToTensor(const UniqueAVFrame& avFrame);
64+
6365
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)