diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp index cb3e8dc67a1ae..c2002d832a1ae 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp @@ -64,6 +64,8 @@ struct VectorMaskedLoadOpConverter final Value mask = maskedLoadOp.getMask(); Value base = maskedLoadOp.getBase(); Value iValue = maskedLoadOp.getPassThru(); + bool nontemporal = false; + std::optional alignment = maskedLoadOp.getAlignment(); auto indices = llvm::to_vector_of(maskedLoadOp.getIndices()); Value one = arith::ConstantOp::create(rewriter, loc, indexType, IntegerAttr::get(indexType, 1)); @@ -74,7 +76,8 @@ struct VectorMaskedLoadOpConverter final rewriter, loc, maskBit, [&](OpBuilder &builder, Location loc) { auto loadedValue = - memref::LoadOp::create(builder, loc, base, indices); + memref::LoadOp::create(builder, loc, base, indices, nontemporal, + alignment.value_or(0)); auto combinedValue = vector::InsertOp::create(builder, loc, loadedValue, iValue, i); scf::YieldOp::create(builder, loc, combinedValue.getResult()); @@ -132,6 +135,8 @@ struct VectorMaskedStoreOpConverter final Value mask = maskedStoreOp.getMask(); Value base = maskedStoreOp.getBase(); Value value = maskedStoreOp.getValueToStore(); + bool nontemporal = false; + std::optional alignment = maskedStoreOp.getAlignment(); auto indices = llvm::to_vector_of(maskedStoreOp.getIndices()); Value one = arith::ConstantOp::create(rewriter, loc, indexType, IntegerAttr::get(indexType, 1)); @@ -141,7 +146,8 @@ struct VectorMaskedStoreOpConverter final auto ifOp = scf::IfOp::create(rewriter, loc, maskBit, /*else=*/false); rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); auto extractedValue = vector::ExtractOp::create(rewriter, loc, value, i); - memref::StoreOp::create(rewriter, loc, extractedValue, base, indices); + memref::StoreOp::create(rewriter, loc, extractedValue, base, indices, + nontemporal, alignment.value_or(0)); rewriter.setInsertionPointAfter(ifOp); indices.back() = diff --git a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir index 3867f075af8e4..e74eb08339684 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir @@ -54,6 +54,22 @@ func.func @vector_maskedload(%arg0 : memref<4x5xf32>) -> vector<4xf32> { return %0: vector<4xf32> } +// CHECK-LABEL: @vector_maskedload_with_alignment +// CHECK: memref.load +// CHECK-SAME: {alignment = 8 : i64} +// CHECK: memref.load +// CHECK-SAME: {alignment = 8 : i64} +func.func @vector_maskedload_with_alignment(%arg0 : memref<4x5xf32>) -> vector<4xf32> { + %idx_0 = arith.constant 0 : index + %idx_1 = arith.constant 1 : index + %idx_4 = arith.constant 4 : index + %mask = vector.create_mask %idx_1 : vector<4xi1> + %s = arith.constant 0.0 : f32 + %pass_thru = vector.splat %s : vector<4xf32> + %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru {alignment = 8}: memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32> + return %0: vector<4xf32> +} + // CHECK-LABEL: @vector_maskedstore // CHECK-SAME: (%[[ARG0:.*]]: memref<4x5xf32>, %[[ARG1:.*]]: vector<4xf32>) { // CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index @@ -93,3 +109,17 @@ func.func @vector_maskedstore(%arg0 : memref<4x5xf32>, %arg1 : vector<4xf32>) { vector.maskedstore %arg0[%idx_0, %idx_4], %mask, %arg1 : memref<4x5xf32>, vector<4xi1>, vector<4xf32> return } + +// CHECK-LABEL: @vector_maskedstore_with_alignment +// CHECK: memref.store +// CHECK-SAME: {alignment = 8 : i64} +// CHECK: memref.store +// CHECK-SAME: {alignment = 8 : i64} +func.func @vector_maskedstore_with_alignment(%arg0 : memref<4x5xf32>, %arg1 : vector<4xf32>) { + %idx_0 = arith.constant 0 : index + %idx_1 = arith.constant 1 : index + %idx_4 = arith.constant 4 : index + %mask = vector.create_mask %idx_1 : vector<4xi1> + vector.maskedstore %arg0[%idx_0, %idx_4], %mask, %arg1 { alignment = 8 } : memref<4x5xf32>, vector<4xi1>, vector<4xf32> + return +}