Index: mlir/include/mlir/IR/Matchers.h =================================================================== --- mlir/include/mlir/IR/Matchers.h +++ mlir/include/mlir/IR/Matchers.h @@ -48,9 +48,22 @@ } }; +/// The matcher that matches a constant operation that has no side +/// effect, no operands and produces a single result. +struct constant_op { + bool match(Operation *op) { + if (op->getNumOperands() > 0 || op->getNumResults() != 1) + return false; + if (!op->hasNoSideEffect()) + return false; + return true; + } +}; + /// The matcher that matches a constant foldable operation that has no side /// effect, no operands and produces a single result. -template struct constant_op_binder { +template +struct constant_op_binder : private constant_op { AttrT *bind_value; /// Creates a matcher instance that binds the constant attribute value to @@ -58,9 +71,7 @@ constant_op_binder(AttrT *bind_value) : bind_value(bind_value) {} bool match(Operation *op) { - if (op->getNumOperands() > 0 || op->getNumResults() != 1) - return false; - if (!op->hasNoSideEffect()) + if (!constant_op::match(op)) return false; SmallVector foldedOp; @@ -196,6 +207,9 @@ } // end namespace detail +/// Matches a value from a constant operation. +inline detail::constant_op m_Constant() { return detail::constant_op(); } + /// Matches a value from a constant foldable operation and writes the value to /// bind_value. template Index: mlir/lib/IR/Builders.cpp =================================================================== --- mlir/lib/IR/Builders.cpp +++ mlir/lib/IR/Builders.cpp @@ -342,8 +342,7 @@ }; // If this operation is already a constant, there is nothing to do. - Attribute unused; - if (matchPattern(op, m_Constant(&unused))) + if (matchPattern(op, m_Constant())) return cleanupFailure(); // Check to see if any operands to the operation is constant and whether Index: mlir/test/IR/test-matchers.mlir =================================================================== --- mlir/test/IR/test-matchers.mlir +++ mlir/test/IR/test-matchers.mlir @@ -40,3 +40,4 @@ // CHECK-LABEL: test2 // CHECK: Pattern add(add(a, constant), a) matched and bound constant to: 1.000000e+00 +// CHECK: Pattern add(add(a, constant), a) matched Index: mlir/test/lib/IR/TestMatchers.cpp =================================================================== --- mlir/test/lib/IR/TestMatchers.cpp +++ mlir/test/lib/IR/TestMatchers.cpp @@ -126,12 +126,15 @@ auto a = m_Val(f.getArgument(0)); FloatAttr floatAttr; auto p = m_Op(a, m_Op(a, m_Constant(&floatAttr))); + auto p1 = m_Op(a, m_Op(a, m_Constant())); // Last operation that is not the terminator. Operation *lastOp = f.getBody().front().back().getPrevNode(); if (p.match(lastOp)) llvm::outs() << "Pattern add(add(a, constant), a) matched and bound constant to: " << floatAttr.getValueAsDouble() << "\n"; + if (p1.match(lastOp)) + llvm::outs() << "Pattern add(add(a, constant), a) matched\n"; } void TestMatchers::runOnFunction() {