|
| 1 | +```{eval-rst} |
| 2 | +.. role:: hidden |
| 3 | + :class: hidden-section |
| 4 | +``` |
| 5 | + |
| 6 | +# Extending PyTorch with New Accelerators |
| 7 | + |
| 8 | +## Backgrounp |
| 9 | + |
| 10 | +The PrivateUse1-based third-party device integration mechanism has become the official path for integrating new devices into PyTorch. Ensuring the usability of this mechanism is crucial for enriching the hardware ecosystem of PyTorch. |
| 11 | + |
| 12 | +To assist third-party device developers in efficiently integrating new backends, this article introduces in detail the integration methods for typical PyTorch modules using a modular approach. It is accompanied by a streamlined code implementation from the official [torch_openreg][OpenReg URL] backend to help developers quickly get started while avoiding common pitfalls. |
| 13 | + |
| 14 | +This document is suitable for the following readers: |
| 15 | + |
| 16 | +* Developers who wish to integrate accelerator backends into PyTorch; |
| 17 | +* Developers interested in the principles of typical PyTorch modules; |
| 18 | + |
| 19 | +--- |
| 20 | + |
| 21 | +## Operator Registration |
| 22 | + |
| 23 | +PyTorch provides multiple methods for operator registration and usage, both at the Python and C++ levels, along with a set of supporting tools to quickly locate issues and query information. The following sections detail the operator registration capabilities. |
| 24 | + |
| 25 | +### Tools |
| 26 | + |
| 27 | +#### Commands |
| 28 | + |
| 29 | +PyTorch provides a set of commands prefixed with `torch._C._dispatch_` around its Dispatch feature. You can query all related interfaces using the following command: |
| 30 | + |
| 31 | +```Shell |
| 32 | +python -c 'import torch; print("\n".join([x for x in dir(torch._C) if x.startswith("_dispatch_")]))' |
| 33 | + |
| 34 | +... |
| 35 | +_dispatch_dump |
| 36 | +_dispatch_dump_table |
| 37 | +_dispatch_has_kernel |
| 38 | +_dispatch_has_kernel_for_any_dispatch_key |
| 39 | +_dispatch_has_kernel_for_dispatch_key |
| 40 | +_dispatch_isTensorSubclassLike |
| 41 | +_dispatch_is_alias_key |
| 42 | +_dispatch_is_included_in_alias |
| 43 | +_dispatch_is_main_interpreter |
| 44 | +_dispatch_kernel_for_dispatch_key_is_fallthrough |
| 45 | +_dispatch_key_for_device |
| 46 | +_dispatch_key_name |
| 47 | +_dispatch_key_parse |
| 48 | +_dispatch_key_set |
| 49 | +... |
| 50 | +``` |
| 51 | + |
| 52 | +Here are explanations for several commonly used commands: |
| 53 | + |
| 54 | +* `torch._C._dispatch_key_set`: |
| 55 | + |
| 56 | + Displays the DispatchKey of the current Tensor, with priority increasing from left to right. |
| 57 | + |
| 58 | + ```Python |
| 59 | + >>> import torch |
| 60 | + >>> a = torch.randn(3,3,device="cuda") |
| 61 | + >>> torch._C._dispatch_key_set(a) |
| 62 | + 'DispatchKeySet(CUDA, ADInplaceOrView, AutogradCUDA, AutocastCUDA)' |
| 63 | + ``` |
| 64 | + |
| 65 | +* `torch._C._dispatch_dump_table`: |
| 66 | + |
| 67 | + Queries the support status of a given operator across different Dispatch Keys, making it easy to locate the corresponding implementation code. |
| 68 | + |
| 69 | + ```Python |
| 70 | + >>> import torch |
| 71 | + >>> print(torch._C._dispatch_dump_table("aten::add.Tensor")) |
| 72 | + >>> ... |
| 73 | + CPU: registered at ./build/aten/src/ATen/RegisterCPU_0.cpp:1309 [kernel] |
| 74 | + CUDA: registered at ./build/aten/src/ATen/RegisterCUDA_0.cpp:2420 [kernel] |
| 75 | + HIP: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel] |
| 76 | + MPS: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel] |
| 77 | + IPU: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel] |
| 78 | + XPU: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel] |
| 79 | + HPU: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel] |
| 80 | + VE: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel] |
| 81 | + MTIA: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel] |
| 82 | + MAIA: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel] |
| 83 | + PrivateUse1: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel] |
| 84 | + ... |
| 85 | + ``` |
| 86 | + |
| 87 | +#### Environment Variables |
| 88 | + |
| 89 | +PyTorch also provides some Dispatcher-related environment variables that can help with learning and quickly locating issues. |
| 90 | + |
| 91 | +* TORCH_SHOW_DISPATCH_TRACE |
| 92 | + |
| 93 | + Displays detailed internal dispatch key scheduling during PyTorch execution. |
| 94 | + |
| 95 | + ```Bash |
| 96 | + export TORCH_SHOW_DISPATCH_TRACE=1 |
| 97 | + ``` |
| 98 | + |
| 99 | + ```Python |
| 100 | + >>> import torch |
| 101 | + >>> a = torch.randn(3,3) |
| 102 | + [call] op=[aten::randn], key=[BackendSelect] |
| 103 | + [redispatch] op=[aten::randn], key=[CPU] |
| 104 | + [call] op=[aten::empty.memory_format], key=[BackendSelect] |
| 105 | + [redispatch] op=[aten::empty.memory_format], key=[CPU] |
| 106 | + [call] op=[aten::normal_], key=[CPU] |
| 107 | + ``` |
| 108 | + |
| 109 | +### Registration |
| 110 | + |
| 111 | +::::{tabs} |
| 112 | + |
| 113 | +:::{tab} C++ |
| 114 | + |
| 115 | +1. Scenario One |
| 116 | + |
| 117 | + This is the most common operator implementation scenario. PyTorch comes with many built-in operators, defining their namespace (mainly in `aten` and `c10d`), schema, and concrete implementations for backends like CPU and CUDA. Our task is to provide the corresponding implementations for new devices for these built-in operators. |
| 118 | + |
| 119 | + ```{eval-rst} |
| 120 | + .. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegMinimal.cpp |
| 121 | + :language: c++ |
| 122 | + :start-after: LITERALINCLUDE START: EMPTY.MEMORY_FORMAT |
| 123 | + :end-before: LITERALINCLUDE END: EMPTY.MEMORY_FORMAT |
| 124 | + ``` |
| 125 | + |
| 126 | + ```{eval-rst} |
| 127 | + .. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegMinimal.cpp |
| 128 | + :language: c++ |
| 129 | + :start-after: LITERALINCLUDE START: TORCH_LIBRARY_IMPL |
| 130 | + :end-before: LITERALINCLUDE END: TORCH_LIBRARY_IMPL |
| 131 | + :emphasize-lines: 2 |
| 132 | + :linenos: |
| 133 | + ``` |
| 134 | + |
| 135 | + This registers the `wrapper_empty_memory_format` implementation for the new device to the `aten::emtpy.memory_format` operator on the `PrivateUse1 DispatchKey`. |
| 136 | + |
| 137 | +2. Scenario Two |
| 138 | + |
| 139 | + For built-in PyTorch operators, besides the registration method in Scenario One, a `STUB` registration method is also supported. Essentially, this approach is based on Scenario One but with added flexibility to enhance code reuse across devices or to enable further dispatching at other granularities (e.g., CPU feature capabilities). |
| 140 | + |
| 141 | + ```{eval-rst} |
| 142 | + .. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp |
| 143 | + :language: c++ |
| 144 | + :start-after: LITERALINCLUDE START: STUB |
| 145 | + :end-before: LITERALINCLUDE END: STUB |
| 146 | + :linenos: |
| 147 | + ``` |
| 148 | + |
| 149 | + ```{todo} |
| 150 | + List of operators that can be registered via `STUB` |
| 151 | + ``` |
| 152 | + |
| 153 | +3. Scenario Three |
| 154 | + |
| 155 | + Besides providing built-in operator definitions, PyTorch also supports user-defined operators, generally in two forms: |
| 156 | + |
| 157 | + * Adding custom operators to a new namespace: |
| 158 | + |
| 159 | + ```{todo} |
| 160 | + TODO(including forward and backward) |
| 161 | + ``` |
| 162 | + |
| 163 | + * Extending existing namespaces with custom operators: |
| 164 | + |
| 165 | + ```{todo} |
| 166 | + TODO(including forward and backward) |
| 167 | + ``` |
| 168 | + |
| 169 | +4. Scenario Four |
| 170 | + |
| 171 | + In addition to separately registering forward and backward functions to `PrivateUse1` and `AutogradPrivateUse1` DispatchKeys, PyTorch also supports a more convenient option using `torch.autograd.Function`. |
| 172 | + |
| 173 | + ```{eval-rst} |
| 174 | + .. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.cpp |
| 175 | + :language: c++ |
| 176 | + :start-after: LITERALINCLUDE START: TORCH.AUTOGRAD.FUNCTION Part1 |
| 177 | + :end-before: LITERALINCLUDE END: TORCH.AUTOGRAD.FUNCTION Part1 |
| 178 | + :linenos: |
| 179 | + ``` |
| 180 | + |
| 181 | + ```{eval-rst} |
| 182 | + .. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.cpp |
| 183 | + :language: c++ |
| 184 | + :start-after: LITERALINCLUDE START: TORCH.AUTOGRAD.FUNCTION Part2 |
| 185 | + :end-before: LITERALINCLUDE END: TORCH.AUTOGRAD.FUNCTION Part2 |
| 186 | + :linenos: |
| 187 | + ``` |
| 188 | + |
| 189 | + ```{eval-rst} |
| 190 | + .. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp |
| 191 | + :language: c++ |
| 192 | + :start-after: LITERALINCLUDE START: TORCH.AUTOGRAD.FUNCTION |
| 193 | + :end-before: LITERALINCLUDE END: TORCH.AUTOGRAD.FUNCTION |
| 194 | + :emphasize-lines: 2,7 |
| 195 | + :linenos: |
| 196 | + ``` |
| 197 | + |
| 198 | +5. Scenario Five |
| 199 | + |
| 200 | + PyTorch provides a fallback mechanism that allows unsupported operators to fall back to CPU execution. This is crucial for in-development accelerator backends to ensure functional correctness at the cost of performance. |
| 201 | + |
| 202 | + * Per-operator fallback |
| 203 | + |
| 204 | + ```{todo} |
| 205 | + TODO |
| 206 | + ``` |
| 207 | + |
| 208 | + * Global fallback |
| 209 | + |
| 210 | + ```{eval-rst} |
| 211 | + .. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegMinimal.cpp |
| 212 | + :language: c++ |
| 213 | + :start-after: LITERALINCLUDE START: FALLBACK GLOBAL |
| 214 | + :end-before: LITERALINCLUDE END: FALLBACK GLOBAL |
| 215 | + :linenos: |
| 216 | + ``` |
| 217 | + |
| 218 | + This enables global fallback so all unimplemented operators on the new backend will default to CPU execution. |
| 219 | + |
| 220 | +6. Scenario Six |
| 221 | + |
| 222 | + ```{todo} |
| 223 | + * Meta registration |
| 224 | + * Overriding default implementations |
| 225 | + * Fallthrough |
| 226 | + * ATen operator set |
| 227 | + * ... |
| 228 | + ``` |
| 229 | + |
| 230 | +::: |
| 231 | + |
| 232 | +:::{tab} Python |
| 233 | + |
| 234 | +TODO |
| 235 | + |
| 236 | +::: |
| 237 | + |
| 238 | +:::: |
| 239 | + |
| 240 | +### Minimum set of operators to support |
| 241 | + |
| 242 | +To help developers better prioritize their work, we provide a minimal set of operators. Implementing these operators ensures basic operator functionality is available. |
| 243 | + |
| 244 | +| Operator Name | Dispatch Key | Description | |
| 245 | +| :---: | :---: | :---: | |
| 246 | +| empty.memory_format | PrivateUse1 | | |
| 247 | +| empty_strided | PrivateUse1 | | |
| 248 | +| as_strided | PrivateUse1 | | |
| 249 | +| resize_ | PrivateUse1 | | |
| 250 | +| _reshape_alias | PrivateUse1 | | |
| 251 | +| _copy_from | PrivateUse1 | | |
| 252 | +| _copy_from_and_resize | PrivateUse1 | | |
| 253 | +| _local_scalar_dense | PrivateUse1 | | |
| 254 | +| set_.source_Tensor | PrivateUse1 | | |
| 255 | +| set_.source_Storage | PrivateUse1 | | |
| 256 | +| set_.source_Storage_storage_offset | PrivateUse1 | | |
| 257 | +| view | PrivateUse1 | | |
| 258 | +| fallback | PrivateUse1 | | |
| 259 | + |
| 260 | +```{todo} |
| 261 | +Add/remove operators above to ensure the minimal set list is reliable and accurate |
| 262 | +``` |
| 263 | + |
| 264 | +## Autocast |
| 265 | + |
| 266 | +## Autoload |
| 267 | + |
| 268 | +## Memory Management |
| 269 | + |
| 270 | +## Custom Storage |
| 271 | + |
| 272 | +## ... |
| 273 | + |
| 274 | +[OpenReg URL]: https://github.com/pytorch/pytorch/tree/main/test/cpp_extensions/open_registration_extension/torch_openreg "OpenReg URL" |
0 commit comments