diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -54,18 +54,18 @@ /// The matcher that matches operations that have the specified op name. struct name_op_matcher { - StringRef opName; - name_op_matcher(StringRef opN) : opName(opN) {} + StringRef name; + name_op_matcher(StringRef name) : name(name) {} - bool match(Operation *op) { return op->getName().getStringRef() == opName; } + bool match(Operation *op) { return op->getName().getStringRef() == name; } }; /// The matcher that matches operations that have the specified attribute name. struct attr_op_matcher { - StringRef opAttrName; - attr_op_matcher(StringRef attrN) : opAttrName(attrN) {} + StringRef attr_name; + attr_op_matcher(StringRef attr_name) : attr_name(attr_name) {} - bool match(Operation *op) { return op->hasAttr(opAttrName); } + bool match(Operation *op) { return op->hasAttr(attr_name); } }; /// The matcher that matches operations that have the `ConstantLike` trait, and @@ -99,6 +99,34 @@ } }; +/// The matcher that matches operations that have have the specified attribute +/// name, and binds the attribute value. +template +struct attr_op_binder { + StringRef attr_name; + AttrT *bind_value; + + /// Creates a matcher instance that binds the attribute value to + /// bind_value if match succeeds. + attr_op_binder(StringRef attr_name, AttrT *bind_value) + : attr_name(attr_name), bind_value(bind_value) {} + /// Creates a matcher instance that doesn't bind if match succeeds. + attr_op_binder(StringRef attr_name) + : attr_name(attr_name), bind_value(nullptr) {} + + bool match(Operation *op) { + if (!op->hasAttr(attr_name)) + return false; + + if (auto attr = op->getAttr(attr_name).dyn_cast()) { + if (bind_value) + *bind_value = attr; + return true; + } + return false; + } +}; + /// The matcher that matches a constant scalar / vector splat / tensor splat /// float operation and binds the constant float value. struct constant_float_op_binder { @@ -266,13 +294,13 @@ } /// Matches a named attribute operation. -inline detail::attr_op_matcher m_AttrName(StringRef attrN) { - return detail::attr_op_matcher(attrN); +inline detail::attr_op_matcher m_Attr(StringRef attr_name) { + return detail::attr_op_matcher(attr_name); } /// Matches a named operation. -inline detail::name_op_matcher m_Name(StringRef opN) { - return detail::name_op_matcher(opN); +inline detail::name_op_matcher m_Op(StringRef op_name) { + return detail::name_op_matcher(op_name); } /// Matches a value from a constant foldable operation and writes the value to @@ -282,6 +310,13 @@ return detail::constant_op_binder(bind_value); } +/// Matches a named attribute operation and writes the value to bind_value. +template +inline detail::attr_op_binder m_Attr(StringRef attr_name, + AttrT *bind_value) { + return detail::attr_op_binder(attr_name, bind_value); +} + /// Matches a constant scalar / vector splat / tensor splat float (both positive /// and negative) zero. inline detail::constant_float_predicate_matcher m_AnyZeroFloat() { diff --git a/mlir/test/IR/test-matchers.mlir b/mlir/test/IR/test-matchers.mlir --- a/mlir/test/IR/test-matchers.mlir +++ b/mlir/test/IR/test-matchers.mlir @@ -50,5 +50,5 @@ } // CHECK-LABEL: test3 -// CHECK: Pattern mul(*, add(*, m_Name("test.name"))) matched -// CHECK: Pattern m_AttrName("fastmath") matched +// CHECK: Pattern mul(*, add(*, m_Op("test.name"))) matched +// CHECK: Pattern m_Attr("fastmath") matched and bound constant to: fast diff --git a/mlir/test/lib/IR/TestMatchers.cpp b/mlir/test/lib/IR/TestMatchers.cpp --- a/mlir/test/lib/IR/TestMatchers.cpp +++ b/mlir/test/lib/IR/TestMatchers.cpp @@ -149,15 +149,18 @@ } void test3(FunctionOpInterface f) { - auto p = m_Op( - m_Any(), m_Op(m_Any(), m_Name("test.name"))); - auto p1 = m_AttrName("fastmath"); + arith::FastMathFlagsAttr fastMathAttr; + auto p = m_Op(m_Any(), + m_Op(m_Any(), m_Op("test.name"))); + auto p1 = m_Attr("fastmath", &fastMathAttr); // Last operation that is not the terminator. Operation *lastOp = f.getFunctionBody().front().back().getPrevNode(); if (p.match(lastOp)) - llvm::outs() << "Pattern mul(*, add(*, m_Name(\"test.name\"))) matched\n"; + llvm::outs() << "Pattern mul(*, add(*, m_Op(\"test.name\"))) matched\n"; if (p1.match(lastOp)) - llvm::outs() << "Pattern m_AttrName(\"fastmath\") matched\n"; + llvm::outs() + << "Pattern m_Attr(\"fastmath\") matched and bound constant to: " + << fastMathAttr.getValue() << "\n"; } void TestMatchers::runOnOperation() {