diff --git a/aten/src/ATen/native/Attention.cpp b/aten/src/ATen/native/Attention.cpp new file mode 100644 index 000000000000..a203a74d3f46 --- /dev/null +++ b/aten/src/ATen/native/Attention.cpp @@ -0,0 +1,36 @@ +// This included is needed for the core Tensor class +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include + +// The includes below are required to call functions we want within the at namespace +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#endif + +namespace at::native { + std::tuple attention(const at::Tensor & query, + const at::Tensor & key, + const at::Tensor & value + ) { + TORCH_CHECK(query.dim() == 2 && key.dim() == 2 && value.dim() == 2, + "Expected input tensors to be 2D, but got query: ", query.dim(), + ", key: ", key.dim(), + ", value: ", value.dim()); + TORCH_CHECK(query.sym_size(0) == key.sym_size(0) && query.sym_size(0) == value.sym_size(0), + "Expected input tensors to have the same first dimension, but got query: ", + query.sym_size(0), ", key: ", key.sym_size(0), ", value: ", + value.sym_size(0)); + TORCH_CHECK(query.sym_size(1) == key.sym_size(1), + "Expected query and key to have the same second dimension, but got query: ", + query.sym_size(1), ", key: ", key.sym_size(1)); + auto a = at::tanh(at::matmul(query, key.transpose(-2, -1))); + auto o = at::matmul(a, value); + return std::make_tuple(o, a); + } +} // namespace at::native + diff --git a/aten/src/ATen/native/Attention.h b/aten/src/ATen/native/Attention.h new file mode 100644 index 000000000000..1b21182d43de --- /dev/null +++ b/aten/src/ATen/native/Attention.h @@ -0,0 +1,8 @@ +#include + +namespace at::native { + std::tuple attention(const at::Tensor & query, + const at::Tensor & key, + const at::Tensor & value + ); +} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index e7492f4c379a..df415638cc7e 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1053,6 +1053,11 @@ - func: atleast_3d.Sequence(Tensor[] tensors) -> Tensor[] variants: function +- func: attention(Tensor query, Tensor key, Tensor value) -> (Tensor, Tensor) + variants: function + dispatch: + CompositeExplicitAutograd: attention + - func: baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor variants: function, method structured_delegate: baddbmm.out diff --git a/test/test_autograd_lab.py b/test/test_autograd_lab.py new file mode 100644 index 000000000000..02338864aee7 --- /dev/null +++ b/test/test_autograd_lab.py @@ -0,0 +1,104 @@ +import torch +from torch.autograd.function import Function +from torch.testing._internal.common_utils import ( + gradcheck, + gradgradcheck, + TestCase, + run_tests, +) + +# mypy: ignore-errors + +class Attention(Function): + @staticmethod + def forward(ctx, q, k, v): + if q.dim() != 2 or k.dim() != 2 or v.dim() != 2: + raise ValueError( + f"Attention: Expected inputs to be 2D, got q = {q.dim()}D, k = {k.dim()}D, v = {v.dim()}D instead." + ) + if q.size(0) != k.size(0) or q.size(0) != v.size(0): + raise ValueError( + f"Attention: Expected inputs to have the same first dimension, got q = {q.size(0)}, k = {k.size(0)}, v = {v.size(0)}." + ) + + if q.size(1) != k.size(1): + raise ValueError( + f"Attention: Expected q and k to have the same second dimension, got q = {q.size(1)}, k = {k.size(1)}." + ) + + x = torch.matmul(q, k.transpose(0, 1)) + a = torch.tanh(x) + o = torch.matmul(a, v) + ctx.save_for_backward(q, k, v, a) + ctx.set_materialize_grads(False) + return o, a + + @staticmethod + def backward(ctx, grad_o, grad_a): + # If both gradients are None, return early + if grad_o is None and grad_a is None: + return None, None, None + + q, k, v, a = ctx.saved_tensors + # We have to add grad_a and ga together here because grad_a contains contributions + # from functions upstream which compute their own gradients w.r.t a, while grad_a_local + # is the contribution from this function + grad_a_local = None + if grad_a is not None: + grad_a_local = grad_a + + grad_v = None + if grad_o is not None: + term = grad_o @ v.transpose(0, 1) + grad_a_local = grad_a_local + term if grad_a_local is not None else term + grad_v = a.transpose(0, 1) @ grad_o + + grad_x = grad_a_local * (1 - a ** 2) + grad_q = grad_x @ k + grad_k = grad_x.transpose(0, 1) @ q + return grad_q, grad_k, grad_v + +class TestAutogradLab(TestCase): + def test_attention(self): + q = torch.randn(3, 5, dtype=torch.float64, requires_grad=True) + k = torch.randn(3, 5, dtype=torch.float64, requires_grad=True) + v = torch.randn(3, 7, dtype=torch.float64, requires_grad=True) + + gradcheck(Attention.apply, (q, k, v)) + gradgradcheck(Attention.apply, (q, k, v)) + + def test_attention_mismatched_dims(self): + test_cases = [ + ((3, 5), (4, 5), (3, 7)), # q and k have different first dimensions + ((3, 5), (3, 4), (3, 7)), # q and k have different second dimensions + ((3, 5), (3, 5), (4, 7)), # q and v have different first dimensions + ] + for q_shape, k_shape, v_shape in test_cases: + q = torch.randn(*q_shape, dtype=torch.float64, requires_grad=True) + k = torch.randn(*k_shape, dtype=torch.float64, requires_grad=True) + v = torch.randn(*v_shape, dtype=torch.float64, requires_grad=True) + + self.assertRaises(ValueError, Attention.apply, q, k, v) + + def test_attention_native(self): + q = torch.randn(3, 5, dtype=torch.float64, requires_grad=True) + k = torch.randn(3, 5, dtype=torch.float64, requires_grad=True) + v = torch.randn(3, 7, dtype=torch.float64, requires_grad=True) + + gradcheck(torch.ops.aten.attention, (q, k, v)) + gradgradcheck(torch.ops.aten.attention, (q, k, v)) + + def test_attention_native_mismatched_dims(self): + test_cases = [ + ((3, 5), (4, 5), (3, 7)), # q and k have different first dimensions + ((3, 5), (3, 4), (3, 7)), # q and k have different second dimensions + ((3, 5), (3, 5), (4, 7)), # q and v have different first dimensions + ] + for q_shape, k_shape, v_shape in test_cases: + q = torch.randn(*q_shape, dtype=torch.float64, requires_grad=True) + k = torch.randn(*k_shape, dtype=torch.float64, requires_grad=True) + v = torch.randn(*v_shape, dtype=torch.float64, requires_grad=True) + + self.assertRaises(RuntimeError, torch.ops.aten.attention, q, k, v) +if __name__ == "__main__": + run_tests() diff --git a/test/test_autograd_lab_derivation.md b/test/test_autograd_lab_derivation.md new file mode 100644 index 000000000000..54511a37e641 --- /dev/null +++ b/test/test_autograd_lab_derivation.md @@ -0,0 +1,34 @@ +We have the following forward function declaration: + +``` +def forward(q, k, v): + x = torch.matmul(q, k.transpose(0, 1)) + a = torch.tanh(x) + o = torch.matmul(a, v) + return o, a +``` + +Per https://docs.pytorch.org/docs/stable/notes/extending.html, we will have the following backward function definition. Since the outputs are (o, a), we should receive (grad_o, grad_a) as inputs. Since the inputs are (q, k, v), we should have (grad_q, grad_k, grad_v) as outputs. + +``` +def backward(grad_o, grad_a): + ... + return grad_q, grad_k, grad_v +``` + +I have derived the following gradients +Note: Although grad_a is passed in as a parameter, we still must compute it to get the local gradient contribution. + +``` +o = av + => grad_a = grad_o @ v^T # dL/da + => grad_v = grad_o^T @ a # dL/dv + +a = tanh(x) + => grad_x = grad_a * (1 - a^2) # dL/dx + +x = qk^T + => grad_q = grad_x @ k # dL/dq + => grad_k = grad_x^T @ q # dL/dk +``` + diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index c050c6cbdc4c..dc48532d45d2 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -336,6 +336,10 @@ - name: atanh_(Tensor(a!) self) -> Tensor(a!) self: not_implemented("inplace version of atanh") +- name: attention(Tensor query, Tensor key, Tensor value) -> (Tensor, Tensor) + output_differentiability: [True, True] + query, key, value: attention_backward(grads[0], grads[1], result1, query, key, value) + - name: as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a) self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset) result: auto_linear diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 8e13d4267edb..840c9c08a4bf 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -7464,4 +7464,36 @@ Tensor values_backward(const Tensor& grad, const Tensor& self) { return grad_self; } +std::tuple attention_backward( + const at::Tensor & grad_o, + const at::Tensor & grad_a, + const at::Tensor & result_a, + const at::Tensor & query, + const at::Tensor & key, + const at::Tensor & value + ) { + Tensor grad_query, grad_key, grad_value; + // Return undefined tensors if grad_o and grad_a are not defined, since we cannot compute gradients. + if (!(grad_o.defined() || grad_a.defined())) { + return std::make_tuple(grad_query, grad_key, grad_value); + } + + Tensor grad_a_local; + + if (grad_a.defined()) { + grad_a_local = grad_a.clone(); + } + + if (grad_o.defined()) { + auto term = grad_o.mm(value.t()); + grad_a_local = grad_a_local.defined() ? grad_a_local + term : term; + grad_value = result_a.t().mm(grad_o); + } + // Assume grad_a_local is now defined, since one of grad_o or grad_a was defined + auto grad_x = grad_a_local * (1 - result_a.pow(2)); + + grad_query = grad_x.mm(key); + grad_key = grad_x.t().mm(query); + return std::make_tuple(grad_query, grad_key, grad_value); +} } // namespace torch::autograd::generated::details diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 96864e165a95..e9bf70d2b7bf 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -1149,4 +1149,13 @@ mkldnn_rnn_layer_differentiable_backward( Tensor values_backward(const Tensor& grad, const Tensor& self); +std::tuple attention_backward( + const at::Tensor & grad_o, + const at::Tensor & grad_a, + const at::Tensor & result_a, + const at::Tensor & query, + const at::Tensor & key, + const at::Tensor & value + ); + } // namespace torch::autograd::generated::details diff --git a/torch/overrides.py b/torch/overrides.py index fe7af6bc4ff0..859bae1615b5 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -475,6 +475,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.atleast_1d: lambda *tensors: -1, torch.atleast_2d: lambda *tensors: -1, torch.atleast_3d: lambda *tensors: -1, + torch.attention: lambda query, key, value: -1, torch.avg_pool1d: lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True: -1, torch.baddbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1, torch.batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled: -1, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 41bb2b96bd93..0a66d6fe5a21 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -7621,6 +7621,28 @@ def sample_inputs_atleast1d2d3d(op_info, device, dtype, requires_grad, **kwargs) yield SampleInput(make_tensor_partial(shape)) yield SampleInput([make_tensor_partial(shape) for shape in shapes]) +def sample_inputs_attention(op_info, device, dtype, requires_grad, **kwargs): + shapes = ( + ((S, 5), (S, 5), (S, 7)), + ((S, 3), (S, 3), (S, 4)), + ) + make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for shape_q, shape_k, shape_v in shapes: + yield SampleInput(make_tensor_partial(shape_q), make_tensor_partial(shape_k), make_tensor_partial(shape_v)) + +def error_inputs_attention(op_info, device, **kwargs): + shapes = ( + ((S, 5), (S + 1, 5), (S, 7)), # q, k different first dims + ((S, 5), (S, 5), (S + 1, 7)), # q, v different first dims + ((S, 5), (S, 6), (S, 7)), # q, k different second dims + ) + make_tensor_partial = partial(make_tensor, dtype=torch.float32, device=device) + for shape_q, shape_k, shape_v in shapes: + yield ErrorInput( + SampleInput(make_tensor_partial(shape_q), make_tensor_partial(shape_k), make_tensor_partial(shape_v)), + error_regex = 'Expected' + ) + def sample_inputs_column_stack(op_info, device, dtype, requires_grad, **kwargs): cases: tuple[tuple, tuple] = ( # type: ignore[assignment] ((S, 2, 1), (S, 3, 1)), @@ -18470,6 +18492,24 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): ), sample_inputs_func=sample_inputs_atleast1d2d3d, ), + OpInfo('attention', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_attention, + error_inputs_func=error_inputs_attention, + supports_autograd=True, + supports_out=False, + supports_forward_ad=False, + supports_fwgrad_bwgrad=False, + skips=( + # Seems like this is getting demoted to torch.bfloat16 for some reason, skipping for now + DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast', dtypes=[torch.float32]), + # Errors with forward AD not implemented + DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_jvpvjp', dtypes=[torch.float32]), + DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapjvpvjp', dtypes=[torch.float32]), + # Errors with "hit the vmap fallback which is currently disabled" + DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapvjp_has_batch_rule', dtypes=[torch.float32]), + DecorateInfo(unittest.expectedFailure, 'TestVmapOperatorsOpInfo', 'test_op_has_batch_rule', dtypes=[torch.float32]), + )), OpInfo('flatten', dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), ref=reference_flatten,