diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index a5f1b28152b9b..d2c6ba557b9bb 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -260,12 +260,37 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { opToErase.push_back(read.getOperation()); } +/// Returns a copy of `shape` without unit dims. +static SmallVector getReducedShape(ArrayRef shape) { + SmallVector reducedShape; + llvm::copy_if(shape, std::back_inserter(reducedShape), + [](int64_t dimSize) { return dimSize != 1; }); + return reducedShape; +} + +/// Converts OpFoldResults to int64_t shape without unit dims. +static SmallVector getReducedShape(ArrayRef mixedSizes) { + SmallVector reducedShape; + for (const auto size : mixedSizes) { + if (llvm::dyn_cast_if_present(size)) { + reducedShape.push_back(ShapedType::kDynamic); + continue; + } + + auto value = cast(size.get()).getValue(); + if (value == 1) + continue; + reducedShape.push_back(value.getSExtValue()); + } + return reducedShape; +} + /// Drops unit dimensions from the input MemRefType. -static MemRefType dropUnitDims(MemRefType inputType, ArrayRef offsets, - ArrayRef sizes, - ArrayRef strides) { - SmallVector targetShape = llvm::to_vector( - llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; })); +static MemRefType dropUnitDims(MemRefType inputType, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) { + auto targetShape = getReducedShape(sizes); Type rankReducedType = memref::SubViewOp::inferRankReducedResultType( targetShape, inputType, offsets, sizes, strides); return canonicalizeStridedLayout(cast(rankReducedType)); @@ -277,17 +302,18 @@ static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, mlir::Location loc, Value input) { MemRefType inputType = cast(input.getType()); - assert(inputType.hasStaticShape()); - SmallVector subViewOffsets(inputType.getRank(), 0); - SmallVector subViewStrides(inputType.getRank(), 1); - ArrayRef subViewSizes = inputType.getShape(); - MemRefType resultType = - dropUnitDims(inputType, subViewOffsets, subViewSizes, subViewStrides); + SmallVector offsets(inputType.getRank(), + rewriter.getIndexAttr(0)); + SmallVector sizes = memref::getMixedSizes(rewriter, loc, input); + SmallVector strides(inputType.getRank(), + rewriter.getIndexAttr(1)); + MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides); + if (canonicalizeStridedLayout(resultType) == canonicalizeStridedLayout(inputType)) return input; - return rewriter.create( - loc, resultType, input, subViewOffsets, subViewSizes, subViewStrides); + return rewriter.create(loc, resultType, input, offsets, + sizes, strides); } /// Returns the number of dims that aren't unit dims. @@ -295,12 +321,44 @@ static int getReducedRank(ArrayRef shape) { return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; }); } -/// Returns a copy of `shape` without unit dims. -static SmallVector getReducedShape(ArrayRef shape) { - SmallVector reducedShape; - llvm::copy_if(shape, std::back_inserter(reducedShape), - [](int64_t dimSize) { return dimSize != 1; }); - return reducedShape; +/// Trims non-scalable one dimensions from `oldType` and returns the result +/// type. +static VectorType trimNonScalableUnitDims(VectorType oldType) { + SmallVector newShape; + SmallVector newScalableDims; + for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) { + if (dimSize == 1 && !oldType.getScalableDims()[dimIdx]) + continue; + newShape.push_back(dimSize); + newScalableDims.push_back(oldType.getScalableDims()[dimIdx]); + } + return VectorType::get(newShape, oldType.getElementType(), newScalableDims); +} + +// Rewrites vector.create_mask 'op' to drop non-scalable one dimensions. +static FailureOr +createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, + vector::CreateMaskOp op) { + auto type = op.getType(); + auto reducedType = trimNonScalableUnitDims(type); + if (reducedType.getRank() == type.getRank()) + return failure(); + + SmallVector reducedOperands; + for (auto [dim, dimIsScalable, operand] : llvm::zip_equal( + type.getShape(), type.getScalableDims(), op.getOperands())) { + if (dim == 1 && !dimIsScalable) { + // If the mask for the unit dim is not a constant of 1, do nothing. + auto constant = operand.getDefiningOp(); + if (!constant || (constant.value() != 1)) + return failure(); + continue; + } + reducedOperands.push_back(operand); + } + return rewriter + .create(loc, reducedType, reducedOperands) + .getResult(); } namespace { @@ -320,9 +378,7 @@ class TransferReadDropUnitDimsPattern Value source = transferReadOp.getSource(); MemRefType sourceType = dyn_cast(source.getType()); // TODO: support tensor types. - if (!sourceType || !sourceType.hasStaticShape()) - return failure(); - if (sourceType.getNumElements() != vectorType.getNumElements()) + if (!sourceType) return failure(); // TODO: generalize this pattern, relax the requirements here. if (transferReadOp.hasOutOfBoundsDim()) @@ -335,23 +391,38 @@ class TransferReadDropUnitDimsPattern return failure(); // Check if the reduced vector shape matches the reduced source shape. // Otherwise, this case is not supported yet. - int vectorReducedRank = getReducedRank(vectorType.getShape()); - if (reducedRank != vectorReducedRank) + auto reducedVectorType = trimNonScalableUnitDims(vectorType); + if (reducedRank != reducedVectorType.getRank()) return failure(); if (llvm::any_of(transferReadOp.getIndices(), [](Value v) { return getConstantIntValue(v) != static_cast(0); })) return failure(); + + Value maskOp = transferReadOp.getMask(); + if (maskOp) { + auto createMaskOp = maskOp.getDefiningOp(); + if (!createMaskOp) + return rewriter.notifyMatchFailure( + transferReadOp, "unsupported mask op, only 'vector.create_mask' is " + "currently supported"); + FailureOr rankReducedCreateMask = + createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp); + if (failed(rankReducedCreateMask)) + return failure(); + maskOp = *rankReducedCreateMask; + } + Value reducedShapeSource = rankReducingSubviewDroppingUnitDims(rewriter, loc, source); Value c0 = rewriter.create(loc, 0); SmallVector zeros(reducedRank, c0); auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); - auto reducedVectorType = VectorType::get( - getReducedShape(vectorType.getShape()), vectorType.getElementType()); - + SmallVector inBounds(reducedVectorType.getRank(), true); auto newTransferReadOp = rewriter.create( - loc, reducedVectorType, reducedShapeSource, zeros, identityMap); + loc, reducedVectorType, reducedShapeSource, zeros, identityMap, + transferReadOp.getPadding(), maskOp, + rewriter.getBoolArrayAttr(inBounds)); auto shapeCast = rewriter.createOrFold( loc, vectorType, newTransferReadOp); rewriter.replaceOp(transferReadOp, shapeCast); diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir index 2852e301888cc..735915d435653 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir @@ -82,6 +82,118 @@ func.func @transfer_write_and_vector_rank_reducing_to_0d( // CHECK: %[[SHCAST:.+]] = vector.shape_cast %[[VECTOR]] : vector<1x1x1xf32> to vector // CHECK: vector.transfer_write %[[SHCAST]], %[[SUBVIEW]]{{.*}} : vector, memref +func.func @transfer_read_dynamic_rank_reducing( + %arg : memref>) -> vector<[16]x1xi8> { + %c0 = arith.constant 0 : index + %pad = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0], %pad {in_bounds = [true, true]} : + memref>, vector<[16]x1xi8> + return %v : vector<[16]x1xi8> +} +// CHECK-LABEL: func @transfer_read_dynamic_rank_reducing +// CHECK-SAME: %[[ARG:.+]]: memref> +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref to memref +// CHECK: vector.transfer_read %[[SUBVIEW]]{{.*}} : memref, vector<[16]xi8> + +func.func @masked_transfer_read_dynamic_rank_reducing_1( + %arg : memref>, + %mask_dim0 : index) -> vector<[16]x1xi8> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %pad = arith.constant 0 : i8 + %mask = vector.create_mask %mask_dim0, %c1 : vector<[16]x1xi1> + %v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : + memref>, vector<[16]x1xi8> + return %v : vector<[16]x1xi8> +} +// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_1 +// CHECK-SAME: %[[ARG:.+]]: memref +// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG]], %[[C0]] : memref> +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref to memref +// CHECK: vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true]} : memref, vector<[16]xi8> + +func.func @masked_transfer_read_dynamic_rank_reducing_2( + %arg : memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>, + %mask_dim1 : index, %mask_dim4 : index) -> vector<1x[1]x3x1x[16]x1xi8> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %pad = arith.constant 0 : i8 + %mask = vector.create_mask %c1, %mask_dim1, %c2, %c1, %mask_dim4, %c1 : vector<1x[1]x3x1x[16]x1xi1> + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0, %c0], %pad, %mask {in_bounds = [true, true, true, true, true, true]} : + memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>>, vector<1x[1]x3x1x[16]x1xi8> + return %v : vector<1x[1]x3x1x[16]x1xi8> +} +// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_2 +// CHECK-SAME: %[[ARG:.+]]: memref<1x?x3x1x?x1xi8 +// CHECK-SAME: %[[MASK_DIM1:.+]]: index, %[[MASK_DIM4:.+]]: index +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[PAD:.+]] = arith.constant 0 : i8 +// CHECK: %[[MASK:.+]] = vector.create_mask %[[MASK_DIM1]], %[[C2]], %[[MASK_DIM4]] : vector<[1]x3x[16]xi1> +// CHECK: %[[DIM1:.+]] = memref.dim %[[ARG]], %[[C1]] : memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>> +// CHECK: %[[DIM4:.+]] = memref.dim %[[ARG]], %[[C4]] : memref<1x?x3x1x?x1xi8, strided<[?, ?, ?, ?, ?, ?], offset: ?>> +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, %[[DIM1]], 3, 1, %[[DIM4]], 1] [1, 1, 1, 1, 1, 1] : memref<1x?x3x1x?x1xi8, {{.*}}> to memref +// CHECK: vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true, true, true]} : memref, vector<[1]x3x[16]xi8> + +/// Only masks operands of vector.create_mask are currently supported. +func.func @unsupported_masked_transfer_read_dynamic_rank_reducing_1( + %arg : memref>, + %mask : vector<[16]x1xi1>) -> vector<[16]x1xi8> { + %c0 = arith.constant 0 : index + %pad = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : + memref>, vector<[16]x1xi8> + return %v : vector<[16]x1xi8> +} +// CHECK-LABEL: func @unsupported_masked_transfer_read_dynamic_rank_reducing_1 +// CHECK-SAME: %[[ARG:.+]]: memref>, + %mask_dim0 : index, %mask_dim1 : index) -> vector<[16]x1xi8> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %pad = arith.constant 0 : i8 + %mask = vector.create_mask %mask_dim0, %mask_dim1 : vector<[16]x1xi1> + %v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : + memref>, vector<[16]x1xi8> + return %v : vector<[16]x1xi8> +} +// CHECK-LABEL: func @unsupported_masked_transfer_read_dynamic_rank_reducing_2 +// CHECK-SAME: %[[ARG:.+]]: memref + +/// Unit dim must be non-scalable. +func.func @masked_transfer_read_dynamic_rank_reducing_scalable_unit_dim( + %arg : memref>, + %mask_dim0 : index) -> vector<[16]x[1]xi8> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %pad = arith.constant 0 : i8 + %mask = vector.create_mask %mask_dim0, %c1 : vector<[16]x[1]xi1> + %v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : + memref>, vector<[16]x[1]xi8> + return %v : vector<[16]x[1]xi8> +} +// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_scalable_unit_dim +// CHECK-SAME: %[[ARG:.+]]: memref + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) { transform.apply_patterns to %func_op {