Skip to content

Commit 81e7922

Browse files
cheliniftynse
authored andcommitted
[mlir] m_Constant()
Summary: Introduce m_Constant() which allows matching a constant operation without forcing the user also to capture the attribute value. Differential Revision: https://reviews.llvm.org/D72397
1 parent 202ab27 commit 81e7922

File tree

4 files changed

+16
-3
lines changed

4 files changed

+16
-3
lines changed

mlir/include/mlir/IR/Matchers.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ template <typename AttrT> struct constant_op_binder {
5656
/// Creates a matcher instance that binds the constant attribute value to
5757
/// bind_value if match succeeds.
5858
constant_op_binder(AttrT *bind_value) : bind_value(bind_value) {}
59+
/// Creates a matcher instance that doesn't bind if match succeeds.
60+
constant_op_binder() : bind_value(nullptr) {}
5961

6062
bool match(Operation *op) {
6163
if (op->getNumOperands() > 0 || op->getNumResults() != 1)
@@ -66,8 +68,11 @@ template <typename AttrT> struct constant_op_binder {
6668
SmallVector<OpFoldResult, 1> foldedOp;
6769
if (succeeded(op->fold(/*operands=*/llvm::None, foldedOp))) {
6870
if (auto attr = foldedOp.front().dyn_cast<Attribute>()) {
69-
if ((*bind_value = attr.dyn_cast<AttrT>()))
71+
if (auto attrT = attr.dyn_cast<AttrT>()) {
72+
if (bind_value)
73+
*bind_value = attrT;
7074
return true;
75+
}
7176
}
7277
}
7378
return false;
@@ -196,6 +201,11 @@ struct RecursivePatternMatcher {
196201

197202
} // end namespace detail
198203

204+
/// Matches a constant foldable operation.
205+
inline detail::constant_op_binder<Attribute> m_Constant() {
206+
return detail::constant_op_binder<Attribute>();
207+
}
208+
199209
/// Matches a value from a constant foldable operation and writes the value to
200210
/// bind_value.
201211
template <typename AttrT>

mlir/lib/IR/Builders.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
342342
};
343343

344344
// If this operation is already a constant, there is nothing to do.
345-
Attribute unused;
346-
if (matchPattern(op, m_Constant(&unused)))
345+
if (matchPattern(op, m_Constant()))
347346
return cleanupFailure();
348347

349348
// Check to see if any operands to the operation is constant and whether

mlir/test/IR/test-matchers.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,4 @@ func @test2(%a: f32) -> f32 {
4040

4141
// CHECK-LABEL: test2
4242
// CHECK: Pattern add(add(a, constant), a) matched and bound constant to: 1.000000e+00
43+
// CHECK: Pattern add(add(a, constant), a) matched

mlir/test/lib/IR/TestMatchers.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,15 @@ void test2(FuncOp f) {
126126
auto a = m_Val(f.getArgument(0));
127127
FloatAttr floatAttr;
128128
auto p = m_Op<MulFOp>(a, m_Op<AddFOp>(a, m_Constant(&floatAttr)));
129+
auto p1 = m_Op<MulFOp>(a, m_Op<AddFOp>(a, m_Constant()));
129130
// Last operation that is not the terminator.
130131
Operation *lastOp = f.getBody().front().back().getPrevNode();
131132
if (p.match(lastOp))
132133
llvm::outs()
133134
<< "Pattern add(add(a, constant), a) matched and bound constant to: "
134135
<< floatAttr.getValueAsDouble() << "\n";
136+
if (p1.match(lastOp))
137+
llvm::outs() << "Pattern add(add(a, constant), a) matched\n";
135138
}
136139

137140
void TestMatchers::runOnFunction() {

0 commit comments

Comments
 (0)