Index: mlir/include/mlir/IR/Matchers.h =================================================================== --- mlir/include/mlir/IR/Matchers.h +++ mlir/include/mlir/IR/Matchers.h @@ -50,7 +50,8 @@ /// The matcher that matches a constant foldable operation that has no side /// effect, no operands and produces a single result. -template struct constant_op_binder { +template +struct constant_op_binder { AttrT *bind_value; /// Creates a matcher instance that binds the constant attribute value to @@ -103,7 +104,8 @@ /// The matcher that matches a given target constant scalar / vector splat / /// tensor splat integer value. -template struct constant_int_value_matcher { +template +struct constant_int_value_matcher { bool match(Operation *op) { APInt value; return constant_int_op_binder(&value).match(op) && TargetValue == value; @@ -112,7 +114,8 @@ /// The matcher that matches anything except the given target constant scalar / /// vector splat / tensor splat integer value. -template struct constant_int_not_value_matcher { +template +struct constant_int_not_value_matcher { bool match(Operation *op) { APInt value; return constant_int_op_binder(&value).match(op) && TargetNotValue != value; @@ -120,8 +123,18 @@ }; /// The matcher that matches a certain kind of op. -template struct op_matcher { - bool match(Operation *op) { return isa(op); } +class op_matcher { +public: + op_matcher() = delete; + op_matcher(llvm::StringRef operation) : operationToMatch(operation){}; + bool match(Operation *op) { + if (operationToMatch == op->getName().getStringRef()) + return true; + return false; + } + +private: + llvm::StringRef operationToMatch; }; /// Trait to check whether T provides a 'match' method with type @@ -178,12 +191,14 @@ } /// RecursivePatternMatcher that composes. -template -struct RecursivePatternMatcher { - RecursivePatternMatcher(OperandMatchers... matchers) - : operandMatchers(matchers...) {} +template +class RecursivePatternMatcher { +public: + RecursivePatternMatcher(llvm::StringRef operation, + OperandMatchers... matchers) + : operationToMatch(operation), operandMatchers(matchers...) {} bool match(Operation *op) { - if (!isa(op) || op->getNumOperands() != sizeof...(OperandMatchers)) + if (!areSame(op) || op->getNumOperands() != sizeof...(OperandMatchers)) return false; bool res = true; enumerate(operandMatchers, [&](size_t index, auto &matcher) { @@ -191,7 +206,15 @@ }); return res; } + +private: + llvm::StringRef operationToMatch; std::tuple operandMatchers; + bool areSame(Operation *op) { + if (operationToMatch == op->getName().getStringRef()) + return true; + return false; + } }; } // end namespace detail @@ -209,8 +232,9 @@ } /// Matches the given OpClass. -template inline detail::op_matcher m_Op() { - return detail::op_matcher(); +template +inline detail::op_matcher m_Op() { + return detail::op_matcher(OpClass::getOperationName()); } /// Matches a constant scalar / vector splat / tensor splat integer zero. @@ -248,7 +272,8 @@ template auto m_Op(Matchers... matchers) { - return detail::RecursivePatternMatcher(matchers...); + return detail::RecursivePatternMatcher( + OpType::getOperationName(), matchers...); } namespace matchers { Index: mlir/lib/Dialect/StandardOps/Ops.cpp =================================================================== --- mlir/lib/Dialect/StandardOps/Ops.cpp +++ mlir/lib/Dialect/StandardOps/Ops.cpp @@ -135,7 +135,8 @@ } /// A custom cast operation verifier. -template static LogicalResult verifyCastOp(T op) { +template +static LogicalResult verifyCastOp(T op) { auto opType = op.getOperand()->getType(); auto resType = op.getType(); if (!T::areCastCompatible(opType, resType)) @@ -195,8 +196,8 @@ /// Matches a ConstantIndexOp. /// TODO: This should probably just be a general matcher that uses m_Constant /// and checks the operation for an index type. -static detail::op_matcher m_ConstantIndex() { - return detail::op_matcher(); +static detail::op_matcher m_ConstantIndex() { + return detail::op_matcher(ConstantIndexOp::getOperationName()); } //===----------------------------------------------------------------------===//