-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[NVPTX] Limit a sparsity selector in sparse MMA intrinsics. #154984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVPTX] Limit a sparsity selector in sparse MMA intrinsics. #154984
Conversation
sparsity selector in sparse MMA intrinsics.
@llvm/pr-subscribers-backend-nvptx Author: Kirill Vedernikov (kvederni) ChangesThis PR fixes NVPTX tests in LLVM testing by adding more limitations for a sparsity selector in sparse MMA intrinsics. Full diff: https://github.com/llvm/llvm-project/pull/154984.diff 2 Files Affected:
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index cd7a0bc9c4b48..130fa27e4f870 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -2161,6 +2161,7 @@ class NVVM_MMA_SP<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
// The range [0;num_threads) is for the sparsity selector that indicates the threads
// which contribute metadata.
int num_threads = !if(!or(!and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "bf16")),
+ !and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "f16")),
!and(!eq(A.geom, "m16n8k16"), !eq(A.ptx_elt_type, "tf32")),
!and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "u8")),
!and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "s8")),
@@ -2175,7 +2176,11 @@ class NVVM_MMA_SP<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
!eq(A.ptx_elt_type, "e3m2"),
!eq(A.ptx_elt_type, "e2m3"),
!eq(A.ptx_elt_type, "e2m1"))),
- 1, 4));
+ 1,
+ !if(!and(!eq(A.geom, "m16n8k128"),
+ !or(!eq(A.ptx_elt_type, "s4"),
+ !eq(A.ptx_elt_type, "u4"))),
+ 1, 4)));
let IntrProperties = [IntrNoMem, IntrNoCallback, ImmArg<ArgIndex<pos>>,
Range<ArgIndex<pos>, 0, num_threads>];
}
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index f4f166c4018d0..6d73bce46da7c 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -1135,6 +1135,7 @@ def sp_selector_gen(op):
# (geom, type) -> allowed selector range
range_01 = {
("m16n8k32", "bf16"),
+ ("m16n8k32", "f16"),
("m16n8k16", "tf32"),
("m16n8k32", "u8"),
("m16n8k32", "s8"),
@@ -1154,6 +1155,11 @@ def sp_selector_gen(op):
"e2m1",
]:
return range(1)
+ if op.a.geom == "m16n8k128" and op.a.mma_type.ptx_type in [
+ "u4",
+ "s4",
+ ]:
+ return range(1)
return range(4)
|
@llvm/pr-subscribers-llvm-ir Author: Kirill Vedernikov (kvederni) ChangesThis PR fixes NVPTX tests in LLVM testing by adding more limitations for a sparsity selector in sparse MMA intrinsics. Full diff: https://github.com/llvm/llvm-project/pull/154984.diff 2 Files Affected:
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index cd7a0bc9c4b48..130fa27e4f870 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -2161,6 +2161,7 @@ class NVVM_MMA_SP<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
// The range [0;num_threads) is for the sparsity selector that indicates the threads
// which contribute metadata.
int num_threads = !if(!or(!and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "bf16")),
+ !and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "f16")),
!and(!eq(A.geom, "m16n8k16"), !eq(A.ptx_elt_type, "tf32")),
!and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "u8")),
!and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "s8")),
@@ -2175,7 +2176,11 @@ class NVVM_MMA_SP<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
!eq(A.ptx_elt_type, "e3m2"),
!eq(A.ptx_elt_type, "e2m3"),
!eq(A.ptx_elt_type, "e2m1"))),
- 1, 4));
+ 1,
+ !if(!and(!eq(A.geom, "m16n8k128"),
+ !or(!eq(A.ptx_elt_type, "s4"),
+ !eq(A.ptx_elt_type, "u4"))),
+ 1, 4)));
let IntrProperties = [IntrNoMem, IntrNoCallback, ImmArg<ArgIndex<pos>>,
Range<ArgIndex<pos>, 0, num_threads>];
}
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index f4f166c4018d0..6d73bce46da7c 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -1135,6 +1135,7 @@ def sp_selector_gen(op):
# (geom, type) -> allowed selector range
range_01 = {
("m16n8k32", "bf16"),
+ ("m16n8k32", "f16"),
("m16n8k16", "tf32"),
("m16n8k32", "u8"),
("m16n8k32", "s8"),
@@ -1154,6 +1155,11 @@ def sp_selector_gen(op):
"e2m1",
]:
return range(1)
+ if op.a.geom == "m16n8k128" and op.a.mma_type.ptx_type in [
+ "u4",
+ "s4",
+ ]:
+ return range(1)
return range(4)
|
This PR fixes NVPTX tests in LLVM testing by adding more limitations for a sparsity selector in sparse MMA intrinsics.
The previous PR that is merged to llvm:main is PR150950. The merge to llvm:main is d9c6b7b.