Skip to content

Commit 2828dc5

Browse files
committed
Refactor pybind_ops to only deal with file like context holders
1 parent ee77f57 commit 2828dc5

File tree

4 files changed

+97
-79
lines changed

4 files changed

+97
-79
lines changed

src/torchcodec/_core/AVIOFileLikeContext.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ namespace facebook::torchcodec {
1717

1818
// Enables uers to pass in a Python file-like object. We then forward all read
1919
// and seek calls back up to the methods on the Python object.
20-
class AVIOFileLikeContext : public AVIOContextHolder {
20+
class __attribute__((visibility("hidden"))) AVIOFileLikeContext
21+
: public AVIOContextHolder {
2122
public:
2223
explicit AVIOFileLikeContext(const py::object& fileLike, bool isForWriting);
2324

src/torchcodec/_core/custom_ops.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <string>
1111
#include "c10/core/SymIntArrayRef.h"
1212
#include "c10/util/Exception.h"
13+
#include "src/torchcodec/_core/AVIOFileLikeContext.h"
1314
#include "src/torchcodec/_core/AVIOTensorContext.h"
1415
#include "src/torchcodec/_core/Encoder.h"
1516
#include "src/torchcodec/_core/SingleStreamDecoder.h"
@@ -33,8 +34,12 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3334
"encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
3435
m.def(
3536
"encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor");
37+
m.def(
38+
"_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
3639
m.def(
3740
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
41+
m.def(
42+
"_create_from_file_like(int file_like_context, str? seek_mode=None) -> Tensor");
3843
m.def("_convert_to_tensor(int decoder_ptr) -> Tensor");
3944
m.def(
4045
"_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()");
@@ -210,6 +215,24 @@ at::Tensor create_from_tensor(
210215
return wrapDecoderPointerToTensor(std::move(uniqueDecoder));
211216
}
212217

218+
at::Tensor _create_from_file_like(
219+
int64_t file_like_context,
220+
std::optional<std::string_view> seek_mode) {
221+
auto fileLikeContext =
222+
reinterpret_cast<AVIOFileLikeContext*>(file_like_context);
223+
TORCH_CHECK(fileLikeContext != nullptr, "file_like must be a valid pointer");
224+
std::unique_ptr<AVIOFileLikeContext> contextHolder(fileLikeContext);
225+
226+
SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact;
227+
if (seek_mode.has_value()) {
228+
realSeek = seekModeFromString(seek_mode.value());
229+
}
230+
231+
std::unique_ptr<SingleStreamDecoder> uniqueDecoder =
232+
std::make_unique<SingleStreamDecoder>(std::move(contextHolder), realSeek);
233+
return wrapDecoderPointerToTensor(std::move(uniqueDecoder));
234+
}
235+
213236
at::Tensor _convert_to_tensor(int64_t decoder_ptr) {
214237
auto decoder = reinterpret_cast<SingleStreamDecoder*>(decoder_ptr);
215238
std::unique_ptr<SingleStreamDecoder> uniqueDecoder(decoder);
@@ -441,6 +464,36 @@ at::Tensor encode_audio_to_tensor(
441464
.encodeToTensor();
442465
}
443466

467+
void _encode_audio_to_file_like(
468+
const at::Tensor& samples,
469+
int64_t sample_rate,
470+
std::string_view format,
471+
int64_t file_like_context,
472+
std::optional<int64_t> bit_rate = std::nullopt,
473+
std::optional<int64_t> num_channels = std::nullopt,
474+
std::optional<int64_t> desired_sample_rate = std::nullopt) {
475+
auto fileLikeContext =
476+
reinterpret_cast<AVIOFileLikeContext*>(file_like_context);
477+
TORCH_CHECK(
478+
fileLikeContext != nullptr, "file_like_context must be a valid pointer");
479+
std::unique_ptr<AVIOFileLikeContext> avioContextHolder(fileLikeContext);
480+
481+
AudioStreamOptions audioStreamOptions;
482+
audioStreamOptions.bitRate = validateOptionalInt64ToInt(bit_rate, "bit_rate");
483+
audioStreamOptions.numChannels =
484+
validateOptionalInt64ToInt(num_channels, "num_channels");
485+
audioStreamOptions.sampleRate =
486+
validateOptionalInt64ToInt(desired_sample_rate, "desired_sample_rate");
487+
488+
AudioEncoder encoder(
489+
samples,
490+
validateInt64ToInt(sample_rate, "sample_rate"),
491+
format,
492+
std::move(avioContextHolder),
493+
audioStreamOptions);
494+
encoder.encode();
495+
}
496+
444497
// For testing only. We need to implement this operation as a core library
445498
// function because what we're testing is round-tripping pts values as
446499
// double-precision floating point numbers from C++ to Python and back to C++.
@@ -694,6 +747,7 @@ void scan_all_streams_to_update_metadata(at::Tensor& decoder) {
694747
TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) {
695748
m.impl("create_from_file", &create_from_file);
696749
m.impl("create_from_tensor", &create_from_tensor);
750+
m.impl("_create_from_file_like", &_create_from_file_like);
697751
m.impl("_convert_to_tensor", &_convert_to_tensor);
698752
m.impl(
699753
"_get_json_ffmpeg_library_versions", &_get_json_ffmpeg_library_versions);
@@ -702,6 +756,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) {
702756
TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
703757
m.impl("encode_audio_to_file", &encode_audio_to_file);
704758
m.impl("encode_audio_to_tensor", &encode_audio_to_tensor);
759+
m.impl("_encode_audio_to_file_like", &_encode_audio_to_file_like);
705760
m.impl("seek_to_pts", &seek_to_pts);
706761
m.impl("add_video_stream", &add_video_stream);
707762
m.impl("_add_video_stream", &_add_video_stream);

src/torchcodec/_core/ops.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,15 @@ def load_torchcodec_shared_libraries():
9595
encode_audio_to_tensor = torch._dynamo.disallow_in_graph(
9696
torch.ops.torchcodec_ns.encode_audio_to_tensor.default
9797
)
98+
_encode_audio_to_file_like = torch._dynamo.disallow_in_graph(
99+
torch.ops.torchcodec_ns._encode_audio_to_file_like.default
100+
)
98101
create_from_tensor = torch._dynamo.disallow_in_graph(
99102
torch.ops.torchcodec_ns.create_from_tensor.default
100103
)
104+
_create_from_file_like = torch._dynamo.disallow_in_graph(
105+
torch.ops.torchcodec_ns._create_from_file_like.default
106+
)
101107
_convert_to_tensor = torch._dynamo.disallow_in_graph(
102108
torch.ops.torchcodec_ns._convert_to_tensor.default
103109
)
@@ -148,7 +154,12 @@ def create_from_file_like(
148154
file_like: Union[io.RawIOBase, io.BufferedReader], seek_mode: Optional[str] = None
149155
) -> torch.Tensor:
150156
assert _pybind_ops is not None
151-
return _convert_to_tensor(_pybind_ops.create_from_file_like(file_like, seek_mode))
157+
return _create_from_file_like(
158+
_pybind_ops.create_file_like_context(
159+
file_like, False # False means not for writing
160+
),
161+
seek_mode,
162+
)
152163

153164

154165
def encode_audio_to_file_like(
@@ -176,36 +187,16 @@ def encode_audio_to_file_like(
176187
if samples.dtype != torch.float32:
177188
raise ValueError(f"samples must have dtype torch.float32, got {samples.dtype}")
178189

179-
# We're having the same problem as with the decoder's create_from_file_like:
180-
# We should be able to pass a tensor directly, but this leads to a pybind
181-
# error. In order to work around this, we pass the pointer to the tensor's
182-
# data, and its shape, in order to re-construct it in C++. For this to work:
183-
# - the tensor must be float32
184-
# - the tensor must be contiguous, which is why we call contiguous().
185-
# In theory we could avoid this restriction by also passing the strides?
186-
# - IMPORTANT: the input samples tensor and its underlying data must be
187-
# alive during the call.
188-
#
189-
# A more elegant solution would be to cast the tensor into a py::object, but
190-
# casting the py::object backk to a tensor in C++ seems to lead to the same
191-
# pybing error.
192-
193-
samples = samples.contiguous()
194-
_pybind_ops.encode_audio_to_file_like(
195-
samples.data_ptr(),
196-
list(samples.shape),
190+
_encode_audio_to_file_like(
191+
samples,
197192
sample_rate,
198193
format,
199-
file_like,
194+
_pybind_ops.create_file_like_context(file_like, True), # True means for writing
200195
bit_rate,
201196
num_channels,
202197
desired_sample_rate,
203198
)
204199

205-
# This check is useless but it's critical to keep it to ensures that samples
206-
# is still alive during the call to encode_audio_to_file_like.
207-
assert samples.is_contiguous()
208-
209200

210201
# ==============================
211202
# Abstract impl for the operators. Needed by torch.compile.
@@ -215,6 +206,13 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch.
215206
return torch.empty([], dtype=torch.long)
216207

217208

209+
@register_fake("torchcodec_ns::_create_from_file_like")
210+
def _create_from_file_like_abstract(
211+
file_like: int, seek_mode: Optional[str]
212+
) -> torch.Tensor:
213+
return torch.empty([], dtype=torch.long)
214+
215+
218216
@register_fake("torchcodec_ns::encode_audio_to_file")
219217
def encode_audio_to_file_abstract(
220218
samples: torch.Tensor,
@@ -239,6 +237,19 @@ def encode_audio_to_tensor_abstract(
239237
return torch.empty([], dtype=torch.long)
240238

241239

240+
@register_fake("torchcodec_ns::_encode_audio_to_file_like")
241+
def _encode_audio_to_file_like_abstract(
242+
samples: torch.Tensor,
243+
sample_rate: int,
244+
format: str,
245+
file_like_context: int,
246+
bit_rate: Optional[int] = None,
247+
num_channels: Optional[int] = None,
248+
desired_sample_rate: Optional[int] = None,
249+
) -> None:
250+
return torch.empty([], dtype=torch.long)
251+
252+
242253
@register_fake("torchcodec_ns::create_from_tensor")
243254
def create_from_tensor_abstract(
244255
video_tensor: torch.Tensor, seek_mode: Optional[str]

src/torchcodec/_core/pybind_ops.cpp

Lines changed: 5 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,8 @@
77
#include <pybind11/pybind11.h>
88
#include <pybind11/stl.h>
99
#include <cstdint>
10-
#include <string>
1110

1211
#include "src/torchcodec/_core/AVIOFileLikeContext.h"
13-
#include "src/torchcodec/_core/Encoder.h"
14-
#include "src/torchcodec/_core/SingleStreamDecoder.h"
15-
#include "src/torchcodec/_core/StreamOptions.h"
16-
#include "src/torchcodec/_core/ValidationUtils.h"
1712

1813
namespace py = pybind11;
1914

@@ -26,62 +21,18 @@ namespace facebook::torchcodec {
2621
//
2722
// So we instead launder the pointer through an int, and then use a conversion
2823
// function on the custom ops side to launder that int into a tensor.
29-
int64_t create_from_file_like(
30-
py::object file_like,
31-
std::optional<std::string_view> seek_mode) {
32-
SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact;
33-
if (seek_mode.has_value()) {
34-
realSeek = seekModeFromString(seek_mode.value());
35-
}
36-
37-
auto avioContextHolder =
38-
std::make_unique<AVIOFileLikeContext>(file_like, /*isForWriting=*/false);
39-
40-
SingleStreamDecoder* decoder =
41-
new SingleStreamDecoder(std::move(avioContextHolder), realSeek);
42-
return reinterpret_cast<int64_t>(decoder);
43-
}
44-
45-
void encode_audio_to_file_like(
46-
int64_t data_ptr,
47-
const std::vector<int64_t>& shape,
48-
int64_t sample_rate,
49-
std::string_view format,
50-
py::object file_like,
51-
std::optional<int64_t> bit_rate = std::nullopt,
52-
std::optional<int64_t> num_channels = std::nullopt,
53-
std::optional<int64_t> desired_sample_rate = std::nullopt) {
54-
// We assume float32 *and* contiguity, this must be enforced by the caller.
55-
auto tensor_options = torch::TensorOptions().dtype(torch::kFloat32);
56-
auto samples = torch::from_blob(
57-
reinterpret_cast<void*>(data_ptr), shape, tensor_options);
58-
59-
AudioStreamOptions audioStreamOptions;
60-
audioStreamOptions.bitRate = validateOptionalInt64ToInt(bit_rate, "bit_rate");
61-
audioStreamOptions.numChannels =
62-
validateOptionalInt64ToInt(num_channels, "num_channels");
63-
audioStreamOptions.sampleRate =
64-
validateOptionalInt64ToInt(desired_sample_rate, "desired_sample_rate");
65-
66-
auto avioContextHolder =
67-
std::make_unique<AVIOFileLikeContext>(file_like, /*isForWriting=*/true);
68-
69-
AudioEncoder encoder(
70-
samples,
71-
validateInt64ToInt(sample_rate, "sample_rate"),
72-
format,
73-
std::move(avioContextHolder),
74-
audioStreamOptions);
75-
encoder.encode();
24+
int64_t create_file_like_context(py::object file_like, bool is_for_writing) {
25+
AVIOFileLikeContext* context =
26+
new AVIOFileLikeContext(file_like, is_for_writing);
27+
return reinterpret_cast<int64_t>(context);
7628
}
7729

7830
#ifndef PYBIND_OPS_MODULE_NAME
7931
#error PYBIND_OPS_MODULE_NAME must be defined!
8032
#endif
8133

8234
PYBIND11_MODULE(PYBIND_OPS_MODULE_NAME, m) {
83-
m.def("create_from_file_like", &create_from_file_like);
84-
m.def("encode_audio_to_file_like", &encode_audio_to_file_like);
35+
m.def("create_file_like_context", &create_file_like_context);
8536
}
8637

8738
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)