@@ -752,6 +752,19 @@ void replaceTransferOp(ConversionPatternRewriter &rewriter,
752
752
Operation *op, ArrayRef<Value> operands, Value dataPtr,
753
753
Value mask);
754
754
755
+ LogicalResult getLLVMTypeAndAlignment (LLVMTypeConverter &typeConverter,
756
+ Type type, LLVM::LLVMType &llvmType,
757
+ unsigned &align) {
758
+ auto convertedType = typeConverter.convertType (type);
759
+ if (!convertedType)
760
+ return failure ();
761
+
762
+ llvmType = convertedType.template cast <LLVM::LLVMType>();
763
+ auto dataLayout = typeConverter.getDialect ()->getLLVMModule ().getDataLayout ();
764
+ align = dataLayout.getPrefTypeAlignment (llvmType.getUnderlyingType ());
765
+ return success ();
766
+ }
767
+
755
768
template <>
756
769
void replaceTransferOp<TransferReadOp>(ConversionPatternRewriter &rewriter,
757
770
LLVMTypeConverter &typeConverter,
@@ -764,10 +777,13 @@ void replaceTransferOp<TransferReadOp>(ConversionPatternRewriter &rewriter,
764
777
Value fill = rewriter.create <SplatOp>(loc, fillType, xferOp.padding ());
765
778
fill = rewriter.create <LLVM::DialectCastOp>(loc, toLLVMTy (fillType), fill);
766
779
767
- auto vecTy = toLLVMTy (xferOp.getVectorType ()).template cast <LLVM::LLVMType>();
768
- rewriter.replaceOpWithNewOp <LLVM::MaskedLoadOp>(
769
- op, vecTy, dataPtr, mask, ValueRange{fill},
770
- rewriter.getI32IntegerAttr (1 ));
780
+ LLVM::LLVMType vecTy;
781
+ unsigned align;
782
+ if (succeeded (getLLVMTypeAndAlignment (typeConverter, xferOp.getVectorType (),
783
+ vecTy, align)))
784
+ rewriter.replaceOpWithNewOp <LLVM::MaskedLoadOp>(
785
+ op, vecTy, dataPtr, mask, ValueRange{fill},
786
+ rewriter.getI32IntegerAttr (align));
771
787
}
772
788
773
789
template <>
@@ -777,8 +793,14 @@ void replaceTransferOp<TransferWriteOp>(ConversionPatternRewriter &rewriter,
777
793
ArrayRef<Value> operands, Value dataPtr,
778
794
Value mask) {
779
795
auto adaptor = TransferWriteOpOperandAdaptor (operands);
780
- rewriter.replaceOpWithNewOp <LLVM::MaskedStoreOp>(
781
- op, adaptor.vector (), dataPtr, mask, rewriter.getI32IntegerAttr (1 ));
796
+
797
+ auto xferOp = cast<TransferWriteOp>(op);
798
+ LLVM::LLVMType vecTy;
799
+ unsigned align;
800
+ if (succeeded (getLLVMTypeAndAlignment (typeConverter, xferOp.getVectorType (),
801
+ vecTy, align)))
802
+ rewriter.replaceOpWithNewOp <LLVM::MaskedStoreOp>(
803
+ op, adaptor.vector (), dataPtr, mask, rewriter.getI32IntegerAttr (align));
782
804
}
783
805
784
806
static TransferReadOpOperandAdaptor
0 commit comments