Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 279 additions & 0 deletions src/torchcodec/_core/Encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,4 +511,283 @@ void AudioEncoder::flushBuffers() {

encodeFrame(autoAVPacket, UniqueAVFrame(nullptr));
}

namespace {

torch::Tensor validateFrames(const torch::Tensor& frames) {
TORCH_CHECK(
frames.dtype() == torch::kUInt8,
"frames must have uint8 dtype, got ",
frames.dtype());
TORCH_CHECK(
frames.dim() == 4,
"frames must have 4 dimensions (N, C, H, W), got ",
frames.dim());
TORCH_CHECK(
frames.sizes()[1] == 3,
"frame must have 3 channels (R, G, B), got ",
frames.sizes()[1]);
// TODO-VideoEncoder: Investigate if non-contiguous frames can be accepted
return frames.contiguous();
}

} // namespace

VideoEncoder::~VideoEncoder() {
if (avFormatContext_ && avFormatContext_->pb) {
avio_flush(avFormatContext_->pb);
avio_close(avFormatContext_->pb);
avFormatContext_->pb = nullptr;
}
}

VideoEncoder::VideoEncoder(
const torch::Tensor& frames,
int frameRate,
std::string_view fileName,
const VideoStreamOptions& videoStreamOptions)
: frames_(validateFrames(frames)), inFrameRate_(frameRate) {
setFFmpegLogLevel();

// Allocate output format context
AVFormatContext* avFormatContext = nullptr;
int status = avformat_alloc_output_context2(
&avFormatContext, nullptr, nullptr, fileName.data());

TORCH_CHECK(
avFormatContext != nullptr,
"Couldn't allocate AVFormatContext. ",
"The destination file is ",
fileName,
", check the desired extension? ",
getFFMPEGErrorStringFromErrorCode(status));
avFormatContext_.reset(avFormatContext);

status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE);
TORCH_CHECK(
status >= 0,
"avio_open failed. The destination file is ",
fileName,
", make sure it's a valid path? ",
getFFMPEGErrorStringFromErrorCode(status));
// TODO-VideoEncoder: Add tests for above fileName related checks

initializeEncoder(videoStreamOptions);
}

void VideoEncoder::initializeEncoder(
const VideoStreamOptions& videoStreamOptions) {
const AVCodec* avCodec =
avcodec_find_encoder(avFormatContext_->oformat->video_codec);
TORCH_CHECK(avCodec != nullptr, "Video codec not found");

AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec);
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
avCodecContext_.reset(avCodecContext);

// Set encoding options
// TODO-VideoEncoder: Allow bitrate to be set
std::optional<int> desiredBitRate = videoStreamOptions.bitRate;
if (desiredBitRate.has_value()) {
TORCH_CHECK(
*desiredBitRate >= 0, "bit_rate=", *desiredBitRate, " must be >= 0.");
}
avCodecContext_->bit_rate = desiredBitRate.value_or(0);

// Store dimension order and input pixel format
// TODO-VideoEncoder: Remove assumption that tensor in NCHW format
auto sizes = frames_.sizes();
inPixelFormat_ = AV_PIX_FMT_GBRP;
inHeight_ = sizes[2];
inWidth_ = sizes[3];

// Use specified dimensions or input dimensions
// TODO-VideoEncoder: Allow height and width to be set
outWidth_ = videoStreamOptions.width.value_or(inWidth_);
outHeight_ = videoStreamOptions.height.value_or(inHeight_);

// Use YUV420P as default output format
// TODO-VideoEncoder: Enable other pixel formats
outPixelFormat_ = AV_PIX_FMT_YUV420P;

// Configure codec parameters
avCodecContext_->codec_id = avCodec->id;
avCodecContext_->width = outWidth_;
avCodecContext_->height = outHeight_;
avCodecContext_->pix_fmt = outPixelFormat_;
// TODO-VideoEncoder: Verify that frame_rate and time_base are correct
avCodecContext_->time_base = {1, inFrameRate_};
avCodecContext_->framerate = {inFrameRate_, 1};

// TODO-VideoEncoder: Allow GOP size and max B-frames to be set
if (videoStreamOptions.gopSize.has_value()) {
avCodecContext_->gop_size = *videoStreamOptions.gopSize;
} else {
avCodecContext_->gop_size = 12; // Default GOP size
}

if (videoStreamOptions.maxBFrames.has_value()) {
avCodecContext_->max_b_frames = *videoStreamOptions.maxBFrames;
} else {
avCodecContext_->max_b_frames = 0; // No max B-frames to reduce compression
}

int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
TORCH_CHECK(
status == AVSUCCESS,
"avcodec_open2 failed: ",
getFFMPEGErrorStringFromErrorCode(status));

AVStream* avStream = avformat_new_stream(avFormatContext_.get(), nullptr);
TORCH_CHECK(avStream != nullptr, "Couldn't create new stream.");

// Set the stream time base to encode correct frame timestamps
avStream->time_base = avCodecContext_->time_base;
status = avcodec_parameters_from_context(
avStream->codecpar, avCodecContext_.get());
TORCH_CHECK(
status == AVSUCCESS,
"avcodec_parameters_from_context failed: ",
getFFMPEGErrorStringFromErrorCode(status));
streamIndex_ = avStream->index;
}

void VideoEncoder::encode() {
// To be on the safe side we enforce that encode() can only be called once
TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice.");
encodeWasCalled_ = true;

int status = avformat_write_header(avFormatContext_.get(), nullptr);
TORCH_CHECK(
status == AVSUCCESS,
"Error in avformat_write_header: ",
getFFMPEGErrorStringFromErrorCode(status));

AutoAVPacket autoAVPacket;
int numFrames = frames_.sizes()[0];
for (int i = 0; i < numFrames; ++i) {
torch::Tensor currFrame = frames_[i];
UniqueAVFrame avFrame = convertTensorToAVFrame(currFrame, i);
encodeFrame(autoAVPacket, avFrame);
}

flushBuffers();

status = av_write_trailer(avFormatContext_.get());
TORCH_CHECK(
status == AVSUCCESS,
"Error in av_write_trailer: ",
getFFMPEGErrorStringFromErrorCode(status));
}

UniqueAVFrame VideoEncoder::convertTensorToAVFrame(
const torch::Tensor& frame,
int frameIndex) {
// Initialize and cache scaling context if it does not exist
if (!swsContext_) {
swsContext_.reset(sws_getContext(
inWidth_,
inHeight_,
inPixelFormat_,
outWidth_,
outHeight_,
outPixelFormat_,
SWS_BILINEAR,
nullptr,
nullptr,
nullptr));
TORCH_CHECK(swsContext_ != nullptr, "Failed to create scaling context");
}

UniqueAVFrame avFrame(av_frame_alloc());
TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame");

// Set output frame properties
avFrame->format = outPixelFormat_;
avFrame->width = outWidth_;
avFrame->height = outHeight_;
avFrame->pts = frameIndex;

int status = av_frame_get_buffer(avFrame.get(), 0);
TORCH_CHECK(status >= 0, "Failed to allocate frame buffer");

// Need to convert/scale the frame
// Create temporary frame with input format
UniqueAVFrame inputFrame(av_frame_alloc());
TORCH_CHECK(inputFrame != nullptr, "Failed to allocate input AVFrame");

inputFrame->format = inPixelFormat_;
inputFrame->width = inWidth_;
inputFrame->height = inHeight_;

uint8_t* tensorData = static_cast<uint8_t*>(frame.data_ptr());

// TODO-VideoEncoder: Reorder tensor if in NHWC format
int channelSize = inHeight_ * inWidth_;
// Reorder RGB -> GBR for AV_PIX_FMT_GBRP format
// TODO-VideoEncoder: Determine if FFmpeg supports planar RGB input format
inputFrame->data[0] = tensorData + channelSize;
inputFrame->data[1] = tensorData + (2 * channelSize);
inputFrame->data[2] = tensorData;

inputFrame->linesize[0] = inWidth_;
inputFrame->linesize[1] = inWidth_;
inputFrame->linesize[2] = inWidth_;

status = sws_scale(
swsContext_.get(),
inputFrame->data,
inputFrame->linesize,
0,
inputFrame->height,
avFrame->data,
avFrame->linesize);
TORCH_CHECK(status == outHeight_, "sws_scale failed");
return avFrame;
}

void VideoEncoder::encodeFrame(
AutoAVPacket& autoAVPacket,
const UniqueAVFrame& avFrame) {
auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get());
TORCH_CHECK(
status == AVSUCCESS,
"Error while sending frame: ",
getFFMPEGErrorStringFromErrorCode(status));

while (true) {
ReferenceAVPacket packet(autoAVPacket);
status = avcodec_receive_packet(avCodecContext_.get(), packet.get());
if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) {
if (status == AVERROR_EOF) {
// Flush remaining buffered packets
status = av_interleaved_write_frame(avFormatContext_.get(), nullptr);
TORCH_CHECK(
status == AVSUCCESS,
"Failed to flush packet: ",
getFFMPEGErrorStringFromErrorCode(status));
}
return;
}
TORCH_CHECK(
status >= 0,
"Error receiving packet: ",
getFFMPEGErrorStringFromErrorCode(status));

packet->stream_index = streamIndex_;

status = av_interleaved_write_frame(avFormatContext_.get(), packet.get());
TORCH_CHECK(
status == AVSUCCESS,
"Error in av_interleaved_write_frame: ",
getFFMPEGErrorStringFromErrorCode(status));
}
}

void VideoEncoder::flushBuffers() {
AutoAVPacket autoAVPacket;
// Send null frame to signal end of input
encodeFrame(autoAVPacket, UniqueAVFrame(nullptr));
}

} // namespace facebook::torchcodec
42 changes: 41 additions & 1 deletion src/torchcodec/_core/Encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ class AudioEncoder {
bool encodeWasCalled_ = false;
int64_t lastEncodedAVFramePts_ = 0;
};
} // namespace facebook::torchcodec

/* clang-format off */
//
Expand Down Expand Up @@ -121,3 +120,44 @@ class AudioEncoder {
//
//
/* clang-format on */

class VideoEncoder {
public:
~VideoEncoder();

VideoEncoder(
const torch::Tensor& frames,
int frameRate,
std::string_view fileName,
const VideoStreamOptions& videoStreamOptions);

void encode();

private:
void initializeEncoder(const VideoStreamOptions& videoStreamOptions);
UniqueAVFrame convertTensorToAVFrame(
const torch::Tensor& frame,
int frameIndex);
void encodeFrame(AutoAVPacket& autoAVPacket, const UniqueAVFrame& avFrame);
void flushBuffers();

UniqueEncodingAVFormatContext avFormatContext_;
UniqueAVCodecContext avCodecContext_;
int streamIndex_;
UniqueSwsContext swsContext_;

const torch::Tensor frames_;
int inFrameRate_;

int inWidth_ = -1;
int inHeight_ = -1;
AVPixelFormat inPixelFormat_ = AV_PIX_FMT_NONE;

int outWidth_ = -1;
int outHeight_ = -1;
AVPixelFormat outPixelFormat_ = AV_PIX_FMT_NONE;

bool encodeWasCalled_ = false;
};

} // namespace facebook::torchcodec
5 changes: 5 additions & 0 deletions src/torchcodec/_core/StreamOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ struct VideoStreamOptions {
std::optional<ColorConversionLibrary> colorConversionLibrary;
// By default we use CPU for decoding for both C++ and python users.
torch::Device device = torch::kCPU;

// Encoding options
std::optional<int> bitRate;
std::optional<int> gopSize;
std::optional<int> maxBFrames;
};

struct AudioStreamOptions {
Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
encode_audio_to_file,
encode_audio_to_file_like,
encode_audio_to_tensor,
encode_video_to_file,
get_ffmpeg_library_versions,
get_frame_at_index,
get_frame_at_pts,
Expand Down
16 changes: 16 additions & 0 deletions src/torchcodec/_core/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
m.def(
"encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
m.def(
"encode_video_to_file(Tensor frames, int frame_rate, str filename) -> ()");
m.def(
"encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor");
m.def(
Expand Down Expand Up @@ -397,6 +399,19 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio(
return makeOpsAudioFramesOutput(result);
}

void encode_video_to_file(
const at::Tensor& frames,
int64_t frame_rate,
std::string_view file_name) {
VideoStreamOptions videoStreamOptions;
VideoEncoder(
frames,
validateInt64ToInt(frame_rate, "frame_rate"),
file_name,
videoStreamOptions)
.encode();
}

void encode_audio_to_file(
const at::Tensor& samples,
int64_t sample_rate,
Expand Down Expand Up @@ -701,6 +716,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) {

TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
m.impl("encode_audio_to_file", &encode_audio_to_file);
m.impl("encode_video_to_file", &encode_video_to_file);
m.impl("encode_audio_to_tensor", &encode_audio_to_tensor);
m.impl("seek_to_pts", &seek_to_pts);
m.impl("add_video_stream", &add_video_stream);
Expand Down
Loading
Loading