Skip to content

Commit ccf48a6

Browse files
committed
[Pytorch][Onboarding][Autograd] Implement custom tanh attention operator
Following parts 1 and 2 of https://github.com/pytorch/pytorch/wiki/Autograd-Onboarding-Lab NOTE: Do NOT merge this diff! Learnings: - When deriving the backwards function analytically, it's easiest to break the forward function out step-by-step and compute the gradient by applying the chain rule - grad_a shows that we must be careful in considering both the local gradient and upstream gradient contributions - gradcheck and gradgradcheck are clever ways of validating the analytical solution using numerical/computational methods - Generally, how to write a test and operator Testing: Run `python3 test/test_autograd_lab.py`
1 parent 01f66d0 commit ccf48a6

File tree

2 files changed

+105
-0
lines changed

2 files changed

+105
-0
lines changed

test/test_autograd_lab.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import torch
2+
from torch.autograd.function import Function
3+
from torch.testing._internal.common_utils import (
4+
gradcheck,
5+
gradgradcheck,
6+
TestCase,
7+
run_tests,
8+
)
9+
10+
# mypy: ignore-errors
11+
12+
class Attention(Function):
13+
@staticmethod
14+
def forward(ctx, q, k, v):
15+
if q.dim() != 2 or k.dim() != 2 or v.dim() != 2:
16+
raise ValueError(
17+
f"Attention: Expected inputs to be 2D, got q = {q.dim()}D, k = {k.dim()}D, v = {v.dim()}D instead."
18+
)
19+
if q.size(0) != k.size(0) or q.size(0) != v.size(0):
20+
raise ValueError(
21+
f"Attention: Expected inputs to have the same first dimension, got q = {q.size(0)}, k = {k.size(0)}, v = {v.size(0)}."
22+
)
23+
24+
if q.size(1) != k.size(1):
25+
raise ValueError(
26+
f"Attention: Expected q and k to have the same second dimension, got q = {q.size(1)}, k = {k.size(1)}."
27+
)
28+
29+
x = torch.matmul(q, k.transpose(0, 1))
30+
a = torch.tanh(x)
31+
o = torch.matmul(a, v)
32+
ctx.save_for_backward(q, k, v, a)
33+
return o, a
34+
35+
@staticmethod
36+
def backward(ctx, grad_o, grad_a):
37+
q, k, v, a = ctx.saved_tensors
38+
grad_a_local = grad_o @ v.transpose(0, 1)
39+
grad_v = a.transpose(0, 1) @ grad_o
40+
# We have to add grad_a and ga together here because grad_a contains contributions
41+
# from functions upstream which compute their own gradients w.r.t a, while grad_a_local
42+
# is the contribution from this function
43+
grad_x = (grad_a + grad_a_local) * (1 - a ** 2)
44+
grad_q = grad_x @ k
45+
grad_k = grad_x.transpose(0, 1) @ q
46+
return grad_q, grad_k, grad_v
47+
48+
class TestAutogradLab(TestCase):
49+
def test_attention(self):
50+
q = torch.randn(3, 5, dtype=torch.float64, requires_grad=True)
51+
k = torch.randn(3, 5, dtype=torch.float64, requires_grad=True)
52+
v = torch.randn(3, 7, dtype=torch.float64, requires_grad=True)
53+
54+
gradcheck(Attention.apply, (q, k, v))
55+
gradgradcheck(Attention.apply, (q, k, v))
56+
57+
def test_attention_mismatched_dims(self):
58+
test_cases = [
59+
((3, 5), (4, 5), (3, 7)), # q and k have different first dimensions
60+
((3, 5), (3, 4), (3, 7)), # q and k have different second dimensions
61+
((3, 5), (3, 5), (4, 7)), # q and v have different first dimensions
62+
]
63+
for q_shape, k_shape, v_shape in test_cases:
64+
q = torch.randn(*q_shape, dtype=torch.float64, requires_grad=True)
65+
k = torch.randn(*k_shape, dtype=torch.float64, requires_grad=True)
66+
v = torch.randn(*v_shape, dtype=torch.float64, requires_grad=True)
67+
68+
self.assertRaises(ValueError, Attention.apply, q, k, v)
69+
70+
if __name__ == "__main__":
71+
run_tests()

test/test_autograd_lab_derivation.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
We have the following forward function declaration:
2+
3+
```
4+
def forward(q, k, v):
5+
x = torch.matmul(q, k.transpose(0, 1))
6+
a = torch.tanh(x)
7+
o = torch.matmul(a, v)
8+
return o, a
9+
```
10+
11+
Per https://docs.pytorch.org/docs/stable/notes/extending.html, we will have the following backward function definition. Since the outputs are (o, a), we should receive (grad_o, grad_a) as inputs. Since the inputs are (q, k, v), we should have (grad_q, grad_k, grad_v) as outputs.
12+
13+
```
14+
def backward(grad_o, grad_a):
15+
...
16+
return grad_q, grad_k, grad_v
17+
```
18+
19+
I have derived the following gradients
20+
Note: Although grad_a is passed in as a parameter, we still must compute it to get the local gradient contribution.
21+
22+
```
23+
o = av
24+
=> grad_a = grad_o @ v^T # dL/da
25+
=> grad_v = grad_o^T @ a # dL/dv
26+
27+
a = tanh(x)
28+
=> grad_x = grad_a * (1 - a^2) # dL/dx
29+
30+
x = qk^T
31+
=> grad_q = grad_x @ k # dL/dq
32+
=> grad_k = grad_x^T @ q # dL/dk
33+
```
34+

0 commit comments

Comments
 (0)