Skip to content

Commit b2c79c5

Browse files
[mlir][VectorOps] Extend VectorTransfer lowering to n-D memref with minor identity map
Summary: This revision extends the lowering of vector transfers to work with n-D memref and 1-D vector where the permutation map is an identity on the most minor dimensions (1 for now). Differential Revision: https://reviews.llvm.org/D78925
1 parent a486edd commit b2c79c5

File tree

2 files changed

+58
-8
lines changed

2 files changed

+58
-8
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,16 @@ getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef<Value> operands) {
791791
return TransferWriteOpOperandAdaptor(operands);
792792
}
793793

794+
bool isMinorIdentity(AffineMap map, unsigned rank) {
795+
if (map.getNumResults() < rank)
796+
return false;
797+
unsigned startDim = map.getNumDims() - rank;
798+
for (unsigned i = 0; i < rank; ++i)
799+
if (map.getResult(i) != getAffineDimExpr(startDim + i, map.getContext()))
800+
return false;
801+
return true;
802+
}
803+
794804
/// Conversion pattern that converts a 1-D vector transfer read/write op in a
795805
/// sequence of:
796806
/// 1. Bitcast to vector form.
@@ -810,9 +820,12 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
810820
ConversionPatternRewriter &rewriter) const override {
811821
auto xferOp = cast<ConcreteOp>(op);
812822
auto adaptor = getTransferOpAdapter(xferOp, operands);
813-
if (xferOp.getMemRefType().getRank() != 1)
823+
824+
if (xferOp.getVectorType().getRank() > 1 ||
825+
llvm::size(xferOp.indices()) == 0)
814826
return failure();
815-
if (!xferOp.permutation_map().isIdentity())
827+
if (!isMinorIdentity(xferOp.permutation_map(),
828+
xferOp.getVectorType().getRank()))
816829
return failure();
817830

818831
auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
@@ -844,17 +857,18 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
844857
loc, toLLVMTy(vectorCmpType), linearIndices);
845858

846859
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
847-
Value offsetIndex = *(xferOp.indices().begin());
848-
offsetIndex = rewriter.create<IndexCastOp>(
849-
loc, vectorCmpType.getElementType(), offsetIndex);
860+
// TODO(ntv, ajcbik): when the leaf transfer rank is k > 1 we need the last
861+
// `k` dimensions here.
862+
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
863+
Value offsetIndex = *(xferOp.indices().begin() + lastIndex);
864+
offsetIndex = rewriter.create<IndexCastOp>(loc, i64Type, offsetIndex);
850865
Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex);
851866
Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices);
852867

853868
// 4. Let dim the memref dimension, compute the vector comparison mask:
854869
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
855-
Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), 0);
856-
dim =
857-
rewriter.create<IndexCastOp>(loc, vectorCmpType.getElementType(), dim);
870+
Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
871+
dim = rewriter.create<IndexCastOp>(loc, i64Type, dim);
858872
dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim);
859873
Value mask =
860874
rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim);

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,3 +828,39 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
828828
// CHECK: llvm.intr.masked.store %[[loaded]], %[[vecPtr_b]], %[[mask_b]]
829829
// CHECK-SAME: {alignment = 1 : i32} :
830830
// CHECK-SAME: !llvm<"<17 x float>">, !llvm<"<17 x i1>"> into !llvm<"<17 x float>*">
831+
832+
func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: index) -> vector<17xf32> {
833+
%f7 = constant 7.0: f32
834+
%f = vector.transfer_read %A[%base0, %base1], %f7
835+
{permutation_map = affine_map<(d0, d1) -> (d1)>} :
836+
memref<?x?xf32>, vector<17xf32>
837+
return %f: vector<17xf32>
838+
}
839+
// CHECK-LABEL: func @transfer_read_2d_to_1d
840+
// CHECK-SAME: %[[BASE_0:[a-zA-Z0-9]*]]: !llvm.i64, %[[BASE_1:[a-zA-Z0-9]*]]: !llvm.i64) -> !llvm<"<17 x float>">
841+
//
842+
// Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
843+
// CHECK: %[[offsetVec:.*]] = llvm.mlir.undef : !llvm<"<17 x i64>">
844+
// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
845+
// Here we check we properly use %BASE_1
846+
// CHECK: %[[offsetVec2:.*]] = llvm.insertelement %[[BASE_1]], %[[offsetVec]][%[[c0]] :
847+
// CHECK-SAME: !llvm.i32] : !llvm<"<17 x i64>">
848+
// CHECK: %[[offsetVec3:.*]] = llvm.shufflevector %[[offsetVec2]], %{{.*}} [
849+
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
850+
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
851+
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32] :
852+
//
853+
// Let dim the memref dimension, compute the vector comparison mask:
854+
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
855+
// Here we check we properly use %DIM[1]
856+
// CHECK: %[[DIM:.*]] = llvm.extractvalue %{{.*}}[3, 1] :
857+
// CHECK-SAME: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
858+
// CHECK: %[[dimVec:.*]] = llvm.mlir.undef : !llvm<"<17 x i64>">
859+
// CHECK: %[[c01:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
860+
// CHECK: %[[dimVec2:.*]] = llvm.insertelement %[[DIM]], %[[dimVec]][%[[c01]] :
861+
// CHECK-SAME: !llvm.i32] : !llvm<"<17 x i64>">
862+
// CHECK: %[[dimVec3:.*]] = llvm.shufflevector %[[dimVec2]], %{{.*}} [
863+
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
864+
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
865+
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32] :
866+
// CHECK-SAME: !llvm<"<17 x i64>">, !llvm<"<17 x i64>">

0 commit comments

Comments
 (0)