-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][amx] Prevent crash on invalid tile element type #155587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Fixes AMX tile type parser to prevent crashes on invalid element type.
@llvm/pr-subscribers-mlir-amx Author: Adam Siemieniuk (adam-smnk) ChangesFixes AMX tile type parser to prevent crashes on invalid element type. Full diff: https://github.com/llvm/llvm-project/pull/155587.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index 6f3110cdf00ef..68990ef0dc0c3 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -271,7 +271,9 @@ Type amx::TileType::parse(AsmParser &parser) {
if (parser.parseGreater())
return nullptr;
- return TileType::get(shape, elementType);
+ return TileType::getChecked(
+ [&] { return parser.emitError(parser.getNameLoc()); }, shape,
+ elementType);
}
void amx::TileType::print(AsmPrinter &os) const {
diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir
index a401770240d0a..5de9b3f82a868 100644
--- a/mlir/test/Dialect/AMX/invalid.mlir
+++ b/mlir/test/Dialect/AMX/invalid.mlir
@@ -16,6 +16,22 @@ func.func @tile_col_width() {
// -----
+func.func @tile_element_type() {
+ // expected-error@+1 {{failed to verify 'elementType'}}
+ %0 = amx.tile_zero : !amx.tile<8x8xi16>
+ return
+}
+
+// -----
+
+func.func @tile_rank() {
+ // expected-error@+1 {{'amx.tile_zero' op result #0 must be tile of}}
+ %0 = amx.tile_zero : !amx.tile<32xi8>
+ return
+}
+
+// -----
+
func.func @tile_col_4_byte_multiple() {
// expected-error@+1 {{'amx.tile_zero' op bad column width: 5}}
%0 = amx.tile_zero : !amx.tile<16x5xi8>
@@ -24,7 +40,7 @@ func.func @tile_col_4_byte_multiple() {
// -----
-func.func @load_base_tilesize(%arg0: memref<?x?xf32>) {
+func.func @load_base_tile_size(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'amx.tile_load' op bad column width: 68}}
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into !amx.tile<16x17xf32>
@@ -33,7 +49,7 @@ func.func @load_base_tilesize(%arg0: memref<?x?xf32>) {
// -----
-func.func @store_base_tilesize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf32>) {
+func.func @store_base_tile_size(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'amx.tile_store' op bad column width: 68}}
amx.tile_store %arg0[%0, %0], %arg1 : memref<?x?xf32>, !amx.tile<16x17xf32>
@@ -42,7 +58,7 @@ func.func @store_base_tilesize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf3
// -----
-func.func @load_base_indexsize(%arg0: memref<?x?xf32>) {
+func.func @load_base_index_size(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'amx.tile_load' op requires 2 indices}}
%1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into !amx.tile<16x16xf32>
@@ -51,7 +67,7 @@ func.func @load_base_indexsize(%arg0: memref<?x?xf32>) {
// -----
-func.func @store_base_indexsize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x16xf32>) {
+func.func @store_base_index_size(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x16xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'amx.tile_store' op requires 2 indices}}
amx.tile_store %arg0[%0], %arg1 : memref<?x?xf32>, !amx.tile<16x16xf32>
|
@llvm/pr-subscribers-mlir Author: Adam Siemieniuk (adam-smnk) ChangesFixes AMX tile type parser to prevent crashes on invalid element type. Full diff: https://github.com/llvm/llvm-project/pull/155587.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index 6f3110cdf00ef..68990ef0dc0c3 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -271,7 +271,9 @@ Type amx::TileType::parse(AsmParser &parser) {
if (parser.parseGreater())
return nullptr;
- return TileType::get(shape, elementType);
+ return TileType::getChecked(
+ [&] { return parser.emitError(parser.getNameLoc()); }, shape,
+ elementType);
}
void amx::TileType::print(AsmPrinter &os) const {
diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir
index a401770240d0a..5de9b3f82a868 100644
--- a/mlir/test/Dialect/AMX/invalid.mlir
+++ b/mlir/test/Dialect/AMX/invalid.mlir
@@ -16,6 +16,22 @@ func.func @tile_col_width() {
// -----
+func.func @tile_element_type() {
+ // expected-error@+1 {{failed to verify 'elementType'}}
+ %0 = amx.tile_zero : !amx.tile<8x8xi16>
+ return
+}
+
+// -----
+
+func.func @tile_rank() {
+ // expected-error@+1 {{'amx.tile_zero' op result #0 must be tile of}}
+ %0 = amx.tile_zero : !amx.tile<32xi8>
+ return
+}
+
+// -----
+
func.func @tile_col_4_byte_multiple() {
// expected-error@+1 {{'amx.tile_zero' op bad column width: 5}}
%0 = amx.tile_zero : !amx.tile<16x5xi8>
@@ -24,7 +40,7 @@ func.func @tile_col_4_byte_multiple() {
// -----
-func.func @load_base_tilesize(%arg0: memref<?x?xf32>) {
+func.func @load_base_tile_size(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'amx.tile_load' op bad column width: 68}}
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into !amx.tile<16x17xf32>
@@ -33,7 +49,7 @@ func.func @load_base_tilesize(%arg0: memref<?x?xf32>) {
// -----
-func.func @store_base_tilesize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf32>) {
+func.func @store_base_tile_size(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'amx.tile_store' op bad column width: 68}}
amx.tile_store %arg0[%0, %0], %arg1 : memref<?x?xf32>, !amx.tile<16x17xf32>
@@ -42,7 +58,7 @@ func.func @store_base_tilesize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf3
// -----
-func.func @load_base_indexsize(%arg0: memref<?x?xf32>) {
+func.func @load_base_index_size(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'amx.tile_load' op requires 2 indices}}
%1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into !amx.tile<16x16xf32>
@@ -51,7 +67,7 @@ func.func @load_base_indexsize(%arg0: memref<?x?xf32>) {
// -----
-func.func @store_base_indexsize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x16xf32>) {
+func.func @store_base_index_size(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x16xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'amx.tile_store' op requires 2 indices}}
amx.tile_store %arg0[%0], %arg1 : memref<?x?xf32>, !amx.tile<16x16xf32>
|
Fixes AMX tile type parser to prevent crashes on invalid element type.