diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -56,6 +56,8 @@ /// Creates a matcher instance that binds the constant attribute value to /// bind_value if match succeeds. constant_op_binder(AttrT *bind_value) : bind_value(bind_value) {} + /// Creates a matcher instance that doesn't bind if match succeeds. + constant_op_binder() : bind_value(nullptr) {} bool match(Operation *op) { if (op->getNumOperands() > 0 || op->getNumResults() != 1) @@ -66,8 +68,11 @@ SmallVector foldedOp; if (succeeded(op->fold(/*operands=*/llvm::None, foldedOp))) { if (auto attr = foldedOp.front().dyn_cast()) { - if ((*bind_value = attr.dyn_cast())) + if (auto attrT = attr.dyn_cast()) { + if (bind_value) + *bind_value = attrT; return true; + } } } return false; @@ -196,6 +201,11 @@ } // end namespace detail +/// Matches a constant foldable operation. +inline detail::constant_op_binder m_Constant() { + return detail::constant_op_binder(); +} + /// Matches a value from a constant foldable operation and writes the value to /// bind_value. template diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -342,8 +342,7 @@ }; // If this operation is already a constant, there is nothing to do. - Attribute unused; - if (matchPattern(op, m_Constant(&unused))) + if (matchPattern(op, m_Constant())) return cleanupFailure(); // Check to see if any operands to the operation is constant and whether diff --git a/mlir/test/IR/test-matchers.mlir b/mlir/test/IR/test-matchers.mlir --- a/mlir/test/IR/test-matchers.mlir +++ b/mlir/test/IR/test-matchers.mlir @@ -40,3 +40,4 @@ // CHECK-LABEL: test2 // CHECK: Pattern add(add(a, constant), a) matched and bound constant to: 1.000000e+00 +// CHECK: Pattern add(add(a, constant), a) matched diff --git a/mlir/test/lib/IR/TestMatchers.cpp b/mlir/test/lib/IR/TestMatchers.cpp --- a/mlir/test/lib/IR/TestMatchers.cpp +++ b/mlir/test/lib/IR/TestMatchers.cpp @@ -126,12 +126,15 @@ auto a = m_Val(f.getArgument(0)); FloatAttr floatAttr; auto p = m_Op(a, m_Op(a, m_Constant(&floatAttr))); + auto p1 = m_Op(a, m_Op(a, m_Constant())); // Last operation that is not the terminator. Operation *lastOp = f.getBody().front().back().getPrevNode(); if (p.match(lastOp)) llvm::outs() << "Pattern add(add(a, constant), a) matched and bound constant to: " << floatAttr.getValueAsDouble() << "\n"; + if (p1.match(lastOp)) + llvm::outs() << "Pattern add(add(a, constant), a) matched\n"; } void TestMatchers::runOnFunction() {