@@ -68,6 +68,37 @@ kernel void matmul(
68
68
}
69
69
}
70
70
71
+ template <typename T>
72
+ kernel void addmm (
73
+ constant T* mat1Data [[buffer(0 )]],
74
+ constant T* mat2Data [[buffer(1 )]],
75
+ device T* outputData [[buffer(2 )]],
76
+ constant T* biasData [[buffer(3 )]],
77
+ constant array<c10::metal::opmath_t<T>, 2>& alpha_beta [[buffer(4 )]],
78
+ constant array<ulong2, 4>& strides [[buffer(5 )]],
79
+ constant uint3& sizes [[buffer(6 )]],
80
+ uint2 tid [[thread_position_in_threadgroup]],
81
+ uint2 thread_id [[thread_position_in_grid]]) {
82
+ threadgroup T A_tile[TILE_DIM][TILE_DIM];
83
+ threadgroup T B_tile[TILE_DIM][TILE_DIM];
84
+
85
+ auto sum = matmul_inner<T>(
86
+ mat1Data,
87
+ mat2Data,
88
+ reinterpret_cast <constant array<ulong2, 3 >&>(strides),
89
+ sizes,
90
+ A_tile,
91
+ B_tile,
92
+ tid,
93
+ thread_id);
94
+ if (thread_id.y < sizes.x && thread_id.x < sizes.z ) {
95
+ auto bias =
96
+ biasData[thread_id.y * strides[3 ].x + thread_id.x * strides[3 ].y ];
97
+ outputData[thread_id.y * strides[2 ].x + thread_id.x * strides[2 ].y ] =
98
+ static_cast <T>(alpha_beta[0 ] * sum + alpha_beta[1 ] * bias);
99
+ }
100
+ }
101
+
71
102
template <typename T>
72
103
kernel void naive_bmm (
73
104
constant T* mat1Data [[buffer(0 )]],
@@ -613,38 +644,42 @@ kernel void applyPivots(
613
644
}
614
645
}
615
646
616
- #define INSTANTIATE_NAIVE_MM (DTYPE ) \
617
- template [[host_name(" matmul_" #DTYPE)]] kernel void matmul<DTYPE>( \
618
- constant DTYPE * mat1Data [[buffer(0 )]], \
619
- constant DTYPE * mat2Data [[buffer(1 )]], \
620
- device DTYPE * outputData [[buffer(2 )]], \
621
- constant array<ulong2, 3 > & strides [[buffer(3 )]], \
622
- constant uint3 & sizes [[buffer(4 )]], \
623
- uint2 tid [[thread_position_in_threadgroup]], \
624
- uint2 group_id [[threadgroup_position_in_grid]])
625
-
626
- #define INSTANTIATE_NAIVE_BMM (DTYPE ) \
647
+ #define INSTANTIATE_MM_OPS (DTYPE ) \
648
+ template [[host_name(" matmul_" #DTYPE)]] kernel void matmul<DTYPE>( \
649
+ constant DTYPE * mat1Data [[buffer(0 )]], \
650
+ constant DTYPE * mat2Data [[buffer(1 )]], \
651
+ device DTYPE * outputData [[buffer(2 )]], \
652
+ constant array<ulong2, 3 > & strides [[buffer(3 )]], \
653
+ constant uint3 & sizes [[buffer(4 )]], \
654
+ uint2 tid [[thread_position_in_threadgroup]], \
655
+ uint2 group_id [[threadgroup_position_in_grid]]); \
627
656
template [[host_name(" naive_bmm_" #DTYPE)]] kernel void naive_bmm<DTYPE>( \
628
657
constant DTYPE * mat1Data [[buffer(0 )]], \
629
658
constant DTYPE * mat2Data [[buffer(1 )]], \
630
659
device DTYPE * outputData [[buffer(2 )]], \
631
660
constant array<ulong, 9 > & strides [[buffer(3 )]], \
632
661
constant uint4 & sizes [[buffer(4 )]], \
633
662
uint3 tid [[thread_position_in_threadgroup]], \
634
- uint3 group_id [[threadgroup_position_in_grid]])
663
+ uint3 group_id [[threadgroup_position_in_grid]]); \
664
+ template [[host_name(" addmm_" #DTYPE)]] kernel void addmm<DTYPE>( \
665
+ constant DTYPE * mat1Data [[buffer(0 )]], \
666
+ constant DTYPE * mat2Data [[buffer(1 )]], \
667
+ device DTYPE * outputData [[buffer(2 )]], \
668
+ constant DTYPE * biasData [[buffer(3 )]], \
669
+ constant array<c10::metal::opmath_t <DTYPE>, 2 > & \
670
+ alpha_beta [[buffer(4 )]], \
671
+ constant array<ulong2, 4 > & strides [[buffer(5 )]], \
672
+ constant uint3 & sizes [[buffer(6 )]], \
673
+ uint2 tid [[thread_position_in_threadgroup]], \
674
+ uint2 group_id [[threadgroup_position_in_grid]])
635
675
636
- INSTANTIATE_NAIVE_MM (float );
637
- INSTANTIATE_NAIVE_MM (half);
638
- INSTANTIATE_NAIVE_MM (bfloat);
676
+ INSTANTIATE_MM_OPS (float );
677
+ INSTANTIATE_MM_OPS (half);
678
+ INSTANTIATE_MM_OPS (bfloat);
639
679
640
680
// Integral MM
641
- INSTANTIATE_NAIVE_MM (short );
642
- INSTANTIATE_NAIVE_MM (int );
643
- INSTANTIATE_NAIVE_MM (long );
644
- INSTANTIATE_NAIVE_MM (char );
645
- INSTANTIATE_NAIVE_MM (uchar);
646
- INSTANTIATE_NAIVE_BMM (short );
647
- INSTANTIATE_NAIVE_BMM (int );
648
- INSTANTIATE_NAIVE_BMM (long );
649
- INSTANTIATE_NAIVE_BMM (char );
650
- INSTANTIATE_NAIVE_BMM (uchar);
681
+ INSTANTIATE_MM_OPS (long );
682
+ INSTANTIATE_MM_OPS (int );
683
+ INSTANTIATE_MM_OPS (short );
684
+ INSTANTIATE_MM_OPS (char );
685
+ INSTANTIATE_MM_OPS (uchar);
0 commit comments