Skip to content

Commit cdfaf5a

Browse files
committed
Update on "[MPS] Extend addmm to integral types"
Fixes #154901 [ghstack-poisoned]
1 parent abe992a commit cdfaf5a

File tree

2 files changed

+43
-45
lines changed

2 files changed

+43
-45
lines changed

aten/src/ATen/native/mps/kernels/LinearAlgebra.metal

Lines changed: 29 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ kernel void addmm(
7474
constant T* mat2Data [[buffer(1)]],
7575
device T* outputData [[buffer(2)]],
7676
constant T* biasData [[buffer(3)]],
77-
constant array<long, 2>& alpha_beta [[buffer(4)]],
77+
constant array<c10::metal::opmath_t<T>, 2>& alpha_beta [[buffer(4)]],
7878
constant array<ulong2, 4>& strides [[buffer(5)]],
7979
constant uint3& sizes [[buffer(6)]],
8080
uint2 tid [[thread_position_in_threadgroup]],
@@ -644,55 +644,42 @@ kernel void applyPivots(
644644
}
645645
}
646646

647-
#define INSTANTIATE_NAIVE_MM(DTYPE) \
648-
template [[host_name("matmul_" #DTYPE)]] kernel void matmul<DTYPE>( \
649-
constant DTYPE * mat1Data [[buffer(0)]], \
650-
constant DTYPE * mat2Data [[buffer(1)]], \
651-
device DTYPE * outputData [[buffer(2)]], \
652-
constant array<ulong2, 3> & strides [[buffer(3)]], \
653-
constant uint3 & sizes [[buffer(4)]], \
654-
uint2 tid [[thread_position_in_threadgroup]], \
655-
uint2 group_id [[threadgroup_position_in_grid]])
656-
657-
#define INSTANTIATE_NAIVE_BMM(DTYPE) \
647+
#define INSTANTIATE_MM_OPS(DTYPE) \
648+
template [[host_name("matmul_" #DTYPE)]] kernel void matmul<DTYPE>( \
649+
constant DTYPE * mat1Data [[buffer(0)]], \
650+
constant DTYPE * mat2Data [[buffer(1)]], \
651+
device DTYPE * outputData [[buffer(2)]], \
652+
constant array<ulong2, 3> & strides [[buffer(3)]], \
653+
constant uint3 & sizes [[buffer(4)]], \
654+
uint2 tid [[thread_position_in_threadgroup]], \
655+
uint2 group_id [[threadgroup_position_in_grid]]); \
658656
template [[host_name("naive_bmm_" #DTYPE)]] kernel void naive_bmm<DTYPE>( \
659657
constant DTYPE * mat1Data [[buffer(0)]], \
660658
constant DTYPE * mat2Data [[buffer(1)]], \
661659
device DTYPE * outputData [[buffer(2)]], \
662660
constant array<ulong, 9> & strides [[buffer(3)]], \
663661
constant uint4 & sizes [[buffer(4)]], \
664662
uint3 tid [[thread_position_in_threadgroup]], \
665-
uint3 group_id [[threadgroup_position_in_grid]])
666-
667-
#define INSTANTIATE_NAIVE_ADDMM(DTYPE) \
668-
template [[host_name("addmm_" #DTYPE)]] kernel void addmm<DTYPE>( \
669-
constant DTYPE * mat1Data [[buffer(0)]], \
670-
constant DTYPE * mat2Data [[buffer(1)]], \
671-
device DTYPE * outputData [[buffer(2)]], \
672-
constant DTYPE * biasData [[buffer(3)]], \
673-
constant array<long, 2> & alpha_beta [[buffer(4)]], \
674-
constant array<ulong2, 4> & strides [[buffer(5)]], \
675-
constant uint3 & sizes [[buffer(6)]], \
676-
uint2 tid [[thread_position_in_threadgroup]], \
663+
uint3 group_id [[threadgroup_position_in_grid]]); \
664+
template [[host_name("addmm_" #DTYPE)]] kernel void addmm<DTYPE>( \
665+
constant DTYPE * mat1Data [[buffer(0)]], \
666+
constant DTYPE * mat2Data [[buffer(1)]], \
667+
device DTYPE * outputData [[buffer(2)]], \
668+
constant DTYPE * biasData [[buffer(3)]], \
669+
constant array<c10::metal::opmath_t<DTYPE>, 2> & \
670+
alpha_beta [[buffer(4)]], \
671+
constant array<ulong2, 4> & strides [[buffer(5)]], \
672+
constant uint3 & sizes [[buffer(6)]], \
673+
uint2 tid [[thread_position_in_threadgroup]], \
677674
uint2 group_id [[threadgroup_position_in_grid]])
678675

679-
INSTANTIATE_NAIVE_MM(float);
680-
INSTANTIATE_NAIVE_MM(half);
681-
INSTANTIATE_NAIVE_MM(bfloat);
676+
INSTANTIATE_MM_OPS(float);
677+
INSTANTIATE_MM_OPS(half);
678+
INSTANTIATE_MM_OPS(bfloat);
682679

683680
// Integral MM
684-
INSTANTIATE_NAIVE_MM(short);
685-
INSTANTIATE_NAIVE_MM(int);
686-
INSTANTIATE_NAIVE_MM(long);
687-
INSTANTIATE_NAIVE_MM(char);
688-
INSTANTIATE_NAIVE_MM(uchar);
689-
INSTANTIATE_NAIVE_BMM(short);
690-
INSTANTIATE_NAIVE_BMM(int);
691-
INSTANTIATE_NAIVE_BMM(long);
692-
INSTANTIATE_NAIVE_BMM(char);
693-
INSTANTIATE_NAIVE_BMM(uchar);
694-
INSTANTIATE_NAIVE_ADDMM(short);
695-
INSTANTIATE_NAIVE_ADDMM(int);
696-
INSTANTIATE_NAIVE_ADDMM(long);
697-
INSTANTIATE_NAIVE_ADDMM(char);
698-
INSTANTIATE_NAIVE_ADDMM(uchar);
681+
INSTANTIATE_MM_OPS(long);
682+
INSTANTIATE_MM_OPS(int);
683+
INSTANTIATE_MM_OPS(short);
684+
INSTANTIATE_MM_OPS(char);
685+
INSTANTIATE_MM_OPS(uchar);

aten/src/ATen/native/mps/operations/LinearAlgebra.mm

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,20 +140,31 @@
140140
output.stride(1),
141141
bias.stride(0),
142142
bias.stride(1)};
143-
std::array<int64_t, 2> alpha_beta = {alpha.toInt(), beta.toInt()};
143+
union {
144+
std::array<int64_t, 2> i64;
145+
std::array<int32_t, 2> i32;
146+
std::array<float, 2> f32;
147+
} alpha_beta;
148+
if (output.scalar_type() == kLong) {
149+
alpha_beta.i64 = {alpha.toLong(), beta.toLong()};
150+
} else if (c10::isIntegralType(output.scalar_type(), true)) {
151+
alpha_beta.i32 = {alpha.toInt(), beta.toInt()};
152+
} else {
153+
TORCH_INTERNAL_ASSERT(c10::isFloatingType(output.scalar_type()));
154+
alpha_beta.f32 = {alpha.toFloat(), beta.toFloat()};
155+
}
144156
constexpr uint32_t TILE_DIM = 16; // fastest performance from tests on multiple macs
145157
uint32_t gridSizeX = (output.size(1) + TILE_DIM - 1) / TILE_DIM;
146158
uint32_t gridSizeY = (self.size(0) + TILE_DIM - 1) / TILE_DIM;
147159

148160
MTLSize threadsPerThreadgroup = MTLSizeMake(TILE_DIM, TILE_DIM, 1);
149161
MTLSize threadgroupsPerGrid = MTLSizeMake(gridSizeX, gridSizeY, 1);
150-
mtl_setArgs(computeEncoder, self, other, output, bias, alpha_beta, strides, sizes);
162+
mtl_setArgs(computeEncoder, self, other, output, bias, alpha_beta.i64, strides, sizes);
151163
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadsPerThreadgroup];
152164
getMPSProfiler().endProfileKernel(matmulPSO);
153165
}
154166
});
155167
return output;
156-
return output;
157168
}
158169

159170
std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*> do_mm(MPSGraph* graph,

0 commit comments

Comments
 (0)