Skip to content

Commit ef09471

Browse files
committed
[OpenReg] Add Develop Notes for Integrating New Backend into PyTorch
As the title stated. To facilitate the integration of the new backend, we plan to publish a new development note that details all the key components, hoping to speed up the development of other accelerators. This PR is the beginning of this note, which we will gradually improve and keep in sync with OpenReg's code. ghstack-source-id: 2e10e33 Pull-Request-resolved: #158644
1 parent f5e2de9 commit ef09471

File tree

6 files changed

+311
-0
lines changed

6 files changed

+311
-0
lines changed

.ci/docker/requirements-docs.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ sphinxext-opengraph==0.9.1
1515
#Description: This is used to generate PyTorch docs
1616
#Pinned versions: 0.9.1
1717

18+
sphinx-tabs==3.4.7
19+
#Description: This is used to generate PyTorch docs
20+
#Pinned versions: 3.4.7
21+
1822
sphinx_sitemap==2.6.0
1923
#Description: This is used to generate sitemap for PyTorch docs
2024
#Pinned versions: 2.6.0

docs/source/conf.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
"sphinx.ext.linkcode",
6767
"sphinxcontrib.mermaid",
6868
"sphinx_sitemap",
69+
"sphinx_tabs.tabs",
6970
]
7071

7172
myst_enable_extensions = [
@@ -86,6 +87,9 @@
8687
"404": "404.html",
8788
}
8889

90+
# todo options
91+
todo_include_todos = True
92+
8993
# build the templated autosummary files
9094
autosummary_generate = True
9195
numpydoc_show_class_members = False
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
# Extending PyTorch with New Accelerators
2+
3+
## Background
4+
5+
### Motivation and Overview
6+
7+
Since PyTorch 2.1, the community has made significant progress in simplifying the integration of new accelerators into the PyTorch ecosystem. These improvements include, but are not limited to: refinement of the `PrivateUse1` Dispatch Key, introduction and improvement of core subsystem extension mechanisms, and device-agnostic refactoring of key modules (e.g., `torch.accelerator`, `memory management`). Taken together, these improvements lay the foundation for a **robust**, **flexible** and developer-friendly accelerator integration path.
8+
9+
### Why Does This Matter?
10+
11+
This integration path has several key advantages:
12+
13+
* **Speed**: Extensibility is built-in across all core PyTorch modules. Without modifying upstream code, developers can independently integrate new accelerator into their own downstream codebases without being hindered by community review or modifications to PyTorch core code.
14+
* **Future-proofing**: This integration path is the default for all future PyTorch features, which means that new modules and features will automatically support scaling to new devices as long as this path is followed.
15+
* **Autonomy**: Vendors have full control over their accelerator integration timelines, enabling agile iteration cycles and reducing reliance on upstream coordination.
16+
17+
### About This Document
18+
19+
This guide aims to provide a **comprehensive overview of the modern integration pathway** for new devices in PyTorch. It walks through the full integration surface, from low-level device primitives to higher-level domain modules like compilation and quantization. The structure follows a **modular and scenario-driven approach**, where each topic is paired with corresponding code examples from [torch_openreg] [OpenReg URL], an official reference implementation.
20+
21+
The goal is to help developers:
22+
23+
* Understand the full scope of accelerator integration;
24+
* Rapidly bring up a new accelerator following best practices;
25+
* Avoid common pitfalls through clear, targeted examples.
26+
27+
### Target Audience
28+
29+
This document is intended for:
30+
31+
* **Accelerator Developers** who are integrating accelerator into PyTorch;
32+
* **Advanced PyTorch Users** interested in the inner workings of key modules;
33+
34+
## Operator Registration
35+
36+
PyTorch provides multiple methods for operator registration and usage, both at the Python and C++ levels, along with a set of supporting tools to quickly locate issues and query information. The following sections detail the operator registration capabilities.
37+
38+
### Tools
39+
40+
#### Commands
41+
42+
PyTorch provides a set of commands prefixed with `torch._C._dispatch_` around its Dispatch feature. You can query all related interfaces using the following command:
43+
44+
```Shell
45+
python -c 'import torch; print("\n".join([x for x in dir(torch._C) if x.startswith("_dispatch_")]))'
46+
47+
...
48+
_dispatch_dump
49+
_dispatch_dump_table
50+
_dispatch_has_kernel
51+
_dispatch_has_kernel_for_any_dispatch_key
52+
_dispatch_has_kernel_for_dispatch_key
53+
_dispatch_isTensorSubclassLike
54+
_dispatch_is_alias_key
55+
_dispatch_is_included_in_alias
56+
_dispatch_is_main_interpreter
57+
_dispatch_kernel_for_dispatch_key_is_fallthrough
58+
_dispatch_key_for_device
59+
_dispatch_key_name
60+
_dispatch_key_parse
61+
_dispatch_key_set
62+
...
63+
```
64+
65+
Here are explanations for several commonly used commands:
66+
67+
* `torch._C._dispatch_key_set`:
68+
69+
Displays the DispatchKey of the current Tensor, with priority increasing from left to right.
70+
71+
```Python
72+
>>> import torch
73+
>>> a = torch.randn(3,3,device="cuda")
74+
>>> torch._C._dispatch_key_set(a)
75+
'DispatchKeySet(CUDA, ADInplaceOrView, AutogradCUDA, AutocastCUDA)'
76+
```
77+
78+
* `torch._C._dispatch_dump_table`:
79+
80+
Queries the support status of a given operator across different Dispatch Keys, making it easy to locate the corresponding implementation code.
81+
82+
```Python
83+
>>> import torch
84+
>>> print(torch._C._dispatch_dump_table("aten::add.Tensor"))
85+
>>> ...
86+
CPU: registered at ./build/aten/src/ATen/RegisterCPU_0.cpp:1309 [kernel]
87+
CUDA: registered at ./build/aten/src/ATen/RegisterCUDA_0.cpp:2420 [kernel]
88+
HIP: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
89+
MPS: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
90+
IPU: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
91+
XPU: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
92+
HPU: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
93+
VE: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
94+
MTIA: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
95+
MAIA: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
96+
PrivateUse1: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
97+
...
98+
```
99+
100+
#### Environment Variables
101+
102+
PyTorch also provides some Dispatcher-related environment variables that can help with learning and quickly locating issues.
103+
104+
* TORCH_SHOW_DISPATCH_TRACE
105+
106+
Displays detailed internal dispatch key scheduling during PyTorch execution.
107+
108+
```Bash
109+
export TORCH_SHOW_DISPATCH_TRACE=1
110+
```
111+
112+
```Python
113+
>>> import torch
114+
>>> a = torch.randn(3,3)
115+
[call] op=[aten::randn], key=[BackendSelect]
116+
[redispatch] op=[aten::randn], key=[CPU]
117+
[call] op=[aten::empty.memory_format], key=[BackendSelect]
118+
[redispatch] op=[aten::empty.memory_format], key=[CPU]
119+
[call] op=[aten::normal_], key=[CPU]
120+
```
121+
122+
### Registration
123+
124+
::::{tabs}
125+
126+
:::{tab} C++
127+
128+
1. Scenario One
129+
130+
This is the most common operator implementation scenario. PyTorch comes with many built-in operators, defining their namespace (mainly in `aten` and `c10d`), schema, and concrete implementations for backends like CPU and CUDA. Our task is to provide the corresponding implementations for new devices for these built-in operators.
131+
132+
```{eval-rst}
133+
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegMinimal.cpp
134+
:language: c++
135+
:start-after: LITERALINCLUDE START: EMPTY.MEMORY_FORMAT
136+
:end-before: LITERALINCLUDE END: EMPTY.MEMORY_FORMAT
137+
```
138+
139+
```{eval-rst}
140+
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegMinimal.cpp
141+
:language: c++
142+
:start-after: LITERALINCLUDE START: TORCH_LIBRARY_IMPL
143+
:end-before: LITERALINCLUDE END: TORCH_LIBRARY_IMPL
144+
:emphasize-lines: 2
145+
:linenos:
146+
```
147+
148+
This registers the `wrapper_empty_memory_format` implementation for the new device to the `aten::empty.memory_format` operator on the `PrivateUse1 DispatchKey`.
149+
150+
2. Scenario Two
151+
152+
For built-in PyTorch operators, besides the registration method in Scenario One, a `STUB` registration method is also supported. Essentially, this approach is based on Scenario One but with added flexibility to enhance code reuse across devices or to enable further dispatching at other granularities (e.g., CPU feature capabilities).
153+
154+
```{eval-rst}
155+
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp
156+
:language: c++
157+
:start-after: LITERALINCLUDE START: STUB
158+
:end-before: LITERALINCLUDE END: STUB
159+
:linenos:
160+
```
161+
162+
```{todo}
163+
List of operators that can be registered via `STUB`
164+
```
165+
166+
3. Scenario Three
167+
168+
Besides providing built-in operator definitions, PyTorch also supports user-defined operators, generally in two forms:
169+
170+
* Adding custom operators to a new namespace:
171+
172+
```{todo}
173+
TODO(including forward and backward)
174+
```
175+
176+
* Extending existing namespaces with custom operators:
177+
178+
```{todo}
179+
TODO(including forward and backward)
180+
```
181+
182+
4. Scenario Four
183+
184+
In addition to separately registering forward and backward functions to `PrivateUse1` and `AutogradPrivateUse1` DispatchKeys, PyTorch also supports a more convenient option using `torch.autograd.Function`.
185+
186+
```{eval-rst}
187+
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.cpp
188+
:language: c++
189+
:start-after: LITERALINCLUDE START: TORCH.AUTOGRAD.FUNCTION Part1
190+
:end-before: LITERALINCLUDE END: TORCH.AUTOGRAD.FUNCTION Part1
191+
:linenos:
192+
```
193+
194+
```{eval-rst}
195+
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.cpp
196+
:language: c++
197+
:start-after: LITERALINCLUDE START: TORCH.AUTOGRAD.FUNCTION Part2
198+
:end-before: LITERALINCLUDE END: TORCH.AUTOGRAD.FUNCTION Part2
199+
:linenos:
200+
```
201+
202+
```{eval-rst}
203+
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp
204+
:language: c++
205+
:start-after: LITERALINCLUDE START: TORCH.AUTOGRAD.FUNCTION
206+
:end-before: LITERALINCLUDE END: TORCH.AUTOGRAD.FUNCTION
207+
:emphasize-lines: 2,7
208+
:linenos:
209+
```
210+
211+
5. Scenario Five
212+
213+
PyTorch provides a fallback mechanism that allows unsupported operators to fall back to CPU execution. This is crucial for in-development accelerator backends to ensure functional correctness at the cost of performance.
214+
215+
* Per-operator fallback
216+
217+
```{todo}
218+
TODO
219+
```
220+
221+
* Global fallback
222+
223+
```{eval-rst}
224+
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegMinimal.cpp
225+
:language: c++
226+
:start-after: LITERALINCLUDE START: FALLBACK GLOBAL
227+
:end-before: LITERALINCLUDE END: FALLBACK GLOBAL
228+
:linenos:
229+
```
230+
231+
This enables global fallback so all unimplemented operators on the new backend will default to CPU execution.
232+
233+
6. Scenario Six
234+
235+
```{todo}
236+
* Meta registration
237+
* Overriding default implementations
238+
* Fallthrough
239+
* ATen operator set
240+
* ...
241+
```
242+
243+
:::
244+
245+
:::{tab} Python
246+
247+
TODO
248+
249+
:::
250+
251+
::::
252+
253+
### Minimum set of operators to support
254+
255+
To help developers better prioritize their work, we provide a minimal set of operators. Implementing these operators ensures basic operator functionality is available.
256+
257+
| Operator Name | Dispatch Key | Description |
258+
| :---: | :---: | :---: |
259+
| empty.memory_format | PrivateUse1 | |
260+
| empty_strided | PrivateUse1 | |
261+
| as_strided | PrivateUse1 | |
262+
| resize_ | PrivateUse1 | |
263+
| _reshape_alias | PrivateUse1 | |
264+
| _copy_from | PrivateUse1 | |
265+
| _copy_from_and_resize | PrivateUse1 | |
266+
| _local_scalar_dense | PrivateUse1 | |
267+
| set_.source_Tensor | PrivateUse1 | |
268+
| set_.source_Storage | PrivateUse1 | |
269+
| set_.source_Storage_storage_offset | PrivateUse1 | |
270+
| view | PrivateUse1 | |
271+
| fallback | PrivateUse1 | |
272+
273+
```{todo}
274+
Add/remove operators above to ensure the minimal set list is reliable and accurate
275+
```
276+
277+
## Autocast
278+
279+
## Autoload
280+
281+
## Memory Management
282+
283+
## Custom Storage
284+
285+
## ...
286+
287+
[OpenReg URL]: https://github.com/pytorch/pytorch/tree/main/test/cpp_extensions/open_registration_extension/torch_openreg "OpenReg URL"

test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
namespace at::openreg {
99

10+
// START my snippet
1011
at::Tensor wrapper_quantize_per_tensor(
1112
const at::Tensor& self,
1213
double scale,
@@ -15,6 +16,7 @@ at::Tensor wrapper_quantize_per_tensor(
1516
return at::native::quantize_per_tensor_openreg(
1617
self, scale, zero_point, dtype);
1718
}
19+
// END my snippet
1820

1921
int64_t wrapper__fused_sdp_choice(
2022
const at::Tensor& query,
@@ -112,6 +114,7 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
112114

113115
} // namespace at::openreg
114116

117+
// LITERALINCLUDE START: TORCH.AUTOGRAD.FUNCTION
115118
namespace at::openreg {
116119
TORCH_LIBRARY(openreg, m) {
117120
m.def("custom_autograd_fn_returns_self(Tensor input)-> Tensor");
@@ -126,7 +129,9 @@ TORCH_LIBRARY_IMPL(openreg, AutogradPrivateUse1, m) {
126129
"custom_autograd_fn_aliasing", &at::native::custom_autograd_fn_aliasing);
127130
}
128131
} // namespace at::openreg
132+
// LITERALINCLUDE END: TORCH.AUTOGRAD.FUNCTION
129133

134+
// LITERALINCLUDE START: STUB
130135
namespace at::native {
131136
REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &abs_kernel_openreg);
132137
REGISTER_PRIVATEUSE1_DISPATCH(
@@ -136,3 +141,4 @@ REGISTER_PRIVATEUSE1_DISPATCH(
136141
_fused_sdp_choice_stub,
137142
&_fused_sdp_choice_openreg);
138143
} // namespace at::native
144+
// LITERALINCLUDE END: STUB

test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegMinimal.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
namespace at::openreg {
99

10+
// LITERALINCLUDE START: EMPTY.MEMORY_FORMAT
1011
at::Tensor wrapper_empty_memory_format(
1112
c10::IntArrayRef size,
1213
std::optional<c10::ScalarType> dtype_opt,
@@ -22,6 +23,7 @@ at::Tensor wrapper_empty_memory_format(
2223
pin_memory_opt,
2324
memory_format_opt);
2425
}
26+
// LITERALINCLUDE END: EMPTY.MEMORY_FORMAT
2527

2628
at::Tensor wrapper_empty_strided(
2729
c10::IntArrayRef size,
@@ -97,6 +99,7 @@ at::Tensor wrapper_view(const at::Tensor& self, c10::SymIntArrayRef size) {
9799
return at::native::view_openreg(self, size);
98100
}
99101

102+
// LITERALINCLUDE START: TORCH_LIBRARY_IMPL
100103
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
101104
m.impl("empty.memory_format", wrapper_empty_memory_format);
102105
m.impl("empty_strided", wrapper_empty_strided);
@@ -113,7 +116,9 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
113116
wrapper_set_source_Storage_storage_offsetset_);
114117
m.impl("view", wrapper_view);
115118
}
119+
// LITERALINCLUDE END: TORCH_LIBRARY_IMPL
116120

121+
// LITERALINCLUDE START: FALLBACK GLOBAL
117122
void wrapper_cpu_fallback(
118123
const c10::OperatorHandle& op,
119124
torch::jit::Stack* stack) {
@@ -124,5 +129,6 @@ TORCH_LIBRARY_IMPL(_, PrivateUse1, m) {
124129
m.fallback(
125130
torch::CppFunction::makeFromBoxedFunction<&wrapper_cpu_fallback>());
126131
}
132+
// LITERALINCLUDE END: FALLBACK GLOBAL
127133

128134
} // namespace at::openreg

0 commit comments

Comments
 (0)