Skip to content

Conversation

amd-eochoalo
Copy link
Contributor

@amd-eochoalo amd-eochoalo commented Aug 27, 2025

Propagate alignment from vector.maskedload and vector.maskedstore to memref.load and memref.store during VectorEmulateMaskedLoadStore pass.

Copy link

github-actions bot commented Aug 27, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@amd-eochoalo amd-eochoalo force-pushed the eochoa/2025-08-27/vector-emulate-masked-load-alignment branch from d2dc9bd to 52d6a7b Compare August 27, 2025 16:27
@amd-eochoalo amd-eochoalo self-assigned this Aug 27, 2025
@amd-eochoalo amd-eochoalo marked this pull request as ready for review August 27, 2025 17:13
@llvmbot
Copy link
Member

llvmbot commented Aug 27, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Erick Ochoa Lopez (amd-eochoalo)

Changes

Improvements that may follow in another NFC pull request


Full diff: https://github.com/llvm/llvm-project/pull/155648.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp (+10-3)
  • (modified) mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir (+30)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
index cb3e8dc67a1ae..a00a1352ce40d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
@@ -64,6 +64,9 @@ struct VectorMaskedLoadOpConverter final
     Value mask = maskedLoadOp.getMask();
     Value base = maskedLoadOp.getBase();
     Value iValue = maskedLoadOp.getPassThru();
+    bool nontemporal = false;
+    auto alignment = maskedLoadOp.getAlignment();
+    uint64_t align = alignment.has_value() ? alignment.value() : 0;
     auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
     Value one = arith::ConstantOp::create(rewriter, loc, indexType,
                                           IntegerAttr::get(indexType, 1));
@@ -73,8 +76,8 @@ struct VectorMaskedLoadOpConverter final
       auto ifOp = scf::IfOp::create(
           rewriter, loc, maskBit,
           [&](OpBuilder &builder, Location loc) {
-            auto loadedValue =
-                memref::LoadOp::create(builder, loc, base, indices);
+            auto loadedValue = memref::LoadOp::create(
+                builder, loc, base, indices, nontemporal, align);
             auto combinedValue =
                 vector::InsertOp::create(builder, loc, loadedValue, iValue, i);
             scf::YieldOp::create(builder, loc, combinedValue.getResult());
@@ -132,6 +135,9 @@ struct VectorMaskedStoreOpConverter final
     Value mask = maskedStoreOp.getMask();
     Value base = maskedStoreOp.getBase();
     Value value = maskedStoreOp.getValueToStore();
+    bool nontemporal = false;
+    auto alignment = maskedStoreOp.getAlignment();
+    uint64_t align = alignment.has_value() ? alignment.value() : 0;
     auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
     Value one = arith::ConstantOp::create(rewriter, loc, indexType,
                                           IntegerAttr::get(indexType, 1));
@@ -141,7 +147,8 @@ struct VectorMaskedStoreOpConverter final
       auto ifOp = scf::IfOp::create(rewriter, loc, maskBit, /*else=*/false);
       rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
       auto extractedValue = vector::ExtractOp::create(rewriter, loc, value, i);
-      memref::StoreOp::create(rewriter, loc, extractedValue, base, indices);
+      memref::StoreOp::create(rewriter, loc, extractedValue, base, indices,
+                              nontemporal, align);
 
       rewriter.setInsertionPointAfter(ifOp);
       indices.back() =
diff --git a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
index 3867f075af8e4..e74eb08339684 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
@@ -54,6 +54,22 @@ func.func @vector_maskedload(%arg0 : memref<4x5xf32>) -> vector<4xf32> {
   return %0: vector<4xf32>
 }
 
+// CHECK-LABEL:  @vector_maskedload_with_alignment
+//       CHECK:       memref.load
+//       CHECK-SAME:  {alignment = 8 : i64}
+//       CHECK:       memref.load
+//       CHECK-SAME:  {alignment = 8 : i64}
+func.func @vector_maskedload_with_alignment(%arg0 : memref<4x5xf32>) -> vector<4xf32> {
+  %idx_0 = arith.constant 0 : index
+  %idx_1 = arith.constant 1 : index
+  %idx_4 = arith.constant 4 : index
+  %mask = vector.create_mask %idx_1 : vector<4xi1>
+  %s = arith.constant 0.0 : f32
+  %pass_thru = vector.splat %s : vector<4xf32>
+  %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru {alignment = 8}: memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  return %0: vector<4xf32>
+}
+
 // CHECK-LABEL:  @vector_maskedstore
 //  CHECK-SAME:  (%[[ARG0:.*]]: memref<4x5xf32>, %[[ARG1:.*]]: vector<4xf32>) {
 //   CHECK-DAG:  %[[C7:.*]] = arith.constant 7 : index
@@ -93,3 +109,17 @@ func.func @vector_maskedstore(%arg0 : memref<4x5xf32>, %arg1 : vector<4xf32>) {
   vector.maskedstore %arg0[%idx_0, %idx_4], %mask, %arg1 : memref<4x5xf32>, vector<4xi1>, vector<4xf32>
   return
 }
+
+// CHECK-LABEL:  @vector_maskedstore_with_alignment
+//       CHECK:       memref.store
+//       CHECK-SAME:  {alignment = 8 : i64}
+//       CHECK:       memref.store
+//       CHECK-SAME:  {alignment = 8 : i64}
+func.func @vector_maskedstore_with_alignment(%arg0 : memref<4x5xf32>, %arg1 : vector<4xf32>) {
+  %idx_0 = arith.constant 0 : index
+  %idx_1 = arith.constant 1 : index
+  %idx_4 = arith.constant 4 : index
+  %mask = vector.create_mask %idx_1 : vector<4xi1>
+  vector.maskedstore %arg0[%idx_0, %idx_4], %mask, %arg1 { alignment = 8 } : memref<4x5xf32>, vector<4xi1>, vector<4xf32>
+  return
+}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants