Skip to content

Commit 5a8f030

Browse files
committed
[MPS] Extend addmm to integral types
Fixes #154901 ghstack-source-id: c2c210e Pull Request resolved: #160270
1 parent 6f460a7 commit 5a8f030

File tree

3 files changed

+119
-34
lines changed

3 files changed

+119
-34
lines changed

aten/src/ATen/native/mps/kernels/LinearAlgebra.metal

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,37 @@ kernel void matmul(
6868
}
6969
}
7070

71+
template <typename T>
72+
kernel void addmm(
73+
constant T* mat1Data [[buffer(0)]],
74+
constant T* mat2Data [[buffer(1)]],
75+
device T* outputData [[buffer(2)]],
76+
constant T* biasData [[buffer(3)]],
77+
constant array<c10::metal::opmath_t<T>, 2>& alpha_beta [[buffer(4)]],
78+
constant array<ulong2, 4>& strides [[buffer(5)]],
79+
constant uint3& sizes [[buffer(6)]],
80+
uint2 tid [[thread_position_in_threadgroup]],
81+
uint2 thread_id [[thread_position_in_grid]]) {
82+
threadgroup T A_tile[TILE_DIM][TILE_DIM];
83+
threadgroup T B_tile[TILE_DIM][TILE_DIM];
84+
85+
auto sum = matmul_inner<T>(
86+
mat1Data,
87+
mat2Data,
88+
reinterpret_cast<constant array<ulong2, 3>&>(strides),
89+
sizes,
90+
A_tile,
91+
B_tile,
92+
tid,
93+
thread_id);
94+
if (thread_id.y < sizes.x && thread_id.x < sizes.z) {
95+
auto bias =
96+
biasData[thread_id.y * strides[3].x + thread_id.x * strides[3].y];
97+
outputData[thread_id.y * strides[2].x + thread_id.x * strides[2].y] =
98+
static_cast<T>(alpha_beta[0] * sum + alpha_beta[1] * bias);
99+
}
100+
}
101+
71102
template <typename T>
72103
kernel void naive_bmm(
73104
constant T* mat1Data [[buffer(0)]],
@@ -613,38 +644,42 @@ kernel void applyPivots(
613644
}
614645
}
615646

616-
#define INSTANTIATE_NAIVE_MM(DTYPE) \
617-
template [[host_name("matmul_" #DTYPE)]] kernel void matmul<DTYPE>( \
618-
constant DTYPE * mat1Data [[buffer(0)]], \
619-
constant DTYPE * mat2Data [[buffer(1)]], \
620-
device DTYPE * outputData [[buffer(2)]], \
621-
constant array<ulong2, 3> & strides [[buffer(3)]], \
622-
constant uint3 & sizes [[buffer(4)]], \
623-
uint2 tid [[thread_position_in_threadgroup]], \
624-
uint2 group_id [[threadgroup_position_in_grid]])
625-
626-
#define INSTANTIATE_NAIVE_BMM(DTYPE) \
647+
#define INSTANTIATE_MM_OPS(DTYPE) \
648+
template [[host_name("matmul_" #DTYPE)]] kernel void matmul<DTYPE>( \
649+
constant DTYPE * mat1Data [[buffer(0)]], \
650+
constant DTYPE * mat2Data [[buffer(1)]], \
651+
device DTYPE * outputData [[buffer(2)]], \
652+
constant array<ulong2, 3> & strides [[buffer(3)]], \
653+
constant uint3 & sizes [[buffer(4)]], \
654+
uint2 tid [[thread_position_in_threadgroup]], \
655+
uint2 group_id [[threadgroup_position_in_grid]]); \
627656
template [[host_name("naive_bmm_" #DTYPE)]] kernel void naive_bmm<DTYPE>( \
628657
constant DTYPE * mat1Data [[buffer(0)]], \
629658
constant DTYPE * mat2Data [[buffer(1)]], \
630659
device DTYPE * outputData [[buffer(2)]], \
631660
constant array<ulong, 9> & strides [[buffer(3)]], \
632661
constant uint4 & sizes [[buffer(4)]], \
633662
uint3 tid [[thread_position_in_threadgroup]], \
634-
uint3 group_id [[threadgroup_position_in_grid]])
663+
uint3 group_id [[threadgroup_position_in_grid]]); \
664+
template [[host_name("addmm_" #DTYPE)]] kernel void addmm<DTYPE>( \
665+
constant DTYPE * mat1Data [[buffer(0)]], \
666+
constant DTYPE * mat2Data [[buffer(1)]], \
667+
device DTYPE * outputData [[buffer(2)]], \
668+
constant DTYPE * biasData [[buffer(3)]], \
669+
constant array<c10::metal::opmath_t<DTYPE>, 2> & \
670+
alpha_beta [[buffer(4)]], \
671+
constant array<ulong2, 4> & strides [[buffer(5)]], \
672+
constant uint3 & sizes [[buffer(6)]], \
673+
uint2 tid [[thread_position_in_threadgroup]], \
674+
uint2 group_id [[threadgroup_position_in_grid]])
635675

636-
INSTANTIATE_NAIVE_MM(float);
637-
INSTANTIATE_NAIVE_MM(half);
638-
INSTANTIATE_NAIVE_MM(bfloat);
676+
INSTANTIATE_MM_OPS(float);
677+
INSTANTIATE_MM_OPS(half);
678+
INSTANTIATE_MM_OPS(bfloat);
639679

640680
// Integral MM
641-
INSTANTIATE_NAIVE_MM(short);
642-
INSTANTIATE_NAIVE_MM(int);
643-
INSTANTIATE_NAIVE_MM(long);
644-
INSTANTIATE_NAIVE_MM(char);
645-
INSTANTIATE_NAIVE_MM(uchar);
646-
INSTANTIATE_NAIVE_BMM(short);
647-
INSTANTIATE_NAIVE_BMM(int);
648-
INSTANTIATE_NAIVE_BMM(long);
649-
INSTANTIATE_NAIVE_BMM(char);
650-
INSTANTIATE_NAIVE_BMM(uchar);
681+
INSTANTIATE_MM_OPS(long);
682+
INSTANTIATE_MM_OPS(int);
683+
INSTANTIATE_MM_OPS(short);
684+
INSTANTIATE_MM_OPS(char);
685+
INSTANTIATE_MM_OPS(uchar);

aten/src/ATen/native/mps/operations/LinearAlgebra.mm

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,61 @@
112112
return output;
113113
}
114114

115+
Tensor& do_metal_addmm(const Tensor& self,
116+
const Tensor& other,
117+
Tensor& output,
118+
const Scalar& alpha,
119+
const Scalar& beta,
120+
const Tensor& bias) {
121+
if (beta.toDouble() == 0 && alpha.toDouble() == 1) {
122+
return do_metal_mm(self, other, output);
123+
}
124+
auto stream = getCurrentMPSStream();
125+
auto device = MPSDevice::getInstance()->device();
126+
auto matmulPSO = lib.getPipelineStateForFunc("addmm_" + mps::scalarToMetalTypeString(output));
127+
dispatch_sync_with_rethrow(stream->queue(), ^() {
128+
@autoreleasepool {
129+
getMPSProfiler().beginProfileKernel(matmulPSO, "addmm", {self, other});
130+
auto computeEncoder = stream->commandEncoder();
131+
[computeEncoder setComputePipelineState:matmulPSO];
132+
std::array<uint32_t, 3> sizes = {static_cast<uint32_t>(self.size(0)),
133+
static_cast<uint32_t>(self.size(1)),
134+
static_cast<uint32_t>(output.size(1))};
135+
std::array<int64_t, 8> strides = {self.stride(0),
136+
self.stride(1),
137+
other.stride(0),
138+
other.stride(1),
139+
output.stride(0),
140+
output.stride(1),
141+
bias.stride(0),
142+
bias.stride(1)};
143+
union {
144+
std::array<int64_t, 2> i64;
145+
std::array<int32_t, 2> i32;
146+
std::array<float, 2> f32;
147+
} alpha_beta;
148+
if (output.scalar_type() == kLong) {
149+
alpha_beta.i64 = {alpha.toLong(), beta.toLong()};
150+
} else if (c10::isIntegralType(output.scalar_type(), true)) {
151+
alpha_beta.i32 = {alpha.toInt(), beta.toInt()};
152+
} else {
153+
TORCH_INTERNAL_ASSERT(c10::isFloatingType(output.scalar_type()));
154+
alpha_beta.f32 = {alpha.toFloat(), beta.toFloat()};
155+
}
156+
constexpr uint32_t TILE_DIM = 16; // fastest performance from tests on multiple macs
157+
uint32_t gridSizeX = (output.size(1) + TILE_DIM - 1) / TILE_DIM;
158+
uint32_t gridSizeY = (self.size(0) + TILE_DIM - 1) / TILE_DIM;
159+
160+
MTLSize threadsPerThreadgroup = MTLSizeMake(TILE_DIM, TILE_DIM, 1);
161+
MTLSize threadgroupsPerGrid = MTLSizeMake(gridSizeX, gridSizeY, 1);
162+
mtl_setArgs(computeEncoder, self, other, output, bias, alpha_beta.i64, strides, sizes);
163+
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadsPerThreadgroup];
164+
getMPSProfiler().endProfileKernel(matmulPSO);
165+
}
166+
});
167+
return output;
168+
}
169+
115170
std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*> do_mm(MPSGraph* graph,
116171
const Tensor& self,
117172
const Tensor& other) {
@@ -644,7 +699,6 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const
644699

645700
TORCH_CHECK(output.is_mps());
646701
TORCH_CHECK(self.dim() == 2 && other.dim() == 2, "tensors must be 2-D");
647-
TORCH_CHECK(supportedFloatingOrComplexType(self), "MPS device does not support addmm for non-float input");
648702

649703
TensorArg args[]{{output, "out", 0}, {bias, "self", 1}, {self, "mat1", 2}, {other, "mat2", 3}};
650704
checkAllSameGPU(__func__, args);
@@ -671,6 +725,10 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const
671725
return output;
672726
}
673727

728+
if (use_metal_mm(self, other, output)) {
729+
return do_metal_addmm(self, other, output, alpha, beta, *bias_);
730+
}
731+
674732
bool is_beta_non_zero = beta.toDouble() != 0.0;
675733

676734
struct CachedGraph : public mps::MPSCachedGraph {

torch/testing/_internal/common_mps.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -426,15 +426,7 @@ def mps_ops_modifier(
426426
torch.uint8,
427427
torch.int8,
428428
],
429-
"addmmdecomposed": [
430-
torch.int16,
431-
torch.int32,
432-
torch.int64,
433-
torch.uint8,
434-
torch.int8,
435-
],
436429
"addbmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
437-
"addmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
438430
"baddbmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
439431
"mat": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
440432
# returned output on CPU is float64

0 commit comments

Comments
 (0)