|
| 1 | +Extending PyTorch with Custom C++ Classes |
| 2 | +=============================================== |
| 3 | + |
| 4 | + |
| 5 | +This tutorial introduces an API for binding C++ classes into PyTorch. |
| 6 | +The API is very similar to |
| 7 | +`pybind11 <https://github.com/pybind/pybind11>`_, and most of the concepts will transfer |
| 8 | +over if you're familiar with that system. |
| 9 | + |
| 10 | +Implementing and Binding the Class in C++ |
| 11 | +----------------------------------------- |
| 12 | + |
| 13 | +For this tutorial, we are going to define a simple C++ class that maintains persistent |
| 14 | +state in a member variable. |
| 15 | + |
| 16 | +.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/class.cpp |
| 17 | + :language: cpp |
| 18 | + :start-after: BEGIN class |
| 19 | + :end-before: END class |
| 20 | + |
| 21 | +There are several things to note: |
| 22 | + |
| 23 | +- ``torch/custom_class.h`` is the header you need to include to extend PyTorch |
| 24 | + with your custom class. |
| 25 | +- Notice that whenever we are working with instances of the custom |
| 26 | + class, we do it via instances of ``c10::intrusive_ptr<>``. Think of ``intrusive_ptr`` |
| 27 | + as a smart pointer like ``std::shared_ptr``, but the reference count is stored |
| 28 | + directly in the object, as opposed to a separate metadata block (as is done in |
| 29 | + ``std::shared_ptr``. ``torch::Tensor`` internally uses the same pointer type; |
| 30 | + and custom classes have to also use this pointer type so that we can |
| 31 | + consistently manage different object types. |
| 32 | +- The second thing to notice is that the user-defined class must inherit from |
| 33 | + ``torch::CustomClassHolder``. This ensures that the custom class has space to |
| 34 | + store the reference count. |
| 35 | + |
| 36 | +Now let's take a look at how we will make this class visible to PyTorch, a process called |
| 37 | +*binding* the class: |
| 38 | + |
| 39 | +.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/class.cpp |
| 40 | + :language: cpp |
| 41 | + :start-after: BEGIN binding |
| 42 | + :end-before: END binding |
| 43 | + :append: |
| 44 | + ; |
| 45 | + } |
| 46 | + |
| 47 | + |
| 48 | + |
| 49 | +Building the Example as a C++ Project With CMake |
| 50 | +------------------------------------------------ |
| 51 | + |
| 52 | +Now, we're going to build the above C++ code with the `CMake |
| 53 | +<https://cmake.org>`_ build system. First, take all the C++ code |
| 54 | +we've covered so far and place it in a file called ``class.cpp``. |
| 55 | +Then, write a simple ``CMakeLists.txt`` file and place it in the |
| 56 | +same directory. Here is what ``CMakeLists.txt`` should look like: |
| 57 | + |
| 58 | +.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/CMakeLists.txt |
| 59 | + :language: cmake |
| 60 | + |
| 61 | +Also, create a ``build`` directory. Your file tree should look like this:: |
| 62 | + |
| 63 | + custom_class_project/ |
| 64 | + class.cpp |
| 65 | + CMakeLists.txt |
| 66 | + build/ |
| 67 | + |
| 68 | +Go ahead and invoke cmake and then make to build the project: |
| 69 | + |
| 70 | +.. code-block:: shell |
| 71 | +
|
| 72 | + $ cd build |
| 73 | + $ cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" .. |
| 74 | + -- The C compiler identification is GNU 7.3.1 |
| 75 | + -- The CXX compiler identification is GNU 7.3.1 |
| 76 | + -- Check for working C compiler: /opt/rh/devtoolset-7/root/usr/bin/cc |
| 77 | + -- Check for working C compiler: /opt/rh/devtoolset-7/root/usr/bin/cc -- works |
| 78 | + -- Detecting C compiler ABI info |
| 79 | + -- Detecting C compiler ABI info - done |
| 80 | + -- Detecting C compile features |
| 81 | + -- Detecting C compile features - done |
| 82 | + -- Check for working CXX compiler: /opt/rh/devtoolset-7/root/usr/bin/c++ |
| 83 | + -- Check for working CXX compiler: /opt/rh/devtoolset-7/root/usr/bin/c++ -- works |
| 84 | + -- Detecting CXX compiler ABI info |
| 85 | + -- Detecting CXX compiler ABI info - done |
| 86 | + -- Detecting CXX compile features |
| 87 | + -- Detecting CXX compile features - done |
| 88 | + -- Looking for pthread.h |
| 89 | + -- Looking for pthread.h - found |
| 90 | + -- Looking for pthread_create |
| 91 | + -- Looking for pthread_create - not found |
| 92 | + -- Looking for pthread_create in pthreads |
| 93 | + -- Looking for pthread_create in pthreads - not found |
| 94 | + -- Looking for pthread_create in pthread |
| 95 | + -- Looking for pthread_create in pthread - found |
| 96 | + -- Found Threads: TRUE |
| 97 | + -- Found torch: /torchbind_tutorial/libtorch/lib/libtorch.so |
| 98 | + -- Configuring done |
| 99 | + -- Generating done |
| 100 | + -- Build files have been written to: /torchbind_tutorial/build |
| 101 | + $ make -j |
| 102 | + Scanning dependencies of target custom_class |
| 103 | + [ 50%] Building CXX object CMakeFiles/custom_class.dir/class.cpp.o |
| 104 | + [100%] Linking CXX shared library libcustom_class.so |
| 105 | + [100%] Built target custom_class |
| 106 | +
|
| 107 | +What you'll find is there is now (among other things) a dynamic library |
| 108 | +file present in the build directory. On Linux, this is probably named |
| 109 | +``libcustom_class.so``. So the file tree should look like:: |
| 110 | +
|
| 111 | + custom_class_project/ |
| 112 | + class.cpp |
| 113 | + CMakeLists.txt |
| 114 | + build/ |
| 115 | + libcustom_class.so |
| 116 | +
|
| 117 | +Using the C++ Class from Python |
| 118 | +----------------------------------------------- |
| 119 | +
|
| 120 | +Now that we have our class and its registration compiled into an ``.so`` file, |
| 121 | +we can load that `.so` into Python and try it out. Here's a script that |
| 122 | +demonstrates that: |
| 123 | +
|
| 124 | +.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/custom_test.py |
| 125 | + :language: python |
| 126 | +
|
| 127 | +
|
| 128 | +Defining Serialization/Deserialization Methods for Custom C++ Classes |
| 129 | +--------------------------------------------------------------------- |
| 130 | +
|
| 131 | +If you try to save a ``ScriptModule`` with a custom-bound C++ class as |
| 132 | +an attribute, you'll get the following error: |
| 133 | +
|
| 134 | +.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/export_attr.py |
| 135 | + :language: python |
| 136 | +
|
| 137 | +.. code-block:: shell |
| 138 | +
|
| 139 | + $ python export_attr.py |
| 140 | + RuntimeError: Cannot serialize custom bound C++ class __torch__.torch.classes.my_classes.MyStackClass. Please define serialization methods via def_pickle for this class. (pushIValueImpl at ../torch/csrc/jit/pickler.cpp:128) |
| 141 | +
|
| 142 | +This is because PyTorch cannot automatically figure out what information |
| 143 | +save from your C++ class. You must specify that manually. The way to do that |
| 144 | +is to define ``__getstate__`` and ``__setstate__`` methods on the class using |
| 145 | +the special ``def_pickle`` method on ``class_``. |
| 146 | +
|
| 147 | +.. note:: |
| 148 | + The semantics of ``__getstate__`` and ``__setstate__`` are |
| 149 | + equivalent to that of the Python pickle module. You can |
| 150 | + `read more <https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md#getstate-and-setstate>`_ |
| 151 | + about how we use these methods. |
| 152 | +
|
| 153 | +Here is an example of the ``def_pickle`` call we can add to the registration of |
| 154 | +``MyStackClass`` to include serialization methods: |
| 155 | +
|
| 156 | +.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/class.cpp |
| 157 | + :language: cpp |
| 158 | + :start-after: BEGIN def_pickle |
| 159 | + :end-before: END def_pickle |
| 160 | +
|
| 161 | +.. note:: |
| 162 | + We take a different approach from pybind11 in the pickle API. Whereas pybind11 |
| 163 | + as a special function ``pybind11::pickle()`` which you pass into ``class_::def()``, |
| 164 | + we have a separate method ``def_pickle`` for this purpose. This is because the |
| 165 | + name ``torch::jit::pickle`` was already taken, and we didn't want to cause confusion. |
| 166 | +
|
| 167 | +Once we have defined the (de)serialization behavior in this way, our script can |
| 168 | +now run successfully: |
| 169 | +
|
| 170 | +.. code-block:: shell |
| 171 | +
|
| 172 | + $ python ../export_attr.py |
| 173 | + testing |
| 174 | +
|
| 175 | +Defining Custom Operators that Take or Return Bound C++ Classes |
| 176 | +--------------------------------------------------------------- |
| 177 | +
|
| 178 | +Once you've defined a custom C++ class, you can also use that class |
| 179 | +as an argument or return from a custom operator (i.e. free functions). Suppose |
| 180 | +you have the following free function: |
| 181 | +
|
| 182 | +.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/class.cpp |
| 183 | + :language: cpp |
| 184 | + :start-after: BEGIN free_function |
| 185 | + :end-before: END free_function |
| 186 | +
|
| 187 | +You can register it running the following code inside your ``TORCH_LIBRARY`` |
| 188 | +block: |
| 189 | +
|
| 190 | +.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/class.cpp |
| 191 | + :language: cpp |
| 192 | + :start-after: BEGIN def_free |
| 193 | + :end-before: END def_free |
| 194 | +
|
| 195 | +Once this is done, you can use the op like the following example: |
| 196 | +
|
| 197 | +.. code-block:: python |
| 198 | +
|
| 199 | + class TryCustomOp(torch.nn.Module): |
| 200 | + def __init__(self): |
| 201 | + super(TryCustomOp, self).__init__() |
| 202 | + self.f = torch.classes.my_classes.MyStackClass(["foo", "bar"]) |
| 203 | +
|
| 204 | + def forward(self): |
| 205 | + return torch.ops.my_classes.manipulate_instance(self.f) |
| 206 | +
|
| 207 | +.. note:: |
| 208 | +
|
| 209 | + Registration of an operator that takes a C++ class as an argument requires that |
| 210 | + the custom class has already been registered. You can enforce this by |
| 211 | + making sure the custom class registration and your free function definitions |
| 212 | + are in the same ``TORCH_LIBRARY`` block, and that the custom class |
| 213 | + registration comes first. In the future, we may relax this requirement, |
| 214 | + so that these can be registered in any order. |
| 215 | +
|
| 216 | +
|
| 217 | +Conclusion |
| 218 | +---------- |
| 219 | +
|
| 220 | +This tutorial walked you through how to expose a C++ class to PyTorch, how to |
| 221 | +register its methods, how to use that class from Python, and how to save and |
| 222 | +load code using the class and run that code in a standalone C++ process. You |
| 223 | +are now ready to extend your PyTorch models with C++ classes that interface |
| 224 | +with third party C++ libraries or implement any other use case that requires |
| 225 | +the lines between Python and C++ to blend smoothly. |
| 226 | +
|
| 227 | +As always, if you run into any problems or have questions, you can use our |
| 228 | +`forum <https://discuss.pytorch.org/>`_ or `GitHub issues |
| 229 | +<https://github.com/pytorch/pytorch/issues>`_ to get in touch. Also, our |
| 230 | +`frequently asked questions (FAQ) page |
| 231 | +<https://pytorch.org/cppdocs/notes/faq.html>`_ may have helpful information. |
0 commit comments