Skip to content

[DO NOT MERGE] Autograd Onboarding Lab #160264

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[Pytorch][Onboarding][Autograd] Implement custom tanh attention operator
Following parts 1 and 2 of https://github.com/pytorch/pytorch/wiki/Autograd-Onboarding-Lab

NOTE: Do NOT merge this diff!

Learnings:
- When deriving the backwards function analytically, it's easiest to break the forward function out step-by-step and compute the gradient by applying the chain rule
- grad_a shows that we must be careful in considering both the local gradient and upstream gradient contributions
- gradcheck and gradgradcheck are clever ways of validating the analytical solution using numerical/computational methods
- Generally, how to write a test and operator

Testing:
Run `python3 test/test_autograd_lab.py`
  • Loading branch information
justinHe123 committed Aug 9, 2025
commit ccf48a6faaa8c9bb0c68d38b8ba4cf434c67530a
71 changes: 71 additions & 0 deletions test/test_autograd_lab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import torch
from torch.autograd.function import Function
from torch.testing._internal.common_utils import (
gradcheck,
gradgradcheck,
TestCase,
run_tests,
)

# mypy: ignore-errors

class Attention(Function):
@staticmethod
def forward(ctx, q, k, v):
if q.dim() != 2 or k.dim() != 2 or v.dim() != 2:
raise ValueError(
f"Attention: Expected inputs to be 2D, got q = {q.dim()}D, k = {k.dim()}D, v = {v.dim()}D instead."
)
if q.size(0) != k.size(0) or q.size(0) != v.size(0):
raise ValueError(
f"Attention: Expected inputs to have the same first dimension, got q = {q.size(0)}, k = {k.size(0)}, v = {v.size(0)}."
)

if q.size(1) != k.size(1):
raise ValueError(
f"Attention: Expected q and k to have the same second dimension, got q = {q.size(1)}, k = {k.size(1)}."
)

x = torch.matmul(q, k.transpose(0, 1))
a = torch.tanh(x)
o = torch.matmul(a, v)
ctx.save_for_backward(q, k, v, a)
return o, a

@staticmethod
def backward(ctx, grad_o, grad_a):
q, k, v, a = ctx.saved_tensors
grad_a_local = grad_o @ v.transpose(0, 1)
grad_v = a.transpose(0, 1) @ grad_o
# We have to add grad_a and ga together here because grad_a contains contributions
# from functions upstream which compute their own gradients w.r.t a, while grad_a_local
# is the contribution from this function
grad_x = (grad_a + grad_a_local) * (1 - a ** 2)
grad_q = grad_x @ k
grad_k = grad_x.transpose(0, 1) @ q
return grad_q, grad_k, grad_v

class TestAutogradLab(TestCase):
def test_attention(self):
q = torch.randn(3, 5, dtype=torch.float64, requires_grad=True)
k = torch.randn(3, 5, dtype=torch.float64, requires_grad=True)
v = torch.randn(3, 7, dtype=torch.float64, requires_grad=True)

gradcheck(Attention.apply, (q, k, v))
gradgradcheck(Attention.apply, (q, k, v))

def test_attention_mismatched_dims(self):
test_cases = [
((3, 5), (4, 5), (3, 7)), # q and k have different first dimensions
((3, 5), (3, 4), (3, 7)), # q and k have different second dimensions
((3, 5), (3, 5), (4, 7)), # q and v have different first dimensions
]
for q_shape, k_shape, v_shape in test_cases:
q = torch.randn(*q_shape, dtype=torch.float64, requires_grad=True)
k = torch.randn(*k_shape, dtype=torch.float64, requires_grad=True)
v = torch.randn(*v_shape, dtype=torch.float64, requires_grad=True)

self.assertRaises(ValueError, Attention.apply, q, k, v)

if __name__ == "__main__":
run_tests()
34 changes: 34 additions & 0 deletions test/test_autograd_lab_derivation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
We have the following forward function declaration:

```
def forward(q, k, v):
x = torch.matmul(q, k.transpose(0, 1))
a = torch.tanh(x)
o = torch.matmul(a, v)
return o, a
```

Per https://docs.pytorch.org/docs/stable/notes/extending.html, we will have the following backward function definition. Since the outputs are (o, a), we should receive (grad_o, grad_a) as inputs. Since the inputs are (q, k, v), we should have (grad_q, grad_k, grad_v) as outputs.

```
def backward(grad_o, grad_a):
...
return grad_q, grad_k, grad_v
```

I have derived the following gradients
Note: Although grad_a is passed in as a parameter, we still must compute it to get the local gradient contribution.

```
o = av
=> grad_a = grad_o @ v^T # dL/da
=> grad_v = grad_o^T @ a # dL/dv
a = tanh(x)
=> grad_x = grad_a * (1 - a^2) # dL/dx
x = qk^T
=> grad_q = grad_x @ k # dL/dq
=> grad_k = grad_x^T @ q # dL/dk
```