Skip to content

Commit 7e6dc92

Browse files
authored
Merge pull request #833 from shiftlab/update-opacus_blog2
Update 9/15 blog with latex
2 parents c60b1fa + 2c239b3 commit 7e6dc92

File tree

1 file changed

+24
-25
lines changed

1 file changed

+24
-25
lines changed

_posts/2021-9-15-efficient-per-sample-gradient-computation-opacus.md

+24-25
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ author: Ashkan Yousefpour, Davide Testuggine, Alex Sablayrolles, and Ilya Mirono
55
featured-img: 'assets/images/image-opacus.png'
66
---
77

8-
## Introduction
8+
## Introduction
99

10-
In our [previous blog post](https://medium.com/pytorch/differential-privacy-series-part-1-dp-sgd-algorithm-explained-12512c3959a3), we went over the basics of the DP-SGD algorithm and introduced [Opacus](https://opacus.ai), a PyTorch library for training ML models with differential privacy. In this blog post, we explain how performance-improving vectorized computation is done in Opacus and why Opacus can compute “per-sample gradients” a lot faster than “microbatching” (read on to see what all these terms mean!)
10+
In our [previous blog post](https://medium.com/pytorch/differential-privacy-series-part-1-dp-sgd-algorithm-explained-12512c3959a3), we went over the basics of the DP-SGD algorithm and introduced [Opacus](https://opacus.ai), a PyTorch library for training ML models with differential privacy. In this blog post, we explain how performance-improving vectorized computation is done in Opacus and why Opacus can compute “per-sample gradients” a lot faster than “microbatching” (read on to see what all these terms mean!)
1111

1212
## Context
1313

@@ -24,10 +24,10 @@ for batch in Dataloader(train_dataset, batch_size):
2424
y_hat = model(x)
2525
loss = criterion(y_hat, y)
2626
loss.backward() # Now p.grad for this x is filled
27-
27+
2828
# Need to clone it to save it
2929
per_sample_gradients = [p.grad.detach().clone() for p in model.parameters()]
30-
30+
3131
all_per_sample_gradients.append(per_sample_gradients)
3232
model.zero_grad() # p.grad is cumulative so we'd better reset it
3333
```
@@ -38,55 +38,54 @@ While the above procedure (called the “micro batch method”, or “micro batc
3838

3939
One of the features of Opacus is “vectorized computation”, in that it can compute per-sample gradients a lot faster than microbatching (they depend on the model, but we observed speedups from ~10x for small MNIST examples to ~50x for Transformers). Microbatching is simply not fast enough to run experiments and conduct research.
4040

41-
So, how do we do vectorized computation in Opacus? We derive the per-sample gradient formula, and implement a vectorized version of it. We will get to this soon. Let us mention that there are other methods (like [this](https://arxiv.org/abs/1510.01799) and [this](https://arxiv.org/abs/2009.03106)) that rely on computing the norm of the per-sample gradients directly. It is worth noting that since these approaches are based on computing the norm of the per-sample gradients, they do two passes of back-propagation to compute the per-sample gradients: one pass for obtaining the norm, and one pass for using the norm as a weight (see the links above for details). Although they are considered efficient, in Opacus we set out to be even more efficient (!) and do everything in one back-propagation pass.
41+
So, how do we do vectorized computation in Opacus? We derive the per-sample gradient formula, and implement a vectorized version of it. We will get to this soon. Let us mention that there are other methods (like [this](https://arxiv.org/abs/1510.01799) and [this](https://arxiv.org/abs/2009.03106)) that rely on computing the norm of the per-sample gradients directly. It is worth noting that since these approaches are based on computing the norm of the per-sample gradients, they do two passes of back-propagation to compute the per-sample gradients: one pass for obtaining the norm, and one pass for using the norm as a weight (see the links above for details). Although they are considered efficient, in Opacus we set out to be even more efficient (!) and do everything in one back-propagation pass.
4242

4343
In this blog post, we focus on the approach for efficiently computing per-sample gradients that is based on deriving the per-sample gradient formula and implementing a vectorized version of it. To make this blog post short, we focus on simple linear layers - building blocks for multi-layer perceptrons (MLPs). In our next blog post, we will talk about how we extend this approach to other layers (e.g., convolutions, or LSTMs) in Opacus.
4444

4545
## Efficient Per-Sample Gradient Computation for MLP
4646

47-
To understand the idea for efficiently computing per-sample gradients, let’s start by talking about how AutoGrad works in the commonly-used deep learning frameworks. We’ll focus on PyTorch from now on, but to the best of our knowledge the same applies to other frameworks (with the exception of Jax).
47+
To understand the idea for efficiently computing per-sample gradients, let’s start by talking about how AutoGrad works in the commonly-used deep learning frameworks. We’ll focus on PyTorch from now on, but to the best of our knowledge the same applies to other frameworks (with the exception of Jax).
4848

49-
For simplicity of explanation, we focus on one linear layer in a neural network, with weight matrix W. Also, we omit the bias from the forward pass equation: assume the forward pass is denoted by Y=WX where X is the input and Y is the output of the linear layer. If we are processing a single sample, X is a vector. On the other hand, if we are processing a batch (and that’s what we do in Opacus), X is a matrix of size B*d, with B rows (B is the batch size), where each row is an input vector of dimension d. Similarly, the output matrix Y would be of size B*r where each row is the output vector corresponding to an element in the batch and r is the output dimension.
49+
For simplicity of explanation, we focus on one linear layer in a neural network, with weight matrix W. Also, we omit the bias from the forward pass equation: assume the forward pass is denoted by Y=WX where X is the input and Y is the output of the linear layer. If we are processing a single sample, X is a vector. On the other hand, if we are processing a batch (and that’s what we do in Opacus), X is a matrix of size B*d, with B rows (B is the batch size), where each row is an input vector of dimension d. Similarly, the output matrix Y would be of size B*r where each row is the output vector corresponding to an element in the batch and r is the output dimension.
5050

5151
The forward pass can be written as the following equation that captures the computation for each element in the matrix Y:
52+
<img src="https://render.githubusercontent.com/render/math?math=Yi(b)=j=1dWi,jXj(b)">
5253

53-
Yi(b)=j=1dWi,jXj(b)
54-
55-
We will return to this equation shortly. Yi(b)denotes the element at row b (batch b) and column i (remember that the dimension of Y is B*r).
54+
We will return to this equation shortly. <img src="https://render.githubusercontent.com/render/math?math=Yi(b)"> denotes the element at row b (batch b) and column i (remember that the dimension of Y is B*r).
5655

57-
In any machine learning problem, we normally need the derivative of the loss with respect to weights W. Comparably, in Opacus we need the “per-sample” version of that, meaning, per-sample derivative of the loss with respect to weights W. Let’s first get the derivative of the loss with respect to weights, and soon, we will get to the per-sample part.
56+
In any machine learning problem, we normally need the derivative of the loss with respect to weights W. Comparably, in Opacus we need the “per-sample” version of that, meaning, per-sample derivative of the loss with respect to weights W. Let’s first get the derivative of the loss with respect to weights, and soon, we will get to the per-sample part.
5857

5958
To obtain the derivative of the loss with respect to weights, we use the chain rule, whose general form is:
6059

61-
Lz=LY*Yz,
60+
<img src="https://render.githubusercontent.com/render/math?math=Lz=LY*Yz,">
6261

6362
which can be written as
6463

65-
Lz=b=1Bi'=1rLYi'(b)Yi'(b)z.
64+
<img src="https://render.githubusercontent.com/render/math?math=Lz=b=1Bi'=1rLYi'(b)Yi'(b)z">.
6665

67-
Now, we can replace z with Wi,jand get
66+
Now, we can replace <img src="https://render.githubusercontent.com/render/math?math=z"> with <img src="https://render.githubusercontent.com/render/math?math=Wi,j"> and get
6867

69-
LWi,j=b=1Bi'=1rLYi'(b)Yi'(b)Wi,j.
68+
<img src="https://render.githubusercontent.com/render/math?math=LWi,j=b=1Bi'=1rLYi'(b)Yi'(b)Wi,j">.
7069

71-
We know from the equation Y=WX that Yi'(b)Wi,j is Xj(b) when i=i, and is 0 otherwise. Hence, we will have
70+
We know from the equation <img src="https://render.githubusercontent.com/render/math?math=Y=WX"> that <img src="https://render.githubusercontent.com/render/math?math=Yi'(b)Wi,j"> is <img src="https://render.githubusercontent.com/render/math?math=Y=WX"> that <img src="https://render.githubusercontent.com/render/math?math=Xj(b)"> when <img src="https://render.githubusercontent.com/render/math?math=i=i">, and is 0 otherwise. Hence, we will have
7271

73-
LWi,j=b=1BLYi(b)Xj(b)(*)
72+
<img src="https://render.githubusercontent.com/render/math?math=LWi,j=b=1BLYi(b)Xj(b)(*)">
7473

7574
This equation corresponds to a matrix multiplication in PyTorch.
76-
75+
7776
As we can see, the gradient of loss with respect to the weight relies on the gradient of loss with respect to the output Y. In a regular backpropagation, the gradients of loss with respect to weights (or simply put, the “gradients”) are computed for the output of each layer, but they are reduced (i.e., summed up over the batch). Since Opacus requires computing **per-sample gradients**, what we need is the following
7877

79-
LbatchWi,j=LYi(b)Xj(b)(**)
78+
<img src="https://render.githubusercontent.com/render/math?math=LbatchWi,j=LYi(b)Xj(b)(**)">
8079

81-
Note that the two equations are very similar; one equation has the sum over the batch and the other one does not. Let’s now focus on how we compute the per-sample gradient (equation ** ) in Opacus efficiently.
80+
Note that the two equations are very similar; one equation has the sum over the batch and the other one does not. Let’s now focus on how we compute the per-sample gradient (equation ** ) in Opacus efficiently.
8281

8382
<p align="center">
84-
<img src="{{ site.url }}/assets/images/image-opacus.png" width="560">
83+
<img src="{{ site.baseurl }}/assets/images/image-opacus.png" width="560">
8584
<br>
8685
Figure 6. The partition boundary is in the middle of a skip connection
8786
</p>
8887

89-
A bit of notation and terminology. Recall that we used the notation Y = WX for forward pass of a single layer of a neural network. When the neural network has more layers, a better notation would be Z(l+1)= W (l+1)Z(l), where l corresponds to each layer of the neural network. In that case, we can call the gradients with respect to any activations Z(l) the “highway gradients” and the gradients with respect to the weights the “exit gradients”.
88+
A bit of notation and terminology. Recall that we used the notation Y = WX for forward pass of a single layer of a neural network. When the neural network has more layers, a better notation would be <img src="https://render.githubusercontent.com/render/math?math=Z(l+1)= W (l+1)Z(l)">, where <img src="https://render.githubusercontent.com/render/math?math=l"> corresponds to each layer of the neural network. In that case, we can call the gradients with respect to any activations <img src="https://render.githubusercontent.com/render/math?math=Z(l)"> the “highway gradients” and the gradients with respect to the weights the “exit gradients”.
9089

9190
If we go with this definition, explaining the issue with Autograd is a one-liner: highway gradients retain per-sample information, but exit gradients do not. Or, highway gradients are per-sample, but exit gradients are not necessarily. This is unfortunate because the per-sample exit gradients are exactly what we need!
9291

@@ -104,7 +103,7 @@ Under the hood, PyTorch is event-based and will call the hooks at the right plac
104103
2. **nn.Module hook**. There are two types of these:
105104
1. **Forward hook**. The signature for this is ```hook(module, input, output) -> None```
106105
2. **Backward hook**. The signature for this is ```hook(module, grad_input, grad_output) -> Tensor or None```
107-
106+
108107
To learn more about these fundamental primitives, check out our [official tutorial](https://pytorch.org/tutorials/beginner/former_torchies/nnft_tutorial.html#forward-and-backward-function-hooks) on hooks, or one of the excellent explainers, such as [Paperspace’s](https://blog.paperspace.com/pytorch-hooks-gradient-clipping-debugging/) or this [Kaggle notebook](https://www.kaggle.com/sironghuang/understanding-pytorch-hooks). Finally, if you want to play with hooks more interactively, we also made a [notebook](https://colab.research.google.com/drive/1zDidGCNI3DJk1oSPIpmB89b5cCWyuHao?usp=sharing) for you.
109108

110109
We use two hooks, one forward hook and one backward hook. In the forward hook, we simply store the activations:
@@ -131,9 +130,9 @@ def compute_linear_grad_sample(input, grad_output):
131130

132131
You can find the full implementation for the linear module [here](https://github.com/pytorch/opacus/blob/204328947145d1759fcb26171368fcff6d652ef6/opacus/grad_sample/linear.py). The actual code has some bookkeeping around the einsum call, but the einsum call is the main building block of the efficient per-sample computation for us.
133132

134-
Since this post is already long, we refer the interested reader to read about [einsum in PyTorch](https://rockt.github.io/2018/04/30/einsum) and do not get into the details of einsum. However, we really encourage you to check it out, as it’s kind of a magical thing! Just as an example, a matrix multiplication describe in
133+
Since this post is already long, we refer the interested reader to read about [einsum in PyTorch](https://rockt.github.io/2018/04/30/einsum) and do not get into the details of einsum. However, we really encourage you to check it out, as it’s kind of a magical thing! Just as an example, a matrix multiplication describe in
135134

136-
Cij=kAikBkj
135+
<img src="https://render.githubusercontent.com/render/math?math=Cij=kAikBkj">
137136

138137
can be implemented beautifully in this line:
139138

0 commit comments

Comments
 (0)