Skip to content

Commit e7a3922

Browse files
committed
Add custom class tutorial back
1 parent 3cbb0f7 commit e7a3922

File tree

12 files changed

+536
-8
lines changed

12 files changed

+536
-8
lines changed

advanced_source/custom_class_pt2.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Supporting Custom C++ Classes in torch.compile/torch.export
33

44

55
This tutorial is a follow-on to the
6-
:doc:`custom C++ classes <torch_script_custom_classes>` tutorial, and
6+
:doc:`custom C++ classes <custom_classes>` tutorial, and
77
introduces additional steps that are needed to support custom C++ classes in
88
torch.compile/torch.export.
99

@@ -30,7 +30,7 @@ Concretely, there are a few steps:
3030
states returned by ``__obj_flatten__``.
3131

3232
Here is a breakdown of the diff. Following the guide in
33-
:doc:`Extending TorchScript with Custom C++ Classes <torch_script_custom_classes>`,
33+
:doc:`Extending TorchScript with Custom C++ Classes <custom_classes>`,
3434
we can create a thread-safe tensor queue and build it.
3535

3636
.. code-block:: cpp

advanced_source/custom_classes.rst

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
Extending PyTorch with Custom C++ Classes
2+
===============================================
3+
4+
5+
This tutorial introduces an API for binding C++ classes into PyTorch.
6+
The API is very similar to
7+
`pybind11 <https://github.com/pybind/pybind11>`_, and most of the concepts will transfer
8+
over if you're familiar with that system.
9+
10+
Implementing and Binding the Class in C++
11+
-----------------------------------------
12+
13+
For this tutorial, we are going to define a simple C++ class that maintains persistent
14+
state in a member variable.
15+
16+
.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/class.cpp
17+
:language: cpp
18+
:start-after: BEGIN class
19+
:end-before: END class
20+
21+
There are several things to note:
22+
23+
- ``torch/custom_class.h`` is the header you need to include to extend PyTorch
24+
with your custom class.
25+
- Notice that whenever we are working with instances of the custom
26+
class, we do it via instances of ``c10::intrusive_ptr<>``. Think of ``intrusive_ptr``
27+
as a smart pointer like ``std::shared_ptr``, but the reference count is stored
28+
directly in the object, as opposed to a separate metadata block (as is done in
29+
``std::shared_ptr``. ``torch::Tensor`` internally uses the same pointer type;
30+
and custom classes have to also use this pointer type so that we can
31+
consistently manage different object types.
32+
- The second thing to notice is that the user-defined class must inherit from
33+
``torch::CustomClassHolder``. This ensures that the custom class has space to
34+
store the reference count.
35+
36+
Now let's take a look at how we will make this class visible to PyTorch, a process called
37+
*binding* the class:
38+
39+
.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/class.cpp
40+
:language: cpp
41+
:start-after: BEGIN binding
42+
:end-before: END binding
43+
:append:
44+
;
45+
}
46+
47+
48+
49+
Building the Example as a C++ Project With CMake
50+
------------------------------------------------
51+
52+
Now, we're going to build the above C++ code with the `CMake
53+
<https://cmake.org>`_ build system. First, take all the C++ code
54+
we've covered so far and place it in a file called ``class.cpp``.
55+
Then, write a simple ``CMakeLists.txt`` file and place it in the
56+
same directory. Here is what ``CMakeLists.txt`` should look like:
57+
58+
.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/CMakeLists.txt
59+
:language: cmake
60+
61+
Also, create a ``build`` directory. Your file tree should look like this::
62+
63+
custom_class_project/
64+
class.cpp
65+
CMakeLists.txt
66+
build/
67+
68+
Go ahead and invoke cmake and then make to build the project:
69+
70+
.. code-block:: shell
71+
72+
$ cd build
73+
$ cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" ..
74+
-- The C compiler identification is GNU 7.3.1
75+
-- The CXX compiler identification is GNU 7.3.1
76+
-- Check for working C compiler: /opt/rh/devtoolset-7/root/usr/bin/cc
77+
-- Check for working C compiler: /opt/rh/devtoolset-7/root/usr/bin/cc -- works
78+
-- Detecting C compiler ABI info
79+
-- Detecting C compiler ABI info - done
80+
-- Detecting C compile features
81+
-- Detecting C compile features - done
82+
-- Check for working CXX compiler: /opt/rh/devtoolset-7/root/usr/bin/c++
83+
-- Check for working CXX compiler: /opt/rh/devtoolset-7/root/usr/bin/c++ -- works
84+
-- Detecting CXX compiler ABI info
85+
-- Detecting CXX compiler ABI info - done
86+
-- Detecting CXX compile features
87+
-- Detecting CXX compile features - done
88+
-- Looking for pthread.h
89+
-- Looking for pthread.h - found
90+
-- Looking for pthread_create
91+
-- Looking for pthread_create - not found
92+
-- Looking for pthread_create in pthreads
93+
-- Looking for pthread_create in pthreads - not found
94+
-- Looking for pthread_create in pthread
95+
-- Looking for pthread_create in pthread - found
96+
-- Found Threads: TRUE
97+
-- Found torch: /torchbind_tutorial/libtorch/lib/libtorch.so
98+
-- Configuring done
99+
-- Generating done
100+
-- Build files have been written to: /torchbind_tutorial/build
101+
$ make -j
102+
Scanning dependencies of target custom_class
103+
[ 50%] Building CXX object CMakeFiles/custom_class.dir/class.cpp.o
104+
[100%] Linking CXX shared library libcustom_class.so
105+
[100%] Built target custom_class
106+
107+
What you'll find is there is now (among other things) a dynamic library
108+
file present in the build directory. On Linux, this is probably named
109+
``libcustom_class.so``. So the file tree should look like::
110+
111+
custom_class_project/
112+
class.cpp
113+
CMakeLists.txt
114+
build/
115+
libcustom_class.so
116+
117+
Using the C++ Class from Python
118+
-----------------------------------------------
119+
120+
Now that we have our class and its registration compiled into an ``.so`` file,
121+
we can load that `.so` into Python and try it out. Here's a script that
122+
demonstrates that:
123+
124+
.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/custom_test.py
125+
:language: python
126+
127+
128+
Defining Serialization/Deserialization Methods for Custom C++ Classes
129+
---------------------------------------------------------------------
130+
131+
If you try to save a ``ScriptModule`` with a custom-bound C++ class as
132+
an attribute, you'll get the following error:
133+
134+
.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/export_attr.py
135+
:language: python
136+
137+
.. code-block:: shell
138+
139+
$ python export_attr.py
140+
RuntimeError: Cannot serialize custom bound C++ class __torch__.torch.classes.my_classes.MyStackClass. Please define serialization methods via def_pickle for this class. (pushIValueImpl at ../torch/csrc/jit/pickler.cpp:128)
141+
142+
This is because PyTorch cannot automatically figure out what information
143+
save from your C++ class. You must specify that manually. The way to do that
144+
is to define ``__getstate__`` and ``__setstate__`` methods on the class using
145+
the special ``def_pickle`` method on ``class_``.
146+
147+
.. note::
148+
The semantics of ``__getstate__`` and ``__setstate__`` are
149+
equivalent to that of the Python pickle module. You can
150+
`read more <https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md#getstate-and-setstate>`_
151+
about how we use these methods.
152+
153+
Here is an example of the ``def_pickle`` call we can add to the registration of
154+
``MyStackClass`` to include serialization methods:
155+
156+
.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/class.cpp
157+
:language: cpp
158+
:start-after: BEGIN def_pickle
159+
:end-before: END def_pickle
160+
161+
.. note::
162+
We take a different approach from pybind11 in the pickle API. Whereas pybind11
163+
as a special function ``pybind11::pickle()`` which you pass into ``class_::def()``,
164+
we have a separate method ``def_pickle`` for this purpose. This is because the
165+
name ``torch::jit::pickle`` was already taken, and we didn't want to cause confusion.
166+
167+
Once we have defined the (de)serialization behavior in this way, our script can
168+
now run successfully:
169+
170+
.. code-block:: shell
171+
172+
$ python ../export_attr.py
173+
testing
174+
175+
Defining Custom Operators that Take or Return Bound C++ Classes
176+
---------------------------------------------------------------
177+
178+
Once you've defined a custom C++ class, you can also use that class
179+
as an argument or return from a custom operator (i.e. free functions). Suppose
180+
you have the following free function:
181+
182+
.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/class.cpp
183+
:language: cpp
184+
:start-after: BEGIN free_function
185+
:end-before: END free_function
186+
187+
You can register it running the following code inside your ``TORCH_LIBRARY``
188+
block:
189+
190+
.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/class.cpp
191+
:language: cpp
192+
:start-after: BEGIN def_free
193+
:end-before: END def_free
194+
195+
Once this is done, you can use the op like the following example:
196+
197+
.. code-block:: python
198+
199+
class TryCustomOp(torch.nn.Module):
200+
def __init__(self):
201+
super(TryCustomOp, self).__init__()
202+
self.f = torch.classes.my_classes.MyStackClass(["foo", "bar"])
203+
204+
def forward(self):
205+
return torch.ops.my_classes.manipulate_instance(self.f)
206+
207+
.. note::
208+
209+
Registration of an operator that takes a C++ class as an argument requires that
210+
the custom class has already been registered. You can enforce this by
211+
making sure the custom class registration and your free function definitions
212+
are in the same ``TORCH_LIBRARY`` block, and that the custom class
213+
registration comes first. In the future, we may relax this requirement,
214+
so that these can be registered in any order.
215+
216+
217+
Conclusion
218+
----------
219+
220+
This tutorial walked you through how to expose a C++ class to PyTorch, how to
221+
register its methods, how to use that class from Python, and how to save and
222+
load code using the class and run that code in a standalone C++ process. You
223+
are now ready to extend your PyTorch models with C++ classes that interface
224+
with third party C++ libraries or implement any other use case that requires
225+
the lines between Python and C++ to blend smoothly.
226+
227+
As always, if you run into any problems or have questions, you can use our
228+
`forum <https://discuss.pytorch.org/>`_ or `GitHub issues
229+
<https://github.com/pytorch/pytorch/issues>`_ to get in touch. Also, our
230+
`frequently asked questions (FAQ) page
231+
<https://pytorch.org/cppdocs/notes/faq.html>`_ may have helpful information.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
2+
project(infer)
3+
4+
find_package(Torch REQUIRED)
5+
6+
add_subdirectory(custom_class_project)
7+
8+
# Define our library target
9+
add_executable(infer infer.cpp)
10+
set(CMAKE_CXX_STANDARD 14)
11+
# Link against LibTorch
12+
target_link_libraries(infer "${TORCH_LIBRARIES}")
13+
# This is where we link in our libcustom_class code, making our
14+
# custom class available in our binary.
15+
target_link_libraries(infer -Wl,--no-as-needed custom_class)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
2+
project(custom_class)
3+
4+
find_package(Torch REQUIRED)
5+
6+
# Define our library target
7+
add_library(custom_class SHARED class.cpp)
8+
set(CMAKE_CXX_STANDARD 14)
9+
# Link against LibTorch
10+
target_link_libraries(custom_class "${TORCH_LIBRARIES}")

0 commit comments

Comments
 (0)