Skip to content

Commit b25251a

Browse files
[PYT-645] tensor blog
1 parent b16f6ee commit b25251a

File tree

4 files changed

+335
-0
lines changed

4 files changed

+335
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
---
2+
layout: blog_detail
3+
title: 'Efficient PyTorch: Tensor Memory Format Matters'
4+
author: 'Dhruv Matani, Suraj Subramanian'
5+
featured-img: ''
6+
---
7+
8+
Ensuring the right memory format for your inputs can significantly impact the running time of your PyTorch vision models. When in doubt, choose a Channels Last memory format.
9+
10+
When dealing with vision models in PyTorch that accept multimedia (for example image Tensorts) as input, the Tensor’s memory format can significantly impact **the inference execution speed of your model on mobile platforms when using the CPU backend along with XNNPACK**. This holds true for training and inference on server platforms as well, but latency is particularly critical for mobile devices and users.
11+
12+
<style type="text/css">
13+
article.pytorch-article table tr th, article.pytorch-article table td {line-height: 1.5rem}
14+
</style>
15+
16+
## Outline of this article
17+
1. Deep Dive into matrix storage/memory representation in C++. Introduction to [Row and Column major order](https://en.wikipedia.org/wiki/Row-_and_column-major_order).
18+
2. Impact of looping over a matrix in the same or different order as the storage representation, along with an example.
19+
3. Introduction to Cachegrind; a tool to inspect the cache friendliness of your code.
20+
4. Memory formats supported by PyTorch Operators.
21+
5. Best practices example to ensure efficient model execution with XNNPACK optimizations
22+
23+
## Matrix Storage Representation in C++
24+
25+
Images are fed into PyTorch ML models as multi-dimensional Tensors. These Tensors have specific memory formats. To understand this concept better, let’s take a look at how a 2-d matrix may be stored in memory.
26+
27+
Broadly speaking, there are 2 main ways of efficiently storing multi-dimensional data in memory.
28+
1. **Row Major Order:** In this format, the matrix is stored in row order, with each row stored before the next row in memory. I.e. row N comes before row N+1.
29+
2. **Column Major Order:** In this format, the matrix is stored in column-order, with each column stored before the next column in memory. I.e. column N comes before column N+1.
30+
31+
You can see the differences graphically below.
32+
33+
<p align="center">
34+
<img src="/assets/images/tensor/image1.png" alt="C++ stores multi-dimensional data in row-major format." width="100%">
35+
<br>
36+
C++ stores multi-dimensional data in row-major format.
37+
</p>
38+
39+
## Efficiently accessing elements of a 2d matrix
40+
41+
Similar to the storage format, there are 2 ways to access data in a 2d matrix.
42+
43+
1. **Loop Over Rows first:** All elements of a row are processed before any element of the next row.
44+
2. **Loop Over Columns first:** All elements of a column are processed before any element of the next column.
45+
46+
For maximum efficiency, one should always access data in the same format in which it is stored. I.e. if the data is stored in row-major order, then one should try to access it in that order.
47+
48+
The code below (main.cpp) shows [2 ways](https://stackoverflow.com/questions/9936132/why-does-the-order-of-the-loops-affect-performance-when-iterating-over-a-2d-arra) of accessing all the elements of a 2d 4000x4000 matrix.
49+
50+
```python
51+
#include <iostream>
52+
#include <chrono>
53+
54+
// loop1 accesses data in matrix 'a' in row major order,
55+
// since i is the outer loop variable, and j is the
56+
// inner loop variable.
57+
int loop1(int a[4000][4000]) {
58+
int s = 0;
59+
for (int i = 0; i < 4000; ++i) {
60+
for (int j = 0; j < 4000; ++j) {
61+
s += a[i][j];
62+
}
63+
}
64+
return s;
65+
}
66+
67+
// loop2 accesses data in matrix 'a' in column major order
68+
// since j is the outer loop variable, and i is the
69+
// inner loop variable.
70+
int loop2(int a[4000][4000]) {
71+
int s = 0;
72+
for (int j = 0; j < 4000; ++j) {
73+
for (int i = 0; i < 4000; ++i) {
74+
s += a[i][j];
75+
}
76+
}
77+
return s;
78+
}
79+
80+
int main() {
81+
static int a[4000][4000] = {0};
82+
for (int i = 0; i < 100; ++i) {
83+
int x = rand() % 4000;
84+
int y = rand() % 4000;
85+
a[x][y] = rand() % 1000;
86+
}
87+
88+
auto start = std::chrono::high_resolution_clock::now();
89+
auto end = start;
90+
int s = 0;
91+
92+
#if defined RUN_LOOP1
93+
start = std::chrono::high_resolution_clock::now();
94+
95+
s = 0;
96+
for (int i = 0; i < 10; ++i) {
97+
s += loop1(a);
98+
s = s % 100;
99+
}
100+
end = std::chrono::high_resolution_clock::now();
101+
102+
std::cout << "s = " << s << std::endl;
103+
std::cout << "Time for loop1: "
104+
<< std::chrono::duration<double, std::milli>(end - start).count()
105+
<< "ms" << std::endl;
106+
#endif
107+
108+
#if defined RUN_LOOP2
109+
start = std::chrono::high_resolution_clock::now();
110+
s = 0;
111+
for (int i = 0; i < 10; ++i) {
112+
s += loop2(a);
113+
s = s % 100;
114+
}
115+
end = std::chrono::high_resolution_clock::now();
116+
117+
std::cout << "s = " << s << std::endl;
118+
std::cout << "Time for loop2: "
119+
<< std::chrono::duration<double, std::milli>(end - start).count()
120+
<< "ms" << std::endl;
121+
#endif
122+
}
123+
124+
125+
Let’s build and run this program and see what it prints.
126+
127+
g++ -O2 main.cpp -DRUN_LOOP1 -DRUN_LOOP2
128+
./a.out
129+
130+
131+
Prints the following:
132+
133+
s = 70
134+
Time for loop1: 77.0687ms
135+
s = 70
136+
Time for loop2: 1219.49ms
137+
```
138+
139+
loop1() is **15x faster** than loop2(). Why is that? Let’s find out below!
140+
141+
## Measure cache misses using Cachegrind
142+
143+
[Cachegrind](https://courses.cs.washington.edu/courses/cse326/05wi/valgrind-doc/cg_main.html) is a cache profiling tool used to see how many I1 (first level instruction), D1 (first level data), and LL (last level) cache misses your program caused.
144+
145+
Let’s build our program with just loop1() and just loop2() to see how cache friendly each of these functions is.
146+
147+
### Build and run/profile just loop1()
148+
149+
```python
150+
g++ -O2 main.cpp -DRUN_LOOP1
151+
valgrind --tool=cachegrind ./a.out
152+
```
153+
154+
#### Prints:
155+
156+
```python
157+
==3299700==
158+
==3299700== I refs: 643,156,721
159+
==3299700== I1 misses: 2,077
160+
==3299700== LLi misses: 2,021
161+
==3299700== I1 miss rate: 0.00%
162+
==3299700== LLi miss rate: 0.00%
163+
==3299700==
164+
==3299700== D refs: 160,952,192 (160,695,444 rd + 256,748 wr)
165+
==3299700== D1 misses: 10,021,300 ( 10,018,723 rd + 2,577 wr)
166+
==3299700== LLd misses: 10,010,916 ( 10,009,147 rd + 1,769 wr)
167+
==3299700== D1 miss rate: 6.2% ( 6.2% + 1.0% )
168+
==3299700== LLd miss rate: 6.2% ( 6.2% + 0.7% )
169+
==3299700==
170+
==3299700== LL refs: 10,023,377 ( 10,020,800 rd + 2,577 wr)
171+
==3299700== LL misses: 10,012,937 ( 10,011,168 rd + 1,769 wr)
172+
==3299700== LL miss rate: 1.2% ( 1.2% + 0.7% )
173+
```
174+
175+
### Build and run/profile just loop2()
176+
177+
178+
```python
179+
g++ -O2 main.cpp -DRUN_LOOP2
180+
valgrind --tool=cachegrind ./a.out
181+
```
182+
183+
#### Prints:
184+
185+
```python
186+
==3300389==
187+
==3300389== I refs: 643,156,726
188+
==3300389== I1 misses: 2,075
189+
==3300389== LLi misses: 2,018
190+
==3300389== I1 miss rate: 0.00%
191+
==3300389== LLi miss rate: 0.00%
192+
==3300389==
193+
==3300389== D refs: 160,952,196 (160,695,447 rd + 256,749 wr)
194+
==3300389== D1 misses: 160,021,290 (160,018,713 rd + 2,577 wr)
195+
==3300389== LLd misses: 10,014,907 ( 10,013,138 rd + 1,769 wr)
196+
==3300389== D1 miss rate: 99.4% ( 99.6% + 1.0% )
197+
==3300389== LLd miss rate: 6.2% ( 6.2% + 0.7% )
198+
==3300389==
199+
==3300389== LL refs: 160,023,365 (160,020,788 rd + 2,577 wr)
200+
==3300389== LL misses: 10,016,925 ( 10,015,156 rd + 1,769 wr)
201+
==3300389== LL miss rate: 1.2% ( 1.2% + 0.7% )
202+
```
203+
204+
The main differences between the 2 runs are:
205+
1. **D1 misses:** 10M v/s 160M
206+
2. **D1 miss rate:** 6.2% v/s 99.4%
207+
208+
As you can see, `loop2()` causes many many more (**~16x more**) L1 data cache misses than loop1(). This is why `loop1()` is ~15x faster than loop2().
209+
210+
## Memory Formats supported by PyTorch Operators
211+
212+
While PyTorch operators expect all tensors to be in [Channels First (NCHW) dimension format](https://discuss.pytorch.org/t/why-does-pytorch-prefer-using-nchw/83637/4), PyTorch operators support 3 output [memory formats](https://github.com/pytorch/pytorch/blob/master/c10/core/MemoryFormat.h).
213+
214+
1. **Contiguous:** Tensor memory is in the same order as the tensor’s dimensions.
215+
2. **ChannelsLast:** Irrespective of the dimension order, the 2d (image) tensor is laid out as an HWC or [NHWC](https://oneapi-src.github.io/oneDNN/dev_guide_understanding_memory_formats.html) (N: batch, H: height, W: width, C: channels) tensor in memory. The dimensions could be permuted in any order.
216+
3. **ChannelsLast3d:** For 3d tensors (video tensors), the memory is laid out in THWC (Time, Height, Width, Channels) or NTHWC (N: batch, T: time, H: height, W: width, C: channels) format. The dimensions could be permuted in any order.
217+
218+
The reason that ChannelsLast is preferred for vision models is because [XNNPACK](https://github.com/google/XNNPACK) (kernel acceleration library) used by PyTorch expects all inputs to be in **Channels Last** format, so if the input to the model isn’t channels last, then it must first be converted to channels last, which is an additional operation.
219+
220+
Additionally, most PyTorch operators preserve the input tensor’s memory format, so if the input is Channels First, then the operator needs to first convert to Channels Last, then perform the operation, and then convert back to Channels First.
221+
222+
When you combine it with the fact that accelerated operators work better with a channels last memory format, you’ll notice that having the operator return back a channels-last memory format is better for subsequent operator calls or you’ll end up having every operator convert to channels-last (should it be more efficient for that specific operator).
223+
224+
From the XNNPACK home page:
225+
226+
> “All operators in XNNPACK support NHWC layout, but additionally allow custom stride along the Channel dimension".
227+
228+
## PyTorch Best Practice
229+
230+
The best way to get the most performance from your PyTorch vision models is to ensure that your input tensor is in a **Channels Last** [memory format](https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html) before it is fed into the model.
231+
232+
You can get even more speedups by optimizing your model to use the XNNPACK backend (by simply calling `optimize_for_mobile()` on your torchscripted model). Note that XNNPACK models will run slower if the inputs are contiguous, so definitely make sure it is in Channels-Last format.
233+
234+
## Working example showing speedup
235+
236+
Run this example on [Google Colab](https://colab.research.google.com/gist/suraj813/ad9aebcbffbdd6d02b23ca7231130a30/channels-last-with-xnnpack.ipynb#scrollTo=xvJN73YWXgDF) - note that runtimes on colab CPUs might not reflect accurate performance; it is recommended to run this code on your local machine.
237+
238+
```python
239+
import torch
240+
from torch.utils.mobile_optimizer import optimize_for_mobile
241+
import torch.backends.xnnpack
242+
import time
243+
244+
print("XNNPACK is enabled: ", torch.backends.xnnpack.enabled, "\n")
245+
246+
N, C, H, W = 1, 3, 200, 200
247+
x = torch.rand(N, C, H, W)
248+
print("Contiguous shape: ", x.shape)
249+
print("Contiguous stride: ", x.stride())
250+
print()
251+
252+
xcl = x.to(memory_format=torch.channels_last)
253+
print("Channels-Last shape: ", xcl.shape)
254+
print("Channels-Last stride: ", xcl.stride())
255+
256+
## Outputs:
257+
258+
# XNNPACK is enabled: True
259+
260+
# Contiguous shape: torch.Size([1, 3, 200, 200])
261+
# Contiguous stride: (120000, 40000, 200, 1)
262+
263+
# Channels-Last shape: torch.Size([1, 3, 200, 200])
264+
# Channels-Last stride: (120000, 1, 600, 3)
265+
266+
```
267+
268+
The input shape stays the same for contiguous and channels-last formats. Internally however, the tensor's layout has changed as you can see in the strides. Now, the number of jumps required to go across channels is only 1 (instead of 40000 in the contiguous tensor).
269+
This better data locality means convolution layers can access all the channels for a given pixel much faster. Let's see now how the memory format affects runtime:
270+
271+
```python
272+
from torchvision.models import resnet34, resnet50, resnet101
273+
274+
m = resnet34(pretrained=False)
275+
# m = resnet50(pretrained=False)
276+
# m = resnet101(pretrained=False)
277+
278+
def get_optimized_model(mm):
279+
mm = mm.eval()
280+
scripted = torch.jit.script(mm)
281+
optimized = optimize_for_mobile(scripted) # explicitly call the xnnpack rewrite
282+
return scripted, optimized
283+
284+
285+
def compare_contiguous_CL(mm):
286+
# inference on contiguous
287+
start = time.perf_counter()
288+
for i in range(20):
289+
mm(x)
290+
end = time.perf_counter()
291+
print("Contiguous: ", end-start)
292+
293+
# inference on channels-last
294+
start = time.perf_counter()
295+
for i in range(20):
296+
mm(xcl)
297+
end = time.perf_counter()
298+
print("Channels-Last: ", end-start)
299+
300+
with torch.inference_mode():
301+
scripted, optimized = get_optimized_model(m)
302+
303+
print("Runtimes for torchscripted model: ")
304+
compare_contiguous_CL(scripted.eval())
305+
print()
306+
print("Runtimes for mobile-optimized model: ")
307+
compare_contiguous_CL(optimized.eval())
308+
309+
310+
## Outputs (on an Intel Core i9 CPU):
311+
312+
# Runtimes for torchscripted model:
313+
# Contiguous: 1.6711160129999598
314+
# Channels-Last: 1.6678222839999535
315+
316+
# Runtimes for mobile-optimized model:
317+
# Contiguous: 0.5712863490000473
318+
# Channels-Last: 0.46113000699995155
319+
320+
```
321+
322+
## Conclusion
323+
324+
The Memory Layout of an input tensor can significantly impact a model’s running time. For Vision Models, prefer a **Channels Last** memory format to get the most out of your PyTorch models.
325+
326+
## References
327+
328+
- [Row/Column Major matrix storage order](https://en.wikipedia.org/wiki/Row-_and_column-major_order)
329+
- [Loop order impact on performance](https://stackoverflow.com/questions/9936132/why-does-the-order-of-the-loops-affect-performance-when-iterating-over-a-2d-arra)
330+
- [Cachegrind: a cache-miss profiler](https://courses.cs.washington.edu/courses/cse326/05wi/valgrind-doc/cg_main.html)
331+
- [NHWC format explained](https://oneapi-src.github.io/oneDNN/dev_guide_understanding_memory_formats.html)
332+
- [Why does PyTorch prefer NCHW?](https://discuss.pytorch.org/t/why-does-pytorch-prefer-using-nchw/83637/4)
333+
- [XNNPACK](https://github.com/google/XNNPACK)
334+
- [PyTorch memory format tutorial](https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html)
335+
- [Supported operators](https://github.com/pytorch/pytorch/wiki/Operators-with-Channels-Last-support)

assets/images/tensor/image1.png

84.8 KB
Loading

assets/images/tensor/image2.png

168 KB
Loading

assets/images/tensor/image3.png

173 KB
Loading

0 commit comments

Comments
 (0)