6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
8
9
+ #include " mlir/Dialect/Affine/IR/AffineOps.h"
9
10
#include " mlir/Dialect/Linalg/IR/Linalg.h"
10
11
#include " mlir/Dialect/Linalg/Transforms/Transforms.h"
11
12
#include " mlir/Dialect/Linalg/Utils/Utils.h"
12
13
#include " mlir/Dialect/Tensor/IR/Tensor.h"
14
+ #include " mlir/Dialect/UB/IR/UBOps.h"
13
15
#include " mlir/Dialect/Utils/IndexingUtils.h"
14
16
#include " mlir/IR/Dominance.h"
15
17
#include " llvm/ADT/SetOperations.h"
@@ -1236,6 +1238,272 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
1236
1238
ControlPropagationFn controlFn;
1237
1239
};
1238
1240
1241
+ // This struct contains infomation about extract_slice dims.
1242
+ struct SliceDimInfo {
1243
+ OpFoldResult offset;
1244
+ OpFoldResult sliceSize;
1245
+ OpFoldResult outputSize;
1246
+ };
1247
+
1248
+ // / Return the first input extract slice operand, if present, for the current
1249
+ // / generic op.
1250
+ static FailureOr<OpOperand *> getSliceOperand (GenericOp genericOp) {
1251
+ OpOperand *sliceOperand = nullptr ;
1252
+ for (auto operand : genericOp.getDpsInputOperands ()) {
1253
+ auto extractOp = operand->get ().getDefiningOp <tensor::ExtractSliceOp>();
1254
+ if (!extractOp)
1255
+ continue ;
1256
+ sliceOperand = operand;
1257
+ break ;
1258
+ }
1259
+ if (!sliceOperand) {
1260
+ return failure ();
1261
+ }
1262
+ return sliceOperand;
1263
+ }
1264
+
1265
+ // Return a map of dims that have partial slices on them so that other operands
1266
+ // can use this information. Also return a bool mentioning if a reduction dim
1267
+ // has a non full slice as that can be used to fold the original extract slice.
1268
+ static FailureOr<llvm::DenseMap<int64_t , SliceDimInfo>>
1269
+ getPartialSliceDimInfo (GenericOp genericOp, OpOperand *sliceOperand) {
1270
+ tensor::ExtractSliceOp producerSliceOp =
1271
+ sliceOperand->get ().getDefiningOp <tensor::ExtractSliceOp>();
1272
+ assert (producerSliceOp && " expect a valid ExtractSliceOp" );
1273
+ llvm::DenseMap<int64_t , SliceDimInfo> partialSliceDimMap;
1274
+ SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets ();
1275
+ SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes ();
1276
+
1277
+ SmallVector<OpFoldResult> shape = getAsIndexOpFoldResult (
1278
+ genericOp.getContext (), producerSliceOp.getSourceType ().getShape ());
1279
+
1280
+ for (auto [idx, expr] : llvm::enumerate (
1281
+ genericOp.getMatchingIndexingMap (sliceOperand).getResults ())) {
1282
+ // If we have a full slice in a dimension then we dont need to add it to
1283
+ // the partial slice map.
1284
+ if (isConstantIntValue (offsets[idx], 0 ) &&
1285
+ isEqualConstantIntOrValue (sizes[idx], shape[idx])) {
1286
+ continue ;
1287
+ }
1288
+ // We only support partial slices of AffineDimExprs so bail-out if thats not
1289
+ // the case.
1290
+ if (!isa<AffineDimExpr>(expr)) {
1291
+ return failure ();
1292
+ }
1293
+ SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
1294
+ int64_t dimPos = cast<AffineDimExpr>(expr).getPosition ();
1295
+ partialSliceDimMap[dimPos] = sliceDimInfo;
1296
+ }
1297
+ // Next check if the dims with partial slice info are used in non
1298
+ // AffineDimExpr in other operands and if they are then bail-out.
1299
+ for (OpOperand &operand : genericOp->getOpOperands ()) {
1300
+ if (operand == *sliceOperand) {
1301
+ continue ;
1302
+ }
1303
+ AffineMap IndexingMap = genericOp.getMatchingIndexingMap (&operand);
1304
+ if (llvm::any_of (IndexingMap.getResults (), [&](AffineExpr expr) {
1305
+ if (isa<AffineDimExpr>(expr)) {
1306
+ return false ;
1307
+ }
1308
+ WalkResult status = expr.walk ([&](AffineExpr expr) {
1309
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
1310
+ if (partialSliceDimMap.contains (dimExpr.getPosition ())) {
1311
+ return WalkResult::interrupt ();
1312
+ }
1313
+ }
1314
+ return WalkResult::advance ();
1315
+ });
1316
+ if (status.wasInterrupted ()) {
1317
+ return true ;
1318
+ }
1319
+ return false ;
1320
+ })) {
1321
+ return failure ();
1322
+ }
1323
+ }
1324
+ return partialSliceDimMap;
1325
+ }
1326
+
1327
+ static FailureOr<std::tuple<GenericOp, Value>>
1328
+ pushDownExtractSliceOpThroughGenericOp (RewriterBase &rewriter,
1329
+ GenericOp genericOp,
1330
+ ControlPropagationFn controlFn) {
1331
+ if (genericOp.getNumResults () != 1 )
1332
+ return rewriter.notifyMatchFailure (
1333
+ genericOp, " propagation through multi-result generic is unsupported." );
1334
+ if (hasGatherSemantics (genericOp))
1335
+ return rewriter.notifyMatchFailure (
1336
+ genericOp,
1337
+ " propagation through generic with gather semantics is unsupported." );
1338
+ // Collect the sliced operand, if present.
1339
+ auto maybeSliceOperand = getSliceOperand (genericOp);
1340
+ if (failed (maybeSliceOperand))
1341
+ return failure ();
1342
+ OpOperand *sliceOperand = *maybeSliceOperand;
1343
+ unsigned OperandIndex = sliceOperand->getOperandNumber ();
1344
+
1345
+ if (!controlFn (sliceOperand))
1346
+ return failure ();
1347
+
1348
+ tensor::ExtractSliceOp producerSliceOp =
1349
+ sliceOperand->get ().getDefiningOp <tensor::ExtractSliceOp>();
1350
+ assert (producerSliceOp && " expect a valid ExtractSliceOp" );
1351
+
1352
+ if (producerSliceOp.getSource ().getType ().getRank () !=
1353
+ producerSliceOp.getResult ().getType ().getRank ()) {
1354
+ return rewriter.notifyMatchFailure (
1355
+ genericOp,
1356
+ " propagation of rank-reducing extract slice is unsupported." );
1357
+ }
1358
+
1359
+ SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides ();
1360
+ if (!areAllConstantIntValue (strides, 1 ))
1361
+ return rewriter.notifyMatchFailure (
1362
+ genericOp, " propagation of strided extract slice is unsupported." );
1363
+
1364
+ // check if we can support the propagation of this extractSlice
1365
+ // through the generic op and if so return the dimensions that
1366
+
1367
+ auto maybePartialSliceDimMap =
1368
+ getPartialSliceDimInfo (genericOp, sliceOperand);
1369
+
1370
+ if (failed (maybePartialSliceDimMap)) {
1371
+ return failure ();
1372
+ }
1373
+
1374
+ auto partialSliceDimMap = *maybePartialSliceDimMap;
1375
+
1376
+ SmallVector<utils::IteratorType> iterators =
1377
+ genericOp.getIteratorTypesArray ();
1378
+ bool hasPartialReductionDimSlice =
1379
+ llvm::any_of (partialSliceDimMap, [&](const auto &slice) {
1380
+ int64_t sliceDim = slice.first ;
1381
+ return iterators[sliceDim] == utils::IteratorType::reduction;
1382
+ });
1383
+
1384
+ // Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
1385
+ Location loc = genericOp->getLoc ();
1386
+ AffineExpr dim0, dim1;
1387
+ bindDims (rewriter.getContext (), dim0, dim1);
1388
+ auto subMap = AffineMap::get (2 , 0 , {dim0 - dim1});
1389
+ auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
1390
+ return affine::makeComposedFoldedAffineApply (rewriter, loc, subMap,
1391
+ {v1, v2});
1392
+ };
1393
+
1394
+ MLIRContext *ctx = genericOp.getContext ();
1395
+ SmallVector<Value> paddedInputs;
1396
+ for (auto [idx, operand] : llvm::enumerate (genericOp.getDpsInputOperands ())) {
1397
+ if (idx == OperandIndex && !hasPartialReductionDimSlice) {
1398
+ paddedInputs.push_back (producerSliceOp.getSource ());
1399
+ continue ;
1400
+ }
1401
+ AffineMap IndexingMap = genericOp.getMatchingIndexingMap (operand);
1402
+ SmallVector<OpFoldResult> operandLowPads (IndexingMap.getNumResults (),
1403
+ getAsIndexOpFoldResult (ctx, 0 ));
1404
+ SmallVector<OpFoldResult> operandHighPads (IndexingMap.getNumResults (),
1405
+ getAsIndexOpFoldResult (ctx, 0 ));
1406
+ for (auto [idx, expr] : llvm::enumerate (IndexingMap.getResults ())) {
1407
+ if (!isa<AffineDimExpr>(expr)) {
1408
+ continue ;
1409
+ }
1410
+ AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
1411
+ if (!partialSliceDimMap.contains (dimExpr.getPosition ())) {
1412
+ continue ;
1413
+ }
1414
+ SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition ()];
1415
+ operandLowPads[idx] = sliceDimInfo.offset ;
1416
+ operandHighPads[idx] =
1417
+ sub (sub (sliceDimInfo.outputSize , sliceDimInfo.offset ),
1418
+ sliceDimInfo.sliceSize );
1419
+ }
1420
+ auto paddingValue = ub::PoisonOp::create (
1421
+ rewriter, loc, getElementTypeOrSelf (operand->get ().getType ()));
1422
+ auto paddedOperand = tensor::PadOp::create (
1423
+ rewriter, loc, Type (), operand->get (), operandLowPads, operandHighPads,
1424
+ paddingValue, /* nofold=*/ false );
1425
+ paddedInputs.push_back (paddedOperand);
1426
+ }
1427
+ AffineMap outputIndexingMap =
1428
+ genericOp.getMatchingIndexingMap (genericOp.getDpsInitOperand (0 ));
1429
+
1430
+ auto outputShapeType =
1431
+ llvm::cast<ShapedType>(genericOp.getDpsInitOperand (0 )->get ().getType ());
1432
+ SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector (
1433
+ outputShapeType.getShape (),
1434
+ [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr (sz); });
1435
+ SmallVector<OpFoldResult> newSizes = OutputShape;
1436
+ SmallVector<OpFoldResult> outputLowPads (outputIndexingMap.getNumResults (),
1437
+ getAsIndexOpFoldResult (ctx, 0 ));
1438
+ SmallVector<OpFoldResult> outputHighPads (outputIndexingMap.getNumResults (),
1439
+ getAsIndexOpFoldResult (ctx, 0 ));
1440
+ SmallVector<OpFoldResult> newStrides (outputIndexingMap.getNumResults (),
1441
+ getAsIndexOpFoldResult (ctx, 1 ));
1442
+ for (auto [idx, expr] : llvm::enumerate (outputIndexingMap.getResults ())) {
1443
+ if (!isa<AffineDimExpr>(expr)) {
1444
+ continue ;
1445
+ }
1446
+ AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
1447
+ if (!partialSliceDimMap.contains (dimExpr.getPosition ())) {
1448
+ continue ;
1449
+ }
1450
+ SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition ()];
1451
+ outputLowPads[idx] = sliceDimInfo.offset ;
1452
+ outputHighPads[idx] = sub (sub (sliceDimInfo.outputSize , sliceDimInfo.offset ),
1453
+ sliceDimInfo.sliceSize );
1454
+ OutputShape[idx] = sliceDimInfo.outputSize ;
1455
+ newSizes[idx] = sliceDimInfo.sliceSize ;
1456
+ }
1457
+ Value newPadOutput;
1458
+ auto outputElType =
1459
+ getElementTypeOrSelf (genericOp.getDpsInits ()[0 ].getType ());
1460
+ if (isGenericOutsNotUsed (genericOp)) {
1461
+ newPadOutput =
1462
+ tensor::EmptyOp::create (rewriter, loc, OutputShape, outputElType);
1463
+ } else {
1464
+ auto paddingValue = ub::PoisonOp::create (rewriter, loc, outputElType);
1465
+ newPadOutput = tensor::PadOp::create (
1466
+ rewriter, loc, Type (), genericOp.getDpsInits ()[0 ], outputLowPads,
1467
+ outputHighPads, paddingValue, /* nofold=*/ false );
1468
+ }
1469
+
1470
+ auto newGenericOp = linalg::GenericOp::create (
1471
+ rewriter, loc, newPadOutput.getType (), paddedInputs, {newPadOutput},
1472
+ genericOp.getIndexingMapsArray (), genericOp.getIteratorTypesArray (),
1473
+ /* bodyBuild=*/ nullptr , linalg::getPrunedAttributeList (genericOp));
1474
+ rewriter.cloneRegionBefore (genericOp.getRegion (), newGenericOp.getRegion (),
1475
+ newGenericOp.getRegion ().begin ());
1476
+
1477
+ auto extractOp = tensor::ExtractSliceOp::create (
1478
+ rewriter, loc,
1479
+ newGenericOp.getTiedOpResult (newGenericOp.getDpsInitOperand (0 )),
1480
+ outputLowPads, newSizes, newStrides);
1481
+ Value extractRes = extractOp.getResult ();
1482
+
1483
+ return std::make_tuple (newGenericOp, extractRes);
1484
+ }
1485
+
1486
+ class PushDownExtractSliceOpThroughGenericOp final
1487
+ : public OpRewritePattern<GenericOp> {
1488
+ public:
1489
+ PushDownExtractSliceOpThroughGenericOp (MLIRContext *context,
1490
+ ControlPropagationFn fun)
1491
+ : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
1492
+
1493
+ LogicalResult matchAndRewrite (GenericOp genericOp,
1494
+ PatternRewriter &rewriter) const override {
1495
+ auto genericAndRepl =
1496
+ pushDownExtractSliceOpThroughGenericOp (rewriter, genericOp, controlFn);
1497
+ if (failed (genericAndRepl))
1498
+ return failure ();
1499
+ rewriter.replaceOp (genericOp, std::get<1 >(*genericAndRepl));
1500
+ return success ();
1501
+ }
1502
+
1503
+ private:
1504
+ ControlPropagationFn controlFn;
1505
+ };
1506
+
1239
1507
} // namespace
1240
1508
1241
1509
void mlir::linalg::populateDataLayoutPropagationPatterns (
@@ -1247,3 +1515,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
1247
1515
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1248
1516
patterns.getContext (), controlPackUnPackPropagation);
1249
1517
}
1518
+
1519
+ void mlir::linalg::populateExtractSliceSinkingPatterns (
1520
+ RewritePatternSet &patterns,
1521
+ const ControlPropagationFn &controlPackUnPackPropagation) {
1522
+ patterns.insert <PushDownExtractSliceOpThroughGenericOp>(
1523
+ patterns.getContext (), controlPackUnPackPropagation);
1524
+ }
0 commit comments