-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[Linalg] Add pattern to push down extract slice through linalg generic op #154162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-linalg Author: Nirvedh Meshram (nirvedhmeshram) ChangesThis PR adds a datalayout propagation pattern to push down extract slice through generic op. It adds a different populate function since there may be conditions where a user doesn't want this pattern but wants the other patterns e.g. extract slice is used as a special op when it comes to tiling. Patch is 20.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/154162.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d4ffe0a91fcfe..046920f5ccd54 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1914,6 +1914,11 @@ void populateDataLayoutPropagationPatterns(
RewritePatternSet &patterns,
const ControlPropagationFn &controlPackUnPackPropagation);
+/// Patterns to bubble up or down extract slice across other operations.
+void populateExtractSlicePropagationPatterns(
+ RewritePatternSet &patterns,
+ const ControlPropagationFn &controlPackUnPackPropagation);
+
/// Pattern to remove dead operands and results of `linalg.generic` operations.
/// This is a pattern wrapper for `deduplicateOperandsAndRemoveDeadResults`.
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 0a9c1766425bd..16d6ac23b0208 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -6,10 +6,12 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Dominance.h"
#include "llvm/ADT/SetOperations.h"
@@ -1236,6 +1238,266 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
ControlPropagationFn controlFn;
};
+// This struct contains infomation about extract_slice dims.
+struct SliceDimInfo {
+ OpFoldResult offset;
+ OpFoldResult sliceSize;
+ OpFoldResult outputSize;
+};
+
+/// Return the first input extract slice operand, if present, for the current
+/// generic op.
+static FailureOr<std::tuple<OpOperand *, unsigned>>
+getSliceOperandAndIndex(GenericOp genericOp) {
+ OpOperand *sliceOperand = nullptr;
+ unsigned operandIndex;
+ for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+ auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
+ if (!extractOp)
+ continue;
+ sliceOperand = operand;
+ operandIndex = idx;
+ break;
+ }
+ if (!sliceOperand) {
+ return failure();
+ }
+ return std::make_tuple(sliceOperand, operandIndex);
+}
+
+static FailureOr<std::tuple<llvm::DenseMap<int64_t, SliceDimInfo>, bool>>
+getNonZeroSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand,
+ tensor::ExtractSliceOp producerSliceOp) {
+ llvm::DenseMap<int64_t, SliceDimInfo> nonZeroSliceDimMap;
+ bool hasNonZeroReductionDimSlice = false;
+ SmallVector<utils::IteratorType> iterators =
+ genericOp.getIteratorTypesArray();
+ SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+
+ SmallVector<OpFoldResult> shape = llvm::map_to_vector(
+ producerSliceOp.getSourceType().getShape(),
+ [&](int64_t sz) -> OpFoldResult {
+ return getAsIndexOpFoldResult(genericOp.getContext(), sz);
+ });
+
+ for (auto [idx, expr] : llvm::enumerate(
+ genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
+ if (isConstantIntValue(offsets[idx], 0) &&
+ isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
+ continue;
+ }
+ if (!isa<AffineDimExpr>(expr)) {
+ return failure();
+ }
+ SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
+ int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
+ nonZeroSliceDimMap[dimPos] = sliceDimInfo;
+ if (iterators[dimPos] == utils::IteratorType::reduction) {
+ hasNonZeroReductionDimSlice = true;
+ }
+ }
+ // Next check if the dims with non zero slice info are used as non
+ // AffineDimExpr and if they are then bail-out.
+ for (OpOperand &operand : genericOp->getOpOperands()) {
+ if (operand == *sliceOperand) {
+ continue;
+ }
+ AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
+ if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) {
+ if (isa<AffineDimExpr>(expr)) {
+ return false;
+ }
+ WalkResult status = expr.walk([&](AffineExpr expr) {
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+ if (nonZeroSliceDimMap.count(dimExpr.getPosition()) != 0) {
+ return WalkResult::interrupt();
+ }
+ }
+ return WalkResult::advance();
+ });
+ if (status.wasInterrupted()) {
+ return true;
+ }
+ return false;
+ })) {
+ return failure();
+ }
+ }
+ return std::make_tuple(nonZeroSliceDimMap, hasNonZeroReductionDimSlice);
+}
+
+static FailureOr<std::tuple<GenericOp, Value>>
+pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
+ GenericOp genericOp,
+ ControlPropagationFn controlFn) {
+ if (genericOp.getNumResults() != 1)
+ return failure();
+ if (hasGatherSemantics(genericOp))
+ return failure();
+ // Collect the unPacked operand, if present.
+ auto maybeSliceOperandAndIndex = getSliceOperandAndIndex(genericOp);
+ if (failed(maybeSliceOperandAndIndex))
+ return failure();
+ OpOperand *sliceOperand = std::get<0>(*maybeSliceOperandAndIndex);
+ unsigned OperandIndex = std::get<1>(*maybeSliceOperandAndIndex);
+
+ if (!controlFn(sliceOperand))
+ return failure();
+
+ tensor::ExtractSliceOp producerSliceOp =
+ sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
+ assert(producerSliceOp && "expect a valid UnPackOp");
+
+ if (producerSliceOp.getSource().getType().getRank() !=
+ producerSliceOp.getResult().getType().getRank()) {
+ return failure();
+ }
+
+ SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides();
+ if (!areAllConstantIntValue(strides, 1))
+ return failure();
+
+ SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+
+ // check if we can support the propagation of this extractSlice
+ // through the generic op and if so return the dimensions that
+
+ auto maybeNonZeroSliceDimMap =
+ getNonZeroSliceDimInfo(genericOp, sliceOperand, producerSliceOp);
+
+ if (failed(maybeNonZeroSliceDimMap)) {
+ return failure();
+ }
+
+ auto nonZeroSliceDimMap = std::get<0>(*maybeNonZeroSliceDimMap);
+ bool hasNonZeroReductionDimSlice = std::get<1>(*maybeNonZeroSliceDimMap);
+
+ // Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
+ Location loc = genericOp->getLoc();
+ AffineExpr dim0, dim1;
+ bindDims(rewriter.getContext(), dim0, dim1);
+ auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
+ auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
+ return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap,
+ {v1, v2});
+ };
+
+ MLIRContext *ctx = genericOp.getContext();
+ SmallVector<Value> paddedInputs;
+ for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+ if (idx == OperandIndex && !hasNonZeroReductionDimSlice) {
+ paddedInputs.push_back(producerSliceOp.getSource());
+ continue;
+ }
+ AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
+ SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 0));
+ SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 0));
+ for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) {
+ if (!isa<AffineDimExpr>(expr)) {
+ continue;
+ }
+ AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
+ if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
+ SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
+ operandLowPads[idx] = sliceDimInfo.offset;
+ operandHighPads[idx] =
+ sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+ sliceDimInfo.sliceSize);
+ }
+ }
+ auto paddingValue = ub::PoisonOp::create(
+ rewriter, loc, getElementTypeOrSelf(operand->get().getType()));
+ auto paddedOperand = tensor::PadOp::create(
+ rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads,
+ paddingValue, /*nofold=*/false);
+ paddedInputs.push_back(paddedOperand);
+ }
+ AffineMap outputIndexingMap =
+ genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
+
+ auto outputShapeType =
+ llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType());
+ SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector(
+ outputShapeType.getShape(),
+ [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); });
+ SmallVector<OpFoldResult> newSizes = OutputShape;
+ SmallVector<OpFoldResult> outputLowPads(outputIndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 0));
+ SmallVector<OpFoldResult> outputHighPads(outputIndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 0));
+ SmallVector<OpFoldResult> newStrides(outputIndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 1));
+ for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) {
+ if (!isa<AffineDimExpr>(expr)) {
+ continue;
+ }
+ AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
+ if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
+ SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
+ outputLowPads[idx] = sliceDimInfo.offset;
+ outputHighPads[idx] =
+ sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+ sliceDimInfo.sliceSize);
+ OutputShape[idx] = sliceDimInfo.outputSize;
+ newSizes[idx] = sliceDimInfo.sliceSize;
+ }
+ }
+ Value newPadOutput;
+ auto outputElType =
+ getElementTypeOrSelf(genericOp.getDpsInits()[0].getType());
+ if (isGenericOutsNotUsed(genericOp)) {
+ newPadOutput =
+ tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
+
+ } else {
+
+ auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
+ newPadOutput = tensor::PadOp::create(
+ rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads,
+ outputHighPads, paddingValue, /*nofold=*/false);
+ }
+
+ auto newGenericOp = linalg::GenericOp::create(
+ rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput},
+ genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
+ /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
+ rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
+ newGenericOp.getRegion().begin());
+
+ auto extractOp = tensor::ExtractSliceOp::create(
+ rewriter, loc,
+ newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
+ outputLowPads, newSizes, newStrides);
+ Value extractRes = extractOp.getResult();
+
+ return std::make_tuple(newGenericOp, extractRes);
+}
+
+class PushDownExtractSliceOpThroughGenericOp final
+ : public OpRewritePattern<GenericOp> {
+public:
+ PushDownExtractSliceOpThroughGenericOp(MLIRContext *context,
+ ControlPropagationFn fun)
+ : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
+
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ auto genericAndRepl =
+ pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn);
+ if (failed(genericAndRepl))
+ return failure();
+ rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
+ return success();
+ }
+
+private:
+ ControlPropagationFn controlFn;
+};
+
} // namespace
void mlir::linalg::populateDataLayoutPropagationPatterns(
@@ -1247,3 +1509,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
patterns.getContext(), controlPackUnPackPropagation);
}
+
+void mlir::linalg::populateExtractSlicePropagationPatterns(
+ RewritePatternSet &patterns,
+ const ControlPropagationFn &controlPackUnPackPropagation) {
+ patterns.insert<PushDownExtractSliceOpThroughGenericOp>(
+ patterns.getContext(), controlPackUnPackPropagation);
+}
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index cc26fa48abf4b..723eecb52351b 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1447,3 +1447,113 @@ func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %ar
// CHECK: %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]]
// CHECK-SAME: into %[[ARG1]]
// CHECK: return %[[UNPACK2]] : tensor<?x64xf32>
+
+// -----
+
+module {
+ func.func @push_extract_through_generic(%arg0: tensor<128x7x128xf32>, %arg1: tensor<?x5x3x128xf32>, %arg2: tensor<?x5x128xbf16>, %arg3: index) -> tensor<?x5x128xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[0, 0, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor<?x5x3x128xf32>) outs(%arg2 : tensor<?x5x128xbf16>) {
+ ^bb0(%in: f32, %in_0: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ linalg.yield %1 : bf16
+ } -> tensor<?x5x128xbf16>
+ return %0 : tensor<?x5x128xbf16>
+ }
+}
+
+// CHECK-LABEL: func.func @push_extract_through_generic
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]
+// CHECK: %[[POISON:.+]] = ub.poison : f32
+// CHECK: %[[PADDED:.+]] = tensor.pad %arg1
+// CHECK: tensor.yield %[[POISON]] : f32
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<128x5x128xbf16>
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]], %[[PADDED]]
+// CHECK-SAME: outs(%[[EMPTY]]
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %3[%[[ARG3]], 0, 0] [%[[ARG3]], 5, 128] [1, 1, 1] : tensor<128x5x128xbf16> to tensor<?x5x128xbf16>
+// CHECK: return %[[EXTRACT]]
+
+// -----
+
+func.func @nopush_extract_through_generic_nodimexpr1(%arg0: tensor<128x7x128xf32>, %arg1: tensor<?x5x3x128xf32>, %arg2: tensor<?x5x128xbf16>, %arg3: index) -> tensor<?x5x128xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[0, %arg3, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor<?x5x3x128xf32>) outs(%arg2 : tensor<?x5x128xbf16>) {
+ ^bb0(%in: f32, %in_0: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ linalg.yield %1 : bf16
+ } -> tensor<?x5x128xbf16>
+ return %0 : tensor<?x5x128xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr1
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK: return %[[GENERIC]]
+
+// -----
+
+func.func @nopush_extract_through_generic_nodimexpr2(%arg0: tensor<128x?x128xf32>, %arg1: tensor<128x5x3x128xf32>, %arg2: tensor<128x?x128xbf16>, %arg3: index) -> tensor<128x?x128xbf16> {
+ %extracted_slice = tensor.extract_slice %arg1[0, %arg3, 0, 0] [128, %arg3, 3, 128] [1, 1, 1, 1] : tensor<128x5x3x128xf32> to tensor<128x?x3x128xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %extracted_slice : tensor<128x?x128xf32>, tensor<128x?x3x128xf32>) outs(%arg2 : tensor<128x?x128xbf16>) {
+ ^bb0(%in: f32, %in_0: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ linalg.yield %1 : bf16
+ } -> tensor<128x?x128xbf16>
+ return %0 : tensor<128x?x128xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr2
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK: return %[[GENERIC]]
+
+// -----
+
+func.func @push_redcutionextract_through_generic_withoutsused_2(%arg0: tensor<128x128xf32>, %arg1: tensor<?xbf16>, %arg2: index) -> tensor<?xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<?x?xf32>) outs(%arg1 : tensor<?xbf16>) {
+ ^bb0(%in: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ %2 = arith.addf %1, %out : bf16
+ linalg.yield %2 : bf16
+ } -> tensor<?xbf16>
+ return %0 : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @push_redcutionextract_through_generic_withoutsused_2
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK: %[[POISON_BF16:.+]] = ub.poison : bf16
+// CHECK: %[[POISON_F32:.+]] = ub.poison : f32
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], %[[ARG2]]] [%[[ARG2]], %[[ARG2]]] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+// CHECK: %[[PADDED:.+]] = tensor.pad %[[EXTRACT]]
+// CHECK: tensor.yield %[[POISON_F32]] : f32
+// CHECK: %[[APPLY2:.+]] = affine.apply #map()[%[[ARG2]]]
+// CHECK: %[[PADDED1:.+]] = tensor.pad %[[ARG1]] low[%[[ARG2]]] high[%[[APPLY2]]]
+// CHECK: tensor.yield %[[POISON_BF16]] : bf16
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[PADDED]]
+// CHECK-SAME: outs(%[[PADDED1]]
+// CHECK: %[[EXTRACT1:.+]] = tensor.extract_slice %[[GENERIC]][%[[ARG2]]] [%[[ARG2]]] [1] : tensor<?xbf16> to tensor<?xbf16>
+// CHECK: return %[[EXTRACT1]]
+
+
+// -----
+
+func.func @nopush_rankreducingextract(%arg0: tensor<128x128x128xf32>, %arg1: tensor<?xbf16>, %arg2: index) -> tensor<?xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[0, %arg2, %arg2] [1, %arg2, %arg2] [1, 1, 1] : tensor<128x128x128xf32> to tensor<?x?xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<?x?xf32>) outs(%arg1 : tensor<?xbf16>) {
+ ^bb0(%in: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ %2 = arith.addf %1, %out : bf16
+ linalg.yield %2 : bf16
+ } -> tensor<?xbf16>
+ return %0 : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_rankreducingextract
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK: return %[[GENERIC]]
diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index d0700f9a4f1a4..449d28fc528b1 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -34,6...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Nirvedh Meshram (nirvedhmeshram) ChangesThis PR adds a datalayout propagation pattern to push down extract slice through generic op. It adds a different populate function since there may be conditions where a user doesn't want this pattern but wants the other patterns e.g. extract slice is used as a special op when it comes to tiling. Patch is 20.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/154162.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d4ffe0a91fcfe..046920f5ccd54 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1914,6 +1914,11 @@ void populateDataLayoutPropagationPatterns(
RewritePatternSet &patterns,
const ControlPropagationFn &controlPackUnPackPropagation);
+/// Patterns to bubble up or down extract slice across other operations.
+void populateExtractSlicePropagationPatterns(
+ RewritePatternSet &patterns,
+ const ControlPropagationFn &controlPackUnPackPropagation);
+
/// Pattern to remove dead operands and results of `linalg.generic` operations.
/// This is a pattern wrapper for `deduplicateOperandsAndRemoveDeadResults`.
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 0a9c1766425bd..16d6ac23b0208 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -6,10 +6,12 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Dominance.h"
#include "llvm/ADT/SetOperations.h"
@@ -1236,6 +1238,266 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
ControlPropagationFn controlFn;
};
+// This struct contains infomation about extract_slice dims.
+struct SliceDimInfo {
+ OpFoldResult offset;
+ OpFoldResult sliceSize;
+ OpFoldResult outputSize;
+};
+
+/// Return the first input extract slice operand, if present, for the current
+/// generic op.
+static FailureOr<std::tuple<OpOperand *, unsigned>>
+getSliceOperandAndIndex(GenericOp genericOp) {
+ OpOperand *sliceOperand = nullptr;
+ unsigned operandIndex;
+ for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+ auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
+ if (!extractOp)
+ continue;
+ sliceOperand = operand;
+ operandIndex = idx;
+ break;
+ }
+ if (!sliceOperand) {
+ return failure();
+ }
+ return std::make_tuple(sliceOperand, operandIndex);
+}
+
+static FailureOr<std::tuple<llvm::DenseMap<int64_t, SliceDimInfo>, bool>>
+getNonZeroSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand,
+ tensor::ExtractSliceOp producerSliceOp) {
+ llvm::DenseMap<int64_t, SliceDimInfo> nonZeroSliceDimMap;
+ bool hasNonZeroReductionDimSlice = false;
+ SmallVector<utils::IteratorType> iterators =
+ genericOp.getIteratorTypesArray();
+ SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+
+ SmallVector<OpFoldResult> shape = llvm::map_to_vector(
+ producerSliceOp.getSourceType().getShape(),
+ [&](int64_t sz) -> OpFoldResult {
+ return getAsIndexOpFoldResult(genericOp.getContext(), sz);
+ });
+
+ for (auto [idx, expr] : llvm::enumerate(
+ genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
+ if (isConstantIntValue(offsets[idx], 0) &&
+ isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
+ continue;
+ }
+ if (!isa<AffineDimExpr>(expr)) {
+ return failure();
+ }
+ SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
+ int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
+ nonZeroSliceDimMap[dimPos] = sliceDimInfo;
+ if (iterators[dimPos] == utils::IteratorType::reduction) {
+ hasNonZeroReductionDimSlice = true;
+ }
+ }
+ // Next check if the dims with non zero slice info are used as non
+ // AffineDimExpr and if they are then bail-out.
+ for (OpOperand &operand : genericOp->getOpOperands()) {
+ if (operand == *sliceOperand) {
+ continue;
+ }
+ AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
+ if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) {
+ if (isa<AffineDimExpr>(expr)) {
+ return false;
+ }
+ WalkResult status = expr.walk([&](AffineExpr expr) {
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+ if (nonZeroSliceDimMap.count(dimExpr.getPosition()) != 0) {
+ return WalkResult::interrupt();
+ }
+ }
+ return WalkResult::advance();
+ });
+ if (status.wasInterrupted()) {
+ return true;
+ }
+ return false;
+ })) {
+ return failure();
+ }
+ }
+ return std::make_tuple(nonZeroSliceDimMap, hasNonZeroReductionDimSlice);
+}
+
+static FailureOr<std::tuple<GenericOp, Value>>
+pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
+ GenericOp genericOp,
+ ControlPropagationFn controlFn) {
+ if (genericOp.getNumResults() != 1)
+ return failure();
+ if (hasGatherSemantics(genericOp))
+ return failure();
+ // Collect the unPacked operand, if present.
+ auto maybeSliceOperandAndIndex = getSliceOperandAndIndex(genericOp);
+ if (failed(maybeSliceOperandAndIndex))
+ return failure();
+ OpOperand *sliceOperand = std::get<0>(*maybeSliceOperandAndIndex);
+ unsigned OperandIndex = std::get<1>(*maybeSliceOperandAndIndex);
+
+ if (!controlFn(sliceOperand))
+ return failure();
+
+ tensor::ExtractSliceOp producerSliceOp =
+ sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
+ assert(producerSliceOp && "expect a valid UnPackOp");
+
+ if (producerSliceOp.getSource().getType().getRank() !=
+ producerSliceOp.getResult().getType().getRank()) {
+ return failure();
+ }
+
+ SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides();
+ if (!areAllConstantIntValue(strides, 1))
+ return failure();
+
+ SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+
+ // check if we can support the propagation of this extractSlice
+ // through the generic op and if so return the dimensions that
+
+ auto maybeNonZeroSliceDimMap =
+ getNonZeroSliceDimInfo(genericOp, sliceOperand, producerSliceOp);
+
+ if (failed(maybeNonZeroSliceDimMap)) {
+ return failure();
+ }
+
+ auto nonZeroSliceDimMap = std::get<0>(*maybeNonZeroSliceDimMap);
+ bool hasNonZeroReductionDimSlice = std::get<1>(*maybeNonZeroSliceDimMap);
+
+ // Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
+ Location loc = genericOp->getLoc();
+ AffineExpr dim0, dim1;
+ bindDims(rewriter.getContext(), dim0, dim1);
+ auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
+ auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
+ return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap,
+ {v1, v2});
+ };
+
+ MLIRContext *ctx = genericOp.getContext();
+ SmallVector<Value> paddedInputs;
+ for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+ if (idx == OperandIndex && !hasNonZeroReductionDimSlice) {
+ paddedInputs.push_back(producerSliceOp.getSource());
+ continue;
+ }
+ AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
+ SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 0));
+ SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 0));
+ for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) {
+ if (!isa<AffineDimExpr>(expr)) {
+ continue;
+ }
+ AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
+ if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
+ SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
+ operandLowPads[idx] = sliceDimInfo.offset;
+ operandHighPads[idx] =
+ sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+ sliceDimInfo.sliceSize);
+ }
+ }
+ auto paddingValue = ub::PoisonOp::create(
+ rewriter, loc, getElementTypeOrSelf(operand->get().getType()));
+ auto paddedOperand = tensor::PadOp::create(
+ rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads,
+ paddingValue, /*nofold=*/false);
+ paddedInputs.push_back(paddedOperand);
+ }
+ AffineMap outputIndexingMap =
+ genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
+
+ auto outputShapeType =
+ llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType());
+ SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector(
+ outputShapeType.getShape(),
+ [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); });
+ SmallVector<OpFoldResult> newSizes = OutputShape;
+ SmallVector<OpFoldResult> outputLowPads(outputIndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 0));
+ SmallVector<OpFoldResult> outputHighPads(outputIndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 0));
+ SmallVector<OpFoldResult> newStrides(outputIndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 1));
+ for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) {
+ if (!isa<AffineDimExpr>(expr)) {
+ continue;
+ }
+ AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
+ if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
+ SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
+ outputLowPads[idx] = sliceDimInfo.offset;
+ outputHighPads[idx] =
+ sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+ sliceDimInfo.sliceSize);
+ OutputShape[idx] = sliceDimInfo.outputSize;
+ newSizes[idx] = sliceDimInfo.sliceSize;
+ }
+ }
+ Value newPadOutput;
+ auto outputElType =
+ getElementTypeOrSelf(genericOp.getDpsInits()[0].getType());
+ if (isGenericOutsNotUsed(genericOp)) {
+ newPadOutput =
+ tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
+
+ } else {
+
+ auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
+ newPadOutput = tensor::PadOp::create(
+ rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads,
+ outputHighPads, paddingValue, /*nofold=*/false);
+ }
+
+ auto newGenericOp = linalg::GenericOp::create(
+ rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput},
+ genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
+ /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
+ rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
+ newGenericOp.getRegion().begin());
+
+ auto extractOp = tensor::ExtractSliceOp::create(
+ rewriter, loc,
+ newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
+ outputLowPads, newSizes, newStrides);
+ Value extractRes = extractOp.getResult();
+
+ return std::make_tuple(newGenericOp, extractRes);
+}
+
+class PushDownExtractSliceOpThroughGenericOp final
+ : public OpRewritePattern<GenericOp> {
+public:
+ PushDownExtractSliceOpThroughGenericOp(MLIRContext *context,
+ ControlPropagationFn fun)
+ : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
+
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ auto genericAndRepl =
+ pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn);
+ if (failed(genericAndRepl))
+ return failure();
+ rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
+ return success();
+ }
+
+private:
+ ControlPropagationFn controlFn;
+};
+
} // namespace
void mlir::linalg::populateDataLayoutPropagationPatterns(
@@ -1247,3 +1509,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
patterns.getContext(), controlPackUnPackPropagation);
}
+
+void mlir::linalg::populateExtractSlicePropagationPatterns(
+ RewritePatternSet &patterns,
+ const ControlPropagationFn &controlPackUnPackPropagation) {
+ patterns.insert<PushDownExtractSliceOpThroughGenericOp>(
+ patterns.getContext(), controlPackUnPackPropagation);
+}
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index cc26fa48abf4b..723eecb52351b 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1447,3 +1447,113 @@ func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %ar
// CHECK: %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]]
// CHECK-SAME: into %[[ARG1]]
// CHECK: return %[[UNPACK2]] : tensor<?x64xf32>
+
+// -----
+
+module {
+ func.func @push_extract_through_generic(%arg0: tensor<128x7x128xf32>, %arg1: tensor<?x5x3x128xf32>, %arg2: tensor<?x5x128xbf16>, %arg3: index) -> tensor<?x5x128xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[0, 0, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor<?x5x3x128xf32>) outs(%arg2 : tensor<?x5x128xbf16>) {
+ ^bb0(%in: f32, %in_0: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ linalg.yield %1 : bf16
+ } -> tensor<?x5x128xbf16>
+ return %0 : tensor<?x5x128xbf16>
+ }
+}
+
+// CHECK-LABEL: func.func @push_extract_through_generic
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]
+// CHECK: %[[POISON:.+]] = ub.poison : f32
+// CHECK: %[[PADDED:.+]] = tensor.pad %arg1
+// CHECK: tensor.yield %[[POISON]] : f32
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<128x5x128xbf16>
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]], %[[PADDED]]
+// CHECK-SAME: outs(%[[EMPTY]]
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %3[%[[ARG3]], 0, 0] [%[[ARG3]], 5, 128] [1, 1, 1] : tensor<128x5x128xbf16> to tensor<?x5x128xbf16>
+// CHECK: return %[[EXTRACT]]
+
+// -----
+
+func.func @nopush_extract_through_generic_nodimexpr1(%arg0: tensor<128x7x128xf32>, %arg1: tensor<?x5x3x128xf32>, %arg2: tensor<?x5x128xbf16>, %arg3: index) -> tensor<?x5x128xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[0, %arg3, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor<?x5x3x128xf32>) outs(%arg2 : tensor<?x5x128xbf16>) {
+ ^bb0(%in: f32, %in_0: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ linalg.yield %1 : bf16
+ } -> tensor<?x5x128xbf16>
+ return %0 : tensor<?x5x128xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr1
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK: return %[[GENERIC]]
+
+// -----
+
+func.func @nopush_extract_through_generic_nodimexpr2(%arg0: tensor<128x?x128xf32>, %arg1: tensor<128x5x3x128xf32>, %arg2: tensor<128x?x128xbf16>, %arg3: index) -> tensor<128x?x128xbf16> {
+ %extracted_slice = tensor.extract_slice %arg1[0, %arg3, 0, 0] [128, %arg3, 3, 128] [1, 1, 1, 1] : tensor<128x5x3x128xf32> to tensor<128x?x3x128xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %extracted_slice : tensor<128x?x128xf32>, tensor<128x?x3x128xf32>) outs(%arg2 : tensor<128x?x128xbf16>) {
+ ^bb0(%in: f32, %in_0: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ linalg.yield %1 : bf16
+ } -> tensor<128x?x128xbf16>
+ return %0 : tensor<128x?x128xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr2
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK: return %[[GENERIC]]
+
+// -----
+
+func.func @push_redcutionextract_through_generic_withoutsused_2(%arg0: tensor<128x128xf32>, %arg1: tensor<?xbf16>, %arg2: index) -> tensor<?xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<?x?xf32>) outs(%arg1 : tensor<?xbf16>) {
+ ^bb0(%in: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ %2 = arith.addf %1, %out : bf16
+ linalg.yield %2 : bf16
+ } -> tensor<?xbf16>
+ return %0 : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @push_redcutionextract_through_generic_withoutsused_2
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK: %[[POISON_BF16:.+]] = ub.poison : bf16
+// CHECK: %[[POISON_F32:.+]] = ub.poison : f32
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], %[[ARG2]]] [%[[ARG2]], %[[ARG2]]] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+// CHECK: %[[PADDED:.+]] = tensor.pad %[[EXTRACT]]
+// CHECK: tensor.yield %[[POISON_F32]] : f32
+// CHECK: %[[APPLY2:.+]] = affine.apply #map()[%[[ARG2]]]
+// CHECK: %[[PADDED1:.+]] = tensor.pad %[[ARG1]] low[%[[ARG2]]] high[%[[APPLY2]]]
+// CHECK: tensor.yield %[[POISON_BF16]] : bf16
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[PADDED]]
+// CHECK-SAME: outs(%[[PADDED1]]
+// CHECK: %[[EXTRACT1:.+]] = tensor.extract_slice %[[GENERIC]][%[[ARG2]]] [%[[ARG2]]] [1] : tensor<?xbf16> to tensor<?xbf16>
+// CHECK: return %[[EXTRACT1]]
+
+
+// -----
+
+func.func @nopush_rankreducingextract(%arg0: tensor<128x128x128xf32>, %arg1: tensor<?xbf16>, %arg2: index) -> tensor<?xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[0, %arg2, %arg2] [1, %arg2, %arg2] [1, 1, 1] : tensor<128x128x128xf32> to tensor<?x?xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<?x?xf32>) outs(%arg1 : tensor<?xbf16>) {
+ ^bb0(%in: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ %2 = arith.addf %1, %out : bf16
+ linalg.yield %2 : bf16
+ } -> tensor<?xbf16>
+ return %0 : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_rankreducingextract
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK: return %[[GENERIC]]
diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index d0700f9a4f1a4..449d28fc528b1 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -34,6...
[truncated]
|
efef821
to
0063270
Compare
c2a5eb9
to
7b9d96e
Compare
7b9d96e
to
4eebe21
Compare
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the overall logic makes sense. Just have a few clarifying comments and clean ups.
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just one comment about the tests
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
This PR adds a datalayout propagation pattern to push down extract slice through generic op. It adds a different populate function since there may be conditions where a user doesn't want this pattern but wants the other patterns e.g. extract slice is used as a special op when it comes to tiling.