diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -52,6 +52,22 @@ bool match(Operation *op) { return op->hasTrait(); } }; +/// The matcher that matches operations that have the specified op name. +struct name_op_matcher { + StringRef name; + name_op_matcher(StringRef name) : name(name) {} + + bool match(Operation *op) { return op->getName().getStringRef() == name; } +}; + +/// The matcher that matches operations that have the specified attribute name. +struct attr_op_matcher { + StringRef attr_name; + attr_op_matcher(StringRef attr_name) : attr_name(attr_name) {} + + bool match(Operation *op) { return op->hasAttr(attr_name); } +}; + /// The matcher that matches operations that have the `ConstantLike` trait, and /// binds the folded attribute value. template @@ -83,6 +99,34 @@ } }; +/// The matcher that matches operations that have have the specified attribute +/// name, and binds the attribute value. +template +struct attr_op_binder { + StringRef attr_name; + AttrT *bind_value; + + /// Creates a matcher instance that binds the attribute value to + /// bind_value if match succeeds. + attr_op_binder(StringRef attr_name, AttrT *bind_value) + : attr_name(attr_name), bind_value(bind_value) {} + /// Creates a matcher instance that doesn't bind if match succeeds. + attr_op_binder(StringRef attr_name) + : attr_name(attr_name), bind_value(nullptr) {} + + bool match(Operation *op) { + if (!op->hasAttr(attr_name)) + return false; + + if (auto attr = op->getAttr(attr_name).dyn_cast()) { + if (bind_value) + *bind_value = attr; + return true; + } + return false; + } +}; + /// The matcher that matches a constant scalar / vector splat / tensor splat /// float operation and binds the constant float value. struct constant_float_op_binder { @@ -249,6 +293,16 @@ return detail::constant_op_matcher(); } +/// Matches a named attribute operation. +inline detail::attr_op_matcher m_Attr(StringRef attr_name) { + return detail::attr_op_matcher(attr_name); +} + +/// Matches a named operation. +inline detail::name_op_matcher m_Op(StringRef op_name) { + return detail::name_op_matcher(op_name); +} + /// Matches a value from a constant foldable operation and writes the value to /// bind_value. template @@ -256,6 +310,13 @@ return detail::constant_op_binder(bind_value); } +/// Matches a named attribute operation and writes the value to bind_value. +template +inline detail::attr_op_binder m_Attr(StringRef attr_name, + AttrT *bind_value) { + return detail::attr_op_binder(attr_name, bind_value); +} + /// Matches a constant scalar / vector splat / tensor splat float (both positive /// and negative) zero. inline detail::constant_float_predicate_matcher m_AnyZeroFloat() { diff --git a/mlir/test/IR/test-matchers.mlir b/mlir/test/IR/test-matchers.mlir --- a/mlir/test/IR/test-matchers.mlir +++ b/mlir/test/IR/test-matchers.mlir @@ -41,3 +41,14 @@ // CHECK-LABEL: test2 // CHECK: Pattern add(add(a, constant), a) matched and bound constant to: 1.000000e+00 // CHECK: Pattern add(add(a, constant), a) matched + +func.func @test3(%a: f32) -> f32 { + %0 = "test.name"() {value = 1.0 : f32} : () -> f32 + %1 = arith.addf %a, %0: f32 + %2 = arith.mulf %a, %1 fastmath: f32 + return %2: f32 +} + +// CHECK-LABEL: test3 +// CHECK: Pattern mul(*, add(*, m_Op("test.name"))) matched +// CHECK: Pattern m_Attr("fastmath") matched and bound constant to: fast diff --git a/mlir/test/lib/IR/TestMatchers.cpp b/mlir/test/lib/IR/TestMatchers.cpp --- a/mlir/test/lib/IR/TestMatchers.cpp +++ b/mlir/test/lib/IR/TestMatchers.cpp @@ -148,6 +148,21 @@ llvm::outs() << "Pattern add(add(a, constant), a) matched\n"; } +void test3(FunctionOpInterface f) { + arith::FastMathFlagsAttr fastMathAttr; + auto p = m_Op(m_Any(), + m_Op(m_Any(), m_Op("test.name"))); + auto p1 = m_Attr("fastmath", &fastMathAttr); + // Last operation that is not the terminator. + Operation *lastOp = f.getFunctionBody().front().back().getPrevNode(); + if (p.match(lastOp)) + llvm::outs() << "Pattern mul(*, add(*, m_Op(\"test.name\"))) matched\n"; + if (p1.match(lastOp)) + llvm::outs() + << "Pattern m_Attr(\"fastmath\") matched and bound constant to: " + << fastMathAttr.getValue() << "\n"; +} + void TestMatchers::runOnOperation() { auto f = getOperation(); llvm::outs() << f.getName() << "\n"; @@ -155,6 +170,8 @@ test1(f); if (f.getName() == "test2") test2(f); + if (f.getName() == "test3") + test3(f); } namespace mlir {