Skip to content

Commit 3626854

Browse files
committed
Make swscale and filtergraph look more similar
1 parent fb06f87 commit 3626854

File tree

2 files changed

+58
-41
lines changed

2 files changed

+58
-41
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 52 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -178,33 +178,13 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
178178

179179
auto colorConversionLibrary = getColorConversionLibrary(outputDims);
180180
torch::Tensor outputTensor;
181-
enum AVPixelFormat frameFormat =
182-
static_cast<enum AVPixelFormat>(avFrame->format);
183181

184182
if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
185-
// We need to compare the current frame context with our previous frame
186-
// context. If they are different, then we need to re-create our colorspace
187-
// conversion objects. We create our colorspace conversion objects late so
188-
// that we don't have to depend on the unreliable metadata in the header.
189-
// And we sometimes re-create them because it's possible for frame
190-
// resolution to change mid-stream. Finally, we want to reuse the colorspace
191-
// conversion objects as much as possible for performance reasons.
192-
SwsFrameContext swsFrameContext(
193-
avFrame->width,
194-
avFrame->height,
195-
frameFormat,
196-
outputDims.width,
197-
outputDims.height);
198-
199183
outputTensor = preAllocatedOutputTensor.value_or(
200184
allocateEmptyHWCTensor(outputDims, torch::kCPU));
201185

202-
if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
203-
createSwsContext(swsFrameContext, avFrame->colorspace);
204-
prevSwsFrameContext_ = swsFrameContext;
205-
}
206186
int resultHeight =
207-
convertAVFrameToTensorUsingSwScale(avFrame, outputTensor);
187+
convertAVFrameToTensorUsingSwScale(avFrame, outputTensor, outputDims);
208188

209189
// If this check failed, it would mean that the frame wasn't reshaped to
210190
// the expected height.
@@ -218,23 +198,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
218198

219199
frameOutput.data = outputTensor;
220200
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
221-
FiltersContext filtersContext(
222-
avFrame->width,
223-
avFrame->height,
224-
frameFormat,
225-
avFrame->sample_aspect_ratio,
226-
outputDims.width,
227-
outputDims.height,
228-
AV_PIX_FMT_RGB24,
229-
filters_,
230-
timeBase_);
231-
232-
if (!filterGraph_ || prevFiltersContext_ != filtersContext) {
233-
filterGraph_ =
234-
std::make_unique<FilterGraph>(filtersContext, videoStreamOptions_);
235-
prevFiltersContext_ = std::move(filtersContext);
236-
}
237-
outputTensor = rgbAVFrameToTensor(filterGraph_->convert(avFrame));
201+
outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame, outputDims);
238202

239203
// Similarly to above, if this check fails it means the frame wasn't
240204
// reshaped to its expected dimensions by filtergraph.
@@ -267,7 +231,30 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
267231

268232
int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
269233
const UniqueAVFrame& avFrame,
270-
torch::Tensor& outputTensor) {
234+
torch::Tensor& outputTensor,
235+
const FrameDims& outputDims) {
236+
enum AVPixelFormat frameFormat =
237+
static_cast<enum AVPixelFormat>(avFrame->format);
238+
239+
// We need to compare the current frame context with our previous frame
240+
// context. If they are different, then we need to re-create our colorspace
241+
// conversion objects. We create our colorspace conversion objects late so
242+
// that we don't have to depend on the unreliable metadata in the header.
243+
// And we sometimes re-create them because it's possible for frame
244+
// resolution to change mid-stream. Finally, we want to reuse the colorspace
245+
// conversion objects as much as possible for performance reasons.
246+
SwsFrameContext swsFrameContext(
247+
avFrame->width,
248+
avFrame->height,
249+
frameFormat,
250+
outputDims.width,
251+
outputDims.height);
252+
253+
if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
254+
createSwsContext(swsFrameContext, avFrame->colorspace);
255+
prevSwsFrameContext_ = swsFrameContext;
256+
}
257+
271258
uint8_t* pointers[4] = {
272259
outputTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
273260
int expectedOutputWidth = outputTensor.sizes()[1];
@@ -293,7 +280,7 @@ void CpuDeviceInterface::createSwsContext(
293280
swsFrameContext.outputWidth,
294281
swsFrameContext.outputHeight,
295282
AV_PIX_FMT_RGB24,
296-
SWS_BILINEAR,
283+
swsFlags_,
297284
nullptr,
298285
nullptr,
299286
nullptr);
@@ -328,4 +315,29 @@ void CpuDeviceInterface::createSwsContext(
328315
swsContext_.reset(swsContext);
329316
}
330317

318+
torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
319+
const UniqueAVFrame& avFrame,
320+
const FrameDims& outputDims) {
321+
enum AVPixelFormat frameFormat =
322+
static_cast<enum AVPixelFormat>(avFrame->format);
323+
324+
FiltersContext filtersContext(
325+
avFrame->width,
326+
avFrame->height,
327+
frameFormat,
328+
avFrame->sample_aspect_ratio,
329+
outputDims.width,
330+
outputDims.height,
331+
AV_PIX_FMT_RGB24,
332+
filters_,
333+
timeBase_);
334+
335+
if (!filterGraph_ || prevFiltersContext_ != filtersContext) {
336+
filterGraph_ =
337+
std::make_unique<FilterGraph>(filtersContext, videoStreamOptions_);
338+
prevFiltersContext_ = std::move(filtersContext);
339+
}
340+
return rgbAVFrameToTensor(filterGraph_->convert(avFrame));
341+
}
342+
331343
} // namespace facebook::torchcodec

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,12 @@ class CpuDeviceInterface : public DeviceInterface {
3939
private:
4040
int convertAVFrameToTensorUsingSwScale(
4141
const UniqueAVFrame& avFrame,
42-
torch::Tensor& outputTensor);
42+
torch::Tensor& outputTensor,
43+
const FrameDims& outputDims);
44+
45+
torch::Tensor convertAVFrameToTensorUsingFilterGraph(
46+
const UniqueAVFrame& avFrame,
47+
const FrameDims& outputDims);
4348

4449
ColorConversionLibrary getColorConversionLibrary(
4550
const FrameDims& inputFrameDims) const;

0 commit comments

Comments
 (0)