10
10
#include < string>
11
11
#include " c10/core/SymIntArrayRef.h"
12
12
#include " c10/util/Exception.h"
13
+ #include " src/torchcodec/_core/AVIOFileLikeContext.h"
13
14
#include " src/torchcodec/_core/AVIOTensorContext.h"
14
15
#include " src/torchcodec/_core/Encoder.h"
15
16
#include " src/torchcodec/_core/SingleStreamDecoder.h"
@@ -33,8 +34,12 @@ TORCH_LIBRARY(torchcodec_ns, m) {
33
34
" encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()" );
34
35
m.def (
35
36
" encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor" );
37
+ m.def (
38
+ " _encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()" );
36
39
m.def (
37
40
" create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor" );
41
+ m.def (
42
+ " _create_from_file_like(int file_like_context, str? seek_mode=None) -> Tensor" );
38
43
m.def (" _convert_to_tensor(int decoder_ptr) -> Tensor" );
39
44
m.def (
40
45
" _add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()" );
@@ -210,6 +215,24 @@ at::Tensor create_from_tensor(
210
215
return wrapDecoderPointerToTensor (std::move (uniqueDecoder));
211
216
}
212
217
218
+ at::Tensor _create_from_file_like (
219
+ int64_t file_like_context,
220
+ std::optional<std::string_view> seek_mode) {
221
+ auto fileLikeContext =
222
+ reinterpret_cast <AVIOFileLikeContext*>(file_like_context);
223
+ TORCH_CHECK (fileLikeContext != nullptr , " file_like must be a valid pointer" );
224
+ std::unique_ptr<AVIOFileLikeContext> contextHolder (fileLikeContext);
225
+
226
+ SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact;
227
+ if (seek_mode.has_value ()) {
228
+ realSeek = seekModeFromString (seek_mode.value ());
229
+ }
230
+
231
+ std::unique_ptr<SingleStreamDecoder> uniqueDecoder =
232
+ std::make_unique<SingleStreamDecoder>(std::move (contextHolder), realSeek);
233
+ return wrapDecoderPointerToTensor (std::move (uniqueDecoder));
234
+ }
235
+
213
236
at::Tensor _convert_to_tensor (int64_t decoder_ptr) {
214
237
auto decoder = reinterpret_cast <SingleStreamDecoder*>(decoder_ptr);
215
238
std::unique_ptr<SingleStreamDecoder> uniqueDecoder (decoder);
@@ -441,6 +464,36 @@ at::Tensor encode_audio_to_tensor(
441
464
.encodeToTensor ();
442
465
}
443
466
467
+ void _encode_audio_to_file_like (
468
+ const at::Tensor& samples,
469
+ int64_t sample_rate,
470
+ std::string_view format,
471
+ int64_t file_like_context,
472
+ std::optional<int64_t > bit_rate = std::nullopt ,
473
+ std::optional<int64_t > num_channels = std::nullopt ,
474
+ std::optional<int64_t > desired_sample_rate = std::nullopt ) {
475
+ auto fileLikeContext =
476
+ reinterpret_cast <AVIOFileLikeContext*>(file_like_context);
477
+ TORCH_CHECK (
478
+ fileLikeContext != nullptr , " file_like_context must be a valid pointer" );
479
+ std::unique_ptr<AVIOFileLikeContext> avioContextHolder (fileLikeContext);
480
+
481
+ AudioStreamOptions audioStreamOptions;
482
+ audioStreamOptions.bitRate = validateOptionalInt64ToInt (bit_rate, " bit_rate" );
483
+ audioStreamOptions.numChannels =
484
+ validateOptionalInt64ToInt (num_channels, " num_channels" );
485
+ audioStreamOptions.sampleRate =
486
+ validateOptionalInt64ToInt (desired_sample_rate, " desired_sample_rate" );
487
+
488
+ AudioEncoder encoder (
489
+ samples,
490
+ validateInt64ToInt (sample_rate, " sample_rate" ),
491
+ format,
492
+ std::move (avioContextHolder),
493
+ audioStreamOptions);
494
+ encoder.encode ();
495
+ }
496
+
444
497
// For testing only. We need to implement this operation as a core library
445
498
// function because what we're testing is round-tripping pts values as
446
499
// double-precision floating point numbers from C++ to Python and back to C++.
@@ -694,6 +747,7 @@ void scan_all_streams_to_update_metadata(at::Tensor& decoder) {
694
747
TORCH_LIBRARY_IMPL (torchcodec_ns, BackendSelect, m) {
695
748
m.impl (" create_from_file" , &create_from_file);
696
749
m.impl (" create_from_tensor" , &create_from_tensor);
750
+ m.impl (" _create_from_file_like" , &_create_from_file_like);
697
751
m.impl (" _convert_to_tensor" , &_convert_to_tensor);
698
752
m.impl (
699
753
" _get_json_ffmpeg_library_versions" , &_get_json_ffmpeg_library_versions);
@@ -702,6 +756,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) {
702
756
TORCH_LIBRARY_IMPL (torchcodec_ns, CPU, m) {
703
757
m.impl (" encode_audio_to_file" , &encode_audio_to_file);
704
758
m.impl (" encode_audio_to_tensor" , &encode_audio_to_tensor);
759
+ m.impl (" _encode_audio_to_file_like" , &_encode_audio_to_file_like);
705
760
m.impl (" seek_to_pts" , &seek_to_pts);
706
761
m.impl (" add_video_stream" , &add_video_stream);
707
762
m.impl (" _add_video_stream" , &_add_video_stream);
0 commit comments