Skip to content

Commit 7813005

Browse files
committed
Deal with variable resolution and lying metadata - again
1 parent 48e3ea3 commit 7813005

File tree

7 files changed

+189
-118
lines changed

7 files changed

+189
-118
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 95 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -51,31 +51,70 @@ void CpuDeviceInterface::initialize(
5151
const VideoStreamOptions& videoStreamOptions,
5252
const std::vector<std::unique_ptr<Transform>>& transforms,
5353
const AVRational& timeBase,
54-
const FrameDims& outputDims) {
54+
[[maybe_unused]] const FrameDims& metadataDims,
55+
const std::optional<FrameDims>& resizedOutputDims) {
5556
videoStreamOptions_ = videoStreamOptions;
5657
timeBase_ = timeBase;
57-
outputDims_ = outputDims;
58-
59-
// We want to use swscale for color conversion if possible because it is
60-
// faster than filtergraph. The following are the conditions we need to meet
61-
// to use it.
58+
resizedOutputDims_ = resizedOutputDims;
6259

6360
// We can only use swscale when we have a single resize transform. Note that
6461
// this means swscale will not support the case of having several,
6562
// back-to-base resizes. There's no strong reason to even do that, but if
6663
// someone does, it's more correct to implement that with filtergraph.
67-
bool areTransformsSwScaleCompatible = transforms.empty() ||
64+
//
65+
// We calculate this value during initilization but we don't refer to it until
66+
// getColorConversionLibrary() is called. Calculating this value during
67+
// initialization saves us from having to save all of the transforms.
68+
areTransformsSwScaleCompatible_ = transforms.empty() ||
6869
(transforms.size() == 1 && transforms[0]->isResize());
6970

70-
// swscale requires widths to be multiples of 32:
71-
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
72-
bool isWidthSwScaleCompatible = (outputDims_.width % 32) == 0;
73-
7471
// Note that we do not expose this capability in the public API, only through
7572
// the core API.
76-
bool userRequestedSwScale = videoStreamOptions_.colorConversionLibrary ==
73+
//
74+
// Same as above, we calculate this value during initialization and refer to
75+
// it in getColorConversionLibrary().
76+
userRequestedSwScale_ = videoStreamOptions_.colorConversionLibrary ==
7777
ColorConversionLibrary::SWSCALE;
7878

79+
// We can only use swscale when we have a single resize transform. Note that
80+
// we actually decide on whether or not to actually use swscale at the last
81+
// possible moment, when we actually convert the frame. This is because we
82+
// need to know the actual frame dimensions.
83+
if (transforms.size() == 1 && transforms[0]->isResize()) {
84+
auto resize = dynamic_cast<ResizeTransform*>(transforms[0].get());
85+
TORCH_CHECK(resize != nullptr, "ResizeTransform expected but not found!")
86+
swsFlags_ = resize->getSwsFlags();
87+
}
88+
89+
// If we have any transforms, replace filters_ with the filter strings from
90+
// the transforms. As noted above, we decide between swscale and filtergraph
91+
// when we actually decode a frame.
92+
std::stringstream filters;
93+
bool first = true;
94+
for (const auto& transform : transforms) {
95+
if (!first) {
96+
filters << ",";
97+
}
98+
filters << transform->getFilterGraphCpu();
99+
first = false;
100+
}
101+
if (!transforms.empty()) {
102+
filters_ = filters.str();
103+
}
104+
105+
initialized_ = true;
106+
}
107+
108+
ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary(
109+
const FrameDims& outputDims) {
110+
// swscale requires widths to be multiples of 32:
111+
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
112+
bool isWidthSwScaleCompatible = (outputDims.width % 32) == 0;
113+
114+
// We want to use swscale for color conversion if possible because it is
115+
// faster than filtergraph. The following are the conditions we need to meet
116+
// to use it.
117+
//
79118
// Note that we treat the transform limitation differently from the width
80119
// limitation. That is, we consider the transforms being compatible with
81120
// swscale as a hard requirement. If the transforms are not compatiable,
@@ -86,38 +125,12 @@ void CpuDeviceInterface::initialize(
86125
// behavior. Since we don't expose the ability to choose swscale or
87126
// filtergraph in our public API, this is probably okay. It's also the only
88127
// way that we can be certain we are testing one versus the other.
89-
if (areTransformsSwScaleCompatible &&
90-
(userRequestedSwScale || isWidthSwScaleCompatible)) {
91-
colorConversionLibrary_ = ColorConversionLibrary::SWSCALE;
92-
93-
// We established above that if the transforms are swscale compatible and
94-
// non-empty, then they must have only one transform, and that transform is
95-
// ResizeTransform.
96-
if (!transforms.empty()) {
97-
auto resize = dynamic_cast<ResizeTransform*>(transforms[0].get());
98-
TORCH_CHECK(resize != nullptr, "ResizeTransform expected but not found!")
99-
swsFlags_ = resize->getSwsFlags();
100-
}
128+
if (areTransformsSwScaleCompatible_ &&
129+
(userRequestedSwScale_ || isWidthSwScaleCompatible)) {
130+
return ColorConversionLibrary::SWSCALE;
101131
} else {
102-
colorConversionLibrary_ = ColorConversionLibrary::FILTERGRAPH;
103-
104-
// If we have any transforms, replace filters_ with the filter strings from
105-
// the transforms.
106-
std::stringstream filters;
107-
bool first = true;
108-
for (const auto& transform : transforms) {
109-
if (!first) {
110-
filters << ",";
111-
}
112-
filters << transform->getFilterGraphCpu();
113-
first = false;
114-
}
115-
if (!transforms.empty()) {
116-
filters_ = filters.str();
117-
}
132+
return ColorConversionLibrary::FILTERGRAPH;
118133
}
119-
120-
initialized_ = true;
121134
}
122135

123136
// Note [preAllocatedOutputTensor with swscale and filtergraph]:
@@ -134,24 +147,42 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
134147
FrameOutput& frameOutput,
135148
std::optional<torch::Tensor> preAllocatedOutputTensor) {
136149
TORCH_CHECK(initialized_, "CpuDeviceInterface was not initialized.");
150+
151+
// Note that we ignore the dimensions from the metadata; we don't even bother
152+
// storing them. The resized dimensions take priority. If we don't have any,
153+
// then we use the dimensions from the actual decoded frame. We use the actual
154+
// decoded frame and not the metadata for two reasons:
155+
//
156+
// 1. Metadata may be wrong. If we access to more accurate information, we
157+
// should use it.
158+
// 2. Video streams can have variable resolution. This fact is not captured
159+
// in the stream metadata.
160+
//
161+
// Both cases cause problems for our batch APIs, as we allocate
162+
// FrameBatchOutputs based on the the stream metadata. But single-frame APIs
163+
// can still work in such situations, so they should.
164+
auto outputDims =
165+
resizedOutputDims_.value_or(FrameDims(avFrame->width, avFrame->height));
166+
137167
if (preAllocatedOutputTensor.has_value()) {
138168
auto shape = preAllocatedOutputTensor.value().sizes();
139169
TORCH_CHECK(
140-
(shape.size() == 3) && (shape[0] == outputDims_.height) &&
141-
(shape[1] == outputDims_.width) && (shape[2] == 3),
170+
(shape.size() == 3) && (shape[0] == outputDims.height) &&
171+
(shape[1] == outputDims.width) && (shape[2] == 3),
142172
"Expected pre-allocated tensor of shape ",
143-
outputDims_.height,
173+
outputDims.height,
144174
"x",
145-
outputDims_.width,
175+
outputDims.width,
146176
"x3, got ",
147177
shape);
148178
}
149179

180+
auto colorConversionLibrary = getColorConversionLibrary(outputDims);
150181
torch::Tensor outputTensor;
151182
enum AVPixelFormat frameFormat =
152183
static_cast<enum AVPixelFormat>(avFrame->format);
153184

154-
if (colorConversionLibrary_ == ColorConversionLibrary::SWSCALE) {
185+
if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
155186
// We need to compare the current frame context with our previous frame
156187
// context. If they are different, then we need to re-create our colorspace
157188
// conversion objects. We create our colorspace conversion objects late so
@@ -163,11 +194,11 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
163194
avFrame->width,
164195
avFrame->height,
165196
frameFormat,
166-
outputDims_.width,
167-
outputDims_.height);
197+
outputDims.width,
198+
outputDims.height);
168199

169200
outputTensor = preAllocatedOutputTensor.value_or(
170-
allocateEmptyHWCTensor(outputDims_, torch::kCPU));
201+
allocateEmptyHWCTensor(outputDims, torch::kCPU));
171202

172203
if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
173204
createSwsContext(swsFrameContext, avFrame->colorspace);
@@ -180,42 +211,42 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
180211
// the expected height.
181212
// TODO: Can we do the same check for width?
182213
TORCH_CHECK(
183-
resultHeight == outputDims_.height,
184-
"resultHeight != outputDims_.height: ",
214+
resultHeight == outputDims.height,
215+
"resultHeight != outputDims.height: ",
185216
resultHeight,
186217
" != ",
187-
outputDims_.height);
218+
outputDims.height);
188219

189220
frameOutput.data = outputTensor;
190-
} else if (colorConversionLibrary_ == ColorConversionLibrary::FILTERGRAPH) {
221+
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
191222
FiltersContext filtersContext(
192223
avFrame->width,
193224
avFrame->height,
194225
frameFormat,
195226
avFrame->sample_aspect_ratio,
196-
outputDims_.width,
197-
outputDims_.height,
227+
outputDims.width,
228+
outputDims.height,
198229
AV_PIX_FMT_RGB24,
199230
filters_,
200231
timeBase_);
201232

202-
if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
203-
filterGraphContext_ =
233+
if (!filterGraph_ || prevFiltersContext_ != filtersContext) {
234+
filterGraph_ =
204235
std::make_unique<FilterGraph>(filtersContext, videoStreamOptions_);
205236
prevFiltersContext_ = std::move(filtersContext);
206237
}
207-
outputTensor = rgbAVFrameToTensor(filterGraphContext_->convert(avFrame));
238+
outputTensor = rgbAVFrameToTensor(filterGraph_->convert(avFrame));
208239

209240
// Similarly to above, if this check fails it means the frame wasn't
210241
// reshaped to its expected dimensions by filtergraph.
211242
auto shape = outputTensor.sizes();
212243
TORCH_CHECK(
213-
(shape.size() == 3) && (shape[0] == outputDims_.height) &&
214-
(shape[1] == outputDims_.width) && (shape[2] == 3),
244+
(shape.size() == 3) && (shape[0] == outputDims.height) &&
245+
(shape[1] == outputDims.width) && (shape[2] == 3),
215246
"Expected output tensor of shape ",
216-
outputDims_.height,
247+
outputDims.height,
217248
"x",
218-
outputDims_.width,
249+
outputDims.width,
219250
"x3, got ",
220251
shape);
221252

@@ -231,7 +262,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
231262
TORCH_CHECK(
232263
false,
233264
"Invalid color conversion library: ",
234-
static_cast<int>(colorConversionLibrary_));
265+
static_cast<int>(colorConversionLibrary));
235266
}
236267
}
237268

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ class CpuDeviceInterface : public DeviceInterface {
2828
const VideoStreamOptions& videoStreamOptions,
2929
const std::vector<std::unique_ptr<Transform>>& transforms,
3030
const AVRational& timeBase,
31-
const FrameDims& outputDims) override;
31+
[[maybe_unused]] const FrameDims& metadataDims,
32+
const std::optional<FrameDims>& resizedOutputDims) override;
3233

3334
void convertAVFrameToFrameOutput(
3435
UniqueAVFrame& avFrame,
@@ -41,6 +42,9 @@ class CpuDeviceInterface : public DeviceInterface {
4142
const UniqueAVFrame& avFrame,
4243
torch::Tensor& outputTensor);
4344

45+
ColorConversionLibrary getColorConversionLibrary(
46+
const FrameDims& inputFrameDims);
47+
4448
struct SwsFrameContext {
4549
int inputWidth = 0;
4650
int inputHeight = 0;
@@ -64,28 +68,44 @@ class CpuDeviceInterface : public DeviceInterface {
6468
const enum AVColorSpace colorspace);
6569

6670
VideoStreamOptions videoStreamOptions_;
67-
ColorConversionLibrary colorConversionLibrary_;
6871
AVRational timeBase_;
69-
FrameDims outputDims_;
70-
71-
// If we use swscale for resizing, the flags control the resizing algorithm.
72-
// We default to bilinear. Users can override this with a ResizeTransform.
73-
int swsFlags_ = SWS_BILINEAR;
72+
std::optional<FrameDims> resizedOutputDims_;
73+
74+
// Color-conversion objects. Only one of filterGraph_ and swsContext_ should
75+
// be non-null. Which one we use is controlled by colorConversionLibrary_.
76+
//
77+
// Creating both filterGraph_ and swsContext_ is relatively expensive, so we
78+
// reuse them across frames. However, it is possbile that subsequent frames
79+
// are different enough (change in dimensions) that we can't reuse the color
80+
// conversion object. We store the relevant frame context from the frame used
81+
// to create the object last time. We always compare the current frame's info
82+
// against the previous one to determine if we need to recreate the color
83+
// conversion object.
84+
//
85+
// TODO: The names of these fields is confusing, as the actual color
86+
// conversion object for Sws has "context" in the name, and we use
87+
// "context" for the structs we store to know if we need to recreate a
88+
// color conversion object. We should clean that up.
89+
std::unique_ptr<FilterGraph> filterGraph_;
90+
FiltersContext prevFiltersContext_;
91+
UniqueSwsContext swsContext_;
92+
SwsFrameContext prevSwsFrameContext_;
7493

75-
// The copy filter just copies the input to the output. Computationally, it
76-
// should be a no-op. If we get no user-provided transforms, we will use the
77-
// copy filter.
94+
// The filter we supply to filterGraph_, if it is used. The copy filter just
95+
// copies the input to the output. Computationally, it should be a no-op. If
96+
// we get no user-provided transforms, we will use the copy filter. Otherwise,
97+
// we will construct the string from the transforms.
7898
std::string filters_ = "copy";
7999

80-
// color-conversion fields. Only one of FilterGraphContext and
81-
// UniqueSwsContext should be non-null.
82-
std::unique_ptr<FilterGraph> filterGraphContext_;
83-
UniqueSwsContext swsContext_;
100+
// The flags we supply to swsContext_, if it used. The flags control the
101+
// resizing algorithm. We default to bilinear. Users can override this with a
102+
// ResizeTransform.
103+
int swsFlags_ = SWS_BILINEAR;
84104

85-
// Used to know whether a new FilterGraphContext or UniqueSwsContext should
86-
// be created before decoding a new frame.
87-
SwsFrameContext prevSwsFrameContext_;
88-
FiltersContext prevFiltersContext_;
105+
// Values set during initialization and referred to in
106+
// getColorConversionLibrary().
107+
bool areTransformsSwScaleCompatible_;
108+
bool userRequestedSwScale_;
89109

90110
bool initialized_ = false;
91111
};

0 commit comments

Comments
 (0)