|
| 1 | +import torch |
| 2 | +from torch.autograd.function import Function |
| 3 | +from torch.testing._internal.common_utils import ( |
| 4 | + gradcheck, |
| 5 | + gradgradcheck, |
| 6 | + TestCase, |
| 7 | + run_tests, |
| 8 | +) |
| 9 | + |
| 10 | +# mypy: ignore-errors |
| 11 | + |
| 12 | +class Attention(Function): |
| 13 | + @staticmethod |
| 14 | + def forward(ctx, q, k, v): |
| 15 | + if q.dim() != 2 or k.dim() != 2 or v.dim() != 2: |
| 16 | + raise ValueError( |
| 17 | + f"Attention: Expected inputs to be 2D, got q = {q.dim()}D, k = {k.dim()}D, v = {v.dim()}D instead." |
| 18 | + ) |
| 19 | + if q.size(0) != k.size(0) or q.size(0) != v.size(0): |
| 20 | + raise ValueError( |
| 21 | + f"Attention: Expected inputs to have the same first dimension, got q = {q.size(0)}, k = {k.size(0)}, v = {v.size(0)}." |
| 22 | + ) |
| 23 | + |
| 24 | + if q.size(1) != k.size(1): |
| 25 | + raise ValueError( |
| 26 | + f"Attention: Expected q and k to have the same second dimension, got q = {q.size(1)}, k = {k.size(1)}." |
| 27 | + ) |
| 28 | + |
| 29 | + x = torch.matmul(q, k.transpose(0, 1)) |
| 30 | + a = torch.tanh(x) |
| 31 | + o = torch.matmul(a, v) |
| 32 | + ctx.save_for_backward(q, k, v, a) |
| 33 | + return o, a |
| 34 | + |
| 35 | + @staticmethod |
| 36 | + def backward(ctx, grad_o, grad_a): |
| 37 | + q, k, v, a = ctx.saved_tensors |
| 38 | + grad_a_local = grad_o @ v.transpose(0, 1) |
| 39 | + grad_v = a.transpose(0, 1) @ grad_o |
| 40 | + # We have to add grad_a and ga together here because grad_a contains contributions |
| 41 | + # from functions upstream which compute their own gradients w.r.t a, while grad_a_local |
| 42 | + # is the contribution from this function |
| 43 | + grad_x = (grad_a + grad_a_local) * (1 - a ** 2) |
| 44 | + grad_q = grad_x @ k |
| 45 | + grad_k = grad_x.transpose(0, 1) @ q |
| 46 | + return grad_q, grad_k, grad_v |
| 47 | + |
| 48 | +class TestAutogradLab(TestCase): |
| 49 | + def test_attention(self): |
| 50 | + q = torch.randn(3, 5, dtype=torch.float64, requires_grad=True) |
| 51 | + k = torch.randn(3, 5, dtype=torch.float64, requires_grad=True) |
| 52 | + v = torch.randn(3, 7, dtype=torch.float64, requires_grad=True) |
| 53 | + |
| 54 | + gradcheck(Attention.apply, (q, k, v)) |
| 55 | + gradgradcheck(Attention.apply, (q, k, v)) |
| 56 | + |
| 57 | + def test_attention_mismatched_dims(self): |
| 58 | + test_cases = [ |
| 59 | + ((3, 5), (4, 5), (3, 7)), # q and k have different first dimensions |
| 60 | + ((3, 5), (3, 4), (3, 7)), # q and k have different second dimensions |
| 61 | + ((3, 5), (3, 5), (4, 7)), # q and v have different first dimensions |
| 62 | + ] |
| 63 | + for q_shape, k_shape, v_shape in test_cases: |
| 64 | + q = torch.randn(*q_shape, dtype=torch.float64, requires_grad=True) |
| 65 | + k = torch.randn(*k_shape, dtype=torch.float64, requires_grad=True) |
| 66 | + v = torch.randn(*v_shape, dtype=torch.float64, requires_grad=True) |
| 67 | + |
| 68 | + self.assertRaises(ValueError, Attention.apply, q, k, v) |
| 69 | + |
| 70 | +if __name__ == "__main__": |
| 71 | + run_tests() |
0 commit comments