Skip to content

Conversation

PragmaTwice
Copy link
Contributor

@PragmaTwice PragmaTwice commented Aug 29, 2025

STATUS: This PR is ready-for-review but quite experimental :)

It tries to close #155996.

This PR exports a class mlir.passmanager.Pass for Python-side to use for defining new MLIR passes.

This is a simple example of a Python-defined pass.

from mlir.passmanager import PassManager

def demo_pass_1(op):
    # do something with op
    pass

class DemoPass:
    def __init__(self, ...):
        pass
    def __call__(op):
        # do something
        pass

demo_pass_2 = DemoPass(..)

pm = PassManager('any', ctx)
pm.add(demo_pass_1)
pm.add(demo_pass_2)
pm.add("registered-passes")
pm.run(..)

TODO list:

Comment on lines 304 to 306
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyForOp(
MlirOperation op, MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure which name is suitable.. So the current name is just a placeholder. And modifying mlirApplyPatternsAndFoldGreedily seems a breaking change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a breaking change, but not sure how widely and also C API is really best effort stable (meaning, we try not to break it). I probably should have included a Module suffix on the original, and then this could have gone without. How many usages in the wild can you find of this call?

(Nit: ForOp suffix makes me think it's related to ForOp).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I perform a simple check on public repos via GitHub search (https://github.com/search?q=mlirApplyPatternsAndFoldGreedily+&type=code) and there seems to be some use cases of this API so I'm not sure if we should change it or not (maybe leave to maintainers for this decision : ).

Currently I'm going to rename it to mlirApplyPatternsAndFoldGreedilyWithOp to avoid such confusion with ForOp as you mentioned : )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please extract this change in an independent PR? This is largely unrelated (your test python pass could just do another kind of rewrite right now).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed in ca80408.

Copy link
Contributor Author

@PragmaTwice PragmaTwice Sep 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please extract this change in an independent PR? This is largely unrelated (your test python pass could just do another kind of rewrite right now).

Ahh got it.

I think I can do it via op/block/region/walk/erase APIs directly, but it maybe a little ugly since the normal rewriter API is not ported to python yet AFAIK?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep - no rewriter API in Python but you can use relevant methods on ops directly, etc. (Here's one example: https://github.com/libxsmm/tpp-mlir/pull/1064/files#diff-7aa62724b21b998da9cf032da6e6b77bbdb664258a5dca850f9509a5459646f6R110-R118 - @makslevental has more.)

That should be enough to demonstrate that your test pass is actually running and, if kept simple, will not be that ugly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just use apply_patterns_and_fold_greedily here and then add apply_patterns_and_fold_greedily_with_op in a follow-up (the existing test passes just fine with apply_patterns_and_fold_greedily).

Copy link

github-actions bot commented Aug 29, 2025

✅ With the latest revision this PR passed the Python code formatter.

Copy link

github-actions bot commented Aug 29, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

mlirTypeIDCreate(this), mlirStringRefCreate(name.data(), name.length()),
mlirStringRefCreate(argument.data(), argument.length()),
mlirStringRefCreate(description.data(), description.length()),
mlirStringRefCreate(opName.data(), opName.size()), 0, nullptr,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dependent dialects are not yet supported.

Copy link
Contributor

@rolfmorel rolfmorel Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is relevant for the discussion on whether a function as a pass suffices. I think dependent dialects demonstrate that, at least when it comes to registering them, passes are not just functions. They also have a bit of metadata.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, as you can see we pass 0, nullptr here, and it can be supported if we change them to nDialects, dialectListPtr. So I don't think it is a proof of a problem in this design. (it is supported by ExternalPass)

A function in python is just an object with __call__ method implemented. So I think we can bind any information with this object by assigning them to attrs of this object, if needed. 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@PragmaTwice and I already discussed this offline: to support dependent dialects we need to add nanobindings for MlirDialectHandle. Easy to do since it's just an opaque handle. If it's a high-priority I can quickly do that as a follow-up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With Python we can monkey patch arbitrary things arbitrarily. Doesn't mean that's a good idea / good to base your design on. As such, binding new attributes to an object is not the way to go.

If you and @makslevental still prefer just passing a function, then this metadata could (/has to) be passed as arguments to the API that registers the callback as a pass. Alternatively, have both a Pass class to inherit with a dependent_dialects property (and potentially with __call__ implemented as a call to run()) which the registration API automatically uses and have a mechanism for wrapping up a callback. The wrapping-up mechanism could also be a factory method on Pass though maybe that doesn't make things simpler when it comes to lifetime management.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then this metadata could (/has to) be passed as arguments to the API that registers the callback as a pass

Yes that's exactly the follow-up I had in mind.

The wrapping-up mechanism could also be a factory method on Pass though maybe that doesn't make things simpler when it comes to lifetime management.

It's the lifetime management that is the issue - the C++ APIs expect ownership of the Pass object. But there's simply no way to express "unique ownership" in Python. That's why I rewrote @PragmaTwice's original PR (which isn't very different from what you propose) to only manage the lifetime of a single Python object - the run callback.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay maybe we can use nanobind itself to get the right semantics - let me try again using https://nanobind.readthedocs.io/en/latest/ownership.html#unique-pointers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nah this won't blend - you have to create the object in C++ in order for the unique_ptr magic to work. So that won't blend with "subclass and instantiate a nanobind class in Python".

Comment on lines 14 to 16

#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
#include "nanobind/trampoline.h"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The empty line is to workaround clang-format: otherwise it will report that "mlir-c/Bindings/Python/Interop.h" should be put before "mlir/Bindings/Python/Nanobind.h" : )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should do

// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // ON WINDOWS This is expected after nanobind.
// clang-format on

(the ON WINDOWS isn't currently there but it should be because I discovered recently that it's in fact only windows that this matters for).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh thank you for your suggestion. It is done in c8c2fae.

@PragmaTwice PragmaTwice marked this pull request as ready for review August 29, 2025 17:10
@llvmbot llvmbot added the mlir label Aug 29, 2025
@llvmbot
Copy link
Member

llvmbot commented Aug 29, 2025

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

Changes

STATUS: This PR is work-in-progress now :)

It tries to close #155996.

This PR exports a class mlir.passmanager.Pass for Python-side to use for defining new MLIR passes.

This is a simple example of a Python-defined pass.

from mlir.passmanager import Pass, PassManager

class DemoPass(Pass):
  def run(op):
    # do something with op
    pass

pm = PassManager('any', ctx)
pm.add(DemoPass())
pm.run(..)

TODO list:

  • tests for this change
  • interop with PDL rewriting
  • support to clone passes
  • use Python-native ref-count for lifetime

Full diff: https://github.com/llvm/llvm-project/pull/156000.diff

7 Files Affected:

  • (modified) mlir/include/mlir-c/Rewrite.h (+4)
  • (modified) mlir/lib/Bindings/Python/MainModule.cpp (+1)
  • (modified) mlir/lib/Bindings/Python/Pass.cpp (+103)
  • (modified) mlir/lib/Bindings/Python/Pass.h (+1)
  • (modified) mlir/lib/Bindings/Python/Rewrite.cpp (+21-10)
  • (modified) mlir/lib/CAPI/Transforms/Rewrite.cpp (+7)
  • (added) mlir/test/python/pass.py (+79)
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 61d3446317550..21ae236d6f73f 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -301,6 +301,10 @@ mlirFreezeRewritePattern(MlirRewritePatternSet op);
 MLIR_CAPI_EXPORTED void
 mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
 
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyForOp(
+    MlirOperation op, MlirFrozenRewritePatternSet patterns,
+    MlirGreedyRewriteDriverConfig);
+
 MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
     MlirModule op, MlirFrozenRewritePatternSet patterns,
     MlirGreedyRewriteDriverConfig);
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 278847e7ac7f5..590e862a8d358 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -139,4 +139,5 @@ NB_MODULE(_mlir, m) {
   auto passModule =
       m.def_submodule("passmanager", "MLIR Pass Management Bindings");
   populatePassManagerSubmodule(passModule);
+  populatePassSubmodule(passModule);
 }
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 1030dea7f364c..3fd27fed34587 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -11,7 +11,9 @@
 #include "IRModule.h"
 #include "mlir-c/Pass.h"
 #include "mlir/Bindings/Python/Nanobind.h"
+
 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+#include "nanobind/trampoline.h"
 
 namespace nb = nanobind;
 using namespace nb::literals;
@@ -20,6 +22,81 @@ using namespace mlir::python;
 
 namespace {
 
+// A base class for defining passes in Python
+// Users are expected to subclass this and implement the `run` method, e.g.
+// ```
+// class MyPass(mlir.passmanager.Pass):
+//   def __init__(self):
+//     super().__init__("MyPass", ..)
+//     # other init stuff..
+//   def run(self, operation):
+//     # do something with operation..
+//     pass
+// ```
+class PyPassBase {
+public:
+  PyPassBase(std::string name, std::string argument, std::string description,
+             std::string opName)
+      : name(std::move(name)), argument(std::move(argument)),
+        description(std::move(description)), opName(std::move(opName)) {
+    callbacks.construct = [](void *obj) {};
+    callbacks.destruct = [](void *obj) {
+      nb::handle(static_cast<PyObject *>(obj)).dec_ref();
+    };
+    callbacks.run = [](MlirOperation op, MlirExternalPass, void *obj) {
+      auto handle = nb::handle(static_cast<PyObject *>(obj));
+      nb::cast<PyPassBase *>(handle)->run(op);
+    };
+    callbacks.clone = [](void *obj) -> void * {
+      nb::object copy = nb::module_::import_("copy");
+      nb::object deepcopy = copy.attr("deepcopy");
+      return deepcopy(obj).release().ptr();
+    };
+    callbacks.initialize = nullptr;
+  }
+
+  // this method should be overridden by subclasses in Python.
+  virtual void run(MlirOperation op) = 0;
+
+  virtual ~PyPassBase() = default;
+
+  // Make an MlirPass instance on-the-fly that wraps this object.
+  // Note that passmanager will take the ownership of the returned
+  // object and release it when appropriate.
+  MlirPass make() {
+    auto *obj = nb::find(this).release().ptr();
+    return mlirCreateExternalPass(
+        mlirTypeIDCreate(this), mlirStringRefCreate(name.data(), name.length()),
+        mlirStringRefCreate(argument.data(), argument.length()),
+        mlirStringRefCreate(description.data(), description.length()),
+        mlirStringRefCreate(opName.data(), opName.size()), 0, nullptr,
+        callbacks, obj);
+  }
+
+  const std::string &getName() const { return name; }
+  const std::string &getArgument() const { return argument; }
+  const std::string &getDescription() const { return description; }
+  const std::string &getOpName() const { return opName; }
+
+private:
+  MlirExternalPassCallbacks callbacks;
+
+  std::string name;
+  std::string argument;
+  std::string description;
+  std::string opName;
+};
+
+// A trampoline class upon PyPassBase.
+// Refer to
+// https://nanobind.readthedocs.io/en/latest/classes.html#overriding-virtual-functions-in-python
+class PyPass : PyPassBase {
+public:
+  NB_TRAMPOLINE(PyPassBase, 1);
+
+  void run(MlirOperation op) override { NB_OVERRIDE_PURE(run, op); }
+};
+
 /// Owning Wrapper around a PassManager.
 class PyPassManager {
 public:
@@ -52,6 +129,26 @@ class PyPassManager {
 
 } // namespace
 
+void mlir::python::populatePassSubmodule(nanobind::module_ &m) {
+  //----------------------------------------------------------------------------
+  // Mapping of the Python-defined Pass interface
+  //----------------------------------------------------------------------------
+  nb::class_<PyPassBase, PyPass>(m, "Pass")
+      .def(nb::init<std::string, std::string, std::string, std::string>(),
+           "name"_a, nb::kw_only(), "argument"_a = "", "description"_a = "",
+           "op_name"_a = "", "Create a new Pass.")
+      .def("run", &PyPassBase::run, "operation"_a,
+           "Run the pass on the provided operation.")
+      .def_prop_ro("name",
+                   [](const PyPassBase &self) { return self.getName(); })
+      .def_prop_ro("argument",
+                   [](const PyPassBase &self) { return self.getArgument(); })
+      .def_prop_ro("description",
+                   [](const PyPassBase &self) { return self.getDescription(); })
+      .def_prop_ro("op_name",
+                   [](const PyPassBase &self) { return self.getOpName(); });
+}
+
 /// Create the `mlir.passmanager` here.
 void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
   //----------------------------------------------------------------------------
@@ -157,6 +254,12 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
           "pipeline"_a,
           "Add textual pipeline elements to the pass manager. Throws a "
           "ValueError if the pipeline can't be parsed.")
+      .def(
+          "add",
+          [](PyPassManager &passManager, PyPassBase &pass) {
+            mlirPassManagerAddOwnedPass(passManager.get(), pass.make());
+          },
+          "pass"_a, "Add a python-defined pass to the pass manager.")
       .def(
           "run",
           [](PyPassManager &passManager, PyOperationBase &op,
diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h
index bc40943521829..ba3fbb707fed7 100644
--- a/mlir/lib/Bindings/Python/Pass.h
+++ b/mlir/lib/Bindings/Python/Pass.h
@@ -15,6 +15,7 @@ namespace mlir {
 namespace python {
 
 void populatePassManagerSubmodule(nanobind::module_ &m);
+void populatePassSubmodule(nanobind::module_ &m);
 
 } // namespace python
 } // namespace mlir
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 0373f9c7affe9..675bd685ec2db 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -99,14 +99,25 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
            &PyFrozenRewritePatternSet::createFromCapsule);
   m.def(
-      "apply_patterns_and_fold_greedily",
-      [](MlirModule module, MlirFrozenRewritePatternSet set) {
-        auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
-        if (mlirLogicalResultIsFailure(status))
-          // FIXME: Not sure this is the right error to throw here.
-          throw nb::value_error("pattern application failed to converge");
-      },
-      "module"_a, "set"_a,
-      "Applys the given patterns to the given module greedily while folding "
-      "results.");
+       "apply_patterns_and_fold_greedily",
+       [](MlirModule module, MlirFrozenRewritePatternSet set) {
+         auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
+         if (mlirLogicalResultIsFailure(status))
+           // FIXME: Not sure this is the right error to throw here.
+           throw nb::value_error("pattern application failed to converge");
+       },
+       "module"_a, "set"_a,
+       "Applys the given patterns to the given module greedily while folding "
+       "results.")
+      .def(
+          "apply_patterns_and_fold_greedily_for_op",
+          [](MlirOperation op, MlirFrozenRewritePatternSet set) {
+            auto status = mlirApplyPatternsAndFoldGreedilyForOp(op, set, {});
+            if (mlirLogicalResultIsFailure(status))
+              // FIXME: Not sure this is the right error to throw here.
+              throw nb::value_error("pattern application failed to converge");
+          },
+          "op"_a, "set"_a,
+          "Applys the given patterns to the given op greedily while folding "
+          "results.");
 }
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index a4df97f7beace..d606445cfad31 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -294,6 +294,13 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op,
   return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
 }
 
+MlirLogicalResult
+mlirApplyPatternsAndFoldGreedilyForOp(MlirOperation op,
+                                      MlirFrozenRewritePatternSet patterns,
+                                      MlirGreedyRewriteDriverConfig) {
+  return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
+}
+
 //===----------------------------------------------------------------------===//
 /// PDLPatternModule API
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/pass.py b/mlir/test/python/pass.py
new file mode 100644
index 0000000000000..1c338e4bc6b49
--- /dev/null
+++ b/mlir/test/python/pass.py
@@ -0,0 +1,79 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+import gc, sys
+from mlir.ir import *
+from mlir.passmanager import *
+from mlir.dialects.builtin import ModuleOp
+from mlir.dialects import pdl
+from mlir.rewrite import *
+
+
+def log(*args):
+    print(*args, file=sys.stderr)
+    sys.stderr.flush()
+
+
+def run(f):
+    log("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
+
+
+def make_pdl_module():
+    with Location.unknown():
+        pdl_module = Module.create()
+        with InsertionPoint(pdl_module.body):
+            # Change all arith.addi with index types to arith.muli.
+            @pdl.pattern(benefit=1, sym_name="addi_to_mul")
+            def pat():
+                # Match arith.addi with index types.
+                index_type = pdl.TypeOp(IndexType.get())
+                operand0 = pdl.OperandOp(index_type)
+                operand1 = pdl.OperandOp(index_type)
+                op0 = pdl.OperationOp(
+                    name="arith.addi", args=[operand0, operand1], types=[index_type]
+                )
+
+                # Replace the matched op with arith.muli.
+                @pdl.rewrite()
+                def rew():
+                    newOp = pdl.OperationOp(
+                        name="arith.muli", args=[operand0, operand1], types=[index_type]
+                    )
+                    pdl.ReplaceOp(op0, with_op=newOp)
+
+        return pdl_module
+
+
+# CHECK-LABEL: TEST: testCustomPass
+@run
+def testCustomPass():
+    with Context():
+        pdl_module = make_pdl_module()
+
+        class CustomPass(Pass):
+            def __init__(self):
+                super().__init__("CustomPass", op_name="builtin.module")
+
+            def run(self, m):
+                frozen = PDLModule(pdl_module).freeze()
+                apply_patterns_and_fold_greedily_for_op(m, frozen)
+
+        module = ModuleOp.parse(
+            r"""
+            module {
+              func.func @add(%a: index, %b: index) -> index {
+                %sum = arith.addi %a, %b : index
+                return %sum : index
+              }
+            }
+        """
+        )
+
+        # CHECK-LABEL: Dump After CustomPass
+        # CHECK: arith.muli
+        pm = PassManager("any")
+        pm.enable_ir_printing()
+        pm.add(CustomPass())
+        pm.run(module)

nb::object deepcopy = copy.attr("deepcopy");
return deepcopy(obj).release().ptr();
};
callbacks.initialize = nullptr;
Copy link
Contributor Author

@PragmaTwice PragmaTwice Sep 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

initialize is not ported to python side yet.

Comment on lines 139 to 142
auto passModule =
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
populatePassManagerSubmodule(passModule);
populatePassSubmodule(passModule);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that pass is a keyword in python so that name like mlir.pass.Pass doesn't work. 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, pass vs pass :-) I don't have a good suggestion given how established that is MLIR/compilers and Python side ... MlirPass or PyPass or pydefpass would be most obvious.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently the full name of the base class is mlir.passmanager.Pass. Is it good enough or we'd better to rename it with another module name?

Copy link
Contributor

@rolfmorel rolfmorel Sep 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that PEP8 recommends appending a underscore in case of names of arguments / attributes clashing with reserved keywords, I feel mlir.pass_.Pass is an option as it is not too surprising for me personally.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think putting a trailing underscore in a namespace is a little awkward. How about mlir.passes? or mlir.passinfra? or I dunno something like that that I'm failing to come up with right now...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. For now, I will rename it to mlir.passes.Pass. Let me know if anyone think of a better name : )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 01e68c5.

// Note that passmanager will take the ownership of the returned
// object and release it when appropriate.
MlirPass make() {
auto *obj = nb::find(this).release().ptr();
Copy link
Contributor Author

@PragmaTwice PragmaTwice Sep 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About lifetime of the python object: here we increase the ref count by nb::find (and by release() we avoid decrease the count here), and when the ExternalPass is destructed, callback.destructor is called so that dec_ref() is called for this object.

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to take another pass (not pun intended) but here are some initial non-nit comments.

Comment on lines 139 to 142
auto passModule =
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
populatePassManagerSubmodule(passModule);
populatePassSubmodule(passModule);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think putting a trailing underscore in a namespace is a little awkward. How about mlir.passes? or mlir.passinfra? or I dunno something like that that I'm failing to come up with right now...

Comment on lines 14 to 16

#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
#include "nanobind/trampoline.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should do

// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // ON WINDOWS This is expected after nanobind.
// clang-format on

(the ON WINDOWS isn't currently there but it should be because I discovered recently that it's in fact only windows that this matters for).

//----------------------------------------------------------------------------
// Mapping of the Python-defined Pass interface
//----------------------------------------------------------------------------
nb::class_<PyPassBase, PyPass>(m, "Pass")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this backwards? cf <PyOperation, PyOperationBase>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh this is a quite interesting case that I just did some research on : )

Refer to the nanobind documentation (https://nanobind.readthedocs.io/en/latest/api_core.html#_CPPv4I0DpEN8nanobind6class_E):

template<typename T, typename ...Ts>
class class_ : public object

The variable length parameter Ts is optional and can be used to specify the base class of T and/or an alias needed to realize trampoline classes.

So in the case you mentioned (<PyOperation, PyOperationBase>), the parameter Ts (PyOperationBase) is a base class of T (PyOperation); and for the case here (<PyPassBase, PyPass>), the Ts (PyPass) is a trampoline class of T (PyPassBase). So I think both of them is correct here.

For example, in the documentation of trampoline classes (https://nanobind.readthedocs.io/en/latest/classes.html#overriding-virtual-functions-in-python), we can see such an instance of nb::class_:

nb::class_<Dog, PyDog /* <-- trampoline */>(m, "Dog")

And here PyDog is a derived class of Dog but also a trampoline class of Dog : )

@makslevental
Copy link
Contributor

makslevental commented Sep 7, 2025

Sorry for taking so long to review!

Firstly, I want to repeat/reiterate/state that I'm fully in support of this functionality (and I believe many others are as well). Coincidentally, at $DAYJOB we want to provide/support something very similar to our users very soon. So thank you very much for the implementation!

Secondly, I'd like to propose a slightly different alternative: #157369. The core of the difference being that instead of providing an API based on classes (PyPass, PyPassBase) we just provide an API that expects the user to supply a run callback. IMHO this greatly simplifies the design/API, but more importantly it simplifies the lifetime concerns; the pass itself is wholly owned by the passmanager (as is expected on the C++ side) and only the lifetime of the nb::callable needs matching inc_ref/dec_refs.

@PragmaTwice take a look - let me know what you think. If you agree/like the alternative, feel free to just copy-paste into this PR (no need to merge or whatever).

makslevental added a commit to makslevental/llvm-project that referenced this pull request Sep 7, 2025
makslevental added a commit to makslevental/llvm-project that referenced this pull request Sep 7, 2025
makslevental added a commit to makslevental/llvm-project that referenced this pull request Sep 7, 2025
@PragmaTwice
Copy link
Contributor Author

PragmaTwice commented Sep 8, 2025

@PragmaTwice take a look - let me know what you think. If you agree/like the alternative, feel free to just copy-paste into this PR (no need to merge or whatever).

Thanks for the patch! It really helps clarify what the pass API should look like. I’ll try to merge it into this PR. I’ve left a few minor comments over there (#157369).

makslevental added a commit to makslevental/llvm-project that referenced this pull request Sep 8, 2025
makslevental added a commit to makslevental/llvm-project that referenced this pull request Sep 8, 2025
makslevental added a commit to makslevental/llvm-project that referenced this pull request Sep 8, 2025
makslevental added a commit to makslevental/llvm-project that referenced this pull request Sep 8, 2025
makslevental added a commit to makslevental/llvm-project that referenced this pull request Sep 8, 2025
makslevental added a commit to makslevental/llvm-project that referenced this pull request Sep 8, 2025
makslevental added a commit to makslevental/llvm-project that referenced this pull request Sep 8, 2025
PragmaTwice and others added 3 commits September 8, 2025 15:49
Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
@PragmaTwice
Copy link
Contributor Author

PragmaTwice commented Sep 8, 2025

Hi @makslevental, merging is done via e565ffb. Thank you for your review and code : ) Now changes of this PR is mostly the same as #157369 except:

  • in the test python file, we demonstrate two ways of constructing a pass callable: both def some_pass(op) and class SomePass: def __call__(op).
  • pass.py is renamed to python_pass.py to align with the method name pm.add_python_pass(..).
  • callback.clone remains required than optional (can be nullptr) since the cloned userData can be an uninitialized pointer if clone function is not provided.

@joker-eph
Copy link
Collaborator

We should preserve the ability of the system to generate reproducer: that is a pass is more than a callable and should be registered in a way that we can print a pass pipeline and replay it (that implies the ability to register it and a stable mnemnonic). Registering a pass is something you need to do independently from the pass manager, since you can build a pass manager from a pipeline.
This is one of the important aspect of fitting into the pass manager, otherwise if you're just running functions on operations without more structure, you can just write it all in python without hooking into the MLIR pass management infrastructure.

@PragmaTwice
Copy link
Contributor Author

PragmaTwice commented Sep 8, 2025

We should preserve the ability of the system to generate reproducer: that is a pass is more than a callable and should be registered in a way that we can print a pass pipeline and replay it (that implies the ability to register it and a stable mnemnonic). Registering a pass is something you need to do independently from the pass manager, since you can build a pass manager from a pipeline. This is one of the important aspect of fitting into the pass manager, otherwise if you're just running functions on operations without more structure, you can just write it all in python without hooking into the MLIR pass management infrastructure.

Yeah I see your concern — a pass needs to be registered with a stable mnemonic and interoperate with the pass pipeline (for replay, etc.), rather than just being a bare function. That said, in Python a callable can just as well be an object of a class type implementing the __call__ method, which makes it very easy to extend. For example, such a class could also provide other methods/attributes (e.g. initialize(ctx)/can_schedule_on(..)/.. as optional methods, and description/argument/.. as optional attributes) to integrate with the pass infrastructure (as optional callbacks in ExternalPass). What we currently have is just a minimal working entity, but it’s not difficult to extend it further.

Already, we can mix Python passes with C++ passes in the same pipeline (as shown in test/python/python_pass.py), and use enable_ir_printing to dump the IR before and after each pass. We’ve also had some discussions around pass cloning; it’s not implemented yet, but it’s certainly possible to add in Python.

@makslevental
Copy link
Contributor

We should preserve the ability of the system to generate reproducer

Preserve where? PassRegistry is not bound in Python so this ability does not exist and hence cannot be preserved. Do you mean "We should create the ability..."? My point is these two things are independent - we can land this PR and then also work on binding PassRegistry in order to support what you want.

@joker-eph
Copy link
Collaborator

joker-eph commented Sep 8, 2025

Preserve where?

Preserve in the concept of what is a "pass" in the context of MLIR.

PassRegistry is not bound in Python

What do you mean? We can't parse a pass pipeline from python?

My point is these two things are independent - we can land this PR and then also work on binding PassRegistry in order to support what you want.

I don't quite follow your point, can we create passes in python today? If not then we're adding the way to create passes, and my point is that "create a pass" comes with expectations in the context of MLIR.

@@ -157,6 +159,45 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
"pipeline"_a,
"Add textual pipeline elements to the pass manager. Throws a "
"ValueError if the pipeline can't be parsed.")
.def(
"add_python_pass",
Copy link
Contributor

@rolfmorel rolfmorel Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about just "add_pass" (similar to C++-API)? Or even just reusing the current "add" and dispatching on arg types.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! Done in 46b833d : )

@makslevental
Copy link
Contributor

makslevental commented Sep 8, 2025

PassRegistry is not bound in Python

What do you mean? We can't parse a pass pipeline from python?

You can't create re-usable pipelines in Python at all (except as strings). So you cannot do anything in Python except add pass.

I don't quite follow your point, can we create passes in python today? If not then we're adding the way to create passes, and my point is that "create a pass" comes with expectations in the context of MLIR.

This PR just maps to Python already existing functionality available for C users (you realize that right...?).

@rolfmorel
Copy link
Contributor

rolfmorel commented Sep 8, 2025

My understanding now is that MLIR's external pass mechanism indeed does not register such passes in the pass registry. Such passes are meant to just directly be added to pass managers (which do get a name for these external passes).

As this is a general C-API facility of MLIR, I don't see how/why the Python bindings should be doing anything differently here.

In general happy for this to go in. The dependent dialects aspect could come as a follow-up (as you could just make sure the dialects your pass needs are loaded before running the pass manager).

Comment on lines 102 to 110
"apply_patterns_and_fold_greedily",
[](MlirModule module, MlirFrozenRewritePatternSet set) {
auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error("pattern application failed to converge");
},
"module"_a, "set"_a,
"Applys the given patterns to the given module greedily while folding "
"results.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some kind of explanation of how this works would help a lot. Doesn't have to be in this file, but I'm struggling to understand what this pass does from the description you've provided.

Copy link
Contributor

@makslevental makslevental Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the pass itself or this method? the pass itself is from the original PR for exposing FrozenPatternRewriter but you can take a look at #157487 which just landed and slightly refactored.

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent work! Bravo!

Copy link
Contributor

@rolfmorel rolfmorel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fine to me. We can iterate on it when we get to supporting dependent dialects.

const std::string &description, const std::string &opName) {
if (!name.has_value()) {
name = nb::cast<std::string>(
nb::borrow<nb::str>(run.attr("__name__")));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: what happens on lambdas?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>>> x = lambda: "bob"
>>> x.__name__
'<lambda>'

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not the best thing ever but it doesn't blow up (and also I don't see anyone using a lambda here...).

@joker-eph
Copy link
Collaborator

You can't create re-usable pipelines in Python at all (except as strings).

Sure, but strings count... A lot of things are going through strings as shortcut in python instead of bindings APIs today (for better or worse).

This PR just maps to Python already existing functionality available for C users (you realize that right...?).

I didn't actually. I see this as a deficiency in the C API though, I don't remember the history there though so I may be missing something, but off-hands I don't see a reason why this should be intrinsically different than the invariants we preserve in the system in general.
I need to look into fixing the C API.

As this is a general C-API facility of MLIR, I don't see how/why the Python bindings should be doing anything differently here.

The reason why is that in general if we find something that is broken, we aim at addressing it instead of piling more broken things on top of it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[MLIR][Python] Support defining MLIR passes in Python
7 participants