Skip to content

Conversation

PragmaTwice
Copy link
Contributor

@PragmaTwice PragmaTwice commented Sep 8, 2025

In #94714, we add a python function apply_patterns_and_fold_greedily which accepts an MlirModule as the argument type. However, sometimes we want to apply patterns with an MlirOperation argument, and there is currently no python API to convert an MlirOperation to MlirModule.

So here we overload this function apply_patterns_and_fold_greedily to do this (also a corresponding new C API mlirApplyPatternsAndFoldGreedilyWithOp)

@PragmaTwice PragmaTwice marked this pull request as ready for review September 8, 2025 15:14
@llvmbot llvmbot added the mlir label Sep 8, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 8, 2025

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

Changes

In #94714, we add a python function apply_patterns_and_fold_greedily which accepts an MlirModule as the argument type. However, sometimes we want to apply patterns with an MlirOperation argument, and there is currently no python API to convert an MlirOperation to MlirModule.

So here we provide a new python function apply_patterns_and_fold_greedily_with_op to do this (also a corresponding new C API mlirApplyPatternsAndFoldGreedilyWithOp).


Full diff: https://github.com/llvm/llvm-project/pull/157487.diff

4 Files Affected:

  • (modified) mlir/include/mlir-c/Rewrite.h (+4)
  • (modified) mlir/lib/Bindings/Python/Rewrite.cpp (+20-10)
  • (modified) mlir/lib/CAPI/Transforms/Rewrite.cpp (+7)
  • (modified) mlir/test/python/integration/dialects/pdl.py (+34-15)
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 61d3446317550..374d2fb78de88 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -301,6 +301,10 @@ mlirFreezeRewritePattern(MlirRewritePatternSet op);
 MLIR_CAPI_EXPORTED void
 mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
 
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(
+    MlirOperation op, MlirFrozenRewritePatternSet patterns,
+    MlirGreedyRewriteDriverConfig);
+
 MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
     MlirModule op, MlirFrozenRewritePatternSet patterns,
     MlirGreedyRewriteDriverConfig);
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 0373f9c7affe9..feb22485c5609 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -99,14 +99,24 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
            &PyFrozenRewritePatternSet::createFromCapsule);
   m.def(
-      "apply_patterns_and_fold_greedily",
-      [](MlirModule module, MlirFrozenRewritePatternSet set) {
-        auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
-        if (mlirLogicalResultIsFailure(status))
-          // FIXME: Not sure this is the right error to throw here.
-          throw nb::value_error("pattern application failed to converge");
-      },
-      "module"_a, "set"_a,
-      "Applys the given patterns to the given module greedily while folding "
-      "results.");
+       "apply_patterns_and_fold_greedily",
+       [](MlirModule module, MlirFrozenRewritePatternSet set) {
+         auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
+         if (mlirLogicalResultIsFailure(status))
+           throw std::runtime_error("pattern application failed to converge");
+       },
+       "module"_a, "set"_a,
+       "Applys the given patterns to the given module greedily while folding "
+       "results.")
+      .def(
+          "apply_patterns_and_fold_greedily_with_op",
+          [](MlirOperation op, MlirFrozenRewritePatternSet set) {
+            auto status = mlirApplyPatternsAndFoldGreedilyWithOp(op, set, {});
+            if (mlirLogicalResultIsFailure(status))
+              throw std::runtime_error(
+                  "pattern application failed to converge");
+          },
+          "op"_a, "set"_a,
+          "Applys the given patterns to the given op greedily while folding "
+          "results.");
 }
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index a4df97f7beace..6f85357a14a18 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -294,6 +294,13 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op,
   return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
 }
 
+MlirLogicalResult
+mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
+                                       MlirFrozenRewritePatternSet patterns,
+                                       MlirGreedyRewriteDriverConfig) {
+  return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
+}
+
 //===----------------------------------------------------------------------===//
 /// PDLPatternModule API
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index 923af29a71ad7..42d3707017e17 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -16,20 +16,7 @@ def construct_and_print_in_module(f):
             print(module)
     return f
 
-
-# CHECK-LABEL: TEST: test_add_to_mul
-# CHECK: arith.muli
-@construct_and_print_in_module
-def test_add_to_mul(module_):
-    index_type = IndexType.get()
-
-    # Create a test case.
-    @module(sym_name="ir")
-    def ir():
-        @func.func(index_type, index_type)
-        def add_func(a, b):
-            return arith.addi(a, b)
-
+def get_pdl_patterns():
     # Create a rewrite from add to mul. This will match
     # - operation name is arith.addi
     # - operands are index types.
@@ -61,7 +48,39 @@ def rew():
     # not yet captured Python side/has sharp edges. So best to construct the
     # module and PDL module in same scope.
     # FIXME: This should be made more robust.
-    frozen = PDLModule(m).freeze()
+    return PDLModule(m).freeze()
+
+# CHECK-LABEL: TEST: test_add_to_mul
+# CHECK: arith.muli
+@construct_and_print_in_module
+def test_add_to_mul(module_):
+    index_type = IndexType.get()
+
+    # Create a test case.
+    @module(sym_name="ir")
+    def ir():
+        @func.func(index_type, index_type)
+        def add_func(a, b):
+            return arith.addi(a, b)
+
+    frozen = get_pdl_patterns()
     # Could apply frozen pattern set multiple times.
     apply_patterns_and_fold_greedily(module_, frozen)
     return module_
+
+# CHECK-LABEL: TEST: test_add_to_mul_with_op
+# CHECK: arith.muli
+@construct_and_print_in_module
+def test_add_to_mul_with_op(module_):
+    index_type = IndexType.get()
+
+    # Create a test case.
+    @module(sym_name="ir")
+    def ir():
+        @func.func(index_type, index_type)
+        def add_func(a, b):
+            return arith.addi(a, b)
+
+    frozen = get_pdl_patterns()
+    apply_patterns_and_fold_greedily_with_op(module_.operation, frozen)
+    return module_

Copy link

github-actions bot commented Sep 8, 2025

✅ With the latest revision this PR passed the Python code formatter.

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Let me know when you're ready to merge

@PragmaTwice
Copy link
Contributor Author

Ready to merge now : )

@makslevental makslevental enabled auto-merge (squash) September 8, 2025 16:02
@makslevental makslevental merged commit aac4eb5 into llvm:main Sep 8, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants