Skip to content

[MPS] Extend addmm to integral types #160270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[MPS] Extend addmm to integral types
Fixes #154901

[ghstack-poisoned]
  • Loading branch information
malfet committed Aug 10, 2025
commit abe992a0ca9b6c9119d04514f5d5cd9865f66fcf
48 changes: 48 additions & 0 deletions aten/src/ATen/native/mps/kernels/LinearAlgebra.metal
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,37 @@ kernel void matmul(
}
}

template <typename T>
kernel void addmm(
constant T* mat1Data [[buffer(0)]],
constant T* mat2Data [[buffer(1)]],
device T* outputData [[buffer(2)]],
constant T* biasData [[buffer(3)]],
constant array<long, 2>& alpha_beta [[buffer(4)]],
constant array<ulong2, 4>& strides [[buffer(5)]],
constant uint3& sizes [[buffer(6)]],
uint2 tid [[thread_position_in_threadgroup]],
uint2 thread_id [[thread_position_in_grid]]) {
threadgroup T A_tile[TILE_DIM][TILE_DIM];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's ugly, but can this be rewritten as an std array too?

Copy link
Contributor Author

@malfet malfet Aug 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Threadgroups are a bit weird(i.e. this statement affects GPU occupancy), let me give it a try in a separate PR, but make sure it would not regress the perf...

threadgroup T B_tile[TILE_DIM][TILE_DIM];

auto sum = matmul_inner<T>(
mat1Data,
mat2Data,
reinterpret_cast<constant array<ulong2, 3>&>(strides),
sizes,
A_tile,
B_tile,
tid,
thread_id);
if (thread_id.y < sizes.x && thread_id.x < sizes.z) {
auto bias =
biasData[thread_id.y * strides[3].x + thread_id.x * strides[3].y];
outputData[thread_id.y * strides[2].x + thread_id.x * strides[2].y] =
static_cast<T>(alpha_beta[0] * sum + alpha_beta[1] * bias);
}
}

template <typename T>
kernel void naive_bmm(
constant T* mat1Data [[buffer(0)]],
Expand Down Expand Up @@ -633,6 +664,18 @@ kernel void applyPivots(
uint3 tid [[thread_position_in_threadgroup]], \
uint3 group_id [[threadgroup_position_in_grid]])

#define INSTANTIATE_NAIVE_ADDMM(DTYPE) \
template [[host_name("addmm_" #DTYPE)]] kernel void addmm<DTYPE>( \
constant DTYPE * mat1Data [[buffer(0)]], \
constant DTYPE * mat2Data [[buffer(1)]], \
device DTYPE * outputData [[buffer(2)]], \
constant DTYPE * biasData [[buffer(3)]], \
constant array<long, 2> & alpha_beta [[buffer(4)]], \
constant array<ulong2, 4> & strides [[buffer(5)]], \
constant uint3 & sizes [[buffer(6)]], \
uint2 tid [[thread_position_in_threadgroup]], \
uint2 group_id [[threadgroup_position_in_grid]])

INSTANTIATE_NAIVE_MM(float);
INSTANTIATE_NAIVE_MM(half);
INSTANTIATE_NAIVE_MM(bfloat);
Expand All @@ -648,3 +691,8 @@ INSTANTIATE_NAIVE_BMM(int);
INSTANTIATE_NAIVE_BMM(long);
INSTANTIATE_NAIVE_BMM(char);
INSTANTIATE_NAIVE_BMM(uchar);
INSTANTIATE_NAIVE_ADDMM(short);
INSTANTIATE_NAIVE_ADDMM(int);
INSTANTIATE_NAIVE_ADDMM(long);
INSTANTIATE_NAIVE_ADDMM(char);
INSTANTIATE_NAIVE_ADDMM(uchar);
49 changes: 48 additions & 1 deletion aten/src/ATen/native/mps/operations/LinearAlgebra.mm
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,50 @@
return output;
}

Tensor& do_metal_addmm(const Tensor& self,
const Tensor& other,
Tensor& output,
const Scalar& alpha,
const Scalar& beta,
const Tensor& bias) {
if (beta.toDouble() == 0 && alpha.toDouble() == 1) {
return do_metal_mm(self, other, output);
}
auto stream = getCurrentMPSStream();
auto device = MPSDevice::getInstance()->device();
auto matmulPSO = lib.getPipelineStateForFunc("addmm_" + mps::scalarToMetalTypeString(output));
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
getMPSProfiler().beginProfileKernel(matmulPSO, "addmm", {self, other});
auto computeEncoder = stream->commandEncoder();
[computeEncoder setComputePipelineState:matmulPSO];
std::array<uint32_t, 3> sizes = {static_cast<uint32_t>(self.size(0)),
static_cast<uint32_t>(self.size(1)),
static_cast<uint32_t>(output.size(1))};
std::array<int64_t, 8> strides = {self.stride(0),
self.stride(1),
other.stride(0),
other.stride(1),
output.stride(0),
output.stride(1),
bias.stride(0),
bias.stride(1)};
std::array<int64_t, 2> alpha_beta = {alpha.toInt(), beta.toInt()};
constexpr uint32_t TILE_DIM = 16; // fastest performance from tests on multiple macs
uint32_t gridSizeX = (output.size(1) + TILE_DIM - 1) / TILE_DIM;
uint32_t gridSizeY = (self.size(0) + TILE_DIM - 1) / TILE_DIM;

MTLSize threadsPerThreadgroup = MTLSizeMake(TILE_DIM, TILE_DIM, 1);
MTLSize threadgroupsPerGrid = MTLSizeMake(gridSizeX, gridSizeY, 1);
mtl_setArgs(computeEncoder, self, other, output, bias, alpha_beta, strides, sizes);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadsPerThreadgroup];
getMPSProfiler().endProfileKernel(matmulPSO);
}
});
return output;
return output;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double return? Surprised this didn't error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I hoped some of the linters will be triggered by it, but feels like this is fine...

}

std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*> do_mm(MPSGraph* graph,
const Tensor& self,
const Tensor& other) {
Expand Down Expand Up @@ -644,7 +688,6 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const

TORCH_CHECK(output.is_mps());
TORCH_CHECK(self.dim() == 2 && other.dim() == 2, "tensors must be 2-D");
TORCH_CHECK(supportedFloatingOrComplexType(self), "MPS device does not support addmm for non-float input");

TensorArg args[]{{output, "out", 0}, {bias, "self", 1}, {self, "mat1", 2}, {other, "mat2", 3}};
checkAllSameGPU(__func__, args);
Expand All @@ -671,6 +714,10 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const
return output;
}

if (use_metal_mm(self, other, output)) {
return do_metal_addmm(self, other, output, alpha, beta, *bias_);
}

bool is_beta_non_zero = beta.toDouble() != 0.0;

struct CachedGraph : public mps::MPSCachedGraph {
Expand Down
8 changes: 0 additions & 8 deletions torch/testing/_internal/common_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,15 +426,7 @@ def mps_ops_modifier(
torch.uint8,
torch.int8,
],
"addmmdecomposed": [
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.int8,
],
"addbmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
"addmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
"baddbmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
"mat": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
# returned output on CPU is float64
Expand Down
Loading