@@ -74,7 +74,7 @@ kernel void addmm(
74
74
constant T* mat2Data [[buffer(1 )]],
75
75
device T* outputData [[buffer(2 )]],
76
76
constant T* biasData [[buffer(3 )]],
77
- constant array<long , 2>& alpha_beta [[buffer(4 )]],
77
+ constant array<c10::metal::opmath_t<T> , 2>& alpha_beta [[buffer(4 )]],
78
78
constant array<ulong2, 4>& strides [[buffer(5 )]],
79
79
constant uint3& sizes [[buffer(6 )]],
80
80
uint2 tid [[thread_position_in_threadgroup]],
@@ -644,55 +644,42 @@ kernel void applyPivots(
644
644
}
645
645
}
646
646
647
- #define INSTANTIATE_NAIVE_MM (DTYPE ) \
648
- template [[host_name(" matmul_" #DTYPE)]] kernel void matmul<DTYPE>( \
649
- constant DTYPE * mat1Data [[buffer(0 )]], \
650
- constant DTYPE * mat2Data [[buffer(1 )]], \
651
- device DTYPE * outputData [[buffer(2 )]], \
652
- constant array<ulong2, 3 > & strides [[buffer(3 )]], \
653
- constant uint3 & sizes [[buffer(4 )]], \
654
- uint2 tid [[thread_position_in_threadgroup]], \
655
- uint2 group_id [[threadgroup_position_in_grid]])
656
-
657
- #define INSTANTIATE_NAIVE_BMM (DTYPE ) \
647
+ #define INSTANTIATE_MM_OPS (DTYPE ) \
648
+ template [[host_name(" matmul_" #DTYPE)]] kernel void matmul<DTYPE>( \
649
+ constant DTYPE * mat1Data [[buffer(0 )]], \
650
+ constant DTYPE * mat2Data [[buffer(1 )]], \
651
+ device DTYPE * outputData [[buffer(2 )]], \
652
+ constant array<ulong2, 3 > & strides [[buffer(3 )]], \
653
+ constant uint3 & sizes [[buffer(4 )]], \
654
+ uint2 tid [[thread_position_in_threadgroup]], \
655
+ uint2 group_id [[threadgroup_position_in_grid]]); \
658
656
template [[host_name(" naive_bmm_" #DTYPE)]] kernel void naive_bmm<DTYPE>( \
659
657
constant DTYPE * mat1Data [[buffer(0 )]], \
660
658
constant DTYPE * mat2Data [[buffer(1 )]], \
661
659
device DTYPE * outputData [[buffer(2 )]], \
662
660
constant array<ulong, 9 > & strides [[buffer(3 )]], \
663
661
constant uint4 & sizes [[buffer(4 )]], \
664
662
uint3 tid [[thread_position_in_threadgroup]], \
665
- uint3 group_id [[threadgroup_position_in_grid]])
666
-
667
- #define INSTANTIATE_NAIVE_ADDMM (DTYPE ) \
668
- template [[host_name(" addmm_" #DTYPE)]] kernel void addmm<DTYPE>( \
669
- constant DTYPE * mat1Data [[buffer(0 )]], \
670
- constant DTYPE * mat2Data [[buffer(1 )]], \
671
- device DTYPE * outputData [[buffer(2 )]], \
672
- constant DTYPE * biasData [[buffer(3 )]], \
673
- constant array<long , 2 > & alpha_beta [[buffer(4 )]], \
674
- constant array<ulong2, 4 > & strides [[buffer(5 )]], \
675
- constant uint3 & sizes [[buffer(6 )]], \
676
- uint2 tid [[thread_position_in_threadgroup]], \
663
+ uint3 group_id [[threadgroup_position_in_grid]]); \
664
+ template [[host_name(" addmm_" #DTYPE)]] kernel void addmm<DTYPE>( \
665
+ constant DTYPE * mat1Data [[buffer(0 )]], \
666
+ constant DTYPE * mat2Data [[buffer(1 )]], \
667
+ device DTYPE * outputData [[buffer(2 )]], \
668
+ constant DTYPE * biasData [[buffer(3 )]], \
669
+ constant array<c10::metal::opmath_t <DTYPE>, 2 > & \
670
+ alpha_beta [[buffer(4 )]], \
671
+ constant array<ulong2, 4 > & strides [[buffer(5 )]], \
672
+ constant uint3 & sizes [[buffer(6 )]], \
673
+ uint2 tid [[thread_position_in_threadgroup]], \
677
674
uint2 group_id [[threadgroup_position_in_grid]])
678
675
679
- INSTANTIATE_NAIVE_MM (float );
680
- INSTANTIATE_NAIVE_MM (half);
681
- INSTANTIATE_NAIVE_MM (bfloat);
676
+ INSTANTIATE_MM_OPS (float );
677
+ INSTANTIATE_MM_OPS (half);
678
+ INSTANTIATE_MM_OPS (bfloat);
682
679
683
680
// Integral MM
684
- INSTANTIATE_NAIVE_MM (short );
685
- INSTANTIATE_NAIVE_MM (int );
686
- INSTANTIATE_NAIVE_MM (long );
687
- INSTANTIATE_NAIVE_MM (char );
688
- INSTANTIATE_NAIVE_MM (uchar);
689
- INSTANTIATE_NAIVE_BMM (short );
690
- INSTANTIATE_NAIVE_BMM (int );
691
- INSTANTIATE_NAIVE_BMM (long );
692
- INSTANTIATE_NAIVE_BMM (char );
693
- INSTANTIATE_NAIVE_BMM (uchar);
694
- INSTANTIATE_NAIVE_ADDMM (short );
695
- INSTANTIATE_NAIVE_ADDMM (int );
696
- INSTANTIATE_NAIVE_ADDMM (long );
697
- INSTANTIATE_NAIVE_ADDMM (char );
698
- INSTANTIATE_NAIVE_ADDMM (uchar);
681
+ INSTANTIATE_MM_OPS (long );
682
+ INSTANTIATE_MM_OPS (int );
683
+ INSTANTIATE_MM_OPS (short );
684
+ INSTANTIATE_MM_OPS (char );
685
+ INSTANTIATE_MM_OPS (uchar);
0 commit comments