Skip to content

Commit 931d651

Browse files
committed
[OpenReg] Add Develop Notes for Integrating New Backend into PyTorch
As the title stated. To facilitate the integration of the new backend, we plan to publish a new development note that details all the key components, hoping to speed up the development of other accelerators. This PR is the beginning of this note, which we will gradually improve and keep in sync with OpenReg's code. ghstack-source-id: 6b8c629 Pull-Request-resolved: #158644
1 parent 94b45eb commit 931d651

File tree

6 files changed

+298
-0
lines changed

6 files changed

+298
-0
lines changed

.ci/docker/requirements-docs.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ sphinxext-opengraph==0.9.1
1515
#Description: This is used to generate PyTorch docs
1616
#Pinned versions: 0.9.1
1717

18+
sphinx-tabs==3.4.7
19+
#Description: This is used to generate PyTorch docs
20+
#Pinned versions: 3.4.7
21+
1822
sphinx_sitemap==2.6.0
1923
#Description: This is used to generate sitemap for PyTorch docs
2024
#Pinned versions: 2.6.0

docs/source/conf.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
"sphinx.ext.linkcode",
6767
"sphinxcontrib.mermaid",
6868
"sphinx_sitemap",
69+
'sphinx_tabs.tabs',
6970
]
7071

7172
myst_enable_extensions = [
@@ -82,6 +83,9 @@
8283
]
8384
sitemap_url_scheme = "{link}"
8485

86+
# todo options
87+
todo_include_todos = True
88+
8589
# build the templated autosummary files
8690
autosummary_generate = True
8791
numpydoc_show_class_members = False

docs/source/notes/accelerator.md

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
```{eval-rst}
2+
.. role:: hidden
3+
:class: hidden-section
4+
```
5+
6+
# Extending PyTorch with New Accelerators
7+
8+
## Backgrounp
9+
10+
The PrivateUse1-based third-party device integration mechanism has become the official path for integrating new devices into PyTorch. Ensuring the usability of this mechanism is crucial for enriching the hardware ecosystem of PyTorch.
11+
12+
To assist third-party device developers in efficiently integrating new backends, this article introduces in detail the integration methods for typical PyTorch modules using a modular approach. It is accompanied by a streamlined code implementation from the official [torch_openreg][OpenReg URL] backend to help developers quickly get started while avoiding common pitfalls.
13+
14+
This document is suitable for the following readers:
15+
16+
* Developers who wish to integrate accelerator backends into PyTorch;
17+
* Developers interested in the principles of typical PyTorch modules;
18+
19+
---
20+
21+
## Operator Registration
22+
23+
PyTorch provides multiple methods for operator registration and usage, both at the Python and C++ levels, along with a set of supporting tools to quickly locate issues and query information. The following sections detail the operator registration capabilities.
24+
25+
### Tools
26+
27+
#### Commands
28+
29+
PyTorch provides a set of commands prefixed with `torch._C._dispatch_` around its Dispatch feature. You can query all related interfaces using the following command:
30+
31+
```Shell
32+
python -c 'import torch; print("\n".join([x for x in dir(torch._C) if x.startswith("_dispatch_")]))'
33+
34+
...
35+
_dispatch_dump
36+
_dispatch_dump_table
37+
_dispatch_has_kernel
38+
_dispatch_has_kernel_for_any_dispatch_key
39+
_dispatch_has_kernel_for_dispatch_key
40+
_dispatch_isTensorSubclassLike
41+
_dispatch_is_alias_key
42+
_dispatch_is_included_in_alias
43+
_dispatch_is_main_interpreter
44+
_dispatch_kernel_for_dispatch_key_is_fallthrough
45+
_dispatch_key_for_device
46+
_dispatch_key_name
47+
_dispatch_key_parse
48+
_dispatch_key_set
49+
...
50+
```
51+
52+
Here are explanations for several commonly used commands:
53+
54+
* `torch._C._dispatch_key_set`:
55+
56+
Displays the DispatchKey of the current Tensor, with priority increasing from left to right.
57+
58+
```Python
59+
>>> import torch
60+
>>> a = torch.randn(3,3,device="cuda")
61+
>>> torch._C._dispatch_key_set(a)
62+
'DispatchKeySet(CUDA, ADInplaceOrView, AutogradCUDA, AutocastCUDA)'
63+
```
64+
65+
* `torch._C._dispatch_dump_table`:
66+
67+
Queries the support status of a given operator across different Dispatch Keys, making it easy to locate the corresponding implementation code.
68+
69+
```Python
70+
>>> import torch
71+
>>> print(torch._C._dispatch_dump_table("aten::add.Tensor"))
72+
>>> ...
73+
CPU: registered at ./build/aten/src/ATen/RegisterCPU_0.cpp:1309 [kernel]
74+
CUDA: registered at ./build/aten/src/ATen/RegisterCUDA_0.cpp:2420 [kernel]
75+
HIP: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
76+
MPS: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
77+
IPU: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
78+
XPU: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
79+
HPU: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
80+
VE: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
81+
MTIA: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
82+
MAIA: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
83+
PrivateUse1: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
84+
...
85+
```
86+
87+
#### Environment Variables
88+
89+
PyTorch also provides some Dispatcher-related environment variables that can help with learning and quickly locating issues.
90+
91+
* TORCH_SHOW_DISPATCH_TRACE
92+
93+
Displays detailed internal dispatch key scheduling during PyTorch execution.
94+
95+
```Bash
96+
export TORCH_SHOW_DISPATCH_TRACE=1
97+
```
98+
99+
```Python
100+
>>> import torch
101+
>>> a = torch.randn(3,3)
102+
[call] op=[aten::randn], key=[BackendSelect]
103+
[redispatch] op=[aten::randn], key=[CPU]
104+
[call] op=[aten::empty.memory_format], key=[BackendSelect]
105+
[redispatch] op=[aten::empty.memory_format], key=[CPU]
106+
[call] op=[aten::normal_], key=[CPU]
107+
```
108+
109+
### Registration
110+
111+
::::{tabs}
112+
113+
:::{tab} C++
114+
115+
1. Scenario One
116+
117+
This is the most common operator implementation scenario. PyTorch comes with many built-in operators, defining their namespace (mainly in `aten` and `c10d`), schema, and concrete implementations for backends like CPU and CUDA. Our task is to provide the corresponding implementations for new devices for these built-in operators.
118+
119+
```{eval-rst}
120+
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegMinimal.cpp
121+
:language: c++
122+
:start-after: LITERALINCLUDE START: EMPTY.MEMORY_FORMAT
123+
:end-before: LITERALINCLUDE END: EMPTY.MEMORY_FORMAT
124+
```
125+
126+
```{eval-rst}
127+
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegMinimal.cpp
128+
:language: c++
129+
:start-after: LITERALINCLUDE START: TORCH_LIBRARY_IMPL
130+
:end-before: LITERALINCLUDE END: TORCH_LIBRARY_IMPL
131+
:emphasize-lines: 2
132+
:linenos:
133+
```
134+
135+
This registers the `wrapper_empty_memory_format` implementation for the new device to the `aten::emtpy.memory_format` operator on the `PrivateUse1 DispatchKey`.
136+
137+
2. Scenario Two
138+
139+
For built-in PyTorch operators, besides the registration method in Scenario One, a `STUB` registration method is also supported. Essentially, this approach is based on Scenario One but with added flexibility to enhance code reuse across devices or to enable further dispatching at other granularities (e.g., CPU feature capabilities).
140+
141+
```{eval-rst}
142+
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp
143+
:language: c++
144+
:start-after: LITERALINCLUDE START: STUB
145+
:end-before: LITERALINCLUDE END: STUB
146+
:linenos:
147+
```
148+
149+
```{todo}
150+
List of operators that can be registered via `STUB`
151+
```
152+
153+
3. Scenario Three
154+
155+
Besides providing built-in operator definitions, PyTorch also supports user-defined operators, generally in two forms:
156+
157+
* Adding custom operators to a new namespace:
158+
159+
```{todo}
160+
TODO(including forward and backward)
161+
```
162+
163+
* Extending existing namespaces with custom operators:
164+
165+
```{todo}
166+
TODO(including forward and backward)
167+
```
168+
169+
4. Scenario Four
170+
171+
In addition to separately registering forward and backward functions to `PrivateUse1` and `AutogradPrivateUse1` DispatchKeys, PyTorch also supports a more convenient option using `torch.autograd.Function`.
172+
173+
```{eval-rst}
174+
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.cpp
175+
:language: c++
176+
:start-after: LITERALINCLUDE START: TORCH.AUTOGRAD.FUNCTION Part1
177+
:end-before: LITERALINCLUDE END: TORCH.AUTOGRAD.FUNCTION Part1
178+
:linenos:
179+
```
180+
181+
```{eval-rst}
182+
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.cpp
183+
:language: c++
184+
:start-after: LITERALINCLUDE START: TORCH.AUTOGRAD.FUNCTION Part2
185+
:end-before: LITERALINCLUDE END: TORCH.AUTOGRAD.FUNCTION Part2
186+
:linenos:
187+
```
188+
189+
```{eval-rst}
190+
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp
191+
:language: c++
192+
:start-after: LITERALINCLUDE START: TORCH.AUTOGRAD.FUNCTION
193+
:end-before: LITERALINCLUDE END: TORCH.AUTOGRAD.FUNCTION
194+
:emphasize-lines: 2,7
195+
:linenos:
196+
```
197+
198+
5. Scenario Five
199+
200+
PyTorch provides a fallback mechanism that allows unsupported operators to fall back to CPU execution. This is crucial for in-development accelerator backends to ensure functional correctness at the cost of performance.
201+
202+
* Per-operator fallback
203+
204+
```{todo}
205+
TODO
206+
```
207+
208+
* Global fallback
209+
210+
```{eval-rst}
211+
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegMinimal.cpp
212+
:language: c++
213+
:start-after: LITERALINCLUDE START: FALLBACK GLOBAL
214+
:end-before: LITERALINCLUDE END: FALLBACK GLOBAL
215+
:linenos:
216+
```
217+
218+
This enables global fallback so all unimplemented operators on the new backend will default to CPU execution.
219+
220+
6. Scenario Six
221+
222+
```{todo}
223+
* Meta registration
224+
* Overriding default implementations
225+
* Fallthrough
226+
* ATen operator set
227+
* ...
228+
```
229+
230+
:::
231+
232+
:::{tab} Python
233+
234+
TODO
235+
236+
:::
237+
238+
::::
239+
240+
### Minimum set of operators to support
241+
242+
To help developers better prioritize their work, we provide a minimal set of operators. Implementing these operators ensures basic operator functionality is available.
243+
244+
| Operator Name | Dispatch Key | Description |
245+
| :---: | :---: | :---: |
246+
| empty.memory_format | PrivateUse1 | |
247+
| empty_strided | PrivateUse1 | |
248+
| as_strided | PrivateUse1 | |
249+
| resize_ | PrivateUse1 | |
250+
| _reshape_alias | PrivateUse1 | |
251+
| _copy_from | PrivateUse1 | |
252+
| _copy_from_and_resize | PrivateUse1 | |
253+
| _local_scalar_dense | PrivateUse1 | |
254+
| set_.source_Tensor | PrivateUse1 | |
255+
| set_.source_Storage | PrivateUse1 | |
256+
| set_.source_Storage_storage_offset | PrivateUse1 | |
257+
| view | PrivateUse1 | |
258+
| fallback | PrivateUse1 | |
259+
260+
```{todo}
261+
Add/remove operators above to ensure the minimal set list is reliable and accurate
262+
```
263+
264+
## Autocast
265+
266+
## Autoload
267+
268+
## Memory Management
269+
270+
## Custom Storage
271+
272+
## ...
273+
274+
[OpenReg URL]: https://github.com/pytorch/pytorch/tree/main/test/cpp_extensions/open_registration_extension/torch_openreg "OpenReg URL"

test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
namespace at::openreg {
99

10+
// START my snippet
1011
at::Tensor wrapper_quantize_per_tensor(
1112
const at::Tensor& self,
1213
double scale,
@@ -15,6 +16,7 @@ at::Tensor wrapper_quantize_per_tensor(
1516
return at::native::quantize_per_tensor_openreg(
1617
self, scale, zero_point, dtype);
1718
}
19+
// END my snippet
1820

1921
int64_t wrapper__fused_sdp_choice(
2022
const at::Tensor& query,
@@ -112,6 +114,7 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
112114

113115
} // namespace at::openreg
114116

117+
// LITERALINCLUDE START: TORCH.AUTOGRAD.FUNCTION
115118
namespace at::openreg {
116119
TORCH_LIBRARY(openreg, m) {
117120
m.def("custom_autograd_fn_returns_self(Tensor input)-> Tensor");
@@ -126,7 +129,9 @@ TORCH_LIBRARY_IMPL(openreg, AutogradPrivateUse1, m) {
126129
"custom_autograd_fn_aliasing", &at::native::custom_autograd_fn_aliasing);
127130
}
128131
} // namespace at::openreg
132+
// LITERALINCLUDE END: TORCH.AUTOGRAD.FUNCTION
129133

134+
// LITERALINCLUDE START: STUB
130135
namespace at::native {
131136
REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &abs_kernel_openreg);
132137
REGISTER_PRIVATEUSE1_DISPATCH(
@@ -136,3 +141,4 @@ REGISTER_PRIVATEUSE1_DISPATCH(
136141
_fused_sdp_choice_stub,
137142
&_fused_sdp_choice_openreg);
138143
} // namespace at::native
144+
// LITERALINCLUDE END: STUB

test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegMinimal.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
namespace at::openreg {
99

10+
// LITERALINCLUDE START: EMPTY.MEMORY_FORMAT
1011
at::Tensor wrapper_empty_memory_format(
1112
c10::IntArrayRef size,
1213
std::optional<c10::ScalarType> dtype_opt,
@@ -22,6 +23,7 @@ at::Tensor wrapper_empty_memory_format(
2223
pin_memory_opt,
2324
memory_format_opt);
2425
}
26+
// LITERALINCLUDE END: EMPTY.MEMORY_FORMAT
2527

2628
at::Tensor wrapper_empty_strided(
2729
c10::IntArrayRef size,
@@ -97,6 +99,7 @@ at::Tensor wrapper_view(const at::Tensor& self, c10::SymIntArrayRef size) {
9799
return at::native::view_openreg(self, size);
98100
}
99101

102+
// LITERALINCLUDE START: TORCH_LIBRARY_IMPL
100103
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
101104
m.impl("empty.memory_format", wrapper_empty_memory_format);
102105
m.impl("empty_strided", wrapper_empty_strided);
@@ -113,7 +116,9 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
113116
wrapper_set_source_Storage_storage_offsetset_);
114117
m.impl("view", wrapper_view);
115118
}
119+
// LITERALINCLUDE END: TORCH_LIBRARY_IMPL
116120

121+
// LITERALINCLUDE START: FALLBACK GLOBAL
117122
void wrapper_cpu_fallback(
118123
const c10::OperatorHandle& op,
119124
torch::jit::Stack* stack) {
@@ -124,5 +129,6 @@ TORCH_LIBRARY_IMPL(_, PrivateUse1, m) {
124129
m.fallback(
125130
torch::CppFunction::makeFromBoxedFunction<&wrapper_cpu_fallback>());
126131
}
132+
// LITERALINCLUDE END: FALLBACK GLOBAL
127133

128134
} // namespace at::openreg

0 commit comments

Comments
 (0)