Skip to content

Commit 533cc9a

Browse files
authored
[NVPTX] Limit a sparsity selector in sparse MMA intrinsics. (#154984)
This PR fixes NVPTX tests in LLVM testing by adding more limitations for a sparsity selector in sparse MMA intrinsics. The previous PR that is merged to llvm:main is [PR150950](#150950). The merge to llvm:main is d9c6b7b
1 parent 0f07235 commit 533cc9a

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2161,6 +2161,7 @@ class NVVM_MMA_SP<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
21612161
// The range [0;num_threads) is for the sparsity selector that indicates the threads
21622162
// which contribute metadata.
21632163
int num_threads = !if(!or(!and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "bf16")),
2164+
!and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "f16")),
21642165
!and(!eq(A.geom, "m16n8k16"), !eq(A.ptx_elt_type, "tf32")),
21652166
!and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "u8")),
21662167
!and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "s8")),
@@ -2175,7 +2176,11 @@ class NVVM_MMA_SP<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
21752176
!eq(A.ptx_elt_type, "e3m2"),
21762177
!eq(A.ptx_elt_type, "e2m3"),
21772178
!eq(A.ptx_elt_type, "e2m1"))),
2178-
1, 4));
2179+
1,
2180+
!if(!and(!eq(A.geom, "m16n8k128"),
2181+
!or(!eq(A.ptx_elt_type, "s4"),
2182+
!eq(A.ptx_elt_type, "u4"))),
2183+
1, 4)));
21792184
let IntrProperties = [IntrNoMem, IntrNoCallback, ImmArg<ArgIndex<pos>>,
21802185
Range<ArgIndex<pos>, 0, num_threads>];
21812186
}

llvm/test/CodeGen/NVPTX/wmma.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,6 +1135,7 @@ def sp_selector_gen(op):
11351135
# (geom, type) -> allowed selector range
11361136
range_01 = {
11371137
("m16n8k32", "bf16"),
1138+
("m16n8k32", "f16"),
11381139
("m16n8k16", "tf32"),
11391140
("m16n8k32", "u8"),
11401141
("m16n8k32", "s8"),
@@ -1154,6 +1155,11 @@ def sp_selector_gen(op):
11541155
"e2m1",
11551156
]:
11561157
return range(1)
1158+
if op.a.geom == "m16n8k128" and op.a.mma_type.ptx_type in [
1159+
"u4",
1160+
"s4",
1161+
]:
1162+
return range(1)
11571163
return range(4)
11581164

11591165

0 commit comments

Comments
 (0)