Skip to content

Conversation

adam-smnk
Copy link
Contributor

Fixes AMX tile type parser to prevent crashes on invalid element type.

Fixes AMX tile type parser to prevent crashes on invalid element type.
@llvmbot
Copy link
Member

llvmbot commented Aug 27, 2025

@llvm/pr-subscribers-mlir-amx

Author: Adam Siemieniuk (adam-smnk)

Changes

Fixes AMX tile type parser to prevent crashes on invalid element type.


Full diff: https://github.com/llvm/llvm-project/pull/155587.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/AMX/IR/AMXDialect.cpp (+3-1)
  • (modified) mlir/test/Dialect/AMX/invalid.mlir (+20-4)
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index 6f3110cdf00ef..68990ef0dc0c3 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -271,7 +271,9 @@ Type amx::TileType::parse(AsmParser &parser) {
   if (parser.parseGreater())
     return nullptr;
 
-  return TileType::get(shape, elementType);
+  return TileType::getChecked(
+      [&] { return parser.emitError(parser.getNameLoc()); }, shape,
+      elementType);
 }
 
 void amx::TileType::print(AsmPrinter &os) const {
diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir
index a401770240d0a..5de9b3f82a868 100644
--- a/mlir/test/Dialect/AMX/invalid.mlir
+++ b/mlir/test/Dialect/AMX/invalid.mlir
@@ -16,6 +16,22 @@ func.func @tile_col_width() {
 
 // -----
 
+func.func @tile_element_type() {
+  // expected-error@+1 {{failed to verify 'elementType'}}
+  %0 = amx.tile_zero : !amx.tile<8x8xi16>
+  return
+}
+
+// -----
+
+func.func @tile_rank() {
+  // expected-error@+1 {{'amx.tile_zero' op result #0 must be tile of}}
+  %0 = amx.tile_zero : !amx.tile<32xi8>
+  return
+}
+
+// -----
+
 func.func @tile_col_4_byte_multiple() {
   // expected-error@+1 {{'amx.tile_zero' op bad column width: 5}}
   %0 = amx.tile_zero : !amx.tile<16x5xi8>
@@ -24,7 +40,7 @@ func.func @tile_col_4_byte_multiple() {
 
 // -----
 
-func.func @load_base_tilesize(%arg0: memref<?x?xf32>) {
+func.func @load_base_tile_size(%arg0: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
   // expected-error@+1 {{'amx.tile_load' op bad column width: 68}}
   %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into !amx.tile<16x17xf32>
@@ -33,7 +49,7 @@ func.func @load_base_tilesize(%arg0: memref<?x?xf32>) {
 
 // -----
 
-func.func @store_base_tilesize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf32>) {
+func.func @store_base_tile_size(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf32>) {
   %0 = arith.constant 0 : index
   // expected-error@+1 {{'amx.tile_store' op bad column width: 68}}
   amx.tile_store %arg0[%0, %0], %arg1 : memref<?x?xf32>, !amx.tile<16x17xf32>
@@ -42,7 +58,7 @@ func.func @store_base_tilesize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf3
 
 // -----
 
-func.func @load_base_indexsize(%arg0: memref<?x?xf32>) {
+func.func @load_base_index_size(%arg0: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
   // expected-error@+1 {{'amx.tile_load' op requires 2 indices}}
   %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into !amx.tile<16x16xf32>
@@ -51,7 +67,7 @@ func.func @load_base_indexsize(%arg0: memref<?x?xf32>) {
 
 // -----
 
-func.func @store_base_indexsize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x16xf32>) {
+func.func @store_base_index_size(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x16xf32>) {
   %0 = arith.constant 0 : index
   // expected-error@+1 {{'amx.tile_store' op requires 2 indices}}
   amx.tile_store %arg0[%0], %arg1 : memref<?x?xf32>, !amx.tile<16x16xf32>

@llvmbot
Copy link
Member

llvmbot commented Aug 27, 2025

@llvm/pr-subscribers-mlir

Author: Adam Siemieniuk (adam-smnk)

Changes

Fixes AMX tile type parser to prevent crashes on invalid element type.


Full diff: https://github.com/llvm/llvm-project/pull/155587.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/AMX/IR/AMXDialect.cpp (+3-1)
  • (modified) mlir/test/Dialect/AMX/invalid.mlir (+20-4)
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index 6f3110cdf00ef..68990ef0dc0c3 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -271,7 +271,9 @@ Type amx::TileType::parse(AsmParser &parser) {
   if (parser.parseGreater())
     return nullptr;
 
-  return TileType::get(shape, elementType);
+  return TileType::getChecked(
+      [&] { return parser.emitError(parser.getNameLoc()); }, shape,
+      elementType);
 }
 
 void amx::TileType::print(AsmPrinter &os) const {
diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir
index a401770240d0a..5de9b3f82a868 100644
--- a/mlir/test/Dialect/AMX/invalid.mlir
+++ b/mlir/test/Dialect/AMX/invalid.mlir
@@ -16,6 +16,22 @@ func.func @tile_col_width() {
 
 // -----
 
+func.func @tile_element_type() {
+  // expected-error@+1 {{failed to verify 'elementType'}}
+  %0 = amx.tile_zero : !amx.tile<8x8xi16>
+  return
+}
+
+// -----
+
+func.func @tile_rank() {
+  // expected-error@+1 {{'amx.tile_zero' op result #0 must be tile of}}
+  %0 = amx.tile_zero : !amx.tile<32xi8>
+  return
+}
+
+// -----
+
 func.func @tile_col_4_byte_multiple() {
   // expected-error@+1 {{'amx.tile_zero' op bad column width: 5}}
   %0 = amx.tile_zero : !amx.tile<16x5xi8>
@@ -24,7 +40,7 @@ func.func @tile_col_4_byte_multiple() {
 
 // -----
 
-func.func @load_base_tilesize(%arg0: memref<?x?xf32>) {
+func.func @load_base_tile_size(%arg0: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
   // expected-error@+1 {{'amx.tile_load' op bad column width: 68}}
   %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into !amx.tile<16x17xf32>
@@ -33,7 +49,7 @@ func.func @load_base_tilesize(%arg0: memref<?x?xf32>) {
 
 // -----
 
-func.func @store_base_tilesize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf32>) {
+func.func @store_base_tile_size(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf32>) {
   %0 = arith.constant 0 : index
   // expected-error@+1 {{'amx.tile_store' op bad column width: 68}}
   amx.tile_store %arg0[%0, %0], %arg1 : memref<?x?xf32>, !amx.tile<16x17xf32>
@@ -42,7 +58,7 @@ func.func @store_base_tilesize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf3
 
 // -----
 
-func.func @load_base_indexsize(%arg0: memref<?x?xf32>) {
+func.func @load_base_index_size(%arg0: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
   // expected-error@+1 {{'amx.tile_load' op requires 2 indices}}
   %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into !amx.tile<16x16xf32>
@@ -51,7 +67,7 @@ func.func @load_base_indexsize(%arg0: memref<?x?xf32>) {
 
 // -----
 
-func.func @store_base_indexsize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x16xf32>) {
+func.func @store_base_index_size(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x16xf32>) {
   %0 = arith.constant 0 : index
   // expected-error@+1 {{'amx.tile_store' op requires 2 indices}}
   amx.tile_store %arg0[%0], %arg1 : memref<?x?xf32>, !amx.tile<16x16xf32>

@adam-smnk adam-smnk merged commit a4f67f3 into llvm:main Aug 27, 2025
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants