diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -52,6 +52,22 @@ bool match(Operation *op) { return op->hasTrait(); } }; +/// The matcher that matches operations that have the specified op name. +struct NameOpMatcher { + NameOpMatcher(StringRef name) : name(name) {} + bool match(Operation *op) { return op->getName().getStringRef() == name; } + + StringRef name; +}; + +/// The matcher that matches operations that have the specified attribute name. +struct AttrOpMatcher { + AttrOpMatcher(StringRef attrName) : attrName(attrName) {} + bool match(Operation *op) { return op->hasAttr(attrName); } + + StringRef attrName; +}; + /// The matcher that matches operations that have the `ConstantLike` trait, and /// binds the folded attribute value. template @@ -83,6 +99,29 @@ } }; +/// The matcher that matches operations that have the specified attribute +/// name, and binds the attribute value. +template +struct AttrOpBinder { + /// Creates a matcher instance that binds the attribute value to + /// bind_value if match succeeds. + AttrOpBinder(StringRef attrName, AttrT *bindValue) + : attrName(attrName), bindValue(bindValue) {} + /// Creates a matcher instance that doesn't bind if match succeeds. + AttrOpBinder(StringRef attrName) : attrName(attrName), bindValue(nullptr) {} + + bool match(Operation *op) { + if (auto attr = op->getAttrOfType(attrName)) { + if (bindValue) + *bindValue = attr; + return true; + } + return false; + } + StringRef attrName; + AttrT *bindValue; +}; + /// The matcher that matches a constant scalar / vector splat / tensor splat /// float operation and binds the constant float value. struct constant_float_op_binder { @@ -249,6 +288,16 @@ return detail::constant_op_matcher(); } +/// Matches a named attribute operation. +inline detail::AttrOpMatcher m_Attr(StringRef attrName) { + return detail::AttrOpMatcher(attrName); +} + +/// Matches a named operation. +inline detail::NameOpMatcher m_Op(StringRef opName) { + return detail::NameOpMatcher(opName); +} + /// Matches a value from a constant foldable operation and writes the value to /// bind_value. template @@ -256,6 +305,13 @@ return detail::constant_op_binder(bind_value); } +/// Matches a named attribute operation and writes the value to bind_value. +template +inline detail::AttrOpBinder m_Attr(StringRef attrName, + AttrT *bindValue) { + return detail::AttrOpBinder(attrName, bindValue); +} + /// Matches a constant scalar / vector splat / tensor splat float (both positive /// and negative) zero. inline detail::constant_float_predicate_matcher m_AnyZeroFloat() { diff --git a/mlir/test/IR/test-matchers.mlir b/mlir/test/IR/test-matchers.mlir --- a/mlir/test/IR/test-matchers.mlir +++ b/mlir/test/IR/test-matchers.mlir @@ -41,3 +41,14 @@ // CHECK-LABEL: test2 // CHECK: Pattern add(add(a, constant), a) matched and bound constant to: 1.000000e+00 // CHECK: Pattern add(add(a, constant), a) matched + +func.func @test3(%a: f32) -> f32 { + %0 = "test.name"() {value = 1.0 : f32} : () -> f32 + %1 = arith.addf %a, %0: f32 + %2 = arith.mulf %a, %1 fastmath: f32 + return %2: f32 +} + +// CHECK-LABEL: test3 +// CHECK: Pattern mul(*, add(*, m_Op("test.name"))) matched +// CHECK: Pattern m_Attr("fastmath") matched and bound value to: fast diff --git a/mlir/test/lib/IR/TestMatchers.cpp b/mlir/test/lib/IR/TestMatchers.cpp --- a/mlir/test/lib/IR/TestMatchers.cpp +++ b/mlir/test/lib/IR/TestMatchers.cpp @@ -148,6 +148,21 @@ llvm::outs() << "Pattern add(add(a, constant), a) matched\n"; } +void test3(FunctionOpInterface f) { + arith::FastMathFlagsAttr fastMathAttr; + auto p = m_Op(m_Any(), + m_Op(m_Any(), m_Op("test.name"))); + auto p1 = m_Attr("fastmath", &fastMathAttr); + + // Last operation that is not the terminator. + Operation *lastOp = f.getFunctionBody().front().back().getPrevNode(); + if (p.match(lastOp)) + llvm::outs() << "Pattern mul(*, add(*, m_Op(\"test.name\"))) matched\n"; + if (p1.match(lastOp)) + llvm::outs() << "Pattern m_Attr(\"fastmath\") matched and bound value to: " + << fastMathAttr.getValue() << "\n"; +} + void TestMatchers::runOnOperation() { auto f = getOperation(); llvm::outs() << f.getName() << "\n"; @@ -155,6 +170,8 @@ test1(f); if (f.getName() == "test2") test2(f); + if (f.getName() == "test3") + test3(f); } namespace mlir {