-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][Transforms] Add support for ConversionPatternRewriter::replaceAllUsesWith
#155244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
5205679
to
a8719ed
Compare
b217bce
to
88d9afc
Compare
88d9afc
to
93ad0c4
Compare
93ad0c4
to
1962033
Compare
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesThis commit generalizes
Note for LLVM integration: Replace You can temporarily reactivate the old behavior by calling Patch is 22.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/155244.diff 7 Files Affected:
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 57e73c1d8c7c1..7b0b9cef9c5bd 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -633,7 +633,7 @@ class RewriterBase : public OpBuilder {
/// Find uses of `from` and replace them with `to`. Also notify the listener
/// about every in-place op modification (for every use that was replaced).
- void replaceAllUsesWith(Value from, Value to) {
+ virtual void replaceAllUsesWith(Value from, Value to) {
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
Operation *op = operand.getOwner();
modifyOpInPlace(op, [&]() { operand.set(to); });
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 14dfbf18836c6..fe48e45a9b98c 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -854,15 +854,26 @@ class ConversionPatternRewriter final : public PatternRewriter {
Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion = nullptr);
- /// Replace all the uses of the block argument `from` with `to`. This
- /// function supports both 1:1 and 1:N replacements.
+ /// Replace all the uses of `from` with `to`. The type of `from` and `to` is
+ /// allowed to differ. The conversion driver will try to reconcile all type
+ /// mismatches that still exist at the end of the conversion with
+ /// materializations. This function supports both 1:1 and 1:N replacements.
///
- /// Note: If `allowPatternRollback` is set to "true", this function replaces
- /// all current and future uses of the block argument. This same block
- /// block argument must not be replaced multiple times. Uses are not replaced
- /// immediately but in a delayed fashion. Patterns may still see the original
- /// uses when inspecting IR.
- void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);
+ /// Note: If `allowPatternRollback` is set to "true", this function behaves
+ /// slightly different:
+ ///
+ /// 1. All current and future uses of `from` are replaced. The same value must
+ /// not be replaced multiple times. That's an API violation.
+ /// 2. Uses are not replaced immediately but in a delayed fashion. Patterns
+ /// may still see the original uses when inspecting IR.
+ /// 3. Uses within the same block that appear before the defining operation
+ /// of the replacement value are not replaced. This allows users to
+ /// perform certain replaceAllUsesExcept-style replacements, even though
+ /// such API is not directly supported.
+ void replaceAllUsesWith(Value from, ValueRange to);
+ void replaceAllUsesWith(Value from, Value to) override {
+ replaceAllUsesWith(from, ValueRange{to});
+ }
/// Return the converted value of 'key' with a type defined by the type
/// converter of the currently executing pattern. Return nullptr in the case
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 42c76ed475b4c..93fe2edad5274 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -284,7 +284,7 @@ static void restoreByValRefArgumentType(
cast<TypeAttr>(byValRefAttr->getValue()).getValue());
Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
- rewriter.replaceUsesOfBlockArgument(arg, valueArg);
+ rewriter.replaceAllUsesWith(arg, valueArg);
}
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 5ba109d96cf13..d72429298754f 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -277,13 +277,14 @@ class IRRewrite {
InlineBlock,
MoveBlock,
BlockTypeConversion,
- ReplaceBlockArg,
// Operation rewrites
MoveOperation,
ModifyOperation,
ReplaceOperation,
CreateOperation,
- UnresolvedMaterialization
+ UnresolvedMaterialization,
+ // Value rewrites
+ ReplaceValue
};
virtual ~IRRewrite() = default;
@@ -330,7 +331,7 @@ class BlockRewrite : public IRRewrite {
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() >= Kind::CreateBlock &&
- rewrite->getKind() <= Kind::ReplaceBlockArg;
+ rewrite->getKind() <= Kind::BlockTypeConversion;
}
protected:
@@ -342,6 +343,25 @@ class BlockRewrite : public IRRewrite {
Block *block;
};
+/// A value rewrite.
+class ValueRewrite : public IRRewrite {
+public:
+ /// Return the value that this rewrite operates on.
+ Value getValue() const { return value; }
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::ReplaceValue;
+ }
+
+protected:
+ ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
+ Value value)
+ : IRRewrite(kind, rewriterImpl), value(value) {}
+
+ // The value that this rewrite operates on.
+ Value value;
+};
+
/// Creation of a block. Block creations are immediately reflected in the IR.
/// There is no extra work to commit the rewrite. During rollback, the newly
/// created block is erased.
@@ -548,19 +568,18 @@ class BlockTypeConversionRewrite : public BlockRewrite {
Block *newBlock;
};
-/// Replacing a block argument. This rewrite is not immediately reflected in the
+/// Replacing a value. This rewrite is not immediately reflected in the
/// IR. An internal IR mapping is updated, but the actual replacement is delayed
/// until the rewrite is committed.
-class ReplaceBlockArgRewrite : public BlockRewrite {
+class ReplaceValueRewrite : public ValueRewrite {
public:
- ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
- Block *block, BlockArgument arg,
- const TypeConverter *converter)
- : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
+ ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
+ const TypeConverter *converter)
+ : ValueRewrite(Kind::ReplaceValue, rewriterImpl, value),
converter(converter) {}
static bool classof(const IRRewrite *rewrite) {
- return rewrite->getKind() == Kind::ReplaceBlockArg;
+ return rewrite->getKind() == Kind::ReplaceValue;
}
void commit(RewriterBase &rewriter) override;
@@ -568,9 +587,7 @@ class ReplaceBlockArgRewrite : public BlockRewrite {
void rollback() override;
private:
- BlockArgument arg;
-
- /// The current type converter when the block argument was replaced.
+ /// The current type converter when the value was replaced.
const TypeConverter *converter;
};
@@ -940,10 +957,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// uses.
void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
- /// Replace the given block argument with the given values. The specified
+ /// Replace the uses of the given value with the given values. The specified
/// converter is used to build materializations (if necessary).
- void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to,
- const TypeConverter *converter);
+ void replaceAllUsesWith(Value from, ValueRange to,
+ const TypeConverter *converter);
/// Erase the given block and its contents.
void eraseBlock(Block *block);
@@ -1129,10 +1146,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
IRRewriter notifyingRewriter;
#ifndef NDEBUG
- /// A set of replaced block arguments. This set is for debugging purposes
- /// only and it is maintained only if `allowPatternRollback` is set to
- /// "true".
- DenseSet<BlockArgument> replacedArgs;
+ /// A set of replaced values. This set is for debugging purposes only and it
+ /// is maintained only if `allowPatternRollback` is set to "true".
+ DenseSet<Value> replacedValues;
/// A set of operations that have pending updates. This tracking isn't
/// strictly necessary, and is thus only active during debug builds for extra
@@ -1169,32 +1185,54 @@ void BlockTypeConversionRewrite::rollback() {
getNewBlock()->replaceAllUsesWith(getOrigBlock());
}
-static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg,
- Value repl) {
+/// Replace all uses of `from` with `repl`.
+static void performReplaceValue(RewriterBase &rewriter, Value from,
+ Value repl) {
if (isa<BlockArgument>(repl)) {
- rewriter.replaceAllUsesWith(arg, repl);
+ // `repl` is a block argument. Directly replace all uses.
+ rewriter.replaceAllUsesWith(from, repl);
return;
}
- // If the replacement value is an operation, we check to make sure that we
- // don't replace uses that are within the parent operation of the
- // replacement value.
- Operation *replOp = cast<OpResult>(repl).getOwner();
+ // If the replacement value is an operation, only replace those uses that:
+ // - are in a different block than the replacement operation, or
+ // - are in the same block but after the replacement operation.
+ //
+ // Example:
+ // ^bb0(%arg0: i32):
+ // %0 = "consumer"(%arg0) : (i32) -> (i32)
+ // "another_consumer"(%arg0) : (i32) -> ()
+ //
+ // In the above example, replaceAllUsesWith(%arg0, %0) will replace the
+ // use in "another_consumer" but not the use in "consumer". When using the
+ // normal RewriterBase API, this would typically be done with
+ // `replaceUsesWithIf` / `replaceAllUsesExcept`. However, that API is not
+ // supported by the `ConversionPatternRewriter`. Due to the mapping mechanism
+ // it cannot be supported efficiently with `allowPatternRollback` set to
+ // "true". Therefore, the conversion driver is trying to be smart and replaces
+ // only those uses that do not lead to a dominance violation. E.g., the
+ // FuncToLLVM lowering (`restoreByValRefArgumentType`) relies on this
+ // behavior.
+ //
+ // TODO: As we move more and more towards `allowPatternRollback` set to
+ // "false", we should remove this special handling, in order to align the
+ // `ConversionPatternRewriter` API with the normal `RewriterBase` API.
+ Operation *replOp = repl.getDefiningOp();
Block *replBlock = replOp->getBlock();
- rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
+ rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
Operation *user = operand.getOwner();
return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
});
}
-void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
- Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
+void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
+ Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
if (!repl)
return;
- performReplaceBlockArg(rewriter, arg, repl);
+ performReplaceValue(rewriter, value, repl);
}
-void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
+void ReplaceValueRewrite::rollback() { rewriterImpl.mapping.erase({value}); }
void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
auto *listener =
@@ -1584,7 +1622,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
/*isPureTypeConversion=*/false)
.front();
- replaceUsesOfBlockArgument(origArg, mat, converter);
+ replaceAllUsesWith(origArg, mat, converter);
continue;
}
@@ -1593,15 +1631,14 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
assert(inputMap->size == 0 &&
"invalid to provide a replacement value when the argument isn't "
"dropped");
- replaceUsesOfBlockArgument(origArg, inputMap->replacementValues,
- converter);
+ replaceAllUsesWith(origArg, inputMap->replacementValues, converter);
continue;
}
// This is a 1->1+ mapping.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
- replaceUsesOfBlockArgument(origArg, replArgs, converter);
+ replaceAllUsesWith(origArg, replArgs, converter);
}
if (config.allowPatternRollback)
@@ -1873,8 +1910,8 @@ void ConversionPatternRewriterImpl::replaceOp(
op->walk([&](Operation *op) { replacedOps.insert(op); });
}
-void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
- BlockArgument from, ValueRange to, const TypeConverter *converter) {
+void ConversionPatternRewriterImpl::replaceAllUsesWith(
+ Value from, ValueRange to, const TypeConverter *converter) {
if (!config.allowPatternRollback) {
SmallVector<Value> toConv = llvm::to_vector(to);
SmallVector<Value> repls =
@@ -1884,25 +1921,25 @@ void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
if (!repl)
return;
- performReplaceBlockArg(r, from, repl);
+ performReplaceValue(r, from, repl);
return;
}
#ifndef NDEBUG
- // Make sure that a block argument is not replaced multiple times. In
- // rollback mode, `replaceUsesOfBlockArgument` replaces not only all current
- // uses of the given block argument, but also all future uses that may be
- // introduced by future pattern applications. Therefore, it does not make
- // sense to call `replaceUsesOfBlockArgument` multiple times with the same
- // block argument. Doing so would overwrite the mapping and mess with the
- // internal state of the dialect conversion driver.
- assert(!replacedArgs.contains(from) &&
- "attempting to replace a block argument that was already replaced");
- replacedArgs.insert(from);
+ // Make sure that a value is not replaced multiple times. In rollback mode,
+ // `replaceAllUsesWith` replaces not only all current uses of the given value,
+ // but also all future uses that may be introduced by future pattern
+ // applications. Therefore, it does not make sense to call
+ // `replaceAllUsesWith` multiple times with the same value. Doing so would
+ // overwrite the mapping and mess with the internal state of the dialect
+ // conversion driver.
+ assert(!replacedValues.contains(from) &&
+ "attempting to replace a value that was already replaced");
+ replacedValues.insert(from);
#endif // NDEBUG
- appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
mapping.map(from, to);
+ appendRewrite<ReplaceValueRewrite>(from, converter);
}
void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
@@ -2107,18 +2144,19 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
return impl->convertRegionTypes(region, converter, entryConversion);
}
-void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
- ValueRange to) {
+void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) {
LLVM_DEBUG({
- impl->logger.startLine() << "** Replace Argument : '" << from << "'";
- if (Operation *parentOp = from.getOwner()->getParentOp()) {
- impl->logger.getOStream() << " (in region of '" << parentOp->getName()
- << "' (" << parentOp << ")\n";
- } else {
- impl->logger.getOStream() << " (unlinked block)\n";
+ impl->logger.startLine() << "** Replace Value : '" << from << "'";
+ if (auto blockArg = dyn_cast<BlockArgument>(from)) {
+ if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
+ impl->logger.getOStream() << " (in region of '" << parentOp->getName()
+ << "' (" << parentOp << ")\n";
+ } else {
+ impl->logger.getOStream() << " (unlinked block)\n";
+ }
}
});
- impl->replaceUsesOfBlockArgument(from, to, impl->currentTypeConverter);
+ impl->replaceAllUsesWith(from, to, impl->currentTypeConverter);
}
Value ConversionPatternRewriter::getRemappedValue(Value key) {
@@ -2176,7 +2214,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
// Replace all uses of block arguments.
for (auto it : llvm::zip(source->getArguments(), argValues))
- replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
+ replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
if (fastPath) {
// Move all ops at once.
diff --git a/mlir/test/Transforms/test-legalizer-rollback.mlir b/mlir/test/Transforms/test-legalizer-rollback.mlir
index 460911fd88ad1..71e11782e14b0 100644
--- a/mlir/test/Transforms/test-legalizer-rollback.mlir
+++ b/mlir/test/Transforms/test-legalizer-rollback.mlir
@@ -49,14 +49,16 @@ func.func @create_illegal_block() {
// expected-remark@+1{{applyPartialConversion failed}}
module {
func.func @undo_block_arg_replace() {
- // expected-error@+1{{failed to legalize operation 'test.block_arg_replace' that was explicitly marked illegal}}
- "test.block_arg_replace"() ({
+ "test.legal_op"() ({
^bb0(%arg0: i32, %arg1: i16):
// CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
+ // CHECK-NEXT: "test.value_replace"(%[[ARG0]], %[[ARG1]]) {trigger_rollback}
// CHECK-NEXT: "test.return"(%[[ARG0]]) : (i32)
+ // expected-error@+1{{failed to legalize operation 'test.value_replace' that was explicitly marked illegal}}
+ "test.value_replace"(%arg0, %arg1) {trigger_rollback} : (i32, i16) -> ()
"test.return"(%arg0) : (i32) -> ()
- }) {trigger_rollback} : () -> ()
+ }) : () -> ()
return
}
}
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 3fa42ff6b2757..94c5bb4e93b06 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -269,12 +269,14 @@ builtin.module {
// CHECK-LABEL: @replace_block_arg_1_to_n
func.func @replace_block_arg_1_to_n() {
- // CHECK: "test.block_arg_replace"
- "test.block_arg_replace"() ({
+ // CHECK: "test.legal_op"
+ "test.legal_op"() ({
^bb0(%arg0: i32, %arg1: i16):
- // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
- // CHECK: %[[cast:.*]] = "test.cast"(%[[ARG1]], %[[ARG1]]) : (i16, i16) -> i32
+ // CHECK-NEXT: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
+ // CHECK-NEXT: %[[cast:.*]] = "test.cast"(%[[ARG1]], %[[ARG1]]) : (i16, i16) -> i32
+ // CHECK-NEXT: "test.value_replace"(%[[cast]], %[[ARG1]]) {is_legal} : (i32, i16) -> ()
// CHECK-NEXT: "test.return"(%[[cast]]) : (i32)
+ "test.value_replace"(%arg0, %arg1) : (i32, i16) -> ()
"test.return"(%arg0) : (i32) -> ()
}) : () -> ()
"test.return"() : () -> ()
@@ -282,6 +284,22 @@ func.func @replace_block_arg_1_to_n() {
// -----
+// CHECK-LABEL: @replace_op_result_1_to_n
+func.func @replace_op_result_1_to_n() {
+ // CHECK: %[[orig:.*]] = "test.legal_op"() : () -> i32
+ // CHECK: %[[repl:.*]] = "test.legal_op"() : () -> i16
+ %0 = "test.legal_op"() : () -> i32
+ %1 = "test.legal_op"() : () -> i16
+
+ // CHECK-NEXT: %[[cast:.*]] = "test.cast"(%[[repl]], %[[repl]]) : (i16, i16) -> i32
+ // CHECK-NEXT: "test.value_replace"(%[[cast]], %[[repl]]) {is_legal} : (i32, i16) -> ()
+ // CHECK-NEXT: "test.return"(%[[cast]]) : (i32)
+ "test.value_replace"(%0, %1) : (i32, i16) -> ()
+ "test.return"(%0) : (i32) -> ()
+}
+
+// -----
+
// Check that a conversion pattern on `test.blackhole` can mark the producer
// for deletion.
// CHECK-LABEL: @blackhole
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 95f381ec471d6..93b007c792ad9 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -952,19 +952,19 @@ struct TestCreateIllegalBlock : public RewritePattern {
}
};
-/// A simple pattern that tests the "replaceUsesOfBlockArgument" API.
-struct TestBlockArgReplace : public ConversionPattern {
- TestBlockArgReplace(MLIRContext *ctx, const TypeConverter &converter)
- : ConversionPattern(converter, "test.block_arg_replace", /*benefit=*/1,
- ctx) {}
+/// A simple pattern that tests the "replaceAllUsesWith"...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThis commit generalizes
Note for LLVM integration: Replace You can temporarily reactivate the old behavior by calling Patch is 22.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/155244.diff 7 Files Affected:
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 57e73c1d8c7c1..7b0b9cef9c5bd 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -633,7 +633,7 @@ class RewriterBase : public OpBuilder {
/// Find uses of `from` and replace them with `to`. Also notify the listener
/// about every in-place op modification (for every use that was replaced).
- void replaceAllUsesWith(Value from, Value to) {
+ virtual void replaceAllUsesWith(Value from, Value to) {
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
Operation *op = operand.getOwner();
modifyOpInPlace(op, [&]() { operand.set(to); });
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 14dfbf18836c6..fe48e45a9b98c 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -854,15 +854,26 @@ class ConversionPatternRewriter final : public PatternRewriter {
Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion = nullptr);
- /// Replace all the uses of the block argument `from` with `to`. This
- /// function supports both 1:1 and 1:N replacements.
+ /// Replace all the uses of `from` with `to`. The type of `from` and `to` is
+ /// allowed to differ. The conversion driver will try to reconcile all type
+ /// mismatches that still exist at the end of the conversion with
+ /// materializations. This function supports both 1:1 and 1:N replacements.
///
- /// Note: If `allowPatternRollback` is set to "true", this function replaces
- /// all current and future uses of the block argument. This same block
- /// block argument must not be replaced multiple times. Uses are not replaced
- /// immediately but in a delayed fashion. Patterns may still see the original
- /// uses when inspecting IR.
- void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);
+ /// Note: If `allowPatternRollback` is set to "true", this function behaves
+ /// slightly different:
+ ///
+ /// 1. All current and future uses of `from` are replaced. The same value must
+ /// not be replaced multiple times. That's an API violation.
+ /// 2. Uses are not replaced immediately but in a delayed fashion. Patterns
+ /// may still see the original uses when inspecting IR.
+ /// 3. Uses within the same block that appear before the defining operation
+ /// of the replacement value are not replaced. This allows users to
+ /// perform certain replaceAllUsesExcept-style replacements, even though
+ /// such API is not directly supported.
+ void replaceAllUsesWith(Value from, ValueRange to);
+ void replaceAllUsesWith(Value from, Value to) override {
+ replaceAllUsesWith(from, ValueRange{to});
+ }
/// Return the converted value of 'key' with a type defined by the type
/// converter of the currently executing pattern. Return nullptr in the case
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 42c76ed475b4c..93fe2edad5274 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -284,7 +284,7 @@ static void restoreByValRefArgumentType(
cast<TypeAttr>(byValRefAttr->getValue()).getValue());
Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
- rewriter.replaceUsesOfBlockArgument(arg, valueArg);
+ rewriter.replaceAllUsesWith(arg, valueArg);
}
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 5ba109d96cf13..d72429298754f 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -277,13 +277,14 @@ class IRRewrite {
InlineBlock,
MoveBlock,
BlockTypeConversion,
- ReplaceBlockArg,
// Operation rewrites
MoveOperation,
ModifyOperation,
ReplaceOperation,
CreateOperation,
- UnresolvedMaterialization
+ UnresolvedMaterialization,
+ // Value rewrites
+ ReplaceValue
};
virtual ~IRRewrite() = default;
@@ -330,7 +331,7 @@ class BlockRewrite : public IRRewrite {
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() >= Kind::CreateBlock &&
- rewrite->getKind() <= Kind::ReplaceBlockArg;
+ rewrite->getKind() <= Kind::BlockTypeConversion;
}
protected:
@@ -342,6 +343,25 @@ class BlockRewrite : public IRRewrite {
Block *block;
};
+/// A value rewrite.
+class ValueRewrite : public IRRewrite {
+public:
+ /// Return the value that this rewrite operates on.
+ Value getValue() const { return value; }
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::ReplaceValue;
+ }
+
+protected:
+ ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
+ Value value)
+ : IRRewrite(kind, rewriterImpl), value(value) {}
+
+ // The value that this rewrite operates on.
+ Value value;
+};
+
/// Creation of a block. Block creations are immediately reflected in the IR.
/// There is no extra work to commit the rewrite. During rollback, the newly
/// created block is erased.
@@ -548,19 +568,18 @@ class BlockTypeConversionRewrite : public BlockRewrite {
Block *newBlock;
};
-/// Replacing a block argument. This rewrite is not immediately reflected in the
+/// Replacing a value. This rewrite is not immediately reflected in the
/// IR. An internal IR mapping is updated, but the actual replacement is delayed
/// until the rewrite is committed.
-class ReplaceBlockArgRewrite : public BlockRewrite {
+class ReplaceValueRewrite : public ValueRewrite {
public:
- ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
- Block *block, BlockArgument arg,
- const TypeConverter *converter)
- : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
+ ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
+ const TypeConverter *converter)
+ : ValueRewrite(Kind::ReplaceValue, rewriterImpl, value),
converter(converter) {}
static bool classof(const IRRewrite *rewrite) {
- return rewrite->getKind() == Kind::ReplaceBlockArg;
+ return rewrite->getKind() == Kind::ReplaceValue;
}
void commit(RewriterBase &rewriter) override;
@@ -568,9 +587,7 @@ class ReplaceBlockArgRewrite : public BlockRewrite {
void rollback() override;
private:
- BlockArgument arg;
-
- /// The current type converter when the block argument was replaced.
+ /// The current type converter when the value was replaced.
const TypeConverter *converter;
};
@@ -940,10 +957,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// uses.
void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
- /// Replace the given block argument with the given values. The specified
+ /// Replace the uses of the given value with the given values. The specified
/// converter is used to build materializations (if necessary).
- void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to,
- const TypeConverter *converter);
+ void replaceAllUsesWith(Value from, ValueRange to,
+ const TypeConverter *converter);
/// Erase the given block and its contents.
void eraseBlock(Block *block);
@@ -1129,10 +1146,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
IRRewriter notifyingRewriter;
#ifndef NDEBUG
- /// A set of replaced block arguments. This set is for debugging purposes
- /// only and it is maintained only if `allowPatternRollback` is set to
- /// "true".
- DenseSet<BlockArgument> replacedArgs;
+ /// A set of replaced values. This set is for debugging purposes only and it
+ /// is maintained only if `allowPatternRollback` is set to "true".
+ DenseSet<Value> replacedValues;
/// A set of operations that have pending updates. This tracking isn't
/// strictly necessary, and is thus only active during debug builds for extra
@@ -1169,32 +1185,54 @@ void BlockTypeConversionRewrite::rollback() {
getNewBlock()->replaceAllUsesWith(getOrigBlock());
}
-static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg,
- Value repl) {
+/// Replace all uses of `from` with `repl`.
+static void performReplaceValue(RewriterBase &rewriter, Value from,
+ Value repl) {
if (isa<BlockArgument>(repl)) {
- rewriter.replaceAllUsesWith(arg, repl);
+ // `repl` is a block argument. Directly replace all uses.
+ rewriter.replaceAllUsesWith(from, repl);
return;
}
- // If the replacement value is an operation, we check to make sure that we
- // don't replace uses that are within the parent operation of the
- // replacement value.
- Operation *replOp = cast<OpResult>(repl).getOwner();
+ // If the replacement value is an operation, only replace those uses that:
+ // - are in a different block than the replacement operation, or
+ // - are in the same block but after the replacement operation.
+ //
+ // Example:
+ // ^bb0(%arg0: i32):
+ // %0 = "consumer"(%arg0) : (i32) -> (i32)
+ // "another_consumer"(%arg0) : (i32) -> ()
+ //
+ // In the above example, replaceAllUsesWith(%arg0, %0) will replace the
+ // use in "another_consumer" but not the use in "consumer". When using the
+ // normal RewriterBase API, this would typically be done with
+ // `replaceUsesWithIf` / `replaceAllUsesExcept`. However, that API is not
+ // supported by the `ConversionPatternRewriter`. Due to the mapping mechanism
+ // it cannot be supported efficiently with `allowPatternRollback` set to
+ // "true". Therefore, the conversion driver is trying to be smart and replaces
+ // only those uses that do not lead to a dominance violation. E.g., the
+ // FuncToLLVM lowering (`restoreByValRefArgumentType`) relies on this
+ // behavior.
+ //
+ // TODO: As we move more and more towards `allowPatternRollback` set to
+ // "false", we should remove this special handling, in order to align the
+ // `ConversionPatternRewriter` API with the normal `RewriterBase` API.
+ Operation *replOp = repl.getDefiningOp();
Block *replBlock = replOp->getBlock();
- rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
+ rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
Operation *user = operand.getOwner();
return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
});
}
-void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
- Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
+void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
+ Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
if (!repl)
return;
- performReplaceBlockArg(rewriter, arg, repl);
+ performReplaceValue(rewriter, value, repl);
}
-void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
+void ReplaceValueRewrite::rollback() { rewriterImpl.mapping.erase({value}); }
void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
auto *listener =
@@ -1584,7 +1622,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
/*isPureTypeConversion=*/false)
.front();
- replaceUsesOfBlockArgument(origArg, mat, converter);
+ replaceAllUsesWith(origArg, mat, converter);
continue;
}
@@ -1593,15 +1631,14 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
assert(inputMap->size == 0 &&
"invalid to provide a replacement value when the argument isn't "
"dropped");
- replaceUsesOfBlockArgument(origArg, inputMap->replacementValues,
- converter);
+ replaceAllUsesWith(origArg, inputMap->replacementValues, converter);
continue;
}
// This is a 1->1+ mapping.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
- replaceUsesOfBlockArgument(origArg, replArgs, converter);
+ replaceAllUsesWith(origArg, replArgs, converter);
}
if (config.allowPatternRollback)
@@ -1873,8 +1910,8 @@ void ConversionPatternRewriterImpl::replaceOp(
op->walk([&](Operation *op) { replacedOps.insert(op); });
}
-void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
- BlockArgument from, ValueRange to, const TypeConverter *converter) {
+void ConversionPatternRewriterImpl::replaceAllUsesWith(
+ Value from, ValueRange to, const TypeConverter *converter) {
if (!config.allowPatternRollback) {
SmallVector<Value> toConv = llvm::to_vector(to);
SmallVector<Value> repls =
@@ -1884,25 +1921,25 @@ void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
if (!repl)
return;
- performReplaceBlockArg(r, from, repl);
+ performReplaceValue(r, from, repl);
return;
}
#ifndef NDEBUG
- // Make sure that a block argument is not replaced multiple times. In
- // rollback mode, `replaceUsesOfBlockArgument` replaces not only all current
- // uses of the given block argument, but also all future uses that may be
- // introduced by future pattern applications. Therefore, it does not make
- // sense to call `replaceUsesOfBlockArgument` multiple times with the same
- // block argument. Doing so would overwrite the mapping and mess with the
- // internal state of the dialect conversion driver.
- assert(!replacedArgs.contains(from) &&
- "attempting to replace a block argument that was already replaced");
- replacedArgs.insert(from);
+ // Make sure that a value is not replaced multiple times. In rollback mode,
+ // `replaceAllUsesWith` replaces not only all current uses of the given value,
+ // but also all future uses that may be introduced by future pattern
+ // applications. Therefore, it does not make sense to call
+ // `replaceAllUsesWith` multiple times with the same value. Doing so would
+ // overwrite the mapping and mess with the internal state of the dialect
+ // conversion driver.
+ assert(!replacedValues.contains(from) &&
+ "attempting to replace a value that was already replaced");
+ replacedValues.insert(from);
#endif // NDEBUG
- appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
mapping.map(from, to);
+ appendRewrite<ReplaceValueRewrite>(from, converter);
}
void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
@@ -2107,18 +2144,19 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
return impl->convertRegionTypes(region, converter, entryConversion);
}
-void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
- ValueRange to) {
+void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) {
LLVM_DEBUG({
- impl->logger.startLine() << "** Replace Argument : '" << from << "'";
- if (Operation *parentOp = from.getOwner()->getParentOp()) {
- impl->logger.getOStream() << " (in region of '" << parentOp->getName()
- << "' (" << parentOp << ")\n";
- } else {
- impl->logger.getOStream() << " (unlinked block)\n";
+ impl->logger.startLine() << "** Replace Value : '" << from << "'";
+ if (auto blockArg = dyn_cast<BlockArgument>(from)) {
+ if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
+ impl->logger.getOStream() << " (in region of '" << parentOp->getName()
+ << "' (" << parentOp << ")\n";
+ } else {
+ impl->logger.getOStream() << " (unlinked block)\n";
+ }
}
});
- impl->replaceUsesOfBlockArgument(from, to, impl->currentTypeConverter);
+ impl->replaceAllUsesWith(from, to, impl->currentTypeConverter);
}
Value ConversionPatternRewriter::getRemappedValue(Value key) {
@@ -2176,7 +2214,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
// Replace all uses of block arguments.
for (auto it : llvm::zip(source->getArguments(), argValues))
- replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
+ replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
if (fastPath) {
// Move all ops at once.
diff --git a/mlir/test/Transforms/test-legalizer-rollback.mlir b/mlir/test/Transforms/test-legalizer-rollback.mlir
index 460911fd88ad1..71e11782e14b0 100644
--- a/mlir/test/Transforms/test-legalizer-rollback.mlir
+++ b/mlir/test/Transforms/test-legalizer-rollback.mlir
@@ -49,14 +49,16 @@ func.func @create_illegal_block() {
// expected-remark@+1{{applyPartialConversion failed}}
module {
func.func @undo_block_arg_replace() {
- // expected-error@+1{{failed to legalize operation 'test.block_arg_replace' that was explicitly marked illegal}}
- "test.block_arg_replace"() ({
+ "test.legal_op"() ({
^bb0(%arg0: i32, %arg1: i16):
// CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
+ // CHECK-NEXT: "test.value_replace"(%[[ARG0]], %[[ARG1]]) {trigger_rollback}
// CHECK-NEXT: "test.return"(%[[ARG0]]) : (i32)
+ // expected-error@+1{{failed to legalize operation 'test.value_replace' that was explicitly marked illegal}}
+ "test.value_replace"(%arg0, %arg1) {trigger_rollback} : (i32, i16) -> ()
"test.return"(%arg0) : (i32) -> ()
- }) {trigger_rollback} : () -> ()
+ }) : () -> ()
return
}
}
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 3fa42ff6b2757..94c5bb4e93b06 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -269,12 +269,14 @@ builtin.module {
// CHECK-LABEL: @replace_block_arg_1_to_n
func.func @replace_block_arg_1_to_n() {
- // CHECK: "test.block_arg_replace"
- "test.block_arg_replace"() ({
+ // CHECK: "test.legal_op"
+ "test.legal_op"() ({
^bb0(%arg0: i32, %arg1: i16):
- // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
- // CHECK: %[[cast:.*]] = "test.cast"(%[[ARG1]], %[[ARG1]]) : (i16, i16) -> i32
+ // CHECK-NEXT: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
+ // CHECK-NEXT: %[[cast:.*]] = "test.cast"(%[[ARG1]], %[[ARG1]]) : (i16, i16) -> i32
+ // CHECK-NEXT: "test.value_replace"(%[[cast]], %[[ARG1]]) {is_legal} : (i32, i16) -> ()
// CHECK-NEXT: "test.return"(%[[cast]]) : (i32)
+ "test.value_replace"(%arg0, %arg1) : (i32, i16) -> ()
"test.return"(%arg0) : (i32) -> ()
}) : () -> ()
"test.return"() : () -> ()
@@ -282,6 +284,22 @@ func.func @replace_block_arg_1_to_n() {
// -----
+// CHECK-LABEL: @replace_op_result_1_to_n
+func.func @replace_op_result_1_to_n() {
+ // CHECK: %[[orig:.*]] = "test.legal_op"() : () -> i32
+ // CHECK: %[[repl:.*]] = "test.legal_op"() : () -> i16
+ %0 = "test.legal_op"() : () -> i32
+ %1 = "test.legal_op"() : () -> i16
+
+ // CHECK-NEXT: %[[cast:.*]] = "test.cast"(%[[repl]], %[[repl]]) : (i16, i16) -> i32
+ // CHECK-NEXT: "test.value_replace"(%[[cast]], %[[repl]]) {is_legal} : (i32, i16) -> ()
+ // CHECK-NEXT: "test.return"(%[[cast]]) : (i32)
+ "test.value_replace"(%0, %1) : (i32, i16) -> ()
+ "test.return"(%0) : (i32) -> ()
+}
+
+// -----
+
// Check that a conversion pattern on `test.blackhole` can mark the producer
// for deletion.
// CHECK-LABEL: @blackhole
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 95f381ec471d6..93b007c792ad9 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -952,19 +952,19 @@ struct TestCreateIllegalBlock : public RewritePattern {
}
};
-/// A simple pattern that tests the "replaceUsesOfBlockArgument" API.
-struct TestBlockArgReplace : public ConversionPattern {
- TestBlockArgReplace(MLIRContext *ctx, const TypeConverter &converter)
- : ConversionPattern(converter, "test.block_arg_replace", /*benefit=*/1,
- ctx) {}
+/// A simple pattern that tests the "replaceAllUsesWith"...
[truncated]
|
// don't replace uses that are within the parent operation of the | ||
// replacement value. | ||
Operation *replOp = cast<OpResult>(repl).getOwner(); | ||
// If the replacement value is an operation, only replace those uses that: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: The implementation here has not changed. I just added more documentation that explains the behavior.
f4dbd0d
to
fee875b
Compare
1962033
to
2031bc5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flang changes LGTM. Thank you for the fixes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % 1 public API comment
2031bc5
to
e763834
Compare
e763834
to
af99610
Compare
This commit generalizes
replaceUsesOfBlockArgument
toreplaceAllUsesWith
. In rollback mode, the same restrictions keep applying: a value cannot be replaced multiple times and a call toreplaceAllUsesWith
will replace all current and future uses of thefrom
value.replaceAllUsesWith
is now fully supported and its behavior is consistent with the remaining dialect conversion API. Before this commit,replaceAllUsesWith
was immediately reflected in the IR when running in rollback mode. After this commit,replaceAllUsesWith
changes are materialized in a delayed fashion, at the end of the dialect conversion. This is consistent with thereplaceUsesOfBlockArgument
andreplaceOp
APIs.replaceAllUsesExcept
etc. are still not supported and will be deactivated on theConversionPatternRewriter
(when running in rollback mode) in a follow-up commit.Note for LLVM integration: Replace
replaceUsesOfBlockArgument
withreplaceAllUsesWith
. If you are seeing failures, you may have patterns that usereplaceAllUsesWith
incorrectly (e.g., being called multiple times on the same value) or bypass the rewriter API entirely. E.g., such failures were mitigated in Flang by switching to the walk-patterns driver (#156171).You can temporarily reactivate the old behavior by calling
RewriterBase::replaceAllUsesWith
. However, note that that behavior is faulty in a dialect conversion. E.g., the baseRewriterBase::replaceAllUsesWith
implementation does not see uses of thefrom
value that have not materialized yet and will, therefore, not replace them.