-
Notifications
You must be signed in to change notification settings - Fork 15k
[MLIR][Python] Add a python function to apply patterns with MlirOperation #157487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Twice (PragmaTwice) ChangesIn #94714, we add a python function So here we provide a new python function Full diff: https://github.com/llvm/llvm-project/pull/157487.diff 4 Files Affected:
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 61d3446317550..374d2fb78de88 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -301,6 +301,10 @@ mlirFreezeRewritePattern(MlirRewritePatternSet op);
MLIR_CAPI_EXPORTED void
mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(
+ MlirOperation op, MlirFrozenRewritePatternSet patterns,
+ MlirGreedyRewriteDriverConfig);
+
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
MlirModule op, MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig);
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 0373f9c7affe9..feb22485c5609 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -99,14 +99,24 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
&PyFrozenRewritePatternSet::createFromCapsule);
m.def(
- "apply_patterns_and_fold_greedily",
- [](MlirModule module, MlirFrozenRewritePatternSet set) {
- auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
- if (mlirLogicalResultIsFailure(status))
- // FIXME: Not sure this is the right error to throw here.
- throw nb::value_error("pattern application failed to converge");
- },
- "module"_a, "set"_a,
- "Applys the given patterns to the given module greedily while folding "
- "results.");
+ "apply_patterns_and_fold_greedily",
+ [](MlirModule module, MlirFrozenRewritePatternSet set) {
+ auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
+ if (mlirLogicalResultIsFailure(status))
+ throw std::runtime_error("pattern application failed to converge");
+ },
+ "module"_a, "set"_a,
+ "Applys the given patterns to the given module greedily while folding "
+ "results.")
+ .def(
+ "apply_patterns_and_fold_greedily_with_op",
+ [](MlirOperation op, MlirFrozenRewritePatternSet set) {
+ auto status = mlirApplyPatternsAndFoldGreedilyWithOp(op, set, {});
+ if (mlirLogicalResultIsFailure(status))
+ throw std::runtime_error(
+ "pattern application failed to converge");
+ },
+ "op"_a, "set"_a,
+ "Applys the given patterns to the given op greedily while folding "
+ "results.");
}
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index a4df97f7beace..6f85357a14a18 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -294,6 +294,13 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op,
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
}
+MlirLogicalResult
+mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
+ MlirFrozenRewritePatternSet patterns,
+ MlirGreedyRewriteDriverConfig) {
+ return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
+}
+
//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index 923af29a71ad7..42d3707017e17 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -16,20 +16,7 @@ def construct_and_print_in_module(f):
print(module)
return f
-
-# CHECK-LABEL: TEST: test_add_to_mul
-# CHECK: arith.muli
-@construct_and_print_in_module
-def test_add_to_mul(module_):
- index_type = IndexType.get()
-
- # Create a test case.
- @module(sym_name="ir")
- def ir():
- @func.func(index_type, index_type)
- def add_func(a, b):
- return arith.addi(a, b)
-
+def get_pdl_patterns():
# Create a rewrite from add to mul. This will match
# - operation name is arith.addi
# - operands are index types.
@@ -61,7 +48,39 @@ def rew():
# not yet captured Python side/has sharp edges. So best to construct the
# module and PDL module in same scope.
# FIXME: This should be made more robust.
- frozen = PDLModule(m).freeze()
+ return PDLModule(m).freeze()
+
+# CHECK-LABEL: TEST: test_add_to_mul
+# CHECK: arith.muli
+@construct_and_print_in_module
+def test_add_to_mul(module_):
+ index_type = IndexType.get()
+
+ # Create a test case.
+ @module(sym_name="ir")
+ def ir():
+ @func.func(index_type, index_type)
+ def add_func(a, b):
+ return arith.addi(a, b)
+
+ frozen = get_pdl_patterns()
# Could apply frozen pattern set multiple times.
apply_patterns_and_fold_greedily(module_, frozen)
return module_
+
+# CHECK-LABEL: TEST: test_add_to_mul_with_op
+# CHECK: arith.muli
+@construct_and_print_in_module
+def test_add_to_mul_with_op(module_):
+ index_type = IndexType.get()
+
+ # Create a test case.
+ @module(sym_name="ir")
+ def ir():
+ @func.func(index_type, index_type)
+ def add_func(a, b):
+ return arith.addi(a, b)
+
+ frozen = get_pdl_patterns()
+ apply_patterns_and_fold_greedily_with_op(module_.operation, frozen)
+ return module_
|
✅ With the latest revision this PR passed the Python code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Let me know when you're ready to merge
Ready to merge now : ) |
In #94714, we add a python function
apply_patterns_and_fold_greedily
which accepts anMlirModule
as the argument type. However, sometimes we want to apply patterns with anMlirOperation
argument, and there is currently no python API to convert anMlirOperation
toMlirModule
.So here we overload this function
apply_patterns_and_fold_greedily
to do this (also a corresponding new C APImlirApplyPatternsAndFoldGreedilyWithOp
)