Skip to content

Conversation

nirvedhmeshram
Copy link
Contributor

This PR adds a datalayout propagation pattern to push down extract slice through generic op. It adds a different populate function since there may be conditions where a user doesn't want this pattern but wants the other patterns e.g. extract slice is used as a special op when it comes to tiling.

@llvmbot
Copy link
Member

llvmbot commented Aug 18, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Nirvedh Meshram (nirvedhmeshram)

Changes

This PR adds a datalayout propagation pattern to push down extract slice through generic op. It adds a different populate function since there may be conditions where a user doesn't want this pattern but wants the other patterns e.g. extract slice is used as a special op when it comes to tiling.


Patch is 20.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/154162.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+5)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+269)
  • (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+110)
  • (modified) mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp (+2)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d4ffe0a91fcfe..046920f5ccd54 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1914,6 +1914,11 @@ void populateDataLayoutPropagationPatterns(
     RewritePatternSet &patterns,
     const ControlPropagationFn &controlPackUnPackPropagation);
 
+/// Patterns to bubble up or down extract slice across other operations.
+void populateExtractSlicePropagationPatterns(
+    RewritePatternSet &patterns,
+    const ControlPropagationFn &controlPackUnPackPropagation);
+
 /// Pattern to remove dead operands and results of `linalg.generic` operations.
 /// This is a pattern wrapper for `deduplicateOperandsAndRemoveDeadResults`.
 void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 0a9c1766425bd..16d6ac23b0208 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -6,10 +6,12 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/Dominance.h"
 #include "llvm/ADT/SetOperations.h"
@@ -1236,6 +1238,266 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
   ControlPropagationFn controlFn;
 };
 
+// This struct contains infomation about extract_slice dims.
+struct SliceDimInfo {
+  OpFoldResult offset;
+  OpFoldResult sliceSize;
+  OpFoldResult outputSize;
+};
+
+/// Return the first input extract slice operand, if present, for the current
+/// generic op.
+static FailureOr<std::tuple<OpOperand *, unsigned>>
+getSliceOperandAndIndex(GenericOp genericOp) {
+  OpOperand *sliceOperand = nullptr;
+  unsigned operandIndex;
+  for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+    auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
+    if (!extractOp)
+      continue;
+    sliceOperand = operand;
+    operandIndex = idx;
+    break;
+  }
+  if (!sliceOperand) {
+    return failure();
+  }
+  return std::make_tuple(sliceOperand, operandIndex);
+}
+
+static FailureOr<std::tuple<llvm::DenseMap<int64_t, SliceDimInfo>, bool>>
+getNonZeroSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand,
+                       tensor::ExtractSliceOp producerSliceOp) {
+  llvm::DenseMap<int64_t, SliceDimInfo> nonZeroSliceDimMap;
+  bool hasNonZeroReductionDimSlice = false;
+  SmallVector<utils::IteratorType> iterators =
+      genericOp.getIteratorTypesArray();
+  SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+
+  SmallVector<OpFoldResult> shape = llvm::map_to_vector(
+      producerSliceOp.getSourceType().getShape(),
+      [&](int64_t sz) -> OpFoldResult {
+        return getAsIndexOpFoldResult(genericOp.getContext(), sz);
+      });
+
+  for (auto [idx, expr] : llvm::enumerate(
+           genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
+    if (isConstantIntValue(offsets[idx], 0) &&
+        isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
+      continue;
+    }
+    if (!isa<AffineDimExpr>(expr)) {
+      return failure();
+    }
+    SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
+    int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
+    nonZeroSliceDimMap[dimPos] = sliceDimInfo;
+    if (iterators[dimPos] == utils::IteratorType::reduction) {
+      hasNonZeroReductionDimSlice = true;
+    }
+  }
+  // Next check if the dims with non zero slice info are used as non
+  // AffineDimExpr and if they are then bail-out.
+  for (OpOperand &operand : genericOp->getOpOperands()) {
+    if (operand == *sliceOperand) {
+      continue;
+    }
+    AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
+    if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) {
+          if (isa<AffineDimExpr>(expr)) {
+            return false;
+          }
+          WalkResult status = expr.walk([&](AffineExpr expr) {
+            if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+              if (nonZeroSliceDimMap.count(dimExpr.getPosition()) != 0) {
+                return WalkResult::interrupt();
+              }
+            }
+            return WalkResult::advance();
+          });
+          if (status.wasInterrupted()) {
+            return true;
+          }
+          return false;
+        })) {
+      return failure();
+    }
+  }
+  return std::make_tuple(nonZeroSliceDimMap, hasNonZeroReductionDimSlice);
+}
+
+static FailureOr<std::tuple<GenericOp, Value>>
+pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
+                                       GenericOp genericOp,
+                                       ControlPropagationFn controlFn) {
+  if (genericOp.getNumResults() != 1)
+    return failure();
+  if (hasGatherSemantics(genericOp))
+    return failure();
+  // Collect the unPacked operand, if present.
+  auto maybeSliceOperandAndIndex = getSliceOperandAndIndex(genericOp);
+  if (failed(maybeSliceOperandAndIndex))
+    return failure();
+  OpOperand *sliceOperand = std::get<0>(*maybeSliceOperandAndIndex);
+  unsigned OperandIndex = std::get<1>(*maybeSliceOperandAndIndex);
+
+  if (!controlFn(sliceOperand))
+    return failure();
+
+  tensor::ExtractSliceOp producerSliceOp =
+      sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
+  assert(producerSliceOp && "expect a valid UnPackOp");
+
+  if (producerSliceOp.getSource().getType().getRank() !=
+      producerSliceOp.getResult().getType().getRank()) {
+    return failure();
+  }
+
+  SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides();
+  if (!areAllConstantIntValue(strides, 1))
+    return failure();
+
+  SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+
+  // check if we can support the propagation of this extractSlice
+  // through the generic op and if so return the dimensions that
+
+  auto maybeNonZeroSliceDimMap =
+      getNonZeroSliceDimInfo(genericOp, sliceOperand, producerSliceOp);
+
+  if (failed(maybeNonZeroSliceDimMap)) {
+    return failure();
+  }
+
+  auto nonZeroSliceDimMap = std::get<0>(*maybeNonZeroSliceDimMap);
+  bool hasNonZeroReductionDimSlice = std::get<1>(*maybeNonZeroSliceDimMap);
+
+  // Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
+  Location loc = genericOp->getLoc();
+  AffineExpr dim0, dim1;
+  bindDims(rewriter.getContext(), dim0, dim1);
+  auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
+  auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
+    return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap,
+                                                 {v1, v2});
+  };
+
+  MLIRContext *ctx = genericOp.getContext();
+  SmallVector<Value> paddedInputs;
+  for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+    if (idx == OperandIndex && !hasNonZeroReductionDimSlice) {
+      paddedInputs.push_back(producerSliceOp.getSource());
+      continue;
+    }
+    AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
+    SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(),
+                                             getAsIndexOpFoldResult(ctx, 0));
+    SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(),
+                                              getAsIndexOpFoldResult(ctx, 0));
+    for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) {
+      if (!isa<AffineDimExpr>(expr)) {
+        continue;
+      }
+      AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
+      if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
+        SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
+        operandLowPads[idx] = sliceDimInfo.offset;
+        operandHighPads[idx] =
+            sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+                sliceDimInfo.sliceSize);
+      }
+    }
+    auto paddingValue = ub::PoisonOp::create(
+        rewriter, loc, getElementTypeOrSelf(operand->get().getType()));
+    auto paddedOperand = tensor::PadOp::create(
+        rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads,
+        paddingValue, /*nofold=*/false);
+    paddedInputs.push_back(paddedOperand);
+  }
+  AffineMap outputIndexingMap =
+      genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
+
+  auto outputShapeType =
+      llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType());
+  SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector(
+      outputShapeType.getShape(),
+      [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); });
+  SmallVector<OpFoldResult> newSizes = OutputShape;
+  SmallVector<OpFoldResult> outputLowPads(outputIndexingMap.getNumResults(),
+                                          getAsIndexOpFoldResult(ctx, 0));
+  SmallVector<OpFoldResult> outputHighPads(outputIndexingMap.getNumResults(),
+                                           getAsIndexOpFoldResult(ctx, 0));
+  SmallVector<OpFoldResult> newStrides(outputIndexingMap.getNumResults(),
+                                       getAsIndexOpFoldResult(ctx, 1));
+  for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) {
+    if (!isa<AffineDimExpr>(expr)) {
+      continue;
+    }
+    AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
+    if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
+      SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
+      outputLowPads[idx] = sliceDimInfo.offset;
+      outputHighPads[idx] =
+          sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+              sliceDimInfo.sliceSize);
+      OutputShape[idx] = sliceDimInfo.outputSize;
+      newSizes[idx] = sliceDimInfo.sliceSize;
+    }
+  }
+  Value newPadOutput;
+  auto outputElType =
+      getElementTypeOrSelf(genericOp.getDpsInits()[0].getType());
+  if (isGenericOutsNotUsed(genericOp)) {
+    newPadOutput =
+        tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
+
+  } else {
+
+    auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
+    newPadOutput = tensor::PadOp::create(
+        rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads,
+        outputHighPads, paddingValue, /*nofold=*/false);
+  }
+
+  auto newGenericOp = linalg::GenericOp::create(
+      rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput},
+      genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
+      /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
+  rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
+                             newGenericOp.getRegion().begin());
+
+  auto extractOp = tensor::ExtractSliceOp::create(
+      rewriter, loc,
+      newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
+      outputLowPads, newSizes, newStrides);
+  Value extractRes = extractOp.getResult();
+
+  return std::make_tuple(newGenericOp, extractRes);
+}
+
+class PushDownExtractSliceOpThroughGenericOp final
+    : public OpRewritePattern<GenericOp> {
+public:
+  PushDownExtractSliceOpThroughGenericOp(MLIRContext *context,
+                                         ControlPropagationFn fun)
+      : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
+
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    auto genericAndRepl =
+        pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn);
+    if (failed(genericAndRepl))
+      return failure();
+    rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
+    return success();
+  }
+
+private:
+  ControlPropagationFn controlFn;
+};
+
 } // namespace
 
 void mlir::linalg::populateDataLayoutPropagationPatterns(
@@ -1247,3 +1509,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
               PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
           patterns.getContext(), controlPackUnPackPropagation);
 }
+
+void mlir::linalg::populateExtractSlicePropagationPatterns(
+    RewritePatternSet &patterns,
+    const ControlPropagationFn &controlPackUnPackPropagation) {
+  patterns.insert<PushDownExtractSliceOpThroughGenericOp>(
+      patterns.getContext(), controlPackUnPackPropagation);
+}
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index cc26fa48abf4b..723eecb52351b 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1447,3 +1447,113 @@ func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %ar
 // CHECK:         %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]]
 // CHECK-SAME:    into %[[ARG1]]
 // CHECK:         return %[[UNPACK2]] : tensor<?x64xf32>
+
+// -----
+
+module {
+  func.func @push_extract_through_generic(%arg0: tensor<128x7x128xf32>, %arg1: tensor<?x5x3x128xf32>, %arg2: tensor<?x5x128xbf16>, %arg3: index) -> tensor<?x5x128xbf16> {
+    %extracted_slice = tensor.extract_slice %arg0[0, 0, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32>
+    %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor<?x5x3x128xf32>) outs(%arg2 : tensor<?x5x128xbf16>) {
+    ^bb0(%in: f32, %in_0: f32, %out: bf16):
+      %1 = arith.truncf %in : f32 to bf16
+      linalg.yield %1 : bf16
+    } -> tensor<?x5x128xbf16>
+    return %0 : tensor<?x5x128xbf16>
+  }
+}
+
+// CHECK-LABEL: func.func @push_extract_through_generic
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]
+// CHECK:         %[[POISON:.+]] = ub.poison : f32
+// CHECK:         %[[PADDED:.+]] = tensor.pad %arg1
+// CHECK:           tensor.yield %[[POISON]] : f32
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<128x5x128xbf16>
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:    ins(%[[ARG0]], %[[PADDED]]   
+// CHECK-SAME:    outs(%[[EMPTY]]
+// CHECK:         %[[EXTRACT:.+]] = tensor.extract_slice %3[%[[ARG3]], 0, 0] [%[[ARG3]], 5, 128] [1, 1, 1] : tensor<128x5x128xbf16> to tensor<?x5x128xbf16>
+// CHECK:         return %[[EXTRACT]]
+
+// -----
+
+func.func @nopush_extract_through_generic_nodimexpr1(%arg0: tensor<128x7x128xf32>, %arg1: tensor<?x5x3x128xf32>, %arg2: tensor<?x5x128xbf16>, %arg3: index) -> tensor<?x5x128xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[0, %arg3, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor<?x5x3x128xf32>) outs(%arg2 : tensor<?x5x128xbf16>) {
+  ^bb0(%in: f32, %in_0: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    linalg.yield %1 : bf16
+  } -> tensor<?x5x128xbf16>
+  return %0 : tensor<?x5x128xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr1
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK:         return %[[GENERIC]]          
+
+// -----
+
+func.func @nopush_extract_through_generic_nodimexpr2(%arg0: tensor<128x?x128xf32>, %arg1: tensor<128x5x3x128xf32>, %arg2: tensor<128x?x128xbf16>, %arg3: index) -> tensor<128x?x128xbf16> {
+  %extracted_slice = tensor.extract_slice %arg1[0, %arg3, 0, 0] [128, %arg3, 3, 128] [1, 1, 1, 1] : tensor<128x5x3x128xf32> to tensor<128x?x3x128xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %extracted_slice : tensor<128x?x128xf32>, tensor<128x?x3x128xf32>) outs(%arg2 : tensor<128x?x128xbf16>) {
+  ^bb0(%in: f32, %in_0: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    linalg.yield %1 : bf16
+  } -> tensor<128x?x128xbf16>
+  return %0 : tensor<128x?x128xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr2
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK:         return %[[GENERIC]]   
+
+// -----
+
+func.func @push_redcutionextract_through_generic_withoutsused_2(%arg0: tensor<128x128xf32>, %arg1: tensor<?xbf16>, %arg2: index) -> tensor<?xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<?x?xf32>) outs(%arg1 : tensor<?xbf16>) {
+  ^bb0(%in: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    %2 = arith.addf %1, %out : bf16
+    linalg.yield %2 : bf16
+  } -> tensor<?xbf16>
+  return %0 : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @push_redcutionextract_through_generic_withoutsused_2
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK:         %[[POISON_BF16:.+]] = ub.poison : bf16
+// CHECK:         %[[POISON_F32:.+]] = ub.poison : f32
+// CHECK:         %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], %[[ARG2]]] [%[[ARG2]], %[[ARG2]]] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+// CHECK:         %[[PADDED:.+]] = tensor.pad %[[EXTRACT]]
+// CHECK:           tensor.yield %[[POISON_F32]] : f32
+// CHECK:         %[[APPLY2:.+]] = affine.apply #map()[%[[ARG2]]]
+// CHECK:         %[[PADDED1:.+]] = tensor.pad %[[ARG1]] low[%[[ARG2]]] high[%[[APPLY2]]]
+// CHECK:           tensor.yield %[[POISON_BF16]] : bf16
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:    ins(%[[PADDED]]
+// CHECK-SAME:    outs(%[[PADDED1]]
+// CHECK:         %[[EXTRACT1:.+]] = tensor.extract_slice %[[GENERIC]][%[[ARG2]]] [%[[ARG2]]] [1] : tensor<?xbf16> to tensor<?xbf16>
+// CHECK:         return %[[EXTRACT1]]
+
+
+// -----
+
+func.func @nopush_rankreducingextract(%arg0: tensor<128x128x128xf32>, %arg1: tensor<?xbf16>, %arg2: index) -> tensor<?xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[0, %arg2, %arg2] [1, %arg2, %arg2] [1, 1, 1] : tensor<128x128x128xf32> to tensor<?x?xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<?x?xf32>) outs(%arg1 : tensor<?xbf16>) {
+  ^bb0(%in: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    %2 = arith.addf %1, %out : bf16
+    linalg.yield %2 : bf16
+  } -> tensor<?xbf16>
+  return %0 : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_rankreducingextract
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK:         return %[[GENERIC]]   
diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index d0700f9a4f1a4..449d28fc528b1 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -34,6...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 18, 2025

@llvm/pr-subscribers-mlir

Author: Nirvedh Meshram (nirvedhmeshram)

Changes

This PR adds a datalayout propagation pattern to push down extract slice through generic op. It adds a different populate function since there may be conditions where a user doesn't want this pattern but wants the other patterns e.g. extract slice is used as a special op when it comes to tiling.


Patch is 20.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/154162.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+5)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+269)
  • (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+110)
  • (modified) mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp (+2)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d4ffe0a91fcfe..046920f5ccd54 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1914,6 +1914,11 @@ void populateDataLayoutPropagationPatterns(
     RewritePatternSet &patterns,
     const ControlPropagationFn &controlPackUnPackPropagation);
 
+/// Patterns to bubble up or down extract slice across other operations.
+void populateExtractSlicePropagationPatterns(
+    RewritePatternSet &patterns,
+    const ControlPropagationFn &controlPackUnPackPropagation);
+
 /// Pattern to remove dead operands and results of `linalg.generic` operations.
 /// This is a pattern wrapper for `deduplicateOperandsAndRemoveDeadResults`.
 void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 0a9c1766425bd..16d6ac23b0208 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -6,10 +6,12 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/Dominance.h"
 #include "llvm/ADT/SetOperations.h"
@@ -1236,6 +1238,266 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
   ControlPropagationFn controlFn;
 };
 
+// This struct contains infomation about extract_slice dims.
+struct SliceDimInfo {
+  OpFoldResult offset;
+  OpFoldResult sliceSize;
+  OpFoldResult outputSize;
+};
+
+/// Return the first input extract slice operand, if present, for the current
+/// generic op.
+static FailureOr<std::tuple<OpOperand *, unsigned>>
+getSliceOperandAndIndex(GenericOp genericOp) {
+  OpOperand *sliceOperand = nullptr;
+  unsigned operandIndex;
+  for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+    auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
+    if (!extractOp)
+      continue;
+    sliceOperand = operand;
+    operandIndex = idx;
+    break;
+  }
+  if (!sliceOperand) {
+    return failure();
+  }
+  return std::make_tuple(sliceOperand, operandIndex);
+}
+
+static FailureOr<std::tuple<llvm::DenseMap<int64_t, SliceDimInfo>, bool>>
+getNonZeroSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand,
+                       tensor::ExtractSliceOp producerSliceOp) {
+  llvm::DenseMap<int64_t, SliceDimInfo> nonZeroSliceDimMap;
+  bool hasNonZeroReductionDimSlice = false;
+  SmallVector<utils::IteratorType> iterators =
+      genericOp.getIteratorTypesArray();
+  SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+
+  SmallVector<OpFoldResult> shape = llvm::map_to_vector(
+      producerSliceOp.getSourceType().getShape(),
+      [&](int64_t sz) -> OpFoldResult {
+        return getAsIndexOpFoldResult(genericOp.getContext(), sz);
+      });
+
+  for (auto [idx, expr] : llvm::enumerate(
+           genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
+    if (isConstantIntValue(offsets[idx], 0) &&
+        isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
+      continue;
+    }
+    if (!isa<AffineDimExpr>(expr)) {
+      return failure();
+    }
+    SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
+    int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
+    nonZeroSliceDimMap[dimPos] = sliceDimInfo;
+    if (iterators[dimPos] == utils::IteratorType::reduction) {
+      hasNonZeroReductionDimSlice = true;
+    }
+  }
+  // Next check if the dims with non zero slice info are used as non
+  // AffineDimExpr and if they are then bail-out.
+  for (OpOperand &operand : genericOp->getOpOperands()) {
+    if (operand == *sliceOperand) {
+      continue;
+    }
+    AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
+    if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) {
+          if (isa<AffineDimExpr>(expr)) {
+            return false;
+          }
+          WalkResult status = expr.walk([&](AffineExpr expr) {
+            if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+              if (nonZeroSliceDimMap.count(dimExpr.getPosition()) != 0) {
+                return WalkResult::interrupt();
+              }
+            }
+            return WalkResult::advance();
+          });
+          if (status.wasInterrupted()) {
+            return true;
+          }
+          return false;
+        })) {
+      return failure();
+    }
+  }
+  return std::make_tuple(nonZeroSliceDimMap, hasNonZeroReductionDimSlice);
+}
+
+static FailureOr<std::tuple<GenericOp, Value>>
+pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
+                                       GenericOp genericOp,
+                                       ControlPropagationFn controlFn) {
+  if (genericOp.getNumResults() != 1)
+    return failure();
+  if (hasGatherSemantics(genericOp))
+    return failure();
+  // Collect the unPacked operand, if present.
+  auto maybeSliceOperandAndIndex = getSliceOperandAndIndex(genericOp);
+  if (failed(maybeSliceOperandAndIndex))
+    return failure();
+  OpOperand *sliceOperand = std::get<0>(*maybeSliceOperandAndIndex);
+  unsigned OperandIndex = std::get<1>(*maybeSliceOperandAndIndex);
+
+  if (!controlFn(sliceOperand))
+    return failure();
+
+  tensor::ExtractSliceOp producerSliceOp =
+      sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
+  assert(producerSliceOp && "expect a valid UnPackOp");
+
+  if (producerSliceOp.getSource().getType().getRank() !=
+      producerSliceOp.getResult().getType().getRank()) {
+    return failure();
+  }
+
+  SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides();
+  if (!areAllConstantIntValue(strides, 1))
+    return failure();
+
+  SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+
+  // check if we can support the propagation of this extractSlice
+  // through the generic op and if so return the dimensions that
+
+  auto maybeNonZeroSliceDimMap =
+      getNonZeroSliceDimInfo(genericOp, sliceOperand, producerSliceOp);
+
+  if (failed(maybeNonZeroSliceDimMap)) {
+    return failure();
+  }
+
+  auto nonZeroSliceDimMap = std::get<0>(*maybeNonZeroSliceDimMap);
+  bool hasNonZeroReductionDimSlice = std::get<1>(*maybeNonZeroSliceDimMap);
+
+  // Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
+  Location loc = genericOp->getLoc();
+  AffineExpr dim0, dim1;
+  bindDims(rewriter.getContext(), dim0, dim1);
+  auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
+  auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
+    return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap,
+                                                 {v1, v2});
+  };
+
+  MLIRContext *ctx = genericOp.getContext();
+  SmallVector<Value> paddedInputs;
+  for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+    if (idx == OperandIndex && !hasNonZeroReductionDimSlice) {
+      paddedInputs.push_back(producerSliceOp.getSource());
+      continue;
+    }
+    AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
+    SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(),
+                                             getAsIndexOpFoldResult(ctx, 0));
+    SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(),
+                                              getAsIndexOpFoldResult(ctx, 0));
+    for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) {
+      if (!isa<AffineDimExpr>(expr)) {
+        continue;
+      }
+      AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
+      if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
+        SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
+        operandLowPads[idx] = sliceDimInfo.offset;
+        operandHighPads[idx] =
+            sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+                sliceDimInfo.sliceSize);
+      }
+    }
+    auto paddingValue = ub::PoisonOp::create(
+        rewriter, loc, getElementTypeOrSelf(operand->get().getType()));
+    auto paddedOperand = tensor::PadOp::create(
+        rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads,
+        paddingValue, /*nofold=*/false);
+    paddedInputs.push_back(paddedOperand);
+  }
+  AffineMap outputIndexingMap =
+      genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
+
+  auto outputShapeType =
+      llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType());
+  SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector(
+      outputShapeType.getShape(),
+      [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); });
+  SmallVector<OpFoldResult> newSizes = OutputShape;
+  SmallVector<OpFoldResult> outputLowPads(outputIndexingMap.getNumResults(),
+                                          getAsIndexOpFoldResult(ctx, 0));
+  SmallVector<OpFoldResult> outputHighPads(outputIndexingMap.getNumResults(),
+                                           getAsIndexOpFoldResult(ctx, 0));
+  SmallVector<OpFoldResult> newStrides(outputIndexingMap.getNumResults(),
+                                       getAsIndexOpFoldResult(ctx, 1));
+  for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) {
+    if (!isa<AffineDimExpr>(expr)) {
+      continue;
+    }
+    AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
+    if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
+      SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
+      outputLowPads[idx] = sliceDimInfo.offset;
+      outputHighPads[idx] =
+          sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+              sliceDimInfo.sliceSize);
+      OutputShape[idx] = sliceDimInfo.outputSize;
+      newSizes[idx] = sliceDimInfo.sliceSize;
+    }
+  }
+  Value newPadOutput;
+  auto outputElType =
+      getElementTypeOrSelf(genericOp.getDpsInits()[0].getType());
+  if (isGenericOutsNotUsed(genericOp)) {
+    newPadOutput =
+        tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
+
+  } else {
+
+    auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
+    newPadOutput = tensor::PadOp::create(
+        rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads,
+        outputHighPads, paddingValue, /*nofold=*/false);
+  }
+
+  auto newGenericOp = linalg::GenericOp::create(
+      rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput},
+      genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
+      /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
+  rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
+                             newGenericOp.getRegion().begin());
+
+  auto extractOp = tensor::ExtractSliceOp::create(
+      rewriter, loc,
+      newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
+      outputLowPads, newSizes, newStrides);
+  Value extractRes = extractOp.getResult();
+
+  return std::make_tuple(newGenericOp, extractRes);
+}
+
+class PushDownExtractSliceOpThroughGenericOp final
+    : public OpRewritePattern<GenericOp> {
+public:
+  PushDownExtractSliceOpThroughGenericOp(MLIRContext *context,
+                                         ControlPropagationFn fun)
+      : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
+
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    auto genericAndRepl =
+        pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn);
+    if (failed(genericAndRepl))
+      return failure();
+    rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
+    return success();
+  }
+
+private:
+  ControlPropagationFn controlFn;
+};
+
 } // namespace
 
 void mlir::linalg::populateDataLayoutPropagationPatterns(
@@ -1247,3 +1509,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
               PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
           patterns.getContext(), controlPackUnPackPropagation);
 }
+
+void mlir::linalg::populateExtractSlicePropagationPatterns(
+    RewritePatternSet &patterns,
+    const ControlPropagationFn &controlPackUnPackPropagation) {
+  patterns.insert<PushDownExtractSliceOpThroughGenericOp>(
+      patterns.getContext(), controlPackUnPackPropagation);
+}
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index cc26fa48abf4b..723eecb52351b 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1447,3 +1447,113 @@ func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %ar
 // CHECK:         %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]]
 // CHECK-SAME:    into %[[ARG1]]
 // CHECK:         return %[[UNPACK2]] : tensor<?x64xf32>
+
+// -----
+
+module {
+  func.func @push_extract_through_generic(%arg0: tensor<128x7x128xf32>, %arg1: tensor<?x5x3x128xf32>, %arg2: tensor<?x5x128xbf16>, %arg3: index) -> tensor<?x5x128xbf16> {
+    %extracted_slice = tensor.extract_slice %arg0[0, 0, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32>
+    %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor<?x5x3x128xf32>) outs(%arg2 : tensor<?x5x128xbf16>) {
+    ^bb0(%in: f32, %in_0: f32, %out: bf16):
+      %1 = arith.truncf %in : f32 to bf16
+      linalg.yield %1 : bf16
+    } -> tensor<?x5x128xbf16>
+    return %0 : tensor<?x5x128xbf16>
+  }
+}
+
+// CHECK-LABEL: func.func @push_extract_through_generic
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]
+// CHECK:         %[[POISON:.+]] = ub.poison : f32
+// CHECK:         %[[PADDED:.+]] = tensor.pad %arg1
+// CHECK:           tensor.yield %[[POISON]] : f32
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<128x5x128xbf16>
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:    ins(%[[ARG0]], %[[PADDED]]   
+// CHECK-SAME:    outs(%[[EMPTY]]
+// CHECK:         %[[EXTRACT:.+]] = tensor.extract_slice %3[%[[ARG3]], 0, 0] [%[[ARG3]], 5, 128] [1, 1, 1] : tensor<128x5x128xbf16> to tensor<?x5x128xbf16>
+// CHECK:         return %[[EXTRACT]]
+
+// -----
+
+func.func @nopush_extract_through_generic_nodimexpr1(%arg0: tensor<128x7x128xf32>, %arg1: tensor<?x5x3x128xf32>, %arg2: tensor<?x5x128xbf16>, %arg3: index) -> tensor<?x5x128xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[0, %arg3, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor<?x5x3x128xf32>) outs(%arg2 : tensor<?x5x128xbf16>) {
+  ^bb0(%in: f32, %in_0: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    linalg.yield %1 : bf16
+  } -> tensor<?x5x128xbf16>
+  return %0 : tensor<?x5x128xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr1
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK:         return %[[GENERIC]]          
+
+// -----
+
+func.func @nopush_extract_through_generic_nodimexpr2(%arg0: tensor<128x?x128xf32>, %arg1: tensor<128x5x3x128xf32>, %arg2: tensor<128x?x128xbf16>, %arg3: index) -> tensor<128x?x128xbf16> {
+  %extracted_slice = tensor.extract_slice %arg1[0, %arg3, 0, 0] [128, %arg3, 3, 128] [1, 1, 1, 1] : tensor<128x5x3x128xf32> to tensor<128x?x3x128xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %extracted_slice : tensor<128x?x128xf32>, tensor<128x?x3x128xf32>) outs(%arg2 : tensor<128x?x128xbf16>) {
+  ^bb0(%in: f32, %in_0: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    linalg.yield %1 : bf16
+  } -> tensor<128x?x128xbf16>
+  return %0 : tensor<128x?x128xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr2
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK:         return %[[GENERIC]]   
+
+// -----
+
+func.func @push_redcutionextract_through_generic_withoutsused_2(%arg0: tensor<128x128xf32>, %arg1: tensor<?xbf16>, %arg2: index) -> tensor<?xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<?x?xf32>) outs(%arg1 : tensor<?xbf16>) {
+  ^bb0(%in: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    %2 = arith.addf %1, %out : bf16
+    linalg.yield %2 : bf16
+  } -> tensor<?xbf16>
+  return %0 : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @push_redcutionextract_through_generic_withoutsused_2
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK:         %[[POISON_BF16:.+]] = ub.poison : bf16
+// CHECK:         %[[POISON_F32:.+]] = ub.poison : f32
+// CHECK:         %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], %[[ARG2]]] [%[[ARG2]], %[[ARG2]]] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+// CHECK:         %[[PADDED:.+]] = tensor.pad %[[EXTRACT]]
+// CHECK:           tensor.yield %[[POISON_F32]] : f32
+// CHECK:         %[[APPLY2:.+]] = affine.apply #map()[%[[ARG2]]]
+// CHECK:         %[[PADDED1:.+]] = tensor.pad %[[ARG1]] low[%[[ARG2]]] high[%[[APPLY2]]]
+// CHECK:           tensor.yield %[[POISON_BF16]] : bf16
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:    ins(%[[PADDED]]
+// CHECK-SAME:    outs(%[[PADDED1]]
+// CHECK:         %[[EXTRACT1:.+]] = tensor.extract_slice %[[GENERIC]][%[[ARG2]]] [%[[ARG2]]] [1] : tensor<?xbf16> to tensor<?xbf16>
+// CHECK:         return %[[EXTRACT1]]
+
+
+// -----
+
+func.func @nopush_rankreducingextract(%arg0: tensor<128x128x128xf32>, %arg1: tensor<?xbf16>, %arg2: index) -> tensor<?xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[0, %arg2, %arg2] [1, %arg2, %arg2] [1, 1, 1] : tensor<128x128x128xf32> to tensor<?x?xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<?x?xf32>) outs(%arg1 : tensor<?xbf16>) {
+  ^bb0(%in: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    %2 = arith.addf %1, %out : bf16
+    linalg.yield %2 : bf16
+  } -> tensor<?xbf16>
+  return %0 : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_rankreducingextract
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK:         return %[[GENERIC]]   
diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index d0700f9a4f1a4..449d28fc528b1 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -34,6...
[truncated]

@nirvedhmeshram nirvedhmeshram force-pushed the push_down_extract branch 2 times, most recently from c2a5eb9 to 7b9d96e Compare August 18, 2025 17:40
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the overall logic makes sense. Just have a few clarifying comments and clean ups.

Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just one comment about the tests

Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
nirvedhmeshram added a commit to nirvedhmeshram/iree that referenced this pull request Aug 26, 2025
nirvedhmeshram added a commit to nirvedhmeshram/iree that referenced this pull request Aug 27, 2025
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
@nirvedhmeshram nirvedhmeshram merged commit de6f750 into llvm:main Aug 27, 2025
9 checks passed
nirvedhmeshram added a commit to nirvedhmeshram/iree that referenced this pull request Aug 27, 2025
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
nirvedhmeshram added a commit to nirvedhmeshram/iree that referenced this pull request Aug 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants