diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -52,6 +52,22 @@ bool match(Operation *op) { return op->hasTrait(); } }; +/// The matcher that matches operations that have the specified op name. +struct name_op_matcher { + StringRef opName; + name_op_matcher(StringRef opN) : opName(opN) {} + + bool match(Operation *op) { return op->getName().getStringRef() == opName; } +}; + +/// The matcher that matches operations that have the specified attribute name. +struct attr_op_matcher { + StringRef opAttrName; + attr_op_matcher(StringRef attrN) : opAttrName(attrN) {} + + bool match(Operation *op) { return op->hasAttr(opAttrName); } +}; + /// The matcher that matches operations that have the `ConstantLike` trait, and /// binds the folded attribute value. template @@ -249,6 +265,16 @@ return detail::constant_op_matcher(); } +/// Matches a named attribute operation. +inline detail::attr_op_matcher m_AttrName(StringRef attrN) { + return detail::attr_op_matcher(attrN); +} + +/// Matches a named operation. +inline detail::name_op_matcher m_Name(StringRef opN) { + return detail::name_op_matcher(opN); +} + /// Matches a value from a constant foldable operation and writes the value to /// bind_value. template diff --git a/mlir/test/IR/test-matchers.mlir b/mlir/test/IR/test-matchers.mlir --- a/mlir/test/IR/test-matchers.mlir +++ b/mlir/test/IR/test-matchers.mlir @@ -41,3 +41,14 @@ // CHECK-LABEL: test2 // CHECK: Pattern add(add(a, constant), a) matched and bound constant to: 1.000000e+00 // CHECK: Pattern add(add(a, constant), a) matched + +func.func @test3(%a: f32) -> f32 { + %0 = "test.name"() {value = 1.0 : f32} : () -> f32 + %1 = arith.addf %a, %0: f32 + %2 = arith.mulf %a, %1 fastmath: f32 + return %2: f32 +} + +// CHECK-LABEL: test3 +// CHECK: Pattern mul(*, add(*, m_Name("test.name"))) matched +// CHECK: Pattern m_AttrName("fastmath") matched diff --git a/mlir/test/lib/IR/TestMatchers.cpp b/mlir/test/lib/IR/TestMatchers.cpp --- a/mlir/test/lib/IR/TestMatchers.cpp +++ b/mlir/test/lib/IR/TestMatchers.cpp @@ -148,6 +148,18 @@ llvm::outs() << "Pattern add(add(a, constant), a) matched\n"; } +void test3(FunctionOpInterface f) { + auto p = m_Op( + m_Any(), m_Op(m_Any(), m_Name("test.name"))); + auto p1 = m_AttrName("fastmath"); + // Last operation that is not the terminator. + Operation *lastOp = f.getFunctionBody().front().back().getPrevNode(); + if (p.match(lastOp)) + llvm::outs() << "Pattern mul(*, add(*, m_Name(\"test.name\"))) matched\n"; + if (p1.match(lastOp)) + llvm::outs() << "Pattern m_AttrName(\"fastmath\") matched\n"; +} + void TestMatchers::runOnOperation() { auto f = getOperation(); llvm::outs() << f.getName() << "\n"; @@ -155,6 +167,8 @@ test1(f); if (f.getName() == "test2") test2(f); + if (f.getName() == "test3") + test3(f); } namespace mlir {