diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h index 61d3446317550..374d2fb78de88 100644 --- a/mlir/include/mlir-c/Rewrite.h +++ b/mlir/include/mlir-c/Rewrite.h @@ -301,6 +301,10 @@ mlirFreezeRewritePattern(MlirRewritePatternSet op); MLIR_CAPI_EXPORTED void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op); +MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp( + MlirOperation op, MlirFrozenRewritePatternSet patterns, + MlirGreedyRewriteDriverConfig); + MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily( MlirModule op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig); diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 278847e7ac7f5..94604a567858a 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -136,7 +136,10 @@ NB_MODULE(_mlir, m) { populateRewriteSubmodule(rewriteModule); // Define and populate PassManager submodule. - auto passModule = + auto passManagerModule = m.def_submodule("passmanager", "MLIR Pass Management Bindings"); - populatePassManagerSubmodule(passModule); + populatePassManagerSubmodule(passManagerModule); + auto passesModule = + m.def_submodule("passes", "MLIR Pass Infrastructure Bindings"); + populatePassSubmodule(passesModule); } diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 1030dea7f364c..73d6c556e5181 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -10,8 +10,11 @@ #include "IRModule.h" #include "mlir-c/Pass.h" +// clang-format off #include "mlir/Bindings/Python/Nanobind.h" #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. +// clang-format on +#include "nanobind/trampoline.h" namespace nb = nanobind; using namespace nb::literals; @@ -20,6 +23,81 @@ using namespace mlir::python; namespace { +// A base class for defining passes in Python +// Users are expected to subclass this and implement the `run` method, e.g. +// ``` +// class MyPass(Pass): +// def __init__(self): +// super().__init__("MyPass", ..) +// # other init stuff.. +// def run(self, operation): +// # do something with operation.. +// pass +// ``` +class PyPassBase { +public: + PyPassBase(std::string name, std::string argument, std::string description, + std::string opName) + : name(std::move(name)), argument(std::move(argument)), + description(std::move(description)), opName(std::move(opName)) { + callbacks.construct = [](void *obj) {}; + callbacks.destruct = [](void *obj) { + nb::handle(static_cast(obj)).dec_ref(); + }; + callbacks.run = [](MlirOperation op, MlirExternalPass, void *obj) { + auto handle = nb::handle(static_cast(obj)); + nb::cast(handle)->run(op); + }; + callbacks.clone = [](void *obj) -> void * { + nb::object copy = nb::module_::import_("copy"); + nb::object deepcopy = copy.attr("deepcopy"); + return deepcopy(obj).release().ptr(); + }; + callbacks.initialize = nullptr; + } + + // this method should be overridden by subclasses in Python. + virtual void run(MlirOperation op) = 0; + + virtual ~PyPassBase() = default; + + // Make an MlirPass instance on-the-fly that wraps this object. + // Note that passmanager will take the ownership of the returned + // object and release it when appropriate. + MlirPass make() { + auto *obj = nb::find(this).release().ptr(); + return mlirCreateExternalPass( + mlirTypeIDCreate(this), mlirStringRefCreate(name.data(), name.length()), + mlirStringRefCreate(argument.data(), argument.length()), + mlirStringRefCreate(description.data(), description.length()), + mlirStringRefCreate(opName.data(), opName.size()), 0, nullptr, + callbacks, obj); + } + + const std::string &getName() const { return name; } + const std::string &getArgument() const { return argument; } + const std::string &getDescription() const { return description; } + const std::string &getOpName() const { return opName; } + +private: + MlirExternalPassCallbacks callbacks; + + std::string name; + std::string argument; + std::string description; + std::string opName; +}; + +// A trampoline class upon PyPassBase. +// Refer to +// https://nanobind.readthedocs.io/en/latest/classes.html#overriding-virtual-functions-in-python +class PyPass : PyPassBase { +public: + NB_TRAMPOLINE(PyPassBase, 1); + + void run(MlirOperation op) override { NB_OVERRIDE_PURE(run, op); } +}; + /// Owning Wrapper around a PassManager. class PyPassManager { public: @@ -52,6 +130,26 @@ class PyPassManager { } // namespace +void mlir::python::populatePassSubmodule(nanobind::module_ &m) { + //---------------------------------------------------------------------------- + // Mapping of the Python-defined Pass interface + //---------------------------------------------------------------------------- + nb::class_(m, "Pass") + .def(nb::init(), + "name"_a, nb::kw_only(), "argument"_a = "", "description"_a = "", + "op_name"_a = "", "Create a new Pass.") + .def("run", &PyPassBase::run, "operation"_a, + "Run the pass on the provided operation.") + .def_prop_ro("name", + [](const PyPassBase &self) { return self.getName(); }) + .def_prop_ro("argument", + [](const PyPassBase &self) { return self.getArgument(); }) + .def_prop_ro("description", + [](const PyPassBase &self) { return self.getDescription(); }) + .def_prop_ro("op_name", + [](const PyPassBase &self) { return self.getOpName(); }); +} + /// Create the `mlir.passmanager` here. void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { //---------------------------------------------------------------------------- @@ -157,6 +255,12 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { "pipeline"_a, "Add textual pipeline elements to the pass manager. Throws a " "ValueError if the pipeline can't be parsed.") + .def( + "add", + [](PyPassManager &passManager, PyPassBase &pass) { + mlirPassManagerAddOwnedPass(passManager.get(), pass.make()); + }, + "pass"_a, "Add a python-defined pass to the pass manager.") .def( "run", [](PyPassManager &passManager, PyOperationBase &op, diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h index bc40943521829..ba3fbb707fed7 100644 --- a/mlir/lib/Bindings/Python/Pass.h +++ b/mlir/lib/Bindings/Python/Pass.h @@ -15,6 +15,7 @@ namespace mlir { namespace python { void populatePassManagerSubmodule(nanobind::module_ &m); +void populatePassSubmodule(nanobind::module_ &m); } // namespace python } // namespace mlir diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 0373f9c7affe9..e764535c1a4c0 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -99,14 +99,25 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyFrozenRewritePatternSet::createFromCapsule); m.def( - "apply_patterns_and_fold_greedily", - [](MlirModule module, MlirFrozenRewritePatternSet set) { - auto status = mlirApplyPatternsAndFoldGreedily(module, set, {}); - if (mlirLogicalResultIsFailure(status)) - // FIXME: Not sure this is the right error to throw here. - throw nb::value_error("pattern application failed to converge"); - }, - "module"_a, "set"_a, - "Applys the given patterns to the given module greedily while folding " - "results."); + "apply_patterns_and_fold_greedily", + [](MlirModule module, MlirFrozenRewritePatternSet set) { + auto status = mlirApplyPatternsAndFoldGreedily(module, set, {}); + if (mlirLogicalResultIsFailure(status)) + // FIXME: Not sure this is the right error to throw here. + throw nb::value_error("pattern application failed to converge"); + }, + "module"_a, "set"_a, + "Applys the given patterns to the given module greedily while folding " + "results.") + .def( + "apply_patterns_and_fold_greedily_with_op", + [](MlirOperation op, MlirFrozenRewritePatternSet set) { + auto status = mlirApplyPatternsAndFoldGreedilyWithOp(op, set, {}); + if (mlirLogicalResultIsFailure(status)) + // FIXME: Not sure this is the right error to throw here. + throw nb::value_error("pattern application failed to converge"); + }, + "op"_a, "set"_a, + "Applys the given patterns to the given op greedily while folding " + "results."); } diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index a4df97f7beace..6f85357a14a18 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -294,6 +294,13 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op, return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns))); } +MlirLogicalResult +mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op, + MlirFrozenRewritePatternSet patterns, + MlirGreedyRewriteDriverConfig) { + return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns))); +} + //===----------------------------------------------------------------------===// /// PDLPatternModule API //===----------------------------------------------------------------------===// diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 7a0c95ebb8200..fde53a4d64d1c 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -20,6 +20,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python SOURCES _mlir_libs/__init__.py ir.py + passes.py passmanager.py rewrite.py dialects/_ods_common.py diff --git a/mlir/python/mlir/passes.py b/mlir/python/mlir/passes.py new file mode 100644 index 0000000000000..aab9d6b252bbc --- /dev/null +++ b/mlir/python/mlir/passes.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._mlir_libs._mlir.passes import * diff --git a/mlir/test/python/pass.py b/mlir/test/python/pass.py new file mode 100644 index 0000000000000..0f6d818e850dc --- /dev/null +++ b/mlir/test/python/pass.py @@ -0,0 +1,84 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import gc, sys +from mlir.ir import * +from mlir.passmanager import * +from mlir.passes import * +from mlir.dialects.builtin import ModuleOp +from mlir.dialects import pdl +from mlir.rewrite import * + + +def log(*args): + print(*args, file=sys.stderr) + sys.stderr.flush() + + +def run(f): + log("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + + +def make_pdl_module(): + with Location.unknown(): + pdl_module = Module.create() + with InsertionPoint(pdl_module.body): + # Change all arith.addi with index types to arith.muli. + @pdl.pattern(benefit=1, sym_name="addi_to_mul") + def pat(): + # Match arith.addi with index types. + i64_type = pdl.TypeOp(IntegerType.get_signless(64)) + operand0 = pdl.OperandOp(i64_type) + operand1 = pdl.OperandOp(i64_type) + op0 = pdl.OperationOp( + name="arith.addi", args=[operand0, operand1], types=[i64_type] + ) + + # Replace the matched op with arith.muli. + @pdl.rewrite() + def rew(): + newOp = pdl.OperationOp( + name="arith.muli", args=[operand0, operand1], types=[i64_type] + ) + pdl.ReplaceOp(op0, with_op=newOp) + + return pdl_module + + +# CHECK-LABEL: TEST: testCustomPass +@run +def testCustomPass(): + with Context(): + pdl_module = make_pdl_module() + frozen = PDLModule(pdl_module).freeze() + + class CustomPass(Pass): + def __init__(self): + super().__init__("CustomPass", op_name="builtin.module") + + def run(self, m): + apply_patterns_and_fold_greedily_with_op(m, frozen) + + module = ModuleOp.parse( + r""" + module { + func.func @add(%a: i64, %b: i64) -> i64 { + %sum = arith.addi %a, %b : i64 + return %sum : i64 + } + } + """ + ) + + pm = PassManager("any") + pm.enable_ir_printing() + + # CHECK-LABEL: Dump After CustomPass + # CHECK: arith.muli + pm.add(CustomPass()) + # CHECK-LABEL: Dump After ArithToLLVMConversionPass + # CHECK: llvm.mul + pm.add("convert-arith-to-llvm") + pm.run(module)