diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index f47e356d6fe14..0f96442bc3756 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5575,6 +5575,34 @@ LogicalResult ShapeCastOp::verify() { return success(); } +namespace { + +/// Return true if `transpose` does not permute a pair of non-unit dims. +/// By `order preserving` we mean that the flattened versions of the input and +/// output vectors are (numerically) identical. In other words `transpose` is +/// effectively a shape cast. +bool isOrderPreserving(TransposeOp transpose) { + ArrayRef permutation = transpose.getPermutation(); + VectorType sourceType = transpose.getSourceVectorType(); + ArrayRef inShape = sourceType.getShape(); + ArrayRef inDimIsScalable = sourceType.getScalableDims(); + auto isNonScalableUnitDim = [&](int64_t dim) { + return inShape[dim] == 1 && !inDimIsScalable[dim]; + }; + int64_t current = 0; + for (auto p : permutation) { + if (!isNonScalableUnitDim(p)) { + if (p < current) { + return false; + } + current = p; + } + } + return true; +} + +} // namespace + OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { VectorType resultType = getType(); @@ -5583,17 +5611,32 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { if (getSource().getType() == resultType) return getSource(); - // Y = shape_cast(shape_cast(X))) - // -> X, if X and Y have same type - // -> shape_cast(X) otherwise. - if (auto otherOp = getSource().getDefiningOp()) { - VectorType srcType = otherOp.getSource().getType(); - if (resultType == srcType) - return otherOp.getSource(); - setOperand(otherOp.getSource()); + // shape_cast(shape_cast(x)) -> shape_cast(x) + if (auto precedingShapeCast = getSource().getDefiningOp()) { + setOperand(precedingShapeCast.getSource()); return getResult(); } + // shape_cast(transpose(x)) -> shape_cast(x) + if (auto transpose = getSource().getDefiningOp()) { + // This folder does + // shape_cast(transpose) -> shape_cast + // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does + // shape_cast -> shape_cast(transpose) + // i.e. the complete opposite. When paired, these 2 patterns can cause + // infinite cycles in pattern rewriting. + // ConvertIllegalShapeCastOpsToTransposes only matches on scalable + // vectors, so by disabling this folder for scalable vectors the + // cycle is avoided. + // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is + // still needed. If it's not, then we can fold here. + if (!transpose.getType().isScalable() && isOrderPreserving(transpose)) { + setOperand(transpose.getVector()); + return getResult(); + } + return {}; + } + // Y = shape_cast(broadcast(X)) // -> X, if X and Y have same type if (auto bcastOp = getSource().getDefiningOp()) { @@ -5619,7 +5662,7 @@ namespace { /// Helper function that computes a new vector type based on the input vector /// type by removing the trailing one dims: /// -/// vector<4x1x1xi1> --> vector<4x1> +/// vector<4x1x1xi1> --> vector<4x1xi1> /// static VectorType trimTrailingOneDims(VectorType oldType) { ArrayRef oldShape = oldType.getShape(); @@ -6086,6 +6129,32 @@ class FoldTransposeCreateMask final : public OpRewritePattern { } }; +/// Folds transpose(shape_cast) into a new shape_cast. +class FoldTransposeShapeCast final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + auto shapeCastOp = + transposeOp.getVector().getDefiningOp(); + if (!shapeCastOp) + return failure(); + if (!isOrderPreserving(transposeOp)) + return failure(); + + VectorType resultType = transposeOp.getType(); + + // We don't need to check isValidShapeCast at this point, because it is + // guaranteed that merging the transpose into the the shape_cast is a valid + // shape_cast, because the transpose just inserts/removes ones. + + rewriter.replaceOpWithNewOp(transposeOp, resultType, + shapeCastOp.getSource()); + return success(); + } +}; + /// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is /// 'order preserving', where 'order preserving' means the flattened /// inputs and outputs of the transpose have identical (numerical) values. @@ -6184,8 +6253,8 @@ class FoldTransposeBroadcast : public OpRewritePattern { void vector::TransposeOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index e0ec9c66d3a48..99f0850000a16 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -8,6 +8,7 @@ func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) { %0 = vector.create_mask %c3, %c2 : vector<4x3xi1> return %0 : vector<4x3xi1> } + // ----- // CHECK-LABEL: create_scalable_vector_mask_to_constant_mask @@ -3061,7 +3062,6 @@ func.func @insert_vector_poison(%a: vector<4x8xf32>) return %1 : vector<4x8xf32> } - // ----- // CHECK-LABEL: @insert_scalar_poison_idx diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir index e97e147459de2..91ee0d335ecca 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir @@ -137,3 +137,113 @@ func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector< return %1 : vector<3x3x3xi8> } + +// ----- + +// Test of FoldTransposeShapeCast +// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows: +// 1 -> 0 +// 2 -> 4 +// Because 0 < 4, this permutation is order preserving and effectively a shape_cast. +// CHECK-LABEL: @transpose_shape_cast +// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>) -> vector<4x4xi8> { +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] : +// CHECK-SAME: vector<1x4x4x1x1xi8> to vector<4x4xi8> +// CHECK: return %[[SHAPE_CAST]] : vector<4x4xi8> +func.func @transpose_shape_cast(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8> { + %0 = vector.transpose %arg, [1, 0, 3, 4, 2] + : vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8> + %1 = vector.shape_cast %0 : vector<4x1x1x1x4xi8> to vector<4x4xi8> + return %1 : vector<4x4xi8> +} + +// ----- + +// Test of FoldTransposeShapeCast +// In this test, the mapping of non-unit dimensions (1 and 2) is as follows: +// 1 -> 2 +// 2 -> 1 +// As this is not increasing (2 > 1), this transpose is not order +// preserving and cannot be treated as a shape_cast. +// CHECK-LABEL: @negative_transpose_shape_cast +// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1xi8>) -> vector<4x4xi8> { +// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG]] +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[TRANSPOSE]] +// CHECK: return %[[SHAPE_CAST]] : vector<4x4xi8> +func.func @negative_transpose_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector<4x4xi8> { + %0 = vector.transpose %arg, [0, 2, 1, 3] + : vector<1x4x4x1xi8> to vector<1x4x4x1xi8> + %1 = vector.shape_cast %0 : vector<1x4x4x1xi8> to vector<4x4xi8> + return %1 : vector<4x4xi8> +} + +// ----- + +// Test of FoldTransposeShapeCast +// Currently the conversion shape_cast(transpose) -> shape_cast is disabled for +// scalable vectors because of bad interaction with ConvertIllegalShapeCastOpsToTransposes +// CHECK-LABEL: @negative_transpose_shape_cast_scalable +// CHECK: vector.transpose +// CHECK: vector.shape_cast +func.func @negative_transpose_shape_cast_scalable(%arg : vector<[4]x1xi8>) -> vector<[4]xi8> { + %0 = vector.transpose %arg, [1, 0] : vector<[4]x1xi8> to vector<1x[4]xi8> + %1 = vector.shape_cast %0 : vector<1x[4]xi8> to vector<[4]xi8> + return %1 : vector<[4]xi8> +} + +// ----- + +// Test of shape_cast folding. +// The conversion transpose(shape_cast) -> shape_cast is not disabled for scalable +// vectors. +// CHECK-LABEL: @shape_cast_transpose_scalable +// CHECK: vector.shape_cast +// CHECK-SAME: vector<[4]xi8> to vector<[4]x1xi8> +func.func @shape_cast_transpose_scalable(%arg : vector<[4]xi8>) -> vector<[4]x1xi8> { + %0 = vector.shape_cast %arg : vector<[4]xi8> to vector<1x[4]xi8> + %1 = vector.transpose %0, [1, 0] : vector<1x[4]xi8> to vector<[4]x1xi8> + return %1 : vector<[4]x1xi8> +} + +// ----- + +// Test of shape_cast folding. +// A transpose that is 'order preserving' can be treated like a shape_cast. +// CHECK-LABEL: @shape_cast_transpose +// CHECK-SAME: %[[ARG:.*]]: vector<2x3x1x1xi8>) -> vector<6x1x1xi8> { +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] : +// CHECK-SAME: vector<2x3x1x1xi8> to vector<6x1x1xi8> +// CHECK: return %[[SHAPE_CAST]] : vector<6x1x1xi8> +func.func @shape_cast_transpose(%arg : vector<2x3x1x1xi8>) -> vector<6x1x1xi8> { + %0 = vector.shape_cast %arg : vector<2x3x1x1xi8> to vector<6x1x1xi8> + %1 = vector.transpose %0, [0, 2, 1] + : vector<6x1x1xi8> to vector<6x1x1xi8> + return %1 : vector<6x1x1xi8> +} + +// ----- + +// Test of shape_cast folding. +// Scalable dimensions should be treated as non-unit dimensions. +// CHECK-LABEL: @shape_cast_transpose_scalable +// CHECK: vector.shape_cast +// CHECK: vector.transpose +func.func @shape_cast_transpose_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> { + %0 = vector.shape_cast %arg : vector<[1]x4x1xi8> to vector<[1]x4xi8> + %1 = vector.transpose %0, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8> + return %1 : vector<4x[1]xi8> +} + +// ----- + +// Test of shape_cast (not) folding. +// CHECK-LABEL: @negative_shape_cast_transpose +// CHECK-SAME: %[[ARG:.*]]: vector<6xi8>) -> vector<2x3xi8> { +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] : +// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]] +// CHECK: return %[[TRANSPOSE]] : vector<2x3xi8> +func.func @negative_shape_cast_transpose(%arg : vector<6xi8>) -> vector<2x3xi8> { + %0 = vector.shape_cast %arg : vector<6xi8> to vector<3x2xi8> + %1 = vector.transpose %0, [1, 0] : vector<3x2xi8> to vector<2x3xi8> + return %1 : vector<2x3xi8> +}