You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: _posts/2021-9-15-efficient-per-sample-gradient-computation-opacus.md
+24-25
Original file line number
Diff line number
Diff line change
@@ -5,9 +5,9 @@ author: Ashkan Yousefpour, Davide Testuggine, Alex Sablayrolles, and Ilya Mirono
5
5
featured-img: 'assets/images/image-opacus.png'
6
6
---
7
7
8
-
## Introduction
8
+
## Introduction
9
9
10
-
In our [previous blog post](https://medium.com/pytorch/differential-privacy-series-part-1-dp-sgd-algorithm-explained-12512c3959a3), we went over the basics of the DP-SGD algorithm and introduced [Opacus](https://opacus.ai), a PyTorch library for training ML models with differential privacy. In this blog post, we explain how performance-improving vectorized computation is done in Opacus and why Opacus can compute “per-sample gradients” a lot faster than “microbatching” (read on to see what all these terms mean!)
10
+
In our [previous blog post](https://medium.com/pytorch/differential-privacy-series-part-1-dp-sgd-algorithm-explained-12512c3959a3), we went over the basics of the DP-SGD algorithm and introduced [Opacus](https://opacus.ai), a PyTorch library for training ML models with differential privacy. In this blog post, we explain how performance-improving vectorized computation is done in Opacus and why Opacus can compute “per-sample gradients” a lot faster than “microbatching” (read on to see what all these terms mean!)
11
11
12
12
## Context
13
13
@@ -24,10 +24,10 @@ for batch in Dataloader(train_dataset, batch_size):
24
24
y_hat = model(x)
25
25
loss = criterion(y_hat, y)
26
26
loss.backward() # Now p.grad for this x is filled
27
-
27
+
28
28
# Need to clone it to save it
29
29
per_sample_gradients = [p.grad.detach().clone() for p in model.parameters()]
model.zero_grad() # p.grad is cumulative so we'd better reset it
33
33
```
@@ -38,55 +38,54 @@ While the above procedure (called the “micro batch method”, or “micro batc
38
38
39
39
One of the features of Opacus is “vectorized computation”, in that it can compute per-sample gradients a lot faster than microbatching (they depend on the model, but we observed speedups from ~10x for small MNIST examples to ~50x for Transformers). Microbatching is simply not fast enough to run experiments and conduct research.
40
40
41
-
So, how do we do vectorized computation in Opacus? We derive the per-sample gradient formula, and implement a vectorized version of it. We will get to this soon. Let us mention that there are other methods (like [this](https://arxiv.org/abs/1510.01799) and [this](https://arxiv.org/abs/2009.03106)) that rely on computing the norm of the per-sample gradients directly. It is worth noting that since these approaches are based on computing the norm of the per-sample gradients, they do two passes of back-propagation to compute the per-sample gradients: one pass for obtaining the norm, and one pass for using the norm as a weight (see the links above for details). Although they are considered efficient, in Opacus we set out to be even more efficient (!) and do everything in one back-propagation pass.
41
+
So, how do we do vectorized computation in Opacus? We derive the per-sample gradient formula, and implement a vectorized version of it. We will get to this soon. Let us mention that there are other methods (like [this](https://arxiv.org/abs/1510.01799) and [this](https://arxiv.org/abs/2009.03106)) that rely on computing the norm of the per-sample gradients directly. It is worth noting that since these approaches are based on computing the norm of the per-sample gradients, they do two passes of back-propagation to compute the per-sample gradients: one pass for obtaining the norm, and one pass for using the norm as a weight (see the links above for details). Although they are considered efficient, in Opacus we set out to be even more efficient (!) and do everything in one back-propagation pass.
42
42
43
43
In this blog post, we focus on the approach for efficiently computing per-sample gradients that is based on deriving the per-sample gradient formula and implementing a vectorized version of it. To make this blog post short, we focus on simple linear layers - building blocks for multi-layer perceptrons (MLPs). In our next blog post, we will talk about how we extend this approach to other layers (e.g., convolutions, or LSTMs) in Opacus.
44
44
45
45
## Efficient Per-Sample Gradient Computation for MLP
46
46
47
-
To understand the idea for efficiently computing per-sample gradients, let’s start by talking about how AutoGrad works in the commonly-used deep learning frameworks. We’ll focus on PyTorch from now on, but to the best of our knowledge the same applies to other frameworks (with the exception of Jax).
47
+
To understand the idea for efficiently computing per-sample gradients, let’s start by talking about how AutoGrad works in the commonly-used deep learning frameworks. We’ll focus on PyTorch from now on, but to the best of our knowledge the same applies to other frameworks (with the exception of Jax).
48
48
49
-
For simplicity of explanation, we focus on one linear layer in a neural network, with weight matrix W. Also, we omit the bias from the forward pass equation: assume the forward pass is denoted by Y=WX where X is the input and Y is the output of the linear layer. If we are processing a single sample, X is a vector. On the other hand, if we are processing a batch (and that’s what we do in Opacus), X is a matrix of size B*d, with B rows (B is the batch size), where each row is an input vector of dimension d. Similarly, the output matrix Y would be of size B*r where each row is the output vector corresponding to an element in the batch and r is the output dimension.
49
+
For simplicity of explanation, we focus on one linear layer in a neural network, with weight matrix W. Also, we omit the bias from the forward pass equation: assume the forward pass is denoted by Y=WX where X is the input and Y is the output of the linear layer. If we are processing a single sample, X is a vector. On the other hand, if we are processing a batch (and that’s what we do in Opacus), X is a matrix of size B*d, with B rows (B is the batch size), where each row is an input vector of dimension d. Similarly, the output matrix Y would be of size B*r where each row is the output vector corresponding to an element in the batch and r is the output dimension.
50
50
51
51
The forward pass can be written as the following equation that captures the computation for each element in the matrix Y:
We will return to this equation shortly. Yi(b)denotes the element at row b (batch b) and column i (remember that the dimension of Y is B*r).
54
+
We will return to this equation shortly. <imgsrc="https://render.githubusercontent.com/render/math?math=Yi(b)"> denotes the element at row b (batch b) and column i (remember that the dimension of Y is B*r).
56
55
57
-
In any machine learning problem, we normally need the derivative of the loss with respect to weights W. Comparably, in Opacus we need the “per-sample” version of that, meaning, per-sample derivative of the loss with respect to weights W. Let’s first get the derivative of the loss with respect to weights, and soon, we will get to the per-sample part.
56
+
In any machine learning problem, we normally need the derivative of the loss with respect to weights W. Comparably, in Opacus we need the “per-sample” version of that, meaning, per-sample derivative of the loss with respect to weights W. Let’s first get the derivative of the loss with respect to weights, and soon, we will get to the per-sample part.
58
57
59
58
To obtain the derivative of the loss with respect to weights, we use the chain rule, whose general form is:
Now, we can replace <imgsrc="https://render.githubusercontent.com/render/math?math=z"> with <imgsrc="https://render.githubusercontent.com/render/math?math=Wi,j"> and get
We know from the equation Y=WX that Yi'(b)Wi,j is Xj(b) when i=i’, and is 0 otherwise. Hence, we will have
70
+
We know from the equation <imgsrc="https://render.githubusercontent.com/render/math?math=Y=WX"> that <imgsrc="https://render.githubusercontent.com/render/math?math=Yi'(b)Wi,j"> is <imgsrc="https://render.githubusercontent.com/render/math?math=Y=WX"> that <imgsrc="https://render.githubusercontent.com/render/math?math=Xj(b)"> when <imgsrc="https://render.githubusercontent.com/render/math?math=i=i">, and is 0 otherwise. Hence, we will have
This equation corresponds to a matrix multiplication in PyTorch.
76
-
75
+
77
76
As we can see, the gradient of loss with respect to the weight relies on the gradient of loss with respect to the output Y. In a regular backpropagation, the gradients of loss with respect to weights (or simply put, the “gradients”) are computed for the output of each layer, but they are reduced (i.e., summed up over the batch). Since Opacus requires computing **per-sample gradients**, what we need is the following
Note that the two equations are very similar; one equation has the sum over the batch and the other one does not. Let’s now focus on how we compute the per-sample gradient (equation ** ) in Opacus efficiently.
80
+
Note that the two equations are very similar; one equation has the sum over the batch and the other one does not. Let’s now focus on how we compute the per-sample gradient (equation ** ) in Opacus efficiently.
Figure 6. The partition boundary is in the middle of a skip connection
87
86
</p>
88
87
89
-
A bit of notation and terminology. Recall that we used the notation Y = WX for forward pass of a single layer of a neural network. When the neural network has more layers, a better notation would be Z(l+1)= W (l+1)Z(l), where l corresponds to each layer of the neural network. In that case, we can call the gradients with respect to any activations Z(l) the “highway gradients” and the gradients with respect to the weights the “exit gradients”.
88
+
A bit of notation and terminology. Recall that we used the notation Y = WX for forward pass of a single layer of a neural network. When the neural network has more layers, a better notation would be <imgsrc="https://render.githubusercontent.com/render/math?math=Z(l+1)= W (l+1)Z(l)">, where <imgsrc="https://render.githubusercontent.com/render/math?math=l"> corresponds to each layer of the neural network. In that case, we can call the gradients with respect to any activations <imgsrc="https://render.githubusercontent.com/render/math?math=Z(l)"> the “highway gradients” and the gradients with respect to the weights the “exit gradients”.
90
89
91
90
If we go with this definition, explaining the issue with Autograd is a one-liner: highway gradients retain per-sample information, but exit gradients do not. Or, highway gradients are per-sample, but exit gradients are not necessarily. This is unfortunate because the per-sample exit gradients are exactly what we need!
92
91
@@ -104,7 +103,7 @@ Under the hood, PyTorch is event-based and will call the hooks at the right plac
104
103
2.**nn.Module hook**. There are two types of these:
105
104
1.**Forward hook**. The signature for this is ```hook(module, input, output) -> None```
106
105
2.**Backward hook**. The signature for this is ```hook(module, grad_input, grad_output) -> Tensor or None```
107
-
106
+
108
107
To learn more about these fundamental primitives, check out our [official tutorial](https://pytorch.org/tutorials/beginner/former_torchies/nnft_tutorial.html#forward-and-backward-function-hooks) on hooks, or one of the excellent explainers, such as [Paperspace’s](https://blog.paperspace.com/pytorch-hooks-gradient-clipping-debugging/) or this [Kaggle notebook](https://www.kaggle.com/sironghuang/understanding-pytorch-hooks). Finally, if you want to play with hooks more interactively, we also made a [notebook](https://colab.research.google.com/drive/1zDidGCNI3DJk1oSPIpmB89b5cCWyuHao?usp=sharing) for you.
109
108
110
109
We use two hooks, one forward hook and one backward hook. In the forward hook, we simply store the activations:
You can find the full implementation for the linear module [here](https://github.com/pytorch/opacus/blob/204328947145d1759fcb26171368fcff6d652ef6/opacus/grad_sample/linear.py). The actual code has some bookkeeping around the einsum call, but the einsum call is the main building block of the efficient per-sample computation for us.
133
132
134
-
Since this post is already long, we refer the interested reader to read about [einsum in PyTorch](https://rockt.github.io/2018/04/30/einsum) and do not get into the details of einsum. However, we really encourage you to check it out, as it’s kind of a magical thing! Just as an example, a matrix multiplication describe in
133
+
Since this post is already long, we refer the interested reader to read about [einsum in PyTorch](https://rockt.github.io/2018/04/30/einsum) and do not get into the details of einsum. However, we really encourage you to check it out, as it’s kind of a magical thing! Just as an example, a matrix multiplication describe in
0 commit comments