diff --git a/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal b/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal index 92774f3ff266..4ba2bca720db 100644 --- a/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal +++ b/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal @@ -68,6 +68,37 @@ kernel void matmul( } } +template +kernel void addmm( + constant T* mat1Data [[buffer(0)]], + constant T* mat2Data [[buffer(1)]], + device T* outputData [[buffer(2)]], + constant T* biasData [[buffer(3)]], + constant array, 2>& alpha_beta [[buffer(4)]], + constant array& strides [[buffer(5)]], + constant uint3& sizes [[buffer(6)]], + uint2 tid [[thread_position_in_threadgroup]], + uint2 thread_id [[thread_position_in_grid]]) { + threadgroup T A_tile[TILE_DIM][TILE_DIM]; + threadgroup T B_tile[TILE_DIM][TILE_DIM]; + + auto sum = matmul_inner( + mat1Data, + mat2Data, + reinterpret_cast&>(strides), + sizes, + A_tile, + B_tile, + tid, + thread_id); + if (thread_id.y < sizes.x && thread_id.x < sizes.z) { + auto bias = + biasData[thread_id.y * strides[3].x + thread_id.x * strides[3].y]; + outputData[thread_id.y * strides[2].x + thread_id.x * strides[2].y] = + static_cast(alpha_beta[0] * sum + alpha_beta[1] * bias); + } +} + template kernel void naive_bmm( constant T* mat1Data [[buffer(0)]], @@ -613,17 +644,15 @@ kernel void applyPivots( } } -#define INSTANTIATE_NAIVE_MM(DTYPE) \ - template [[host_name("matmul_" #DTYPE)]] kernel void matmul( \ - constant DTYPE * mat1Data [[buffer(0)]], \ - constant DTYPE * mat2Data [[buffer(1)]], \ - device DTYPE * outputData [[buffer(2)]], \ - constant array & strides [[buffer(3)]], \ - constant uint3 & sizes [[buffer(4)]], \ - uint2 tid [[thread_position_in_threadgroup]], \ - uint2 group_id [[threadgroup_position_in_grid]]) - -#define INSTANTIATE_NAIVE_BMM(DTYPE) \ +#define INSTANTIATE_MM_OPS(DTYPE) \ + template [[host_name("matmul_" #DTYPE)]] kernel void matmul( \ + constant DTYPE * mat1Data [[buffer(0)]], \ + constant DTYPE * mat2Data [[buffer(1)]], \ + device DTYPE * outputData [[buffer(2)]], \ + constant array & strides [[buffer(3)]], \ + constant uint3 & sizes [[buffer(4)]], \ + uint2 tid [[thread_position_in_threadgroup]], \ + uint2 group_id [[threadgroup_position_in_grid]]); \ template [[host_name("naive_bmm_" #DTYPE)]] kernel void naive_bmm( \ constant DTYPE * mat1Data [[buffer(0)]], \ constant DTYPE * mat2Data [[buffer(1)]], \ @@ -631,20 +660,26 @@ kernel void applyPivots( constant array & strides [[buffer(3)]], \ constant uint4 & sizes [[buffer(4)]], \ uint3 tid [[thread_position_in_threadgroup]], \ - uint3 group_id [[threadgroup_position_in_grid]]) + uint3 group_id [[threadgroup_position_in_grid]]); \ + template [[host_name("addmm_" #DTYPE)]] kernel void addmm( \ + constant DTYPE * mat1Data [[buffer(0)]], \ + constant DTYPE * mat2Data [[buffer(1)]], \ + device DTYPE * outputData [[buffer(2)]], \ + constant DTYPE * biasData [[buffer(3)]], \ + constant array, 2> & \ + alpha_beta [[buffer(4)]], \ + constant array & strides [[buffer(5)]], \ + constant uint3 & sizes [[buffer(6)]], \ + uint2 tid [[thread_position_in_threadgroup]], \ + uint2 group_id [[threadgroup_position_in_grid]]) -INSTANTIATE_NAIVE_MM(float); -INSTANTIATE_NAIVE_MM(half); -INSTANTIATE_NAIVE_MM(bfloat); +INSTANTIATE_MM_OPS(float); +INSTANTIATE_MM_OPS(half); +INSTANTIATE_MM_OPS(bfloat); // Integral MM -INSTANTIATE_NAIVE_MM(short); -INSTANTIATE_NAIVE_MM(int); -INSTANTIATE_NAIVE_MM(long); -INSTANTIATE_NAIVE_MM(char); -INSTANTIATE_NAIVE_MM(uchar); -INSTANTIATE_NAIVE_BMM(short); -INSTANTIATE_NAIVE_BMM(int); -INSTANTIATE_NAIVE_BMM(long); -INSTANTIATE_NAIVE_BMM(char); -INSTANTIATE_NAIVE_BMM(uchar); +INSTANTIATE_MM_OPS(long); +INSTANTIATE_MM_OPS(int); +INSTANTIATE_MM_OPS(short); +INSTANTIATE_MM_OPS(char); +INSTANTIATE_MM_OPS(uchar); diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index 3cdf0021e987..7a3dde679c05 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -112,6 +112,61 @@ return output; } +Tensor& do_metal_addmm(const Tensor& self, + const Tensor& other, + Tensor& output, + const Scalar& alpha, + const Scalar& beta, + const Tensor& bias) { + if (beta.toDouble() == 0 && alpha.toDouble() == 1) { + return do_metal_mm(self, other, output); + } + auto stream = getCurrentMPSStream(); + auto device = MPSDevice::getInstance()->device(); + auto matmulPSO = lib.getPipelineStateForFunc("addmm_" + mps::scalarToMetalTypeString(output)); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + getMPSProfiler().beginProfileKernel(matmulPSO, "addmm", {self, other}); + auto computeEncoder = stream->commandEncoder(); + [computeEncoder setComputePipelineState:matmulPSO]; + std::array sizes = {static_cast(self.size(0)), + static_cast(self.size(1)), + static_cast(output.size(1))}; + std::array strides = {self.stride(0), + self.stride(1), + other.stride(0), + other.stride(1), + output.stride(0), + output.stride(1), + bias.stride(0), + bias.stride(1)}; + union { + std::array i64; + std::array i32; + std::array f32; + } alpha_beta; + if (output.scalar_type() == kLong) { + alpha_beta.i64 = {alpha.toLong(), beta.toLong()}; + } else if (c10::isIntegralType(output.scalar_type(), true)) { + alpha_beta.i32 = {alpha.toInt(), beta.toInt()}; + } else { + TORCH_INTERNAL_ASSERT(c10::isFloatingType(output.scalar_type())); + alpha_beta.f32 = {alpha.toFloat(), beta.toFloat()}; + } + constexpr uint32_t TILE_DIM = 16; // fastest performance from tests on multiple macs + uint32_t gridSizeX = (output.size(1) + TILE_DIM - 1) / TILE_DIM; + uint32_t gridSizeY = (self.size(0) + TILE_DIM - 1) / TILE_DIM; + + MTLSize threadsPerThreadgroup = MTLSizeMake(TILE_DIM, TILE_DIM, 1); + MTLSize threadgroupsPerGrid = MTLSizeMake(gridSizeX, gridSizeY, 1); + mtl_setArgs(computeEncoder, self, other, output, bias, alpha_beta.i64, strides, sizes); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadsPerThreadgroup]; + getMPSProfiler().endProfileKernel(matmulPSO); + } + }); + return output; +} + std::tuple do_mm(MPSGraph* graph, const Tensor& self, const Tensor& other) { @@ -644,7 +699,6 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const TORCH_CHECK(output.is_mps()); TORCH_CHECK(self.dim() == 2 && other.dim() == 2, "tensors must be 2-D"); - TORCH_CHECK(supportedFloatingOrComplexType(self), "MPS device does not support addmm for non-float input"); TensorArg args[]{{output, "out", 0}, {bias, "self", 1}, {self, "mat1", 2}, {other, "mat2", 3}}; checkAllSameGPU(__func__, args); @@ -671,6 +725,10 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const return output; } + if (use_metal_mm(self, other, output)) { + return do_metal_addmm(self, other, output, alpha, beta, *bias_); + } + bool is_beta_non_zero = beta.toDouble() != 0.0; struct CachedGraph : public mps::MPSCachedGraph { diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 41bb2b96bd93..506bf5488f3c 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1161,8 +1161,8 @@ def make_arg_conj(size): def sample_inputs_addmm(op_info, device, dtype, requires_grad, **kwargs): - alpha_val = kwargs.get('alpha', 2 + 3j if dtype.is_complex else 0.6) - beta_val = kwargs.get('beta', 1 + 2j if dtype.is_complex else 0.2) + alpha_val = kwargs.get('alpha', 2 + 3j if dtype.is_complex else 0.6 if dtype.is_floating_point else 2) + beta_val = kwargs.get('beta', 1 + 2j if dtype.is_complex else 0.2 if dtype.is_floating_point else 3) tests_list = [ ((2, 3), (2, 2), (2, 3), False), ((3, 3), (3, 3), (3, 3), False), diff --git a/torch/testing/_internal/common_mps.py b/torch/testing/_internal/common_mps.py index 58afc631d21b..1adc183db2d3 100644 --- a/torch/testing/_internal/common_mps.py +++ b/torch/testing/_internal/common_mps.py @@ -284,85 +284,6 @@ def mps_ops_modifier( "where", "byte", } - # Those ops worked on MacOS12, but broken on MacOS13, see https://github.com/pytorch/pytorch/issues/85758 - MACOS_BEFORE_13_3_XFAILLIST = { - # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+ - "cdist": [torch.float32], - # CPU Error: cpu not giving nan for x/0.0 - "atan2": [ - torch.bool, - torch.int16, - torch.int32, - torch.int64, - torch.uint8, - torch.int8, - ], - # test blow pass on macOS 12 as it falls back to cpu - # Argsort case using duplicate indices (undefined behaviour): - # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], device='cpu') - # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0') - # Elements from index 30 and 5133 are both equal. - # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour. - "argsort": [torch.float16, torch.int8, torch.uint8, torch.bool], - # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices. - # The values of the sorted tensor match the CPU, - # but in case of the returned indices this results in undefined behaviour. - "sort": [torch.int8, torch.uint8, torch.bool, torch.float16], - # Unsupported dtypes - "cumsum": [torch.int64], - "cumprod": [torch.int64], - "cumulative_trapezoid": [torch.int64], - "masked.cumsum": [torch.int64], - "masked.cumprod": [torch.int64], - "linalg.vander": [torch.int64], - # Fail with `Expected 1.0 but got nan.` for empty tensors - # Caused by sample input at index 23: SampleInput( - # input=Tensor[size=(), device="mps:0", dtype=torch.float32], - # args=(0), - # kwargs={'mask': 'Tensor[size=(), device="mps:0", dtype=torch.bool]'}, - # broadcasts_input=False, name='') - "masked.softmin": [torch.float32, torch.float16], - "masked.softmax": [torch.float32, torch.float16], - "masked.log_softmax": [torch.float32, torch.float16], - } - - MACOS_AFTER_13_1_XFAILLIST = { - # before macOS 13.2 it falls back to cpu and pass the forward pass - "grid_sampler_2d": [ - torch.float32, - torch.float16, - torch.bfloat16, - ], # Unsupported Border padding mode - } - - MACOS_13_3_XFAILLIST = { - # Failure due to precision issue for fp16 - # on both cpu and mps there are test cases that might produce inf result - # 'nn.functional.pairwise_distance': [torch.float16], - # test blow pass on macOS 12 as it falls back to cpu - # Argsort case using duplicate indices (undefined behaviour): - # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], device='cpu') - # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0') - # Elements from index 30 and 5133 are both equal. - # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour. - "argsort": [ - torch.float16, - torch.int8, - torch.uint8, - torch.bool, - torch.bfloat16, - ], - # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices. - # The values of the sorted tensor match the CPU, - # but in case of the returned indices this results in undefined behaviour. - "sort": [ - torch.int8, - torch.uint8, - torch.bool, - torch.float16, - torch.bfloat16, - ], - } MACOS_BEFORE_14_4_XFAILLIST = { # These ops work fine in 14.4 but fail in 14.2 or 13.x @@ -495,7 +416,6 @@ def mps_ops_modifier( torch.float16, ], # Unsupported dtypes - "dot": [torch.int64] if MACOS_VERSION < 14.0 else [], "histc": [torch.float16, torch.bfloat16], "index_add": [torch.int64], # GEMM on MPS is not supported for integral types @@ -506,19 +426,9 @@ def mps_ops_modifier( torch.uint8, torch.int8, ], - "addmmdecomposed": [ - torch.int16, - torch.int32, - torch.int64, - torch.uint8, - torch.int8, - ], "addbmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], - "addmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], "baddbmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], "mat": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], - "matmul": [torch.int64] if MACOS_VERSION < 14.0 else [], - "__rmatmul__": [torch.int64] if MACOS_VERSION < 14.0 else [], # returned output on CPU is float64 "bincount": [ torch.int16, @@ -623,6 +533,38 @@ def mps_ops_modifier( "linalg.matrix_rank": None, # Exception: Caused by `torch.arange(-8.001, -4.0, dtype=torch.uint8, device="mps")` "arange": [torch.uint8], + # before macOS 13.2 it falls back to cpu and pass the forward pass + "grid_sampler_2d": [ + torch.float32, + torch.float16, + torch.bfloat16, + ], # Unsupported Border padding mode + # Failure due to precision issue for fp16 + # on both cpu and mps there are test cases that might produce inf result + # 'nn.functional.pairwise_distance': [torch.float16], + # test blow pass on macOS 12 as it falls back to cpu + # Argsort case using duplicate indices (undefined behaviour): + # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], device='cpu') + # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0') + # Elements from index 30 and 5133 are both equal. + # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour. + "argsort": [ + torch.float16, + torch.int8, + torch.uint8, + torch.bool, + torch.bfloat16, + ], + # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices. + # The values of the sorted tensor match the CPU, + # but in case of the returned indices this results in undefined behaviour. + "sort": [ + torch.int8, + torch.uint8, + torch.bool, + torch.float16, + torch.bfloat16, + ], } EMPTY_OPS_SKIPLIST = { @@ -690,43 +632,6 @@ def addDecorator(op: OpInfo, d: DecorateInfo) -> None: ), ) - if ( - key in MACOS_BEFORE_13_3_XFAILLIST - and key not in xfail_exclusion - and (torch.backends.mps.is_macos13_or_newer() and MACOS_VERSION < 13.3) - ): - addDecorator( - op, - DecorateInfo( - unittest.expectedFailure, - dtypes=MACOS_BEFORE_13_3_XFAILLIST[key], - ), - ) - - if ( - key in MACOS_AFTER_13_1_XFAILLIST - and key not in xfail_exclusion - and torch.backends.mps.is_macos13_or_newer(2) - ): - addDecorator( - op, - DecorateInfo( - unittest.expectedFailure, dtypes=MACOS_AFTER_13_1_XFAILLIST[key] - ), - ) - - if ( - key in MACOS_13_3_XFAILLIST - and key not in xfail_exclusion - and (MACOS_VERSION >= 13.3) - ): - addDecorator( - op, - DecorateInfo( - unittest.expectedFailure, dtypes=MACOS_13_3_XFAILLIST[key] - ), - ) - # If ops is not supported for complex types, expect it to fail if key not in SUPPORTED_COMPLEX_OPS: addDecorator(