diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 9e6b073ad..a7fc567cd 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -511,4 +511,283 @@ void AudioEncoder::flushBuffers() { encodeFrame(autoAVPacket, UniqueAVFrame(nullptr)); } + +namespace { + +torch::Tensor validateFrames(const torch::Tensor& frames) { + TORCH_CHECK( + frames.dtype() == torch::kUInt8, + "frames must have uint8 dtype, got ", + frames.dtype()); + TORCH_CHECK( + frames.dim() == 4, + "frames must have 4 dimensions (N, C, H, W), got ", + frames.dim()); + TORCH_CHECK( + frames.sizes()[1] == 3, + "frame must have 3 channels (R, G, B), got ", + frames.sizes()[1]); + // TODO-VideoEncoder: Investigate if non-contiguous frames can be accepted + return frames.contiguous(); +} + +} // namespace + +VideoEncoder::~VideoEncoder() { + if (avFormatContext_ && avFormatContext_->pb) { + avio_flush(avFormatContext_->pb); + avio_close(avFormatContext_->pb); + avFormatContext_->pb = nullptr; + } +} + +VideoEncoder::VideoEncoder( + const torch::Tensor& frames, + int frameRate, + std::string_view fileName, + const VideoStreamOptions& videoStreamOptions) + : frames_(validateFrames(frames)), inFrameRate_(frameRate) { + setFFmpegLogLevel(); + + // Allocate output format context + AVFormatContext* avFormatContext = nullptr; + int status = avformat_alloc_output_context2( + &avFormatContext, nullptr, nullptr, fileName.data()); + + TORCH_CHECK( + avFormatContext != nullptr, + "Couldn't allocate AVFormatContext. ", + "The destination file is ", + fileName, + ", check the desired extension? ", + getFFMPEGErrorStringFromErrorCode(status)); + avFormatContext_.reset(avFormatContext); + + status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE); + TORCH_CHECK( + status >= 0, + "avio_open failed. The destination file is ", + fileName, + ", make sure it's a valid path? ", + getFFMPEGErrorStringFromErrorCode(status)); + // TODO-VideoEncoder: Add tests for above fileName related checks + + initializeEncoder(videoStreamOptions); +} + +void VideoEncoder::initializeEncoder( + const VideoStreamOptions& videoStreamOptions) { + const AVCodec* avCodec = + avcodec_find_encoder(avFormatContext_->oformat->video_codec); + TORCH_CHECK(avCodec != nullptr, "Video codec not found"); + + AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec); + TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context."); + avCodecContext_.reset(avCodecContext); + + // Set encoding options + // TODO-VideoEncoder: Allow bitrate to be set + std::optional desiredBitRate = videoStreamOptions.bitRate; + if (desiredBitRate.has_value()) { + TORCH_CHECK( + *desiredBitRate >= 0, "bit_rate=", *desiredBitRate, " must be >= 0."); + } + avCodecContext_->bit_rate = desiredBitRate.value_or(0); + + // Store dimension order and input pixel format + // TODO-VideoEncoder: Remove assumption that tensor in NCHW format + auto sizes = frames_.sizes(); + inPixelFormat_ = AV_PIX_FMT_GBRP; + inHeight_ = sizes[2]; + inWidth_ = sizes[3]; + + // Use specified dimensions or input dimensions + // TODO-VideoEncoder: Allow height and width to be set + outWidth_ = videoStreamOptions.width.value_or(inWidth_); + outHeight_ = videoStreamOptions.height.value_or(inHeight_); + + // Use YUV420P as default output format + // TODO-VideoEncoder: Enable other pixel formats + outPixelFormat_ = AV_PIX_FMT_YUV420P; + + // Configure codec parameters + avCodecContext_->codec_id = avCodec->id; + avCodecContext_->width = outWidth_; + avCodecContext_->height = outHeight_; + avCodecContext_->pix_fmt = outPixelFormat_; + // TODO-VideoEncoder: Verify that frame_rate and time_base are correct + avCodecContext_->time_base = {1, inFrameRate_}; + avCodecContext_->framerate = {inFrameRate_, 1}; + + // TODO-VideoEncoder: Allow GOP size and max B-frames to be set + if (videoStreamOptions.gopSize.has_value()) { + avCodecContext_->gop_size = *videoStreamOptions.gopSize; + } else { + avCodecContext_->gop_size = 12; // Default GOP size + } + + if (videoStreamOptions.maxBFrames.has_value()) { + avCodecContext_->max_b_frames = *videoStreamOptions.maxBFrames; + } else { + avCodecContext_->max_b_frames = 0; // No max B-frames to reduce compression + } + + int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr); + TORCH_CHECK( + status == AVSUCCESS, + "avcodec_open2 failed: ", + getFFMPEGErrorStringFromErrorCode(status)); + + AVStream* avStream = avformat_new_stream(avFormatContext_.get(), nullptr); + TORCH_CHECK(avStream != nullptr, "Couldn't create new stream."); + + // Set the stream time base to encode correct frame timestamps + avStream->time_base = avCodecContext_->time_base; + status = avcodec_parameters_from_context( + avStream->codecpar, avCodecContext_.get()); + TORCH_CHECK( + status == AVSUCCESS, + "avcodec_parameters_from_context failed: ", + getFFMPEGErrorStringFromErrorCode(status)); + streamIndex_ = avStream->index; +} + +void VideoEncoder::encode() { + // To be on the safe side we enforce that encode() can only be called once + TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice."); + encodeWasCalled_ = true; + + int status = avformat_write_header(avFormatContext_.get(), nullptr); + TORCH_CHECK( + status == AVSUCCESS, + "Error in avformat_write_header: ", + getFFMPEGErrorStringFromErrorCode(status)); + + AutoAVPacket autoAVPacket; + int numFrames = frames_.sizes()[0]; + for (int i = 0; i < numFrames; ++i) { + torch::Tensor currFrame = frames_[i]; + UniqueAVFrame avFrame = convertTensorToAVFrame(currFrame, i); + encodeFrame(autoAVPacket, avFrame); + } + + flushBuffers(); + + status = av_write_trailer(avFormatContext_.get()); + TORCH_CHECK( + status == AVSUCCESS, + "Error in av_write_trailer: ", + getFFMPEGErrorStringFromErrorCode(status)); +} + +UniqueAVFrame VideoEncoder::convertTensorToAVFrame( + const torch::Tensor& frame, + int frameIndex) { + // Initialize and cache scaling context if it does not exist + if (!swsContext_) { + swsContext_.reset(sws_getContext( + inWidth_, + inHeight_, + inPixelFormat_, + outWidth_, + outHeight_, + outPixelFormat_, + SWS_BILINEAR, + nullptr, + nullptr, + nullptr)); + TORCH_CHECK(swsContext_ != nullptr, "Failed to create scaling context"); + } + + UniqueAVFrame avFrame(av_frame_alloc()); + TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame"); + + // Set output frame properties + avFrame->format = outPixelFormat_; + avFrame->width = outWidth_; + avFrame->height = outHeight_; + avFrame->pts = frameIndex; + + int status = av_frame_get_buffer(avFrame.get(), 0); + TORCH_CHECK(status >= 0, "Failed to allocate frame buffer"); + + // Need to convert/scale the frame + // Create temporary frame with input format + UniqueAVFrame inputFrame(av_frame_alloc()); + TORCH_CHECK(inputFrame != nullptr, "Failed to allocate input AVFrame"); + + inputFrame->format = inPixelFormat_; + inputFrame->width = inWidth_; + inputFrame->height = inHeight_; + + uint8_t* tensorData = static_cast(frame.data_ptr()); + + // TODO-VideoEncoder: Reorder tensor if in NHWC format + int channelSize = inHeight_ * inWidth_; + // Reorder RGB -> GBR for AV_PIX_FMT_GBRP format + // TODO-VideoEncoder: Determine if FFmpeg supports planar RGB input format + inputFrame->data[0] = tensorData + channelSize; + inputFrame->data[1] = tensorData + (2 * channelSize); + inputFrame->data[2] = tensorData; + + inputFrame->linesize[0] = inWidth_; + inputFrame->linesize[1] = inWidth_; + inputFrame->linesize[2] = inWidth_; + + status = sws_scale( + swsContext_.get(), + inputFrame->data, + inputFrame->linesize, + 0, + inputFrame->height, + avFrame->data, + avFrame->linesize); + TORCH_CHECK(status == outHeight_, "sws_scale failed"); + return avFrame; +} + +void VideoEncoder::encodeFrame( + AutoAVPacket& autoAVPacket, + const UniqueAVFrame& avFrame) { + auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get()); + TORCH_CHECK( + status == AVSUCCESS, + "Error while sending frame: ", + getFFMPEGErrorStringFromErrorCode(status)); + + while (true) { + ReferenceAVPacket packet(autoAVPacket); + status = avcodec_receive_packet(avCodecContext_.get(), packet.get()); + if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) { + if (status == AVERROR_EOF) { + // Flush remaining buffered packets + status = av_interleaved_write_frame(avFormatContext_.get(), nullptr); + TORCH_CHECK( + status == AVSUCCESS, + "Failed to flush packet: ", + getFFMPEGErrorStringFromErrorCode(status)); + } + return; + } + TORCH_CHECK( + status >= 0, + "Error receiving packet: ", + getFFMPEGErrorStringFromErrorCode(status)); + + packet->stream_index = streamIndex_; + + status = av_interleaved_write_frame(avFormatContext_.get(), packet.get()); + TORCH_CHECK( + status == AVSUCCESS, + "Error in av_interleaved_write_frame: ", + getFFMPEGErrorStringFromErrorCode(status)); + } +} + +void VideoEncoder::flushBuffers() { + AutoAVPacket autoAVPacket; + // Send null frame to signal end of input + encodeFrame(autoAVPacket, UniqueAVFrame(nullptr)); +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 1f4bbb5d6..b9b0f4f31 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -57,7 +57,6 @@ class AudioEncoder { bool encodeWasCalled_ = false; int64_t lastEncodedAVFramePts_ = 0; }; -} // namespace facebook::torchcodec /* clang-format off */ // @@ -121,3 +120,44 @@ class AudioEncoder { // // /* clang-format on */ + +class VideoEncoder { + public: + ~VideoEncoder(); + + VideoEncoder( + const torch::Tensor& frames, + int frameRate, + std::string_view fileName, + const VideoStreamOptions& videoStreamOptions); + + void encode(); + + private: + void initializeEncoder(const VideoStreamOptions& videoStreamOptions); + UniqueAVFrame convertTensorToAVFrame( + const torch::Tensor& frame, + int frameIndex); + void encodeFrame(AutoAVPacket& autoAVPacket, const UniqueAVFrame& avFrame); + void flushBuffers(); + + UniqueEncodingAVFormatContext avFormatContext_; + UniqueAVCodecContext avCodecContext_; + int streamIndex_; + UniqueSwsContext swsContext_; + + const torch::Tensor frames_; + int inFrameRate_; + + int inWidth_ = -1; + int inHeight_ = -1; + AVPixelFormat inPixelFormat_ = AV_PIX_FMT_NONE; + + int outWidth_ = -1; + int outHeight_ = -1; + AVPixelFormat outPixelFormat_ = AV_PIX_FMT_NONE; + + bool encodeWasCalled_ = false; +}; + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/StreamOptions.h b/src/torchcodec/_core/StreamOptions.h index d600aa0ac..19cc5126c 100644 --- a/src/torchcodec/_core/StreamOptions.h +++ b/src/torchcodec/_core/StreamOptions.h @@ -38,6 +38,11 @@ struct VideoStreamOptions { std::optional colorConversionLibrary; // By default we use CPU for decoding for both C++ and python users. torch::Device device = torch::kCPU; + + // Encoding options + std::optional bitRate; + std::optional gopSize; + std::optional maxBFrames; }; struct AudioStreamOptions { diff --git a/src/torchcodec/_core/__init__.py b/src/torchcodec/_core/__init__.py index 3d340bff8..24e54af0e 100644 --- a/src/torchcodec/_core/__init__.py +++ b/src/torchcodec/_core/__init__.py @@ -25,6 +25,7 @@ encode_audio_to_file, encode_audio_to_file_like, encode_audio_to_tensor, + encode_video_to_file, get_ffmpeg_library_versions, get_frame_at_index, get_frame_at_pts, diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index c646ed54a..fcf8e8cca 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -31,6 +31,8 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor"); m.def( "encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()"); + m.def( + "encode_video_to_file(Tensor frames, int frame_rate, str filename) -> ()"); m.def( "encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor"); m.def( @@ -397,6 +399,19 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio( return makeOpsAudioFramesOutput(result); } +void encode_video_to_file( + const at::Tensor& frames, + int64_t frame_rate, + std::string_view file_name) { + VideoStreamOptions videoStreamOptions; + VideoEncoder( + frames, + validateInt64ToInt(frame_rate, "frame_rate"), + file_name, + videoStreamOptions) + .encode(); +} + void encode_audio_to_file( const at::Tensor& samples, int64_t sample_rate, @@ -701,6 +716,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { m.impl("encode_audio_to_file", &encode_audio_to_file); + m.impl("encode_video_to_file", &encode_video_to_file); m.impl("encode_audio_to_tensor", &encode_audio_to_tensor); m.impl("seek_to_pts", &seek_to_pts); m.impl("add_video_stream", &add_video_stream); diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 807ed265d..fbba5b689 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -92,6 +92,9 @@ def load_torchcodec_shared_libraries(): encode_audio_to_file = torch._dynamo.disallow_in_graph( torch.ops.torchcodec_ns.encode_audio_to_file.default ) +encode_video_to_file = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns.encode_video_to_file.default +) encode_audio_to_tensor = torch._dynamo.disallow_in_graph( torch.ops.torchcodec_ns.encode_audio_to_tensor.default ) @@ -227,6 +230,15 @@ def encode_audio_to_file_abstract( return +@register_fake("torchcodec_ns::encode_video_to_file") +def encode_video_to_file_abstract( + frames: torch.Tensor, + frame_rate: int, + filename: str, +) -> None: + return + + @register_fake("torchcodec_ns::encode_audio_to_tensor") def encode_audio_to_tensor_abstract( samples: torch.Tensor, diff --git a/test/test_ops.py b/test/test_ops.py index d718d7e00..9f3a043c3 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -28,6 +28,7 @@ create_from_file_like, create_from_tensor, encode_audio_to_file, + encode_video_to_file, get_ffmpeg_library_versions, get_frame_at_index, get_frame_at_pts, @@ -48,6 +49,7 @@ NASA_AUDIO_MP3, NASA_VIDEO, needs_cuda, + psnr, SINE_MONO_S32, SINE_MONO_S32_44100, SINE_MONO_S32_8000, @@ -1224,5 +1226,63 @@ def test_bad_input(self, tmp_path): ) +class TestVideoEncoderOps: + + def test_bad_input(self, tmp_path): + output_file = str(tmp_path / ".mp4") + + with pytest.raises( + RuntimeError, match="frames must have uint8 dtype, got float" + ): + encode_video_to_file( + frames=torch.rand((10, 3, 60, 60), dtype=torch.float), + frame_rate=10, + filename=output_file, + ) + + with pytest.raises( + RuntimeError, match=r"frames must have 4 dimensions \(N, C, H, W\), got 3" + ): + encode_video_to_file( + frames=torch.randint(high=1, size=(3, 60, 60), dtype=torch.uint8), + frame_rate=10, + filename=output_file, + ) + + with pytest.raises( + RuntimeError, match=r"frame must have 3 channels \(R, G, B\), got 2" + ): + encode_video_to_file( + frames=torch.randint(high=1, size=(10, 2, 60, 60), dtype=torch.uint8), + frame_rate=10, + filename=output_file, + ) + + def decode(self, file_path) -> torch.Tensor: + decoder = create_from_file(str(file_path), seek_mode="approximate") + add_video_stream(decoder) + frames, *_ = get_frames_in_range(decoder, start=0, stop=60) + return frames + + @pytest.mark.parametrize("format", ("mov", "mp4", "avi")) + # TODO-VideoEncoder: enable additional formats (mkv, webm) + def test_video_encoder_test_round_trip(self, tmp_path, format): + # TODO-VideoEncoder: Test with FFmpeg's testsrc2 video + asset = NASA_VIDEO + + # Test that decode(encode(decode(asset))) == decode(asset) + source_frames = self.decode(str(asset.path)).data + + encoded_path = str(tmp_path / f"encoder_output.{format}") + frame_rate = 30 # Frame rate is fixed with num frames decoded + encode_video_to_file(source_frames, frame_rate, encoded_path) + round_trip_frames = self.decode(encoded_path).data + + # Check that PSNR for decode(encode(samples)) is above 30 + for s_frame, rt_frame in zip(source_frames, round_trip_frames): + res = psnr(s_frame, rt_frame) + assert res > 30 + + if __name__ == "__main__": pytest.main()