From ee12c53a4c89e1c12b6534846af0c5c9c98d8328 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Wed, 27 Aug 2025 09:58:51 +0200 Subject: [PATCH] [mlir][amx] Prevent crash on invalid tile element type Fixes AMX tile type parser to prevent crashes on invalid element type. --- mlir/lib/Dialect/AMX/IR/AMXDialect.cpp | 4 +++- mlir/test/Dialect/AMX/invalid.mlir | 24 ++++++++++++++++++++---- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp index 6f3110cdf00ef..68990ef0dc0c3 100644 --- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp +++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp @@ -271,7 +271,9 @@ Type amx::TileType::parse(AsmParser &parser) { if (parser.parseGreater()) return nullptr; - return TileType::get(shape, elementType); + return TileType::getChecked( + [&] { return parser.emitError(parser.getNameLoc()); }, shape, + elementType); } void amx::TileType::print(AsmPrinter &os) const { diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir index a401770240d0a..5de9b3f82a868 100644 --- a/mlir/test/Dialect/AMX/invalid.mlir +++ b/mlir/test/Dialect/AMX/invalid.mlir @@ -16,6 +16,22 @@ func.func @tile_col_width() { // ----- +func.func @tile_element_type() { + // expected-error@+1 {{failed to verify 'elementType'}} + %0 = amx.tile_zero : !amx.tile<8x8xi16> + return +} + +// ----- + +func.func @tile_rank() { + // expected-error@+1 {{'amx.tile_zero' op result #0 must be tile of}} + %0 = amx.tile_zero : !amx.tile<32xi8> + return +} + +// ----- + func.func @tile_col_4_byte_multiple() { // expected-error@+1 {{'amx.tile_zero' op bad column width: 5}} %0 = amx.tile_zero : !amx.tile<16x5xi8> @@ -24,7 +40,7 @@ func.func @tile_col_4_byte_multiple() { // ----- -func.func @load_base_tilesize(%arg0: memref) { +func.func @load_base_tile_size(%arg0: memref) { %0 = arith.constant 0 : index // expected-error@+1 {{'amx.tile_load' op bad column width: 68}} %1 = amx.tile_load %arg0[%0, %0] : memref into !amx.tile<16x17xf32> @@ -33,7 +49,7 @@ func.func @load_base_tilesize(%arg0: memref) { // ----- -func.func @store_base_tilesize(%arg0: memref, %arg1: !amx.tile<16x17xf32>) { +func.func @store_base_tile_size(%arg0: memref, %arg1: !amx.tile<16x17xf32>) { %0 = arith.constant 0 : index // expected-error@+1 {{'amx.tile_store' op bad column width: 68}} amx.tile_store %arg0[%0, %0], %arg1 : memref, !amx.tile<16x17xf32> @@ -42,7 +58,7 @@ func.func @store_base_tilesize(%arg0: memref, %arg1: !amx.tile<16x17xf3 // ----- -func.func @load_base_indexsize(%arg0: memref) { +func.func @load_base_index_size(%arg0: memref) { %0 = arith.constant 0 : index // expected-error@+1 {{'amx.tile_load' op requires 2 indices}} %1 = amx.tile_load %arg0[%0] : memref into !amx.tile<16x16xf32> @@ -51,7 +67,7 @@ func.func @load_base_indexsize(%arg0: memref) { // ----- -func.func @store_base_indexsize(%arg0: memref, %arg1: !amx.tile<16x16xf32>) { +func.func @store_base_index_size(%arg0: memref, %arg1: !amx.tile<16x16xf32>) { %0 = arith.constant 0 : index // expected-error@+1 {{'amx.tile_store' op requires 2 indices}} amx.tile_store %arg0[%0], %arg1 : memref, !amx.tile<16x16xf32>