-
Notifications
You must be signed in to change notification settings - Fork 15k
[MLIR][Python] Support Python-defined passes in MLIR #156000
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
mlir/include/mlir-c/Rewrite.h
Outdated
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyForOp( | ||
MlirOperation op, MlirFrozenRewritePatternSet patterns, | ||
MlirGreedyRewriteDriverConfig); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure which name is suitable.. So the current name is just a placeholder. And modifying mlirApplyPatternsAndFoldGreedily
seems a breaking change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a breaking change, but not sure how widely and also C API is really best effort stable (meaning, we try not to break it). I probably should have included a Module suffix on the original, and then this could have gone without. How many usages in the wild can you find of this call?
(Nit: ForOp suffix makes me think it's related to ForOp).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, I perform a simple check on public repos via GitHub search (https://github.com/search?q=mlirApplyPatternsAndFoldGreedily+&type=code) and there seems to be some use cases of this API so I'm not sure if we should change it or not (maybe leave to maintainers for this decision : ).
Currently I'm going to rename it to mlirApplyPatternsAndFoldGreedilyWithOp
to avoid such confusion with ForOp
as you mentioned : )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please extract this change in an independent PR? This is largely unrelated (your test python pass could just do another kind of rewrite right now).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed in ca80408.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please extract this change in an independent PR? This is largely unrelated (your test python pass could just do another kind of rewrite right now).
Ahh got it.
I think I can do it via op/block/region/walk/erase APIs directly, but it maybe a little ugly since the normal rewriter API is not ported to python yet AFAIK?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep - no rewriter API in Python but you can use relevant methods on ops directly, etc. (Here's one example: https://github.com/libxsmm/tpp-mlir/pull/1064/files#diff-7aa62724b21b998da9cf032da6e6b77bbdb664258a5dca850f9509a5459646f6R110-R118 - @makslevental has more.)
That should be enough to demonstrate that your test pass is actually running and, if kept simple, will not be that ugly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can just use apply_patterns_and_fold_greedily
here and then add apply_patterns_and_fold_greedily_with_op
in a follow-up (the existing test passes just fine with apply_patterns_and_fold_greedily
).
✅ With the latest revision this PR passed the Python code formatter. |
✅ With the latest revision this PR passed the C/C++ code formatter. |
mlir/lib/Bindings/Python/Pass.cpp
Outdated
mlirTypeIDCreate(this), mlirStringRefCreate(name.data(), name.length()), | ||
mlirStringRefCreate(argument.data(), argument.length()), | ||
mlirStringRefCreate(description.data(), description.length()), | ||
mlirStringRefCreate(opName.data(), opName.size()), 0, nullptr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dependent dialects are not yet supported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is relevant for the discussion on whether a function as a pass suffices. I think dependent dialects demonstrate that, at least when it comes to registering them, passes are not just functions. They also have a bit of metadata.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, as you can see we pass 0, nullptr
here, and it can be supported if we change them to nDialects, dialectListPtr
. So I don't think it is a proof of a problem in this design. (it is supported by ExternalPass
)
A function in python is just an object with __call__
method implemented. So I think we can bind any information with this object by assigning them to attrs of this object, if needed. 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@PragmaTwice and I already discussed this offline: to support dependent dialects we need to add nanobindings for MlirDialectHandle
. Easy to do since it's just an opaque handle. If it's a high-priority I can quickly do that as a follow-up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With Python we can monkey patch arbitrary things arbitrarily. Doesn't mean that's a good idea / good to base your design on. As such, binding new attributes to an object is not the way to go.
If you and @makslevental still prefer just passing a function, then this metadata could (/has to) be passed as arguments to the API that registers the callback as a pass. Alternatively, have both a Pass
class to inherit with a dependent_dialects
property (and potentially with __call__
implemented as a call to run()
) which the registration API automatically uses and have a mechanism for wrapping up a callback. The wrapping-up mechanism could also be a factory method on Pass
though maybe that doesn't make things simpler when it comes to lifetime management.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
then this metadata could (/has to) be passed as arguments to the API that registers the callback as a pass
Yes that's exactly the follow-up I had in mind.
The wrapping-up mechanism could also be a factory method on Pass though maybe that doesn't make things simpler when it comes to lifetime management.
It's the lifetime management that is the issue - the C++ APIs expect ownership of the Pass
object. But there's simply no way to express "unique ownership" in Python. That's why I rewrote @PragmaTwice's original PR (which isn't very different from what you propose) to only manage the lifetime of a single Python object - the run
callback.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay maybe we can use nanobind
itself to get the right semantics - let me try again using https://nanobind.readthedocs.io/en/latest/ownership.html#unique-pointers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nah this won't blend - you have to create the object in C++ in order for the unique_ptr
magic to work. So that won't blend with "subclass and instantiate a nanobind class in Python".
mlir/lib/Bindings/Python/Pass.cpp
Outdated
|
||
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. | ||
#include "nanobind/trampoline.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The empty line is to workaround clang-format: otherwise it will report that "mlir-c/Bindings/Python/Interop.h" should be put before "mlir/Bindings/Python/Nanobind.h" : )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should do
// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // ON WINDOWS This is expected after nanobind.
// clang-format on
(the ON WINDOWS
isn't currently there but it should be because I discovered recently that it's in fact only windows that this matters for).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahhh thank you for your suggestion. It is done in c8c2fae.
@llvm/pr-subscribers-mlir Author: Twice (PragmaTwice) ChangesSTATUS: This PR is work-in-progress now :) It tries to close #155996. This PR exports a class This is a simple example of a Python-defined pass. from mlir.passmanager import Pass, PassManager
class DemoPass(Pass):
def run(op):
# do something with op
pass
pm = PassManager('any', ctx)
pm.add(DemoPass())
pm.run(..) TODO list:
Full diff: https://github.com/llvm/llvm-project/pull/156000.diff 7 Files Affected:
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 61d3446317550..21ae236d6f73f 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -301,6 +301,10 @@ mlirFreezeRewritePattern(MlirRewritePatternSet op);
MLIR_CAPI_EXPORTED void
mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyForOp(
+ MlirOperation op, MlirFrozenRewritePatternSet patterns,
+ MlirGreedyRewriteDriverConfig);
+
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
MlirModule op, MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig);
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 278847e7ac7f5..590e862a8d358 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -139,4 +139,5 @@ NB_MODULE(_mlir, m) {
auto passModule =
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
populatePassManagerSubmodule(passModule);
+ populatePassSubmodule(passModule);
}
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 1030dea7f364c..3fd27fed34587 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -11,7 +11,9 @@
#include "IRModule.h"
#include "mlir-c/Pass.h"
#include "mlir/Bindings/Python/Nanobind.h"
+
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+#include "nanobind/trampoline.h"
namespace nb = nanobind;
using namespace nb::literals;
@@ -20,6 +22,81 @@ using namespace mlir::python;
namespace {
+// A base class for defining passes in Python
+// Users are expected to subclass this and implement the `run` method, e.g.
+// ```
+// class MyPass(mlir.passmanager.Pass):
+// def __init__(self):
+// super().__init__("MyPass", ..)
+// # other init stuff..
+// def run(self, operation):
+// # do something with operation..
+// pass
+// ```
+class PyPassBase {
+public:
+ PyPassBase(std::string name, std::string argument, std::string description,
+ std::string opName)
+ : name(std::move(name)), argument(std::move(argument)),
+ description(std::move(description)), opName(std::move(opName)) {
+ callbacks.construct = [](void *obj) {};
+ callbacks.destruct = [](void *obj) {
+ nb::handle(static_cast<PyObject *>(obj)).dec_ref();
+ };
+ callbacks.run = [](MlirOperation op, MlirExternalPass, void *obj) {
+ auto handle = nb::handle(static_cast<PyObject *>(obj));
+ nb::cast<PyPassBase *>(handle)->run(op);
+ };
+ callbacks.clone = [](void *obj) -> void * {
+ nb::object copy = nb::module_::import_("copy");
+ nb::object deepcopy = copy.attr("deepcopy");
+ return deepcopy(obj).release().ptr();
+ };
+ callbacks.initialize = nullptr;
+ }
+
+ // this method should be overridden by subclasses in Python.
+ virtual void run(MlirOperation op) = 0;
+
+ virtual ~PyPassBase() = default;
+
+ // Make an MlirPass instance on-the-fly that wraps this object.
+ // Note that passmanager will take the ownership of the returned
+ // object and release it when appropriate.
+ MlirPass make() {
+ auto *obj = nb::find(this).release().ptr();
+ return mlirCreateExternalPass(
+ mlirTypeIDCreate(this), mlirStringRefCreate(name.data(), name.length()),
+ mlirStringRefCreate(argument.data(), argument.length()),
+ mlirStringRefCreate(description.data(), description.length()),
+ mlirStringRefCreate(opName.data(), opName.size()), 0, nullptr,
+ callbacks, obj);
+ }
+
+ const std::string &getName() const { return name; }
+ const std::string &getArgument() const { return argument; }
+ const std::string &getDescription() const { return description; }
+ const std::string &getOpName() const { return opName; }
+
+private:
+ MlirExternalPassCallbacks callbacks;
+
+ std::string name;
+ std::string argument;
+ std::string description;
+ std::string opName;
+};
+
+// A trampoline class upon PyPassBase.
+// Refer to
+// https://nanobind.readthedocs.io/en/latest/classes.html#overriding-virtual-functions-in-python
+class PyPass : PyPassBase {
+public:
+ NB_TRAMPOLINE(PyPassBase, 1);
+
+ void run(MlirOperation op) override { NB_OVERRIDE_PURE(run, op); }
+};
+
/// Owning Wrapper around a PassManager.
class PyPassManager {
public:
@@ -52,6 +129,26 @@ class PyPassManager {
} // namespace
+void mlir::python::populatePassSubmodule(nanobind::module_ &m) {
+ //----------------------------------------------------------------------------
+ // Mapping of the Python-defined Pass interface
+ //----------------------------------------------------------------------------
+ nb::class_<PyPassBase, PyPass>(m, "Pass")
+ .def(nb::init<std::string, std::string, std::string, std::string>(),
+ "name"_a, nb::kw_only(), "argument"_a = "", "description"_a = "",
+ "op_name"_a = "", "Create a new Pass.")
+ .def("run", &PyPassBase::run, "operation"_a,
+ "Run the pass on the provided operation.")
+ .def_prop_ro("name",
+ [](const PyPassBase &self) { return self.getName(); })
+ .def_prop_ro("argument",
+ [](const PyPassBase &self) { return self.getArgument(); })
+ .def_prop_ro("description",
+ [](const PyPassBase &self) { return self.getDescription(); })
+ .def_prop_ro("op_name",
+ [](const PyPassBase &self) { return self.getOpName(); });
+}
+
/// Create the `mlir.passmanager` here.
void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
@@ -157,6 +254,12 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
"pipeline"_a,
"Add textual pipeline elements to the pass manager. Throws a "
"ValueError if the pipeline can't be parsed.")
+ .def(
+ "add",
+ [](PyPassManager &passManager, PyPassBase &pass) {
+ mlirPassManagerAddOwnedPass(passManager.get(), pass.make());
+ },
+ "pass"_a, "Add a python-defined pass to the pass manager.")
.def(
"run",
[](PyPassManager &passManager, PyOperationBase &op,
diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h
index bc40943521829..ba3fbb707fed7 100644
--- a/mlir/lib/Bindings/Python/Pass.h
+++ b/mlir/lib/Bindings/Python/Pass.h
@@ -15,6 +15,7 @@ namespace mlir {
namespace python {
void populatePassManagerSubmodule(nanobind::module_ &m);
+void populatePassSubmodule(nanobind::module_ &m);
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 0373f9c7affe9..675bd685ec2db 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -99,14 +99,25 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
&PyFrozenRewritePatternSet::createFromCapsule);
m.def(
- "apply_patterns_and_fold_greedily",
- [](MlirModule module, MlirFrozenRewritePatternSet set) {
- auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
- if (mlirLogicalResultIsFailure(status))
- // FIXME: Not sure this is the right error to throw here.
- throw nb::value_error("pattern application failed to converge");
- },
- "module"_a, "set"_a,
- "Applys the given patterns to the given module greedily while folding "
- "results.");
+ "apply_patterns_and_fold_greedily",
+ [](MlirModule module, MlirFrozenRewritePatternSet set) {
+ auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
+ if (mlirLogicalResultIsFailure(status))
+ // FIXME: Not sure this is the right error to throw here.
+ throw nb::value_error("pattern application failed to converge");
+ },
+ "module"_a, "set"_a,
+ "Applys the given patterns to the given module greedily while folding "
+ "results.")
+ .def(
+ "apply_patterns_and_fold_greedily_for_op",
+ [](MlirOperation op, MlirFrozenRewritePatternSet set) {
+ auto status = mlirApplyPatternsAndFoldGreedilyForOp(op, set, {});
+ if (mlirLogicalResultIsFailure(status))
+ // FIXME: Not sure this is the right error to throw here.
+ throw nb::value_error("pattern application failed to converge");
+ },
+ "op"_a, "set"_a,
+ "Applys the given patterns to the given op greedily while folding "
+ "results.");
}
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index a4df97f7beace..d606445cfad31 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -294,6 +294,13 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op,
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
}
+MlirLogicalResult
+mlirApplyPatternsAndFoldGreedilyForOp(MlirOperation op,
+ MlirFrozenRewritePatternSet patterns,
+ MlirGreedyRewriteDriverConfig) {
+ return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
+}
+
//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/pass.py b/mlir/test/python/pass.py
new file mode 100644
index 0000000000000..1c338e4bc6b49
--- /dev/null
+++ b/mlir/test/python/pass.py
@@ -0,0 +1,79 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+import gc, sys
+from mlir.ir import *
+from mlir.passmanager import *
+from mlir.dialects.builtin import ModuleOp
+from mlir.dialects import pdl
+from mlir.rewrite import *
+
+
+def log(*args):
+ print(*args, file=sys.stderr)
+ sys.stderr.flush()
+
+
+def run(f):
+ log("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+
+
+def make_pdl_module():
+ with Location.unknown():
+ pdl_module = Module.create()
+ with InsertionPoint(pdl_module.body):
+ # Change all arith.addi with index types to arith.muli.
+ @pdl.pattern(benefit=1, sym_name="addi_to_mul")
+ def pat():
+ # Match arith.addi with index types.
+ index_type = pdl.TypeOp(IndexType.get())
+ operand0 = pdl.OperandOp(index_type)
+ operand1 = pdl.OperandOp(index_type)
+ op0 = pdl.OperationOp(
+ name="arith.addi", args=[operand0, operand1], types=[index_type]
+ )
+
+ # Replace the matched op with arith.muli.
+ @pdl.rewrite()
+ def rew():
+ newOp = pdl.OperationOp(
+ name="arith.muli", args=[operand0, operand1], types=[index_type]
+ )
+ pdl.ReplaceOp(op0, with_op=newOp)
+
+ return pdl_module
+
+
+# CHECK-LABEL: TEST: testCustomPass
+@run
+def testCustomPass():
+ with Context():
+ pdl_module = make_pdl_module()
+
+ class CustomPass(Pass):
+ def __init__(self):
+ super().__init__("CustomPass", op_name="builtin.module")
+
+ def run(self, m):
+ frozen = PDLModule(pdl_module).freeze()
+ apply_patterns_and_fold_greedily_for_op(m, frozen)
+
+ module = ModuleOp.parse(
+ r"""
+ module {
+ func.func @add(%a: index, %b: index) -> index {
+ %sum = arith.addi %a, %b : index
+ return %sum : index
+ }
+ }
+ """
+ )
+
+ # CHECK-LABEL: Dump After CustomPass
+ # CHECK: arith.muli
+ pm = PassManager("any")
+ pm.enable_ir_printing()
+ pm.add(CustomPass())
+ pm.run(module)
|
mlir/lib/Bindings/Python/Pass.cpp
Outdated
nb::object deepcopy = copy.attr("deepcopy"); | ||
return deepcopy(obj).release().ptr(); | ||
}; | ||
callbacks.initialize = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
initialize
is not ported to python side yet.
auto passModule = | ||
m.def_submodule("passmanager", "MLIR Pass Management Bindings"); | ||
populatePassManagerSubmodule(passModule); | ||
populatePassSubmodule(passModule); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that pass
is a keyword in python so that name like mlir.pass.Pass
doesn't work. 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes, pass vs pass :-) I don't have a good suggestion given how established that is MLIR/compilers and Python side ... MlirPass or PyPass or pydefpass would be most obvious.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently the full name of the base class is mlir.passmanager.Pass
. Is it good enough or we'd better to rename it with another module name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that PEP8 recommends appending a underscore in case of names of arguments / attributes clashing with reserved keywords, I feel mlir.pass_.Pass
is an option as it is not too surprising for me personally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think putting a trailing underscore in a namespace is a little awkward. How about mlir.passes
? or mlir.passinfra
? or I dunno something like that that I'm failing to come up with right now...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. For now, I will rename it to mlir.passes.Pass
. Let me know if anyone think of a better name : )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in 01e68c5.
mlir/lib/Bindings/Python/Pass.cpp
Outdated
// Note that passmanager will take the ownership of the returned | ||
// object and release it when appropriate. | ||
MlirPass make() { | ||
auto *obj = nb::find(this).release().ptr(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
About lifetime of the python object: here we increase the ref count by nb::find
(and by release()
we avoid decrease the count here), and when the ExternalPass
is destructed, callback.destructor
is called so that dec_ref()
is called for this object.
…dFoldGreedilyWithOp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to take another pass (not pun intended) but here are some initial non-nit comments.
auto passModule = | ||
m.def_submodule("passmanager", "MLIR Pass Management Bindings"); | ||
populatePassManagerSubmodule(passModule); | ||
populatePassSubmodule(passModule); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think putting a trailing underscore in a namespace is a little awkward. How about mlir.passes
? or mlir.passinfra
? or I dunno something like that that I'm failing to come up with right now...
mlir/lib/Bindings/Python/Pass.cpp
Outdated
|
||
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. | ||
#include "nanobind/trampoline.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should do
// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // ON WINDOWS This is expected after nanobind.
// clang-format on
(the ON WINDOWS
isn't currently there but it should be because I discovered recently that it's in fact only windows that this matters for).
mlir/lib/Bindings/Python/Pass.cpp
Outdated
//---------------------------------------------------------------------------- | ||
// Mapping of the Python-defined Pass interface | ||
//---------------------------------------------------------------------------- | ||
nb::class_<PyPassBase, PyPass>(m, "Pass") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't this backwards? cf <PyOperation, PyOperationBase>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh this is a quite interesting case that I just did some research on : )
Refer to the nanobind documentation (https://nanobind.readthedocs.io/en/latest/api_core.html#_CPPv4I0DpEN8nanobind6class_E):
template<typename T, typename ...Ts>
class class_ : public objectThe variable length parameter Ts is optional and can be used to specify the base class of T and/or an alias needed to realize trampoline classes.
So in the case you mentioned (<PyOperation, PyOperationBase>
), the parameter Ts
(PyOperationBase
) is a base class of T
(PyOperation
); and for the case here (<PyPassBase, PyPass>
), the Ts
(PyPass
) is a trampoline class of T
(PyPassBase
). So I think both of them is correct here.
For example, in the documentation of trampoline classes (https://nanobind.readthedocs.io/en/latest/classes.html#overriding-virtual-functions-in-python), we can see such an instance of nb::class_
:
nb::class_<Dog, PyDog /* <-- trampoline */>(m, "Dog")
And here PyDog
is a derived class of Dog
but also a trampoline class of Dog
: )
Sorry for taking so long to review! Firstly, I want to repeat/reiterate/state that I'm fully in support of this functionality (and I believe many others are as well). Coincidentally, at $DAYJOB we want to provide/support something very similar to our users very soon. So thank you very much for the implementation! Secondly, I'd like to propose a slightly different alternative: #157369. The core of the difference being that instead of providing an API based on classes ( @PragmaTwice take a look - let me know what you think. If you agree/like the alternative, feel free to just copy-paste into this PR (no need to merge or whatever). |
based heavily on llvm#156000
based heavily on llvm#156000
based heavily on llvm#156000
Thanks for the patch! It really helps clarify what the pass API should look like. I’ll try to merge it into this PR. I’ve left a few minor comments over there (#157369). |
based heavily on llvm#156000
based heavily on llvm#156000
based heavily on llvm#156000
based heavily on llvm#156000
based heavily on llvm#156000
based heavily on llvm#156000
based heavily on llvm#156000
Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
Hi @makslevental, merging is done via e565ffb. Thank you for your review and code : ) Now changes of this PR is mostly the same as #157369 except:
|
We should preserve the ability of the system to generate reproducer: that is a pass is more than a callable and should be registered in a way that we can print a pass pipeline and replay it (that implies the ability to register it and a stable mnemnonic). Registering a pass is something you need to do independently from the pass manager, since you can build a pass manager from a pipeline. |
Yeah I see your concern — a pass needs to be registered with a stable mnemonic and interoperate with the pass pipeline (for replay, etc.), rather than just being a bare function. That said, in Python a callable can just as well be an object of a class type implementing the Already, we can mix Python passes with C++ passes in the same pipeline (as shown in |
Preserve where? |
Preserve in the concept of what is a "pass" in the context of MLIR.
What do you mean? We can't parse a pass pipeline from python?
I don't quite follow your point, can we create passes in python today? If not then we're adding the way to create passes, and my point is that "create a pass" comes with expectations in the context of MLIR. |
mlir/lib/Bindings/Python/Pass.cpp
Outdated
@@ -157,6 +159,45 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { | |||
"pipeline"_a, | |||
"Add textual pipeline elements to the pass manager. Throws a " | |||
"ValueError if the pipeline can't be parsed.") | |||
.def( | |||
"add_python_pass", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about just "add_pass"
(similar to C++-API)? Or even just reusing the current "add"
and dispatching on arg types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! Done in 46b833d : )
You can't create re-usable pipelines in Python at all (except as strings). So you cannot do anything in Python except
This PR just maps to Python already existing functionality available for C users (you realize that right...?). |
My understanding now is that MLIR's external pass mechanism indeed does not register such passes in the pass registry. Such passes are meant to just directly be added to pass managers (which do get a name for these external passes). As this is a general C-API facility of MLIR, I don't see how/why the Python bindings should be doing anything differently here. In general happy for this to go in. The dependent dialects aspect could come as a follow-up (as you could just make sure the dialects your pass needs are loaded before |
mlir/lib/Bindings/Python/Rewrite.cpp
Outdated
"apply_patterns_and_fold_greedily", | ||
[](MlirModule module, MlirFrozenRewritePatternSet set) { | ||
auto status = mlirApplyPatternsAndFoldGreedily(module, set, {}); | ||
if (mlirLogicalResultIsFailure(status)) | ||
throw std::runtime_error("pattern application failed to converge"); | ||
}, | ||
"module"_a, "set"_a, | ||
"Applys the given patterns to the given module greedily while folding " | ||
"results.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some kind of explanation of how this works would help a lot. Doesn't have to be in this file, but I'm struggling to understand what this pass does from the description you've provided.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the pass itself or this method? the pass itself is from the original PR for exposing FrozenPatternRewriter but you can take a look at #157487 which just landed and slightly refactored.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent work! Bravo!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks fine to me. We can iterate on it when we get to supporting dependent dialects.
const std::string &description, const std::string &opName) { | ||
if (!name.has_value()) { | ||
name = nb::cast<std::string>( | ||
nb::borrow<nb::str>(run.attr("__name__"))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: what happens on lambdas?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
>>> x = lambda: "bob"
>>> x.__name__
'<lambda>'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not the best thing ever but it doesn't blow up (and also I don't see anyone using a lambda here...).
Sure, but strings count... A lot of things are going through strings as shortcut in python instead of bindings APIs today (for better or worse).
I didn't actually. I see this as a deficiency in the C API though, I don't remember the history there though so I may be missing something, but off-hands I don't see a reason why this should be intrinsically different than the invariants we preserve in the system in general.
The reason why is that in general if we find something that is broken, we aim at addressing it instead of piling more broken things on top of it. |
STATUS: This PR is ready-for-review but quite experimental :)
It tries to close #155996.
This PR exports a class
mlir.passmanager.Pass
for Python-side to use for defining new MLIR passes.This is a simple example of a Python-defined pass.
TODO list: