Skip to content

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Aug 25, 2025

This commit generalizes replaceUsesOfBlockArgument to replaceAllUsesWith. In rollback mode, the same restrictions keep applying: a value cannot be replaced multiple times and a call to replaceAllUsesWith will replace all current and future uses of the from value.

replaceAllUsesWith is now fully supported and its behavior is consistent with the remaining dialect conversion API. Before this commit, replaceAllUsesWith was immediately reflected in the IR when running in rollback mode. After this commit, replaceAllUsesWith changes are materialized in a delayed fashion, at the end of the dialect conversion. This is consistent with the replaceUsesOfBlockArgument and replaceOp APIs.

replaceAllUsesExcept etc. are still not supported and will be deactivated on the ConversionPatternRewriter (when running in rollback mode) in a follow-up commit.

Note for LLVM integration: Replace replaceUsesOfBlockArgument with replaceAllUsesWith. If you are seeing failures, you may have patterns that use replaceAllUsesWith incorrectly (e.g., being called multiple times on the same value) or bypass the rewriter API entirely. E.g., such failures were mitigated in Flang by switching to the walk-patterns driver (#156171).

You can temporarily reactivate the old behavior by calling RewriterBase::replaceAllUsesWith. However, note that that behavior is faulty in a dialect conversion. E.g., the base RewriterBase::replaceAllUsesWith implementation does not see uses of the from value that have not materialized yet and will, therefore, not replace them.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/scf_openmp branch from 5205679 to a8719ed Compare August 26, 2025 09:34
Base automatically changed from users/matthias-springer/scf_openmp to main August 26, 2025 09:52
@matthias-springer matthias-springer force-pushed the users/matthias-springer/repl_all_uses_with branch from b217bce to 88d9afc Compare August 26, 2025 12:30
@matthias-springer matthias-springer force-pushed the users/matthias-springer/repl_all_uses_with branch from 88d9afc to 93ad0c4 Compare August 30, 2025 16:22
@matthias-springer matthias-springer changed the base branch from main to users/matthias-springer/fix_flang_2 August 30, 2025 16:23
@matthias-springer matthias-springer force-pushed the users/matthias-springer/repl_all_uses_with branch from 93ad0c4 to 1962033 Compare August 30, 2025 17:06
@matthias-springer matthias-springer marked this pull request as ready for review August 30, 2025 17:16
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Aug 30, 2025
@llvmbot
Copy link
Member

llvmbot commented Aug 30, 2025

@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

This commit generalizes replaceUsesOfBlockArgument to replaceAllUsesWith. In rollback mode, the same restrictions keep applying: a value cannot be replaced multiple times and a call to replaceAllUsesWith will replace all current and future uses of the from value.

replaceAllUsesWith is now fully supported and its behavior is consistent with the remaining dialect conversion API. Before this commit, replaceAllUsesWith was immediately reflected in the IR when running in rollback mode. After this commit, replaceAllUsesWith changes are materialized in a delayed fashion, at the end of the dialect conversion. This is consistent with the replaceUsesOfBlockArgument and replaceOp APIs.

replaceAllUsesExcept etc. are still not supported and will be deactivated on the ConversionPatternRewriter (when running in rollback mode) in a follow-up commit.

Note for LLVM integration: Replace replaceUsesOfBlockArgument with replaceAllUsesWith. If you are seeing failures, you may have patterns that use replaceAllUsesWith incorrectly (e.g., being called multiple times on the same value) or bypass the rewriter API entirely. E.g., such failures were mitigated in Flang by switching to the walk-patterns driver (#156171).

You can temporarily reactivate the old behavior by calling RewriterBase::replaceAllUsesWith. However, note that that behavior is faulty in a dialect conversion. E.g., the base RewriterBase::replaceAllUsesWith implementation does not see uses of the from value that have not materialized yet and will, therefore, not replace them.


Patch is 22.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/155244.diff

7 Files Affected:

  • (modified) mlir/include/mlir/IR/PatternMatch.h (+1-1)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+19-8)
  • (modified) mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp (+1-1)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+98-60)
  • (modified) mlir/test/Transforms/test-legalizer-rollback.mlir (+5-3)
  • (modified) mlir/test/Transforms/test-legalizer.mlir (+22-4)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+12-12)
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 57e73c1d8c7c1..7b0b9cef9c5bd 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -633,7 +633,7 @@ class RewriterBase : public OpBuilder {
 
   /// Find uses of `from` and replace them with `to`. Also notify the listener
   /// about every in-place op modification (for every use that was replaced).
-  void replaceAllUsesWith(Value from, Value to) {
+  virtual void replaceAllUsesWith(Value from, Value to) {
     for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
       Operation *op = operand.getOwner();
       modifyOpInPlace(op, [&]() { operand.set(to); });
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 14dfbf18836c6..fe48e45a9b98c 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -854,15 +854,26 @@ class ConversionPatternRewriter final : public PatternRewriter {
       Region *region, const TypeConverter &converter,
       TypeConverter::SignatureConversion *entryConversion = nullptr);
 
-  /// Replace all the uses of the block argument `from` with `to`. This
-  /// function supports both 1:1 and 1:N replacements.
+  /// Replace all the uses of `from` with `to`. The type of `from` and `to` is
+  /// allowed to differ. The conversion driver will try to reconcile all type
+  /// mismatches that still exist at the end of the conversion with
+  /// materializations. This function supports both 1:1 and 1:N replacements.
   ///
-  /// Note: If `allowPatternRollback` is set to "true", this function replaces
-  /// all current and future uses of the block argument. This same block
-  /// block argument must not be replaced multiple times. Uses are not replaced
-  /// immediately but in a delayed fashion. Patterns may still see the original
-  /// uses when inspecting IR.
-  void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);
+  /// Note: If `allowPatternRollback` is set to "true", this function behaves
+  /// slightly different:
+  ///
+  /// 1. All current and future uses of `from` are replaced. The same value must
+  ///    not be replaced multiple times. That's an API violation.
+  /// 2. Uses are not replaced immediately but in a delayed fashion. Patterns
+  ///    may still see the original uses when inspecting IR.
+  /// 3. Uses within the same block that appear before the defining operation
+  ///    of the replacement value are not replaced. This allows users to
+  ///    perform certain replaceAllUsesExcept-style replacements, even though
+  ///    such API is not directly supported.
+  void replaceAllUsesWith(Value from, ValueRange to);
+  void replaceAllUsesWith(Value from, Value to) override {
+    replaceAllUsesWith(from, ValueRange{to});
+  }
 
   /// Return the converted value of 'key' with a type defined by the type
   /// converter of the currently executing pattern. Return nullptr in the case
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 42c76ed475b4c..93fe2edad5274 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -284,7 +284,7 @@ static void restoreByValRefArgumentType(
         cast<TypeAttr>(byValRefAttr->getValue()).getValue());
 
     Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
-    rewriter.replaceUsesOfBlockArgument(arg, valueArg);
+    rewriter.replaceAllUsesWith(arg, valueArg);
   }
 }
 
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 5ba109d96cf13..d72429298754f 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -277,13 +277,14 @@ class IRRewrite {
     InlineBlock,
     MoveBlock,
     BlockTypeConversion,
-    ReplaceBlockArg,
     // Operation rewrites
     MoveOperation,
     ModifyOperation,
     ReplaceOperation,
     CreateOperation,
-    UnresolvedMaterialization
+    UnresolvedMaterialization,
+    // Value rewrites
+    ReplaceValue
   };
 
   virtual ~IRRewrite() = default;
@@ -330,7 +331,7 @@ class BlockRewrite : public IRRewrite {
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() >= Kind::CreateBlock &&
-           rewrite->getKind() <= Kind::ReplaceBlockArg;
+           rewrite->getKind() <= Kind::BlockTypeConversion;
   }
 
 protected:
@@ -342,6 +343,25 @@ class BlockRewrite : public IRRewrite {
   Block *block;
 };
 
+/// A value rewrite.
+class ValueRewrite : public IRRewrite {
+public:
+  /// Return the value that this rewrite operates on.
+  Value getValue() const { return value; }
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::ReplaceValue;
+  }
+
+protected:
+  ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
+               Value value)
+      : IRRewrite(kind, rewriterImpl), value(value) {}
+
+  // The value that this rewrite operates on.
+  Value value;
+};
+
 /// Creation of a block. Block creations are immediately reflected in the IR.
 /// There is no extra work to commit the rewrite. During rollback, the newly
 /// created block is erased.
@@ -548,19 +568,18 @@ class BlockTypeConversionRewrite : public BlockRewrite {
   Block *newBlock;
 };
 
-/// Replacing a block argument. This rewrite is not immediately reflected in the
+/// Replacing a value. This rewrite is not immediately reflected in the
 /// IR. An internal IR mapping is updated, but the actual replacement is delayed
 /// until the rewrite is committed.
-class ReplaceBlockArgRewrite : public BlockRewrite {
+class ReplaceValueRewrite : public ValueRewrite {
 public:
-  ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
-                         Block *block, BlockArgument arg,
-                         const TypeConverter *converter)
-      : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
+  ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
+                      const TypeConverter *converter)
+      : ValueRewrite(Kind::ReplaceValue, rewriterImpl, value),
         converter(converter) {}
 
   static bool classof(const IRRewrite *rewrite) {
-    return rewrite->getKind() == Kind::ReplaceBlockArg;
+    return rewrite->getKind() == Kind::ReplaceValue;
   }
 
   void commit(RewriterBase &rewriter) override;
@@ -568,9 +587,7 @@ class ReplaceBlockArgRewrite : public BlockRewrite {
   void rollback() override;
 
 private:
-  BlockArgument arg;
-
-  /// The current type converter when the block argument was replaced.
+  /// The current type converter when the value was replaced.
   const TypeConverter *converter;
 };
 
@@ -940,10 +957,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// uses.
   void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
 
-  /// Replace the given block argument with the given values. The specified
+  /// Replace the uses of the given value with the given values. The specified
   /// converter is used to build materializations (if necessary).
-  void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to,
-                                  const TypeConverter *converter);
+  void replaceAllUsesWith(Value from, ValueRange to,
+                          const TypeConverter *converter);
 
   /// Erase the given block and its contents.
   void eraseBlock(Block *block);
@@ -1129,10 +1146,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   IRRewriter notifyingRewriter;
 
 #ifndef NDEBUG
-  /// A set of replaced block arguments. This set is for debugging purposes
-  /// only and it is maintained only if `allowPatternRollback` is set to
-  /// "true".
-  DenseSet<BlockArgument> replacedArgs;
+  /// A set of replaced values. This set is for debugging purposes only and it
+  /// is maintained only if `allowPatternRollback` is set to "true".
+  DenseSet<Value> replacedValues;
 
   /// A set of operations that have pending updates. This tracking isn't
   /// strictly necessary, and is thus only active during debug builds for extra
@@ -1169,32 +1185,54 @@ void BlockTypeConversionRewrite::rollback() {
   getNewBlock()->replaceAllUsesWith(getOrigBlock());
 }
 
-static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg,
-                                   Value repl) {
+/// Replace all uses of `from` with `repl`.
+static void performReplaceValue(RewriterBase &rewriter, Value from,
+                                Value repl) {
   if (isa<BlockArgument>(repl)) {
-    rewriter.replaceAllUsesWith(arg, repl);
+    // `repl` is a block argument. Directly replace all uses.
+    rewriter.replaceAllUsesWith(from, repl);
     return;
   }
 
-  // If the replacement value is an operation, we check to make sure that we
-  // don't replace uses that are within the parent operation of the
-  // replacement value.
-  Operation *replOp = cast<OpResult>(repl).getOwner();
+  // If the replacement value is an operation, only replace those uses that:
+  // - are in a different block than the replacement operation, or
+  // - are in the same block but after the replacement operation.
+  //
+  // Example:
+  // ^bb0(%arg0: i32):
+  // %0 = "consumer"(%arg0) : (i32) -> (i32)
+  // "another_consumer"(%arg0) : (i32) -> ()
+  //
+  // In the above example, replaceAllUsesWith(%arg0, %0) will replace the
+  // use in "another_consumer" but not the use in "consumer". When using the
+  // normal RewriterBase API, this would typically be done with
+  // `replaceUsesWithIf` / `replaceAllUsesExcept`. However, that API is not
+  // supported by the `ConversionPatternRewriter`. Due to the mapping mechanism
+  // it cannot be supported efficiently with `allowPatternRollback` set to
+  // "true". Therefore, the conversion driver is trying to be smart and replaces
+  // only those uses that do not lead to a dominance violation. E.g., the
+  // FuncToLLVM lowering (`restoreByValRefArgumentType`) relies on this
+  // behavior.
+  //
+  // TODO: As we move more and more towards `allowPatternRollback` set to
+  // "false", we should remove this special handling, in order to align the
+  // `ConversionPatternRewriter` API with the normal `RewriterBase` API.
+  Operation *replOp = repl.getDefiningOp();
   Block *replBlock = replOp->getBlock();
-  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
+  rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
     Operation *user = operand.getOwner();
     return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
   });
 }
 
-void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
-  Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
+void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
+  Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
   if (!repl)
     return;
-  performReplaceBlockArg(rewriter, arg, repl);
+  performReplaceValue(rewriter, value, repl);
 }
 
-void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
+void ReplaceValueRewrite::rollback() { rewriterImpl.mapping.erase({value}); }
 
 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
   auto *listener =
@@ -1584,7 +1622,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
               /*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
               /*isPureTypeConversion=*/false)
               .front();
-      replaceUsesOfBlockArgument(origArg, mat, converter);
+      replaceAllUsesWith(origArg, mat, converter);
       continue;
     }
 
@@ -1593,15 +1631,14 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
       assert(inputMap->size == 0 &&
              "invalid to provide a replacement value when the argument isn't "
              "dropped");
-      replaceUsesOfBlockArgument(origArg, inputMap->replacementValues,
-                                 converter);
+      replaceAllUsesWith(origArg, inputMap->replacementValues, converter);
       continue;
     }
 
     // This is a 1->1+ mapping.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    replaceUsesOfBlockArgument(origArg, replArgs, converter);
+    replaceAllUsesWith(origArg, replArgs, converter);
   }
 
   if (config.allowPatternRollback)
@@ -1873,8 +1910,8 @@ void ConversionPatternRewriterImpl::replaceOp(
   op->walk([&](Operation *op) { replacedOps.insert(op); });
 }
 
-void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
-    BlockArgument from, ValueRange to, const TypeConverter *converter) {
+void ConversionPatternRewriterImpl::replaceAllUsesWith(
+    Value from, ValueRange to, const TypeConverter *converter) {
   if (!config.allowPatternRollback) {
     SmallVector<Value> toConv = llvm::to_vector(to);
     SmallVector<Value> repls =
@@ -1884,25 +1921,25 @@ void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
     if (!repl)
       return;
 
-    performReplaceBlockArg(r, from, repl);
+    performReplaceValue(r, from, repl);
     return;
   }
 
 #ifndef NDEBUG
-  // Make sure that a block argument is not replaced multiple times. In
-  // rollback mode, `replaceUsesOfBlockArgument` replaces not only all current
-  // uses of the given block argument, but also all future uses that may be
-  // introduced by future pattern applications. Therefore, it does not make
-  // sense to call `replaceUsesOfBlockArgument` multiple times with the same
-  // block argument. Doing so would overwrite the mapping and mess with the
-  // internal state of the dialect conversion driver.
-  assert(!replacedArgs.contains(from) &&
-         "attempting to replace a block argument that was already replaced");
-  replacedArgs.insert(from);
+  // Make sure that a value is not replaced multiple times. In rollback mode,
+  // `replaceAllUsesWith` replaces not only all current uses of the given value,
+  // but also all future uses that may be introduced by future pattern
+  // applications. Therefore, it does not make sense to call
+  // `replaceAllUsesWith` multiple times with the same value. Doing so would
+  // overwrite the mapping and mess with the internal state of the dialect
+  // conversion driver.
+  assert(!replacedValues.contains(from) &&
+         "attempting to replace a value that was already replaced");
+  replacedValues.insert(from);
 #endif // NDEBUG
 
-  appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
   mapping.map(from, to);
+  appendRewrite<ReplaceValueRewrite>(from, converter);
 }
 
 void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
@@ -2107,18 +2144,19 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
   return impl->convertRegionTypes(region, converter, entryConversion);
 }
 
-void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
-                                                           ValueRange to) {
+void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) {
   LLVM_DEBUG({
-    impl->logger.startLine() << "** Replace Argument : '" << from << "'";
-    if (Operation *parentOp = from.getOwner()->getParentOp()) {
-      impl->logger.getOStream() << " (in region of '" << parentOp->getName()
-                                << "' (" << parentOp << ")\n";
-    } else {
-      impl->logger.getOStream() << " (unlinked block)\n";
+    impl->logger.startLine() << "** Replace Value : '" << from << "'";
+    if (auto blockArg = dyn_cast<BlockArgument>(from)) {
+      if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
+        impl->logger.getOStream() << " (in region of '" << parentOp->getName()
+                                  << "' (" << parentOp << ")\n";
+      } else {
+        impl->logger.getOStream() << " (unlinked block)\n";
+      }
     }
   });
-  impl->replaceUsesOfBlockArgument(from, to, impl->currentTypeConverter);
+  impl->replaceAllUsesWith(from, to, impl->currentTypeConverter);
 }
 
 Value ConversionPatternRewriter::getRemappedValue(Value key) {
@@ -2176,7 +2214,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
 
   // Replace all uses of block arguments.
   for (auto it : llvm::zip(source->getArguments(), argValues))
-    replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
+    replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
 
   if (fastPath) {
     // Move all ops at once.
diff --git a/mlir/test/Transforms/test-legalizer-rollback.mlir b/mlir/test/Transforms/test-legalizer-rollback.mlir
index 460911fd88ad1..71e11782e14b0 100644
--- a/mlir/test/Transforms/test-legalizer-rollback.mlir
+++ b/mlir/test/Transforms/test-legalizer-rollback.mlir
@@ -49,14 +49,16 @@ func.func @create_illegal_block() {
 // expected-remark@+1{{applyPartialConversion failed}}
 module {
 func.func @undo_block_arg_replace() {
-  // expected-error@+1{{failed to legalize operation 'test.block_arg_replace' that was explicitly marked illegal}}
-  "test.block_arg_replace"() ({
+  "test.legal_op"() ({
   ^bb0(%arg0: i32, %arg1: i16):
     // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
+    // CHECK-NEXT: "test.value_replace"(%[[ARG0]], %[[ARG1]]) {trigger_rollback}
     // CHECK-NEXT: "test.return"(%[[ARG0]]) : (i32)
 
+    // expected-error@+1{{failed to legalize operation 'test.value_replace' that was explicitly marked illegal}}
+    "test.value_replace"(%arg0, %arg1) {trigger_rollback} : (i32, i16) -> ()
     "test.return"(%arg0) : (i32) -> ()
-  }) {trigger_rollback} : () -> ()
+  }) : () -> ()
   return
 }
 }
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 3fa42ff6b2757..94c5bb4e93b06 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -269,12 +269,14 @@ builtin.module {
 
 // CHECK-LABEL: @replace_block_arg_1_to_n
 func.func @replace_block_arg_1_to_n() {
-  // CHECK: "test.block_arg_replace"
-  "test.block_arg_replace"() ({
+  // CHECK: "test.legal_op"
+  "test.legal_op"() ({
   ^bb0(%arg0: i32, %arg1: i16):
-    // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
-    // CHECK: %[[cast:.*]] = "test.cast"(%[[ARG1]], %[[ARG1]]) : (i16, i16) -> i32
+    // CHECK-NEXT: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
+    // CHECK-NEXT: %[[cast:.*]] = "test.cast"(%[[ARG1]], %[[ARG1]]) : (i16, i16) -> i32
+    // CHECK-NEXT: "test.value_replace"(%[[cast]], %[[ARG1]]) {is_legal} : (i32, i16) -> ()
     // CHECK-NEXT: "test.return"(%[[cast]]) : (i32)
+    "test.value_replace"(%arg0, %arg1) : (i32, i16) -> ()
     "test.return"(%arg0) : (i32) -> ()
   }) : () -> ()
   "test.return"() : () -> ()
@@ -282,6 +284,22 @@ func.func @replace_block_arg_1_to_n() {
 
 // -----
 
+// CHECK-LABEL: @replace_op_result_1_to_n
+func.func @replace_op_result_1_to_n() {
+  // CHECK: %[[orig:.*]] = "test.legal_op"() : () -> i32
+  // CHECK: %[[repl:.*]] = "test.legal_op"() : () -> i16
+  %0 = "test.legal_op"() : () -> i32
+  %1 = "test.legal_op"() : () -> i16
+
+  // CHECK-NEXT: %[[cast:.*]] = "test.cast"(%[[repl]], %[[repl]]) : (i16, i16) -> i32
+  // CHECK-NEXT: "test.value_replace"(%[[cast]], %[[repl]]) {is_legal} : (i32, i16) -> ()
+  // CHECK-NEXT: "test.return"(%[[cast]]) : (i32)
+  "test.value_replace"(%0, %1) : (i32, i16) -> ()
+  "test.return"(%0) : (i32) -> ()
+}
+
+// -----
+
 // Check that a conversion pattern on `test.blackhole` can mark the producer
 // for deletion.
 // CHECK-LABEL: @blackhole
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 95f381ec471d6..93b007c792ad9 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -952,19 +952,19 @@ struct TestCreateIllegalBlock : public RewritePattern {
   }
 };
 
-/// A simple pattern that tests the "replaceUsesOfBlockArgument" API.
-struct TestBlockArgReplace : public ConversionPattern {
-  TestBlockArgReplace(MLIRContext *ctx, const TypeConverter &converter)
-      : ConversionPattern(converter, "test.block_arg_replace", /*benefit=*/1,
-                          ctx) {}
+/// A simple pattern that tests the "replaceAllUsesWith"...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 30, 2025

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

This commit generalizes replaceUsesOfBlockArgument to replaceAllUsesWith. In rollback mode, the same restrictions keep applying: a value cannot be replaced multiple times and a call to replaceAllUsesWith will replace all current and future uses of the from value.

replaceAllUsesWith is now fully supported and its behavior is consistent with the remaining dialect conversion API. Before this commit, replaceAllUsesWith was immediately reflected in the IR when running in rollback mode. After this commit, replaceAllUsesWith changes are materialized in a delayed fashion, at the end of the dialect conversion. This is consistent with the replaceUsesOfBlockArgument and replaceOp APIs.

replaceAllUsesExcept etc. are still not supported and will be deactivated on the ConversionPatternRewriter (when running in rollback mode) in a follow-up commit.

Note for LLVM integration: Replace replaceUsesOfBlockArgument with replaceAllUsesWith. If you are seeing failures, you may have patterns that use replaceAllUsesWith incorrectly (e.g., being called multiple times on the same value) or bypass the rewriter API entirely. E.g., such failures were mitigated in Flang by switching to the walk-patterns driver (#156171).

You can temporarily reactivate the old behavior by calling RewriterBase::replaceAllUsesWith. However, note that that behavior is faulty in a dialect conversion. E.g., the base RewriterBase::replaceAllUsesWith implementation does not see uses of the from value that have not materialized yet and will, therefore, not replace them.


Patch is 22.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/155244.diff

7 Files Affected:

  • (modified) mlir/include/mlir/IR/PatternMatch.h (+1-1)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+19-8)
  • (modified) mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp (+1-1)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+98-60)
  • (modified) mlir/test/Transforms/test-legalizer-rollback.mlir (+5-3)
  • (modified) mlir/test/Transforms/test-legalizer.mlir (+22-4)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+12-12)
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 57e73c1d8c7c1..7b0b9cef9c5bd 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -633,7 +633,7 @@ class RewriterBase : public OpBuilder {
 
   /// Find uses of `from` and replace them with `to`. Also notify the listener
   /// about every in-place op modification (for every use that was replaced).
-  void replaceAllUsesWith(Value from, Value to) {
+  virtual void replaceAllUsesWith(Value from, Value to) {
     for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
       Operation *op = operand.getOwner();
       modifyOpInPlace(op, [&]() { operand.set(to); });
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 14dfbf18836c6..fe48e45a9b98c 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -854,15 +854,26 @@ class ConversionPatternRewriter final : public PatternRewriter {
       Region *region, const TypeConverter &converter,
       TypeConverter::SignatureConversion *entryConversion = nullptr);
 
-  /// Replace all the uses of the block argument `from` with `to`. This
-  /// function supports both 1:1 and 1:N replacements.
+  /// Replace all the uses of `from` with `to`. The type of `from` and `to` is
+  /// allowed to differ. The conversion driver will try to reconcile all type
+  /// mismatches that still exist at the end of the conversion with
+  /// materializations. This function supports both 1:1 and 1:N replacements.
   ///
-  /// Note: If `allowPatternRollback` is set to "true", this function replaces
-  /// all current and future uses of the block argument. This same block
-  /// block argument must not be replaced multiple times. Uses are not replaced
-  /// immediately but in a delayed fashion. Patterns may still see the original
-  /// uses when inspecting IR.
-  void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);
+  /// Note: If `allowPatternRollback` is set to "true", this function behaves
+  /// slightly different:
+  ///
+  /// 1. All current and future uses of `from` are replaced. The same value must
+  ///    not be replaced multiple times. That's an API violation.
+  /// 2. Uses are not replaced immediately but in a delayed fashion. Patterns
+  ///    may still see the original uses when inspecting IR.
+  /// 3. Uses within the same block that appear before the defining operation
+  ///    of the replacement value are not replaced. This allows users to
+  ///    perform certain replaceAllUsesExcept-style replacements, even though
+  ///    such API is not directly supported.
+  void replaceAllUsesWith(Value from, ValueRange to);
+  void replaceAllUsesWith(Value from, Value to) override {
+    replaceAllUsesWith(from, ValueRange{to});
+  }
 
   /// Return the converted value of 'key' with a type defined by the type
   /// converter of the currently executing pattern. Return nullptr in the case
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 42c76ed475b4c..93fe2edad5274 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -284,7 +284,7 @@ static void restoreByValRefArgumentType(
         cast<TypeAttr>(byValRefAttr->getValue()).getValue());
 
     Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
-    rewriter.replaceUsesOfBlockArgument(arg, valueArg);
+    rewriter.replaceAllUsesWith(arg, valueArg);
   }
 }
 
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 5ba109d96cf13..d72429298754f 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -277,13 +277,14 @@ class IRRewrite {
     InlineBlock,
     MoveBlock,
     BlockTypeConversion,
-    ReplaceBlockArg,
     // Operation rewrites
     MoveOperation,
     ModifyOperation,
     ReplaceOperation,
     CreateOperation,
-    UnresolvedMaterialization
+    UnresolvedMaterialization,
+    // Value rewrites
+    ReplaceValue
   };
 
   virtual ~IRRewrite() = default;
@@ -330,7 +331,7 @@ class BlockRewrite : public IRRewrite {
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() >= Kind::CreateBlock &&
-           rewrite->getKind() <= Kind::ReplaceBlockArg;
+           rewrite->getKind() <= Kind::BlockTypeConversion;
   }
 
 protected:
@@ -342,6 +343,25 @@ class BlockRewrite : public IRRewrite {
   Block *block;
 };
 
+/// A value rewrite.
+class ValueRewrite : public IRRewrite {
+public:
+  /// Return the value that this rewrite operates on.
+  Value getValue() const { return value; }
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::ReplaceValue;
+  }
+
+protected:
+  ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
+               Value value)
+      : IRRewrite(kind, rewriterImpl), value(value) {}
+
+  // The value that this rewrite operates on.
+  Value value;
+};
+
 /// Creation of a block. Block creations are immediately reflected in the IR.
 /// There is no extra work to commit the rewrite. During rollback, the newly
 /// created block is erased.
@@ -548,19 +568,18 @@ class BlockTypeConversionRewrite : public BlockRewrite {
   Block *newBlock;
 };
 
-/// Replacing a block argument. This rewrite is not immediately reflected in the
+/// Replacing a value. This rewrite is not immediately reflected in the
 /// IR. An internal IR mapping is updated, but the actual replacement is delayed
 /// until the rewrite is committed.
-class ReplaceBlockArgRewrite : public BlockRewrite {
+class ReplaceValueRewrite : public ValueRewrite {
 public:
-  ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
-                         Block *block, BlockArgument arg,
-                         const TypeConverter *converter)
-      : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
+  ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
+                      const TypeConverter *converter)
+      : ValueRewrite(Kind::ReplaceValue, rewriterImpl, value),
         converter(converter) {}
 
   static bool classof(const IRRewrite *rewrite) {
-    return rewrite->getKind() == Kind::ReplaceBlockArg;
+    return rewrite->getKind() == Kind::ReplaceValue;
   }
 
   void commit(RewriterBase &rewriter) override;
@@ -568,9 +587,7 @@ class ReplaceBlockArgRewrite : public BlockRewrite {
   void rollback() override;
 
 private:
-  BlockArgument arg;
-
-  /// The current type converter when the block argument was replaced.
+  /// The current type converter when the value was replaced.
   const TypeConverter *converter;
 };
 
@@ -940,10 +957,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// uses.
   void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
 
-  /// Replace the given block argument with the given values. The specified
+  /// Replace the uses of the given value with the given values. The specified
   /// converter is used to build materializations (if necessary).
-  void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to,
-                                  const TypeConverter *converter);
+  void replaceAllUsesWith(Value from, ValueRange to,
+                          const TypeConverter *converter);
 
   /// Erase the given block and its contents.
   void eraseBlock(Block *block);
@@ -1129,10 +1146,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   IRRewriter notifyingRewriter;
 
 #ifndef NDEBUG
-  /// A set of replaced block arguments. This set is for debugging purposes
-  /// only and it is maintained only if `allowPatternRollback` is set to
-  /// "true".
-  DenseSet<BlockArgument> replacedArgs;
+  /// A set of replaced values. This set is for debugging purposes only and it
+  /// is maintained only if `allowPatternRollback` is set to "true".
+  DenseSet<Value> replacedValues;
 
   /// A set of operations that have pending updates. This tracking isn't
   /// strictly necessary, and is thus only active during debug builds for extra
@@ -1169,32 +1185,54 @@ void BlockTypeConversionRewrite::rollback() {
   getNewBlock()->replaceAllUsesWith(getOrigBlock());
 }
 
-static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg,
-                                   Value repl) {
+/// Replace all uses of `from` with `repl`.
+static void performReplaceValue(RewriterBase &rewriter, Value from,
+                                Value repl) {
   if (isa<BlockArgument>(repl)) {
-    rewriter.replaceAllUsesWith(arg, repl);
+    // `repl` is a block argument. Directly replace all uses.
+    rewriter.replaceAllUsesWith(from, repl);
     return;
   }
 
-  // If the replacement value is an operation, we check to make sure that we
-  // don't replace uses that are within the parent operation of the
-  // replacement value.
-  Operation *replOp = cast<OpResult>(repl).getOwner();
+  // If the replacement value is an operation, only replace those uses that:
+  // - are in a different block than the replacement operation, or
+  // - are in the same block but after the replacement operation.
+  //
+  // Example:
+  // ^bb0(%arg0: i32):
+  // %0 = "consumer"(%arg0) : (i32) -> (i32)
+  // "another_consumer"(%arg0) : (i32) -> ()
+  //
+  // In the above example, replaceAllUsesWith(%arg0, %0) will replace the
+  // use in "another_consumer" but not the use in "consumer". When using the
+  // normal RewriterBase API, this would typically be done with
+  // `replaceUsesWithIf` / `replaceAllUsesExcept`. However, that API is not
+  // supported by the `ConversionPatternRewriter`. Due to the mapping mechanism
+  // it cannot be supported efficiently with `allowPatternRollback` set to
+  // "true". Therefore, the conversion driver is trying to be smart and replaces
+  // only those uses that do not lead to a dominance violation. E.g., the
+  // FuncToLLVM lowering (`restoreByValRefArgumentType`) relies on this
+  // behavior.
+  //
+  // TODO: As we move more and more towards `allowPatternRollback` set to
+  // "false", we should remove this special handling, in order to align the
+  // `ConversionPatternRewriter` API with the normal `RewriterBase` API.
+  Operation *replOp = repl.getDefiningOp();
   Block *replBlock = replOp->getBlock();
-  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
+  rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
     Operation *user = operand.getOwner();
     return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
   });
 }
 
-void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
-  Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
+void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
+  Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
   if (!repl)
     return;
-  performReplaceBlockArg(rewriter, arg, repl);
+  performReplaceValue(rewriter, value, repl);
 }
 
-void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
+void ReplaceValueRewrite::rollback() { rewriterImpl.mapping.erase({value}); }
 
 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
   auto *listener =
@@ -1584,7 +1622,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
               /*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
               /*isPureTypeConversion=*/false)
               .front();
-      replaceUsesOfBlockArgument(origArg, mat, converter);
+      replaceAllUsesWith(origArg, mat, converter);
       continue;
     }
 
@@ -1593,15 +1631,14 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
       assert(inputMap->size == 0 &&
              "invalid to provide a replacement value when the argument isn't "
              "dropped");
-      replaceUsesOfBlockArgument(origArg, inputMap->replacementValues,
-                                 converter);
+      replaceAllUsesWith(origArg, inputMap->replacementValues, converter);
       continue;
     }
 
     // This is a 1->1+ mapping.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    replaceUsesOfBlockArgument(origArg, replArgs, converter);
+    replaceAllUsesWith(origArg, replArgs, converter);
   }
 
   if (config.allowPatternRollback)
@@ -1873,8 +1910,8 @@ void ConversionPatternRewriterImpl::replaceOp(
   op->walk([&](Operation *op) { replacedOps.insert(op); });
 }
 
-void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
-    BlockArgument from, ValueRange to, const TypeConverter *converter) {
+void ConversionPatternRewriterImpl::replaceAllUsesWith(
+    Value from, ValueRange to, const TypeConverter *converter) {
   if (!config.allowPatternRollback) {
     SmallVector<Value> toConv = llvm::to_vector(to);
     SmallVector<Value> repls =
@@ -1884,25 +1921,25 @@ void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
     if (!repl)
       return;
 
-    performReplaceBlockArg(r, from, repl);
+    performReplaceValue(r, from, repl);
     return;
   }
 
 #ifndef NDEBUG
-  // Make sure that a block argument is not replaced multiple times. In
-  // rollback mode, `replaceUsesOfBlockArgument` replaces not only all current
-  // uses of the given block argument, but also all future uses that may be
-  // introduced by future pattern applications. Therefore, it does not make
-  // sense to call `replaceUsesOfBlockArgument` multiple times with the same
-  // block argument. Doing so would overwrite the mapping and mess with the
-  // internal state of the dialect conversion driver.
-  assert(!replacedArgs.contains(from) &&
-         "attempting to replace a block argument that was already replaced");
-  replacedArgs.insert(from);
+  // Make sure that a value is not replaced multiple times. In rollback mode,
+  // `replaceAllUsesWith` replaces not only all current uses of the given value,
+  // but also all future uses that may be introduced by future pattern
+  // applications. Therefore, it does not make sense to call
+  // `replaceAllUsesWith` multiple times with the same value. Doing so would
+  // overwrite the mapping and mess with the internal state of the dialect
+  // conversion driver.
+  assert(!replacedValues.contains(from) &&
+         "attempting to replace a value that was already replaced");
+  replacedValues.insert(from);
 #endif // NDEBUG
 
-  appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
   mapping.map(from, to);
+  appendRewrite<ReplaceValueRewrite>(from, converter);
 }
 
 void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
@@ -2107,18 +2144,19 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
   return impl->convertRegionTypes(region, converter, entryConversion);
 }
 
-void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
-                                                           ValueRange to) {
+void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) {
   LLVM_DEBUG({
-    impl->logger.startLine() << "** Replace Argument : '" << from << "'";
-    if (Operation *parentOp = from.getOwner()->getParentOp()) {
-      impl->logger.getOStream() << " (in region of '" << parentOp->getName()
-                                << "' (" << parentOp << ")\n";
-    } else {
-      impl->logger.getOStream() << " (unlinked block)\n";
+    impl->logger.startLine() << "** Replace Value : '" << from << "'";
+    if (auto blockArg = dyn_cast<BlockArgument>(from)) {
+      if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
+        impl->logger.getOStream() << " (in region of '" << parentOp->getName()
+                                  << "' (" << parentOp << ")\n";
+      } else {
+        impl->logger.getOStream() << " (unlinked block)\n";
+      }
     }
   });
-  impl->replaceUsesOfBlockArgument(from, to, impl->currentTypeConverter);
+  impl->replaceAllUsesWith(from, to, impl->currentTypeConverter);
 }
 
 Value ConversionPatternRewriter::getRemappedValue(Value key) {
@@ -2176,7 +2214,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
 
   // Replace all uses of block arguments.
   for (auto it : llvm::zip(source->getArguments(), argValues))
-    replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
+    replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
 
   if (fastPath) {
     // Move all ops at once.
diff --git a/mlir/test/Transforms/test-legalizer-rollback.mlir b/mlir/test/Transforms/test-legalizer-rollback.mlir
index 460911fd88ad1..71e11782e14b0 100644
--- a/mlir/test/Transforms/test-legalizer-rollback.mlir
+++ b/mlir/test/Transforms/test-legalizer-rollback.mlir
@@ -49,14 +49,16 @@ func.func @create_illegal_block() {
 // expected-remark@+1{{applyPartialConversion failed}}
 module {
 func.func @undo_block_arg_replace() {
-  // expected-error@+1{{failed to legalize operation 'test.block_arg_replace' that was explicitly marked illegal}}
-  "test.block_arg_replace"() ({
+  "test.legal_op"() ({
   ^bb0(%arg0: i32, %arg1: i16):
     // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
+    // CHECK-NEXT: "test.value_replace"(%[[ARG0]], %[[ARG1]]) {trigger_rollback}
     // CHECK-NEXT: "test.return"(%[[ARG0]]) : (i32)
 
+    // expected-error@+1{{failed to legalize operation 'test.value_replace' that was explicitly marked illegal}}
+    "test.value_replace"(%arg0, %arg1) {trigger_rollback} : (i32, i16) -> ()
     "test.return"(%arg0) : (i32) -> ()
-  }) {trigger_rollback} : () -> ()
+  }) : () -> ()
   return
 }
 }
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 3fa42ff6b2757..94c5bb4e93b06 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -269,12 +269,14 @@ builtin.module {
 
 // CHECK-LABEL: @replace_block_arg_1_to_n
 func.func @replace_block_arg_1_to_n() {
-  // CHECK: "test.block_arg_replace"
-  "test.block_arg_replace"() ({
+  // CHECK: "test.legal_op"
+  "test.legal_op"() ({
   ^bb0(%arg0: i32, %arg1: i16):
-    // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
-    // CHECK: %[[cast:.*]] = "test.cast"(%[[ARG1]], %[[ARG1]]) : (i16, i16) -> i32
+    // CHECK-NEXT: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
+    // CHECK-NEXT: %[[cast:.*]] = "test.cast"(%[[ARG1]], %[[ARG1]]) : (i16, i16) -> i32
+    // CHECK-NEXT: "test.value_replace"(%[[cast]], %[[ARG1]]) {is_legal} : (i32, i16) -> ()
     // CHECK-NEXT: "test.return"(%[[cast]]) : (i32)
+    "test.value_replace"(%arg0, %arg1) : (i32, i16) -> ()
     "test.return"(%arg0) : (i32) -> ()
   }) : () -> ()
   "test.return"() : () -> ()
@@ -282,6 +284,22 @@ func.func @replace_block_arg_1_to_n() {
 
 // -----
 
+// CHECK-LABEL: @replace_op_result_1_to_n
+func.func @replace_op_result_1_to_n() {
+  // CHECK: %[[orig:.*]] = "test.legal_op"() : () -> i32
+  // CHECK: %[[repl:.*]] = "test.legal_op"() : () -> i16
+  %0 = "test.legal_op"() : () -> i32
+  %1 = "test.legal_op"() : () -> i16
+
+  // CHECK-NEXT: %[[cast:.*]] = "test.cast"(%[[repl]], %[[repl]]) : (i16, i16) -> i32
+  // CHECK-NEXT: "test.value_replace"(%[[cast]], %[[repl]]) {is_legal} : (i32, i16) -> ()
+  // CHECK-NEXT: "test.return"(%[[cast]]) : (i32)
+  "test.value_replace"(%0, %1) : (i32, i16) -> ()
+  "test.return"(%0) : (i32) -> ()
+}
+
+// -----
+
 // Check that a conversion pattern on `test.blackhole` can mark the producer
 // for deletion.
 // CHECK-LABEL: @blackhole
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 95f381ec471d6..93b007c792ad9 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -952,19 +952,19 @@ struct TestCreateIllegalBlock : public RewritePattern {
   }
 };
 
-/// A simple pattern that tests the "replaceUsesOfBlockArgument" API.
-struct TestBlockArgReplace : public ConversionPattern {
-  TestBlockArgReplace(MLIRContext *ctx, const TypeConverter &converter)
-      : ConversionPattern(converter, "test.block_arg_replace", /*benefit=*/1,
-                          ctx) {}
+/// A simple pattern that tests the "replaceAllUsesWith"...
[truncated]

// don't replace uses that are within the parent operation of the
// replacement value.
Operation *replOp = cast<OpResult>(repl).getOwner();
// If the replacement value is an operation, only replace those uses that:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: The implementation here has not changed. I just added more documentation that explains the behavior.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/fix_flang_2 branch from f4dbd0d to fee875b Compare August 30, 2025 17:25
Base automatically changed from users/matthias-springer/fix_flang_2 to main August 30, 2025 17:50
matthias-springer added a commit that referenced this pull request Aug 30, 2025
…56171)

This pass uses the rewriter API incorrectly: it calls
`replaceAllUsesWith`. This will start failing with #155244.

Instead of a dialect conversion, use the walk-patterns driver, which is
also more efficient.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/repl_all_uses_with branch from 1962033 to 2031bc5 Compare August 30, 2025 17:55
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Aug 30, 2025
Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flang changes LGTM. Thank you for the fixes.

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % 1 public API comment

@matthias-springer matthias-springer force-pushed the users/matthias-springer/repl_all_uses_with branch from 2031bc5 to e763834 Compare September 3, 2025 13:17
@matthias-springer matthias-springer force-pushed the users/matthias-springer/repl_all_uses_with branch from e763834 to af99610 Compare September 3, 2025 13:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants