Skip to content

Commit de6f750

Browse files
[Linalg] Add pattern to push down extract slice through linalg generic op (#154162)
This PR adds a datalayout propagation pattern to push down extract slice through generic op. It adds a different populate function since there may be conditions where a user doesn't want this pattern but wants the other patterns e.g. extract slice is used as a special op when it comes to tiling. --------- Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
1 parent 3addfd0 commit de6f750

File tree

4 files changed

+395
-0
lines changed

4 files changed

+395
-0
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1918,6 +1918,11 @@ void populateDataLayoutPropagationPatterns(
19181918
RewritePatternSet &patterns,
19191919
const ControlPropagationFn &controlPackUnPackPropagation);
19201920

1921+
/// Patterns to sink extract slice across other operations.
1922+
void populateExtractSliceSinkingPatterns(
1923+
RewritePatternSet &patterns,
1924+
const ControlPropagationFn &controlPackUnPackPropagation);
1925+
19211926
/// Pattern to remove dead operands and results of `linalg.generic` operations.
19221927
/// This is a pattern wrapper for `deduplicateOperandsAndRemoveDeadResults`.
19231928
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
910
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1011
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1112
#include "mlir/Dialect/Linalg/Utils/Utils.h"
1213
#include "mlir/Dialect/Tensor/IR/Tensor.h"
14+
#include "mlir/Dialect/UB/IR/UBOps.h"
1315
#include "mlir/Dialect/Utils/IndexingUtils.h"
1416
#include "mlir/IR/Dominance.h"
1517
#include "llvm/ADT/SetOperations.h"
@@ -1236,6 +1238,272 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
12361238
ControlPropagationFn controlFn;
12371239
};
12381240

1241+
// This struct contains infomation about extract_slice dims.
1242+
struct SliceDimInfo {
1243+
OpFoldResult offset;
1244+
OpFoldResult sliceSize;
1245+
OpFoldResult outputSize;
1246+
};
1247+
1248+
/// Return the first input extract slice operand, if present, for the current
1249+
/// generic op.
1250+
static FailureOr<OpOperand *> getSliceOperand(GenericOp genericOp) {
1251+
OpOperand *sliceOperand = nullptr;
1252+
for (auto operand : genericOp.getDpsInputOperands()) {
1253+
auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
1254+
if (!extractOp)
1255+
continue;
1256+
sliceOperand = operand;
1257+
break;
1258+
}
1259+
if (!sliceOperand) {
1260+
return failure();
1261+
}
1262+
return sliceOperand;
1263+
}
1264+
1265+
// Return a map of dims that have partial slices on them so that other operands
1266+
// can use this information. Also return a bool mentioning if a reduction dim
1267+
// has a non full slice as that can be used to fold the original extract slice.
1268+
static FailureOr<llvm::DenseMap<int64_t, SliceDimInfo>>
1269+
getPartialSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand) {
1270+
tensor::ExtractSliceOp producerSliceOp =
1271+
sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
1272+
assert(producerSliceOp && "expect a valid ExtractSliceOp");
1273+
llvm::DenseMap<int64_t, SliceDimInfo> partialSliceDimMap;
1274+
SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
1275+
SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
1276+
1277+
SmallVector<OpFoldResult> shape = getAsIndexOpFoldResult(
1278+
genericOp.getContext(), producerSliceOp.getSourceType().getShape());
1279+
1280+
for (auto [idx, expr] : llvm::enumerate(
1281+
genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
1282+
// If we have a full slice in a dimension then we dont need to add it to
1283+
// the partial slice map.
1284+
if (isConstantIntValue(offsets[idx], 0) &&
1285+
isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
1286+
continue;
1287+
}
1288+
// We only support partial slices of AffineDimExprs so bail-out if thats not
1289+
// the case.
1290+
if (!isa<AffineDimExpr>(expr)) {
1291+
return failure();
1292+
}
1293+
SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
1294+
int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
1295+
partialSliceDimMap[dimPos] = sliceDimInfo;
1296+
}
1297+
// Next check if the dims with partial slice info are used in non
1298+
// AffineDimExpr in other operands and if they are then bail-out.
1299+
for (OpOperand &operand : genericOp->getOpOperands()) {
1300+
if (operand == *sliceOperand) {
1301+
continue;
1302+
}
1303+
AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
1304+
if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) {
1305+
if (isa<AffineDimExpr>(expr)) {
1306+
return false;
1307+
}
1308+
WalkResult status = expr.walk([&](AffineExpr expr) {
1309+
if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
1310+
if (partialSliceDimMap.contains(dimExpr.getPosition())) {
1311+
return WalkResult::interrupt();
1312+
}
1313+
}
1314+
return WalkResult::advance();
1315+
});
1316+
if (status.wasInterrupted()) {
1317+
return true;
1318+
}
1319+
return false;
1320+
})) {
1321+
return failure();
1322+
}
1323+
}
1324+
return partialSliceDimMap;
1325+
}
1326+
1327+
static FailureOr<std::tuple<GenericOp, Value>>
1328+
pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
1329+
GenericOp genericOp,
1330+
ControlPropagationFn controlFn) {
1331+
if (genericOp.getNumResults() != 1)
1332+
return rewriter.notifyMatchFailure(
1333+
genericOp, "propagation through multi-result generic is unsupported.");
1334+
if (hasGatherSemantics(genericOp))
1335+
return rewriter.notifyMatchFailure(
1336+
genericOp,
1337+
"propagation through generic with gather semantics is unsupported.");
1338+
// Collect the sliced operand, if present.
1339+
auto maybeSliceOperand = getSliceOperand(genericOp);
1340+
if (failed(maybeSliceOperand))
1341+
return failure();
1342+
OpOperand *sliceOperand = *maybeSliceOperand;
1343+
unsigned OperandIndex = sliceOperand->getOperandNumber();
1344+
1345+
if (!controlFn(sliceOperand))
1346+
return failure();
1347+
1348+
tensor::ExtractSliceOp producerSliceOp =
1349+
sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
1350+
assert(producerSliceOp && "expect a valid ExtractSliceOp");
1351+
1352+
if (producerSliceOp.getSource().getType().getRank() !=
1353+
producerSliceOp.getResult().getType().getRank()) {
1354+
return rewriter.notifyMatchFailure(
1355+
genericOp,
1356+
"propagation of rank-reducing extract slice is unsupported.");
1357+
}
1358+
1359+
SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides();
1360+
if (!areAllConstantIntValue(strides, 1))
1361+
return rewriter.notifyMatchFailure(
1362+
genericOp, "propagation of strided extract slice is unsupported.");
1363+
1364+
// check if we can support the propagation of this extractSlice
1365+
// through the generic op and if so return the dimensions that
1366+
1367+
auto maybePartialSliceDimMap =
1368+
getPartialSliceDimInfo(genericOp, sliceOperand);
1369+
1370+
if (failed(maybePartialSliceDimMap)) {
1371+
return failure();
1372+
}
1373+
1374+
auto partialSliceDimMap = *maybePartialSliceDimMap;
1375+
1376+
SmallVector<utils::IteratorType> iterators =
1377+
genericOp.getIteratorTypesArray();
1378+
bool hasPartialReductionDimSlice =
1379+
llvm::any_of(partialSliceDimMap, [&](const auto &slice) {
1380+
int64_t sliceDim = slice.first;
1381+
return iterators[sliceDim] == utils::IteratorType::reduction;
1382+
});
1383+
1384+
// Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
1385+
Location loc = genericOp->getLoc();
1386+
AffineExpr dim0, dim1;
1387+
bindDims(rewriter.getContext(), dim0, dim1);
1388+
auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
1389+
auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
1390+
return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap,
1391+
{v1, v2});
1392+
};
1393+
1394+
MLIRContext *ctx = genericOp.getContext();
1395+
SmallVector<Value> paddedInputs;
1396+
for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
1397+
if (idx == OperandIndex && !hasPartialReductionDimSlice) {
1398+
paddedInputs.push_back(producerSliceOp.getSource());
1399+
continue;
1400+
}
1401+
AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
1402+
SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(),
1403+
getAsIndexOpFoldResult(ctx, 0));
1404+
SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(),
1405+
getAsIndexOpFoldResult(ctx, 0));
1406+
for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) {
1407+
if (!isa<AffineDimExpr>(expr)) {
1408+
continue;
1409+
}
1410+
AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
1411+
if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
1412+
continue;
1413+
}
1414+
SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
1415+
operandLowPads[idx] = sliceDimInfo.offset;
1416+
operandHighPads[idx] =
1417+
sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
1418+
sliceDimInfo.sliceSize);
1419+
}
1420+
auto paddingValue = ub::PoisonOp::create(
1421+
rewriter, loc, getElementTypeOrSelf(operand->get().getType()));
1422+
auto paddedOperand = tensor::PadOp::create(
1423+
rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads,
1424+
paddingValue, /*nofold=*/false);
1425+
paddedInputs.push_back(paddedOperand);
1426+
}
1427+
AffineMap outputIndexingMap =
1428+
genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
1429+
1430+
auto outputShapeType =
1431+
llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType());
1432+
SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector(
1433+
outputShapeType.getShape(),
1434+
[&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); });
1435+
SmallVector<OpFoldResult> newSizes = OutputShape;
1436+
SmallVector<OpFoldResult> outputLowPads(outputIndexingMap.getNumResults(),
1437+
getAsIndexOpFoldResult(ctx, 0));
1438+
SmallVector<OpFoldResult> outputHighPads(outputIndexingMap.getNumResults(),
1439+
getAsIndexOpFoldResult(ctx, 0));
1440+
SmallVector<OpFoldResult> newStrides(outputIndexingMap.getNumResults(),
1441+
getAsIndexOpFoldResult(ctx, 1));
1442+
for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) {
1443+
if (!isa<AffineDimExpr>(expr)) {
1444+
continue;
1445+
}
1446+
AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
1447+
if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
1448+
continue;
1449+
}
1450+
SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
1451+
outputLowPads[idx] = sliceDimInfo.offset;
1452+
outputHighPads[idx] = sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
1453+
sliceDimInfo.sliceSize);
1454+
OutputShape[idx] = sliceDimInfo.outputSize;
1455+
newSizes[idx] = sliceDimInfo.sliceSize;
1456+
}
1457+
Value newPadOutput;
1458+
auto outputElType =
1459+
getElementTypeOrSelf(genericOp.getDpsInits()[0].getType());
1460+
if (isGenericOutsNotUsed(genericOp)) {
1461+
newPadOutput =
1462+
tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
1463+
} else {
1464+
auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
1465+
newPadOutput = tensor::PadOp::create(
1466+
rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads,
1467+
outputHighPads, paddingValue, /*nofold=*/false);
1468+
}
1469+
1470+
auto newGenericOp = linalg::GenericOp::create(
1471+
rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput},
1472+
genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
1473+
/*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
1474+
rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
1475+
newGenericOp.getRegion().begin());
1476+
1477+
auto extractOp = tensor::ExtractSliceOp::create(
1478+
rewriter, loc,
1479+
newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
1480+
outputLowPads, newSizes, newStrides);
1481+
Value extractRes = extractOp.getResult();
1482+
1483+
return std::make_tuple(newGenericOp, extractRes);
1484+
}
1485+
1486+
class PushDownExtractSliceOpThroughGenericOp final
1487+
: public OpRewritePattern<GenericOp> {
1488+
public:
1489+
PushDownExtractSliceOpThroughGenericOp(MLIRContext *context,
1490+
ControlPropagationFn fun)
1491+
: OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
1492+
1493+
LogicalResult matchAndRewrite(GenericOp genericOp,
1494+
PatternRewriter &rewriter) const override {
1495+
auto genericAndRepl =
1496+
pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn);
1497+
if (failed(genericAndRepl))
1498+
return failure();
1499+
rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
1500+
return success();
1501+
}
1502+
1503+
private:
1504+
ControlPropagationFn controlFn;
1505+
};
1506+
12391507
} // namespace
12401508

12411509
void mlir::linalg::populateDataLayoutPropagationPatterns(
@@ -1247,3 +1515,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
12471515
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
12481516
patterns.getContext(), controlPackUnPackPropagation);
12491517
}
1518+
1519+
void mlir::linalg::populateExtractSliceSinkingPatterns(
1520+
RewritePatternSet &patterns,
1521+
const ControlPropagationFn &controlPackUnPackPropagation) {
1522+
patterns.insert<PushDownExtractSliceOpThroughGenericOp>(
1523+
patterns.getContext(), controlPackUnPackPropagation);
1524+
}

0 commit comments

Comments
 (0)