Skip to content

Commit a23f190

Browse files
whchungftynse
authored andcommitted
[mlir][vector] set alignment when lowering transfer_read and transfer_write.
When emitting masked load / store, set alignment from data layout. Differential Revision: https://reviews.llvm.org/D79246
1 parent a31f4c5 commit a23f190

File tree

2 files changed

+30
-8
lines changed

2 files changed

+30
-8
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,19 @@ void replaceTransferOp(ConversionPatternRewriter &rewriter,
752752
Operation *op, ArrayRef<Value> operands, Value dataPtr,
753753
Value mask);
754754

755+
LogicalResult getLLVMTypeAndAlignment(LLVMTypeConverter &typeConverter,
756+
Type type, LLVM::LLVMType &llvmType,
757+
unsigned &align) {
758+
auto convertedType = typeConverter.convertType(type);
759+
if (!convertedType)
760+
return failure();
761+
762+
llvmType = convertedType.template cast<LLVM::LLVMType>();
763+
auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout();
764+
align = dataLayout.getPrefTypeAlignment(llvmType.getUnderlyingType());
765+
return success();
766+
}
767+
755768
template <>
756769
void replaceTransferOp<TransferReadOp>(ConversionPatternRewriter &rewriter,
757770
LLVMTypeConverter &typeConverter,
@@ -764,10 +777,13 @@ void replaceTransferOp<TransferReadOp>(ConversionPatternRewriter &rewriter,
764777
Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
765778
fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
766779

767-
auto vecTy = toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
768-
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
769-
op, vecTy, dataPtr, mask, ValueRange{fill},
770-
rewriter.getI32IntegerAttr(1));
780+
LLVM::LLVMType vecTy;
781+
unsigned align;
782+
if (succeeded(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
783+
vecTy, align)))
784+
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
785+
op, vecTy, dataPtr, mask, ValueRange{fill},
786+
rewriter.getI32IntegerAttr(align));
771787
}
772788

773789
template <>
@@ -777,8 +793,14 @@ void replaceTransferOp<TransferWriteOp>(ConversionPatternRewriter &rewriter,
777793
ArrayRef<Value> operands, Value dataPtr,
778794
Value mask) {
779795
auto adaptor = TransferWriteOpOperandAdaptor(operands);
780-
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
781-
op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(1));
796+
797+
auto xferOp = cast<TransferWriteOp>(op);
798+
LLVM::LLVMType vecTy;
799+
unsigned align;
800+
if (succeeded(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
801+
vecTy, align)))
802+
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
803+
op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(align));
782804
}
783805

784806
static TransferReadOpOperandAdaptor

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,7 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
818818
// CHECK: %[[PASS_THROUGH:.*]] = llvm.mlir.constant(dense<7.000000e+00> :
819819
// CHECK-SAME: vector<17xf32>) : !llvm<"<17 x float>">
820820
// CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[vecPtr]], %[[mask]],
821-
// CHECK-SAME: %[[PASS_THROUGH]] {alignment = 1 : i32} :
821+
// CHECK-SAME: %[[PASS_THROUGH]] {alignment = 128 : i32} :
822822
// CHECK-SAME: (!llvm<"<17 x float>*">, !llvm<"<17 x i1>">, !llvm<"<17 x float>">) -> !llvm<"<17 x float>">
823823

824824
//
@@ -850,7 +850,7 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
850850
//
851851
// 5. Rewrite as a masked write.
852852
// CHECK: llvm.intr.masked.store %[[loaded]], %[[vecPtr_b]], %[[mask_b]]
853-
// CHECK-SAME: {alignment = 1 : i32} :
853+
// CHECK-SAME: {alignment = 128 : i32} :
854854
// CHECK-SAME: !llvm<"<17 x float>">, !llvm<"<17 x i1>"> into !llvm<"<17 x float>*">
855855

856856
func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: index) -> vector<17xf32> {

0 commit comments

Comments
 (0)