|
| 1 | +# Extending PyTorch with New Accelerators |
| 2 | + |
| 3 | +## Background |
| 4 | + |
| 5 | +### Motivation and Overview |
| 6 | + |
| 7 | +Since PyTorch 2.1, the community has made significant progress in simplifying the integration of new accelerators into the PyTorch ecosystem. These improvements include, but are not limited to: refinement of the `PrivateUse1` Dispatch Key, introduction and improvement of core subsystem extension mechanisms, and device-agnostic refactoring of key modules (e.g., `torch.accelerator`, `memory management`). Taken together, these improvements lay the foundation for a **robust**, **flexible** and developer-friendly accelerator integration path. |
| 8 | + |
| 9 | +### Why Does This Matter? |
| 10 | + |
| 11 | +This integration path has several key advantages: |
| 12 | + |
| 13 | +* **Speed**: Extensibility is built-in across all core PyTorch modules. Without modifying upstream code, developers can independently integrate new accelerator into their own downstream codebases without being hindered by community review or modifications to PyTorch core code. |
| 14 | +* **Future-proofing**: This integration path is the default for all future PyTorch features, which means that new modules and features will automatically support scaling to new devices as long as this path is followed. |
| 15 | +* **Autonomy**: Vendors have full control over their accelerator integration timelines, enabling agile iteration cycles and reducing reliance on upstream coordination. |
| 16 | + |
| 17 | +### About This Document |
| 18 | + |
| 19 | +This guide aims to provide a **comprehensive overview of the modern integration pathway** for new devices in PyTorch. It walks through the full integration surface, from low-level device primitives to higher-level domain modules like compilation and quantization. The structure follows a **modular and scenario-driven approach**, where each topic is paired with corresponding code examples from [torch_openreg] [OpenReg URL], an official reference implementation. |
| 20 | + |
| 21 | +The goal is to help developers: |
| 22 | + |
| 23 | +* Understand the full scope of accelerator integration; |
| 24 | +* Rapidly bring up a new accelerator following best practices; |
| 25 | +* Avoid common pitfalls through clear, targeted examples. |
| 26 | + |
| 27 | +### Target Audience |
| 28 | + |
| 29 | +This document is intended for: |
| 30 | + |
| 31 | +* **Accelerator Developers** who are integrating accelerator into PyTorch; |
| 32 | +* **Advanced PyTorch Users** interested in the inner workings of key modules; |
| 33 | + |
| 34 | +## Operator Registration |
| 35 | + |
| 36 | +PyTorch provides multiple methods for operator registration and usage, both at the Python and C++ levels, along with a set of supporting tools to quickly locate issues and query information. The following sections detail the operator registration capabilities. |
| 37 | + |
| 38 | +### Tools |
| 39 | + |
| 40 | +#### Commands |
| 41 | + |
| 42 | +PyTorch provides a set of commands prefixed with `torch._C._dispatch_` around its Dispatch feature. You can query all related interfaces using the following command: |
| 43 | + |
| 44 | +```Shell |
| 45 | +python -c 'import torch; print("\n".join([x for x in dir(torch._C) if x.startswith("_dispatch_")]))' |
| 46 | + |
| 47 | +... |
| 48 | +_dispatch_dump |
| 49 | +_dispatch_dump_table |
| 50 | +_dispatch_has_kernel |
| 51 | +_dispatch_has_kernel_for_any_dispatch_key |
| 52 | +_dispatch_has_kernel_for_dispatch_key |
| 53 | +_dispatch_isTensorSubclassLike |
| 54 | +_dispatch_is_alias_key |
| 55 | +_dispatch_is_included_in_alias |
| 56 | +_dispatch_is_main_interpreter |
| 57 | +_dispatch_kernel_for_dispatch_key_is_fallthrough |
| 58 | +_dispatch_key_for_device |
| 59 | +_dispatch_key_name |
| 60 | +_dispatch_key_parse |
| 61 | +_dispatch_key_set |
| 62 | +... |
| 63 | +``` |
| 64 | + |
| 65 | +Here are explanations for several commonly used commands: |
| 66 | + |
| 67 | +* `torch._C._dispatch_key_set`: |
| 68 | + |
| 69 | + Displays the DispatchKey of the current Tensor, with priority increasing from left to right. |
| 70 | + |
| 71 | + ```Python |
| 72 | + >>> import torch |
| 73 | + >>> a = torch.randn(3,3,device="cuda") |
| 74 | + >>> torch._C._dispatch_key_set(a) |
| 75 | + 'DispatchKeySet(CUDA, ADInplaceOrView, AutogradCUDA, AutocastCUDA)' |
| 76 | + ``` |
| 77 | + |
| 78 | +* `torch._C._dispatch_dump_table`: |
| 79 | + |
| 80 | + Queries the support status of a given operator across different Dispatch Keys, making it easy to locate the corresponding implementation code. |
| 81 | + |
| 82 | + ```Python |
| 83 | + >>> import torch |
| 84 | + >>> print(torch._C._dispatch_dump_table("aten::add.Tensor")) |
| 85 | + >>> ... |
| 86 | + CPU: registered at ./build/aten/src/ATen/RegisterCPU_0.cpp:1309 [kernel] |
| 87 | + CUDA: registered at ./build/aten/src/ATen/RegisterCUDA_0.cpp:2420 [kernel] |
| 88 | + HIP: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel] |
| 89 | + MPS: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel] |
| 90 | + IPU: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel] |
| 91 | + XPU: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel] |
| 92 | + HPU: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel] |
| 93 | + VE: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel] |
| 94 | + MTIA: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel] |
| 95 | + MAIA: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel] |
| 96 | + PrivateUse1: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel] |
| 97 | + ... |
| 98 | + ``` |
| 99 | + |
| 100 | +#### Environment Variables |
| 101 | + |
| 102 | +PyTorch also provides some Dispatcher-related environment variables that can help with learning and quickly locating issues. |
| 103 | + |
| 104 | +* TORCH_SHOW_DISPATCH_TRACE |
| 105 | + |
| 106 | + Displays detailed internal dispatch key scheduling during PyTorch execution. |
| 107 | + |
| 108 | + ```Bash |
| 109 | + export TORCH_SHOW_DISPATCH_TRACE=1 |
| 110 | + ``` |
| 111 | + |
| 112 | + ```Python |
| 113 | + >>> import torch |
| 114 | + >>> a = torch.randn(3,3) |
| 115 | + [call] op=[aten::randn], key=[BackendSelect] |
| 116 | + [redispatch] op=[aten::randn], key=[CPU] |
| 117 | + [call] op=[aten::empty.memory_format], key=[BackendSelect] |
| 118 | + [redispatch] op=[aten::empty.memory_format], key=[CPU] |
| 119 | + [call] op=[aten::normal_], key=[CPU] |
| 120 | + ``` |
| 121 | + |
| 122 | +### Registration |
| 123 | + |
| 124 | +::::{tabs} |
| 125 | + |
| 126 | +:::{tab} C++ |
| 127 | + |
| 128 | +1. Scenario One |
| 129 | + |
| 130 | + This is the most common operator implementation scenario. PyTorch comes with many built-in operators, defining their namespace (mainly in `aten` and `c10d`), schema, and concrete implementations for backends like CPU and CUDA. Our task is to provide the corresponding implementations for new devices for these built-in operators. |
| 131 | + |
| 132 | + ```{eval-rst} |
| 133 | + .. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegMinimal.cpp |
| 134 | + :language: c++ |
| 135 | + :start-after: LITERALINCLUDE START: EMPTY.MEMORY_FORMAT |
| 136 | + :end-before: LITERALINCLUDE END: EMPTY.MEMORY_FORMAT |
| 137 | + ``` |
| 138 | + |
| 139 | + ```{eval-rst} |
| 140 | + .. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegMinimal.cpp |
| 141 | + :language: c++ |
| 142 | + :start-after: LITERALINCLUDE START: TORCH_LIBRARY_IMPL |
| 143 | + :end-before: LITERALINCLUDE END: TORCH_LIBRARY_IMPL |
| 144 | + :emphasize-lines: 2 |
| 145 | + :linenos: |
| 146 | + ``` |
| 147 | + |
| 148 | + This registers the `wrapper_empty_memory_format` implementation for the new device to the `aten::empty.memory_format` operator on the `PrivateUse1 DispatchKey`. |
| 149 | + |
| 150 | +2. Scenario Two |
| 151 | + |
| 152 | + For built-in PyTorch operators, besides the registration method in Scenario One, a `STUB` registration method is also supported. Essentially, this approach is based on Scenario One but with added flexibility to enhance code reuse across devices or to enable further dispatching at other granularities (e.g., CPU feature capabilities). |
| 153 | + |
| 154 | + ```{eval-rst} |
| 155 | + .. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp |
| 156 | + :language: c++ |
| 157 | + :start-after: LITERALINCLUDE START: STUB |
| 158 | + :end-before: LITERALINCLUDE END: STUB |
| 159 | + :linenos: |
| 160 | + ``` |
| 161 | + |
| 162 | + ```{todo} |
| 163 | + List of operators that can be registered via `STUB` |
| 164 | + ``` |
| 165 | + |
| 166 | +3. Scenario Three |
| 167 | + |
| 168 | + Besides providing built-in operator definitions, PyTorch also supports user-defined operators, generally in two forms: |
| 169 | + |
| 170 | + * Adding custom operators to a new namespace: |
| 171 | + |
| 172 | + ```{todo} |
| 173 | + TODO(including forward and backward) |
| 174 | + ``` |
| 175 | + |
| 176 | + * Extending existing namespaces with custom operators: |
| 177 | + |
| 178 | + ```{todo} |
| 179 | + TODO(including forward and backward) |
| 180 | + ``` |
| 181 | + |
| 182 | +4. Scenario Four |
| 183 | + |
| 184 | + In addition to separately registering forward and backward functions to `PrivateUse1` and `AutogradPrivateUse1` DispatchKeys, PyTorch also supports a more convenient option using `torch.autograd.Function`. |
| 185 | + |
| 186 | + ```{eval-rst} |
| 187 | + .. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.cpp |
| 188 | + :language: c++ |
| 189 | + :start-after: LITERALINCLUDE START: TORCH.AUTOGRAD.FUNCTION Part1 |
| 190 | + :end-before: LITERALINCLUDE END: TORCH.AUTOGRAD.FUNCTION Part1 |
| 191 | + :linenos: |
| 192 | + ``` |
| 193 | + |
| 194 | + ```{eval-rst} |
| 195 | + .. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.cpp |
| 196 | + :language: c++ |
| 197 | + :start-after: LITERALINCLUDE START: TORCH.AUTOGRAD.FUNCTION Part2 |
| 198 | + :end-before: LITERALINCLUDE END: TORCH.AUTOGRAD.FUNCTION Part2 |
| 199 | + :linenos: |
| 200 | + ``` |
| 201 | + |
| 202 | + ```{eval-rst} |
| 203 | + .. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp |
| 204 | + :language: c++ |
| 205 | + :start-after: LITERALINCLUDE START: TORCH.AUTOGRAD.FUNCTION |
| 206 | + :end-before: LITERALINCLUDE END: TORCH.AUTOGRAD.FUNCTION |
| 207 | + :emphasize-lines: 2,7 |
| 208 | + :linenos: |
| 209 | + ``` |
| 210 | + |
| 211 | +5. Scenario Five |
| 212 | + |
| 213 | + PyTorch provides a fallback mechanism that allows unsupported operators to fall back to CPU execution. This is crucial for in-development accelerator backends to ensure functional correctness at the cost of performance. |
| 214 | + |
| 215 | + * Per-operator fallback |
| 216 | + |
| 217 | + ```{todo} |
| 218 | + TODO |
| 219 | + ``` |
| 220 | + |
| 221 | + * Global fallback |
| 222 | + |
| 223 | + ```{eval-rst} |
| 224 | + .. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegMinimal.cpp |
| 225 | + :language: c++ |
| 226 | + :start-after: LITERALINCLUDE START: FALLBACK GLOBAL |
| 227 | + :end-before: LITERALINCLUDE END: FALLBACK GLOBAL |
| 228 | + :linenos: |
| 229 | + ``` |
| 230 | + |
| 231 | + This enables global fallback so all unimplemented operators on the new backend will default to CPU execution. |
| 232 | + |
| 233 | +6. Scenario Six |
| 234 | + |
| 235 | + ```{todo} |
| 236 | + * Meta registration |
| 237 | + * Overriding default implementations |
| 238 | + * Fallthrough |
| 239 | + * ATen operator set |
| 240 | + * ... |
| 241 | + ``` |
| 242 | + |
| 243 | +::: |
| 244 | + |
| 245 | +:::{tab} Python |
| 246 | + |
| 247 | +TODO |
| 248 | + |
| 249 | +::: |
| 250 | + |
| 251 | +:::: |
| 252 | + |
| 253 | +### Minimum set of operators to support |
| 254 | + |
| 255 | +To help developers better prioritize their work, we provide a minimal set of operators. Implementing these operators ensures basic operator functionality is available. |
| 256 | + |
| 257 | +| Operator Name | Dispatch Key | Description | |
| 258 | +| :---: | :---: | :---: | |
| 259 | +| empty.memory_format | PrivateUse1 | | |
| 260 | +| empty_strided | PrivateUse1 | | |
| 261 | +| as_strided | PrivateUse1 | | |
| 262 | +| resize_ | PrivateUse1 | | |
| 263 | +| _reshape_alias | PrivateUse1 | | |
| 264 | +| _copy_from | PrivateUse1 | | |
| 265 | +| _copy_from_and_resize | PrivateUse1 | | |
| 266 | +| _local_scalar_dense | PrivateUse1 | | |
| 267 | +| set_.source_Tensor | PrivateUse1 | | |
| 268 | +| set_.source_Storage | PrivateUse1 | | |
| 269 | +| set_.source_Storage_storage_offset | PrivateUse1 | | |
| 270 | +| view | PrivateUse1 | | |
| 271 | +| fallback | PrivateUse1 | | |
| 272 | + |
| 273 | +```{todo} |
| 274 | +Add/remove operators above to ensure the minimal set list is reliable and accurate |
| 275 | +``` |
| 276 | + |
| 277 | +## Autocast |
| 278 | + |
| 279 | +## Autoload |
| 280 | + |
| 281 | +## Memory Management |
| 282 | + |
| 283 | +## Custom Storage |
| 284 | + |
| 285 | +## ... |
| 286 | + |
| 287 | +[OpenReg URL]: https://github.com/pytorch/pytorch/tree/main/test/cpp_extensions/open_registration_extension/torch_openreg "OpenReg URL" |
0 commit comments