Skip to content

Commit 6fb3d59

Browse files
committed
[mlir] Remove 'valuesToRemoveIfDead' from PatternRewriter API
Summary: Remove 'valuesToRemoveIfDead' from PatternRewriter API. The removal functionality wasn't implemented and we decided [1] not to implement it in favor of having more powerful DCE approaches. [1] tensorflow/mlir#212 Reviewers: rriddle, bondhugula Reviewed By: rriddle Subscribers: liufengdb, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D72545
1 parent 27f2e9a commit 6fb3d59

File tree

11 files changed

+39
-74
lines changed

11 files changed

+39
-74
lines changed

mlir/examples/toy/Ch3/mlir/ToyCombine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
4848
return matchFailure();
4949

5050
// Use the rewriter to perform the replacement.
51-
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
51+
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
5252
return matchSuccess();
5353
}
5454
};

mlir/examples/toy/Ch4/mlir/ToyCombine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
5353
return matchFailure();
5454

5555
// Use the rewriter to perform the replacement.
56-
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
56+
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
5757
return matchSuccess();
5858
}
5959
};

mlir/examples/toy/Ch5/mlir/ToyCombine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
5353
return matchFailure();
5454

5555
// Use the rewriter to perform the replacement.
56-
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
56+
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
5757
return matchSuccess();
5858
}
5959
};

mlir/examples/toy/Ch6/mlir/ToyCombine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
5353
return matchFailure();
5454

5555
// Use the rewriter to perform the replacement.
56-
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
56+
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
5757
return matchSuccess();
5858
}
5959
};

mlir/examples/toy/Ch7/mlir/ToyCombine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
7171
return matchFailure();
7272

7373
// Use the rewriter to perform the replacement.
74-
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
74+
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
7575
return matchSuccess();
7676
}
7777
};

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -318,33 +318,15 @@ class PatternRewriter : public OpBuilder {
318318

319319
/// This method performs the final replacement for a pattern, where the
320320
/// results of the operation are updated to use the specified list of SSA
321-
/// values. In addition to replacing and removing the specified operation,
322-
/// clients can specify a list of other nodes that this replacement may make
323-
/// (perhaps transitively) dead. If any of those values are dead, this will
324-
/// remove them as well.
325-
virtual void replaceOp(Operation *op, ValueRange newValues,
326-
ValueRange valuesToRemoveIfDead);
327-
void replaceOp(Operation *op, ValueRange newValues) {
328-
replaceOp(op, newValues, llvm::None);
329-
}
321+
/// values.
322+
virtual void replaceOp(Operation *op, ValueRange newValues);
330323

331324
/// Replaces the result op with a new op that is created without verification.
332325
/// The result values of the two ops must be the same types.
333326
template <typename OpTy, typename... Args>
334327
void replaceOpWithNewOp(Operation *op, Args &&... args) {
335328
auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
336-
replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(), {});
337-
}
338-
339-
/// Replaces the result op with a new op that is created without verification.
340-
/// The result values of the two ops must be the same types. This allows
341-
/// specifying a list of ops that may be removed if dead.
342-
template <typename OpTy, typename... Args>
343-
void replaceOpWithNewOp(ValueRange valuesToRemoveIfDead, Operation *op,
344-
Args &&... args) {
345-
auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
346-
replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(),
347-
valuesToRemoveIfDead);
329+
replaceOpWithResultsOfAnotherOp(op, newOp.getOperation());
348330
}
349331

350332
/// This method erases an operation that is known to have no uses.
@@ -405,10 +387,9 @@ class PatternRewriter : public OpBuilder {
405387
virtual void notifyOperationRemoved(Operation *op) {}
406388

407389
private:
408-
/// op and newOp are known to have the same number of results, replace the
409-
/// uses of op with uses of newOp
410-
void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp,
411-
ValueRange valuesToRemoveIfDead);
390+
/// 'op' and 'newOp' are known to have the same number of results, replace the
391+
/// uses of op with uses of newOp.
392+
void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
412393
};
413394

414395
//===----------------------------------------------------------------------===//

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
332332
//===--------------------------------------------------------------------===//
333333

334334
/// PatternRewriter hook for replacing the results of an operation.
335-
void replaceOp(Operation *op, ValueRange newValues,
336-
ValueRange valuesToRemoveIfDead) override;
335+
void replaceOp(Operation *op, ValueRange newValues) override;
337336
using PatternRewriter::replaceOp;
338337

339338
/// PatternRewriter hook for erasing a dead operation. The uses of this

mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
9090
rewriter.getContext());
9191
auto newConstOp =
9292
rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue);
93-
rewriter.replaceOpWithNewOp<StorageCastOp>({qbarrier.arg()}, qbarrier,
94-
qbarrier.getType(), newConstOp);
93+
rewriter.replaceOpWithNewOp<StorageCastOp>(qbarrier, qbarrier.getType(),
94+
newConstOp);
9595
return matchSuccess();
9696
}
9797

mlir/lib/Dialect/StandardOps/Ops.cpp

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,6 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
328328
SmallVector<int64_t, 4> newShapeConstants;
329329
newShapeConstants.reserve(memrefType.getRank());
330330
SmallVector<Value, 4> newOperands;
331-
SmallVector<Value, 4> droppedOperands;
332331

333332
unsigned dynamicDimPos = 0;
334333
for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
@@ -342,8 +341,6 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
342341
if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
343342
// Dynamic shape dimension will be folded.
344343
newShapeConstants.push_back(constantIndexOp.getValue());
345-
// Record to check for zero uses later below.
346-
droppedOperands.push_back(constantIndexOp);
347344
} else {
348345
// Dynamic shape dimension not folded; copy operand from old memref.
349346
newShapeConstants.push_back(-1);
@@ -366,7 +363,7 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
366363
auto resultCast = rewriter.create<MemRefCastOp>(alloc.getLoc(), newAlloc,
367364
alloc.getType());
368365

369-
rewriter.replaceOp(alloc, {resultCast}, droppedOperands);
366+
rewriter.replaceOp(alloc, {resultCast});
370367
return matchSuccess();
371368
}
372369
};
@@ -2447,7 +2444,6 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
24472444
return matchFailure();
24482445

24492446
SmallVector<Value, 4> newOperands;
2450-
SmallVector<Value, 4> droppedOperands;
24512447

24522448
// Fold dynamic offset operand if it is produced by a constant.
24532449
auto dynamicOffset = viewOp.getDynamicOffset();
@@ -2458,7 +2454,6 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
24582454
if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
24592455
// Dynamic offset will be folded into the map.
24602456
newOffset = constantIndexOp.getValue();
2461-
droppedOperands.push_back(dynamicOffset);
24622457
} else {
24632458
// Unable to fold dynamic offset. Add it to 'newOperands' list.
24642459
newOperands.push_back(dynamicOffset);
@@ -2483,8 +2478,6 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
24832478
if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
24842479
// Dynamic shape dimension will be folded.
24852480
newShapeConstants.push_back(constantIndexOp.getValue());
2486-
// Record to check for zero uses later below.
2487-
droppedOperands.push_back(constantIndexOp);
24882481
} else {
24892482
// Dynamic shape dimension not folded; copy operand from old memref.
24902483
newShapeConstants.push_back(dimSize);
@@ -2522,8 +2515,8 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
25222515
auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
25232516
viewOp.getOperand(0), newOperands);
25242517
// Insert a cast so we have the same type as the old memref type.
2525-
rewriter.replaceOpWithNewOp<MemRefCastOp>(droppedOperands, viewOp,
2526-
newViewOp, viewOp.getType());
2518+
rewriter.replaceOpWithNewOp<MemRefCastOp>(viewOp, newViewOp,
2519+
viewOp.getType());
25272520
return matchSuccess();
25282521
}
25292522
};
@@ -2542,8 +2535,8 @@ struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
25422535
AllocOp allocOp = dyn_cast_or_null<AllocOp>(allocOperand.getDefiningOp());
25432536
if (!allocOp)
25442537
return matchFailure();
2545-
rewriter.replaceOpWithNewOp<ViewOp>(memrefOperand, viewOp, viewOp.getType(),
2546-
allocOperand, viewOp.operands());
2538+
rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
2539+
viewOp.operands());
25472540
return matchSuccess();
25482541
}
25492542
};
@@ -2839,8 +2832,8 @@ class SubViewOpShapeFolder final : public OpRewritePattern<SubViewOp> {
28392832
subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
28402833
ArrayRef<Value>(), subViewOp.strides(), newMemRefType);
28412834
// Insert a memref_cast for compatibility of the uses of the op.
2842-
rewriter.replaceOpWithNewOp<MemRefCastOp>(
2843-
subViewOp.sizes(), subViewOp, newSubViewOp, subViewOp.getType());
2835+
rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
2836+
subViewOp.getType());
28442837
return matchSuccess();
28452838
}
28462839
};
@@ -2889,8 +2882,8 @@ class SubViewOpStrideFolder final : public OpRewritePattern<SubViewOp> {
28892882
subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
28902883
subViewOp.sizes(), ArrayRef<Value>(), newMemRefType);
28912884
// Insert a memref_cast for compatibility of the uses of the op.
2892-
rewriter.replaceOpWithNewOp<MemRefCastOp>(
2893-
subViewOp.strides(), subViewOp, newSubViewOp, subViewOp.getType());
2885+
rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
2886+
subViewOp.getType());
28942887
return matchSuccess();
28952888
}
28962889
};
@@ -2941,8 +2934,8 @@ class SubViewOpOffsetFolder final : public OpRewritePattern<SubViewOp> {
29412934
subViewOp.getLoc(), subViewOp.source(), ArrayRef<Value>(),
29422935
subViewOp.sizes(), subViewOp.strides(), newMemRefType);
29432936
// Insert a memref_cast for compatibility of the uses of the op.
2944-
rewriter.replaceOpWithNewOp<MemRefCastOp>(
2945-
subViewOp.offsets(), subViewOp, newSubViewOp, subViewOp.getType());
2937+
rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
2938+
subViewOp.getType());
29462939
return matchSuccess();
29472940
}
29482941
};

mlir/lib/IR/PatternMatch.cpp

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "mlir/IR/BlockAndValueMapping.h"
1111
#include "mlir/IR/Operation.h"
1212
#include "mlir/IR/Value.h"
13+
1314
using namespace mlir;
1415

1516
PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
@@ -72,12 +73,8 @@ PatternRewriter::~PatternRewriter() {
7273

7374
/// This method performs the final replacement for a pattern, where the
7475
/// results of the operation are updated to use the specified list of SSA
75-
/// values. In addition to replacing and removing the specified operation,
76-
/// clients can specify a list of other nodes that this replacement may make
77-
/// (perhaps transitively) dead. If any of those ops are dead, this will
78-
/// remove them as well.
79-
void PatternRewriter::replaceOp(Operation *op, ValueRange newValues,
80-
ValueRange valuesToRemoveIfDead) {
76+
/// values.
77+
void PatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
8178
// Notify the rewriter subclass that we're about to replace this root.
8279
notifyRootReplaced(op);
8380

@@ -87,9 +84,6 @@ void PatternRewriter::replaceOp(Operation *op, ValueRange newValues,
8784

8885
notifyOperationRemoved(op);
8986
op->erase();
90-
91-
// TODO: Process the valuesToRemoveIfDead list, removing things and calling
92-
// the notifyOperationRemoved hook in the process.
9387
}
9488

9589
/// This method erases an operation that is known to have no uses. The uses of
@@ -129,15 +123,15 @@ Block *PatternRewriter::splitBlock(Block *block, Block::iterator before) {
129123
return block->splitBlock(before);
130124
}
131125

132-
/// op and newOp are known to have the same number of results, replace the
126+
/// 'op' and 'newOp' are known to have the same number of results, replace the
133127
/// uses of op with uses of newOp
134-
void PatternRewriter::replaceOpWithResultsOfAnotherOp(
135-
Operation *op, Operation *newOp, ValueRange valuesToRemoveIfDead) {
128+
void PatternRewriter::replaceOpWithResultsOfAnotherOp(Operation *op,
129+
Operation *newOp) {
136130
assert(op->getNumResults() == newOp->getNumResults() &&
137131
"replacement op doesn't match results of original op");
138132
if (op->getNumResults() == 1)
139-
return replaceOp(op, newOp->getResult(0), valuesToRemoveIfDead);
140-
return replaceOp(op, newOp->getResults(), valuesToRemoveIfDead);
133+
return replaceOp(op, newOp->getResult(0));
134+
return replaceOp(op, newOp->getResults());
141135
}
142136

143137
/// Move the blocks that belong to "region" before the given position in

mlir/lib/Transforms/DialectConversion.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -554,8 +554,7 @@ struct ConversionPatternRewriterImpl {
554554
TypeConverter::SignatureConversion &conversion);
555555

556556
/// PatternRewriter hook for replacing the results of an operation.
557-
void replaceOp(Operation *op, ValueRange newValues,
558-
ValueRange valuesToRemoveIfDead);
557+
void replaceOp(Operation *op, ValueRange newValues);
559558

560559
/// Notifies that a block was split.
561560
void notifySplitBlock(Block *block, Block *continuation);
@@ -757,8 +756,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
757756
}
758757

759758
void ConversionPatternRewriterImpl::replaceOp(Operation *op,
760-
ValueRange newValues,
761-
ValueRange valuesToRemoveIfDead) {
759+
ValueRange newValues) {
762760
assert(newValues.size() == op->getNumResults());
763761

764762
// Create mappings for each of the new result values.
@@ -838,11 +836,11 @@ ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx,
838836
ConversionPatternRewriter::~ConversionPatternRewriter() {}
839837

840838
/// PatternRewriter hook for replacing the results of an operation.
841-
void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues,
842-
ValueRange valuesToRemoveIfDead) {
839+
void ConversionPatternRewriter::replaceOp(Operation *op,
840+
ValueRange newValues) {
843841
LLVM_DEBUG(llvm::dbgs() << "** Replacing operation : " << op->getName()
844842
<< "\n");
845-
impl->replaceOp(op, newValues, valuesToRemoveIfDead);
843+
impl->replaceOp(op, newValues);
846844
}
847845

848846
/// PatternRewriter hook for erasing a dead operation. The uses of this
@@ -852,7 +850,7 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
852850
LLVM_DEBUG(llvm::dbgs() << "** Erasing operation : " << op->getName()
853851
<< "\n");
854852
SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
855-
impl->replaceOp(op, nullRepls, /*valuesToRemoveIfDead=*/llvm::None);
853+
impl->replaceOp(op, nullRepls);
856854
}
857855

858856
/// Apply a signature conversion to the entry block of the given region.

0 commit comments

Comments
 (0)