diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -12,10 +12,12 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "llvm/ADT/FunctionExtras.h" +#include "llvm/Support/TypeName.h" namespace mlir { class PatternRewriter; +class RewritePatternSet; //===----------------------------------------------------------------------===// // PatternBenefit class @@ -132,6 +134,9 @@ return contextAndHasBoundedRecursion.getPointer(); } + /// Return readable pattern name. Can be empty. + StringRef getPatternName() const { return patternName; } + protected: /// This class acts as a special tag that makes the desire to match "any" /// operation type explicit. This helps to avoid unnecessary usages of this @@ -184,10 +189,15 @@ } private: + /// RewritePatternSet need access to setPatternName. + friend class RewritePatternSet; + Pattern(const void *rootValue, RootKind rootKind, ArrayRef generatedNames, PatternBenefit benefit, MLIRContext *context); + void setPatternName(StringRef name) { patternName = name; } + /// The value used to match the root operation of the pattern. const void *rootValue; RootKind rootKind; @@ -202,6 +212,9 @@ /// A list of the potential operations that may be generated when rewriting /// an op with this pattern. SmallVector generatedOps; + + /// Readable pattern name. Can be empty. + StringRef patternName; }; //===----------------------------------------------------------------------===// @@ -981,6 +994,7 @@ addImpl(Args &&... args) { nativePatterns.emplace_back( std::make_unique(std::forward(args)...)); + nativePatterns.back()->setPatternName(llvm::getTypeName()); } template std::enable_if_t::value> diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp --- a/mlir/lib/Rewrite/PatternApplicator.cpp +++ b/mlir/lib/Rewrite/PatternApplicator.cpp @@ -195,7 +195,17 @@ result = success(!onSuccess || succeeded(onSuccess(*bestPattern))); } else { const auto *pattern = static_cast(bestPattern); + auto getPatternName = [&]() { + StringRef name = pattern->getPatternName(); + return (!name.empty() ? name : "UNKNOWN"); + }; + + LLVM_DEBUG(llvm::dbgs() + << "Trying to match \"" << getPatternName() << "\"\n"); result = pattern->matchAndRewrite(op, rewriter); + LLVM_DEBUG(llvm::dbgs() << "Trying to match \"" << getPatternName() + << "\" result " << succeeded(result) << "\n"); + if (succeeded(result) && onSuccess && failed(onSuccess(*pattern))) result = failure(); }