diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -12,6 +12,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "llvm/ADT/FunctionExtras.h" +#include "llvm/Support/TypeName.h" namespace mlir { @@ -132,6 +133,13 @@ return contextAndHasBoundedRecursion.getPointer(); } + /// Return readable pattern name. Should only be used for debugging purposes. + /// Can be empty. + StringRef getDebugName() const { return debugName; } + + /// Set readable pattern name. Should only be used for debugging purposes. + void setDebugName(StringRef name) { debugName = name; } + protected: /// This class acts as a special tag that makes the desire to match "any" /// operation type explicit. This helps to avoid unnecessary usages of this @@ -202,6 +210,9 @@ /// A list of the potential operations that may be generated when rewriting /// an op with this pattern. SmallVector generatedOps; + + /// Readable pattern name. Can be empty. + StringRef debugName; }; //===----------------------------------------------------------------------===// @@ -959,7 +970,9 @@ struct FnPattern final : public OpRewritePattern { FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter), MLIRContext *context) - : OpRewritePattern(context), implFn(implFn) {} + : OpRewritePattern(context), implFn(implFn) { + setDebugName(llvm::getTypeName()); + } LogicalResult matchAndRewrite(OpType op, PatternRewriter &rewriter) const override { @@ -981,6 +994,11 @@ addImpl(Args &&... args) { nativePatterns.emplace_back( std::make_unique(std::forward(args)...)); + auto *pattern = nativePatterns.back().get(); + + // Pattern can potentially set name in ctor. Preserve old name if present. + if (pattern->getDebugName().empty()) + pattern->setDebugName(llvm::getTypeName()); } template std::enable_if_t::value> diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp --- a/mlir/lib/Rewrite/PatternApplicator.cpp +++ b/mlir/lib/Rewrite/PatternApplicator.cpp @@ -195,7 +195,17 @@ result = success(!onSuccess || succeeded(onSuccess(*bestPattern))); } else { const auto *pattern = static_cast(bestPattern); + auto getPatternName = [&]() { + StringRef name = pattern->getDebugName(); + return (!name.empty() ? name : "UNKNOWN"); + }; + + LLVM_DEBUG(llvm::dbgs() + << "Trying to match \"" << getPatternName() << "\"\n"); result = pattern->matchAndRewrite(op, rewriter); + LLVM_DEBUG(llvm::dbgs() << "\"" << getPatternName() << "\" result " + << succeeded(result) << "\n"); + if (succeeded(result) && onSuccess && failed(onSuccess(*pattern))) result = failure(); }