diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -12,6 +12,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "llvm/ADT/FunctionExtras.h" +#include "llvm/Support/TypeName.h" namespace mlir { @@ -132,6 +133,13 @@ return contextAndHasBoundedRecursion.getPointer(); } + /// Return readable pattern name. Should only be used for debugging purposes. + /// Can be empty. + StringRef getDebugName() const { return debugName; } + + /// Set readable pattern name. Should only be used for debugging purposes. + void setDebugName(StringRef name) { debugName = name; } + protected: /// This class acts as a special tag that makes the desire to match "any" /// operation type explicit. This helps to avoid unnecessary usages of this @@ -202,6 +210,9 @@ /// A list of the potential operations that may be generated when rewriting /// an op with this pattern. SmallVector generatedOps; + + /// Readable pattern name. Can be empty. + StringRef debugName; }; //===----------------------------------------------------------------------===// @@ -959,7 +970,9 @@ struct FnPattern final : public OpRewritePattern { FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter), MLIRContext *context) - : OpRewritePattern(context), implFn(implFn) {} + : OpRewritePattern(context), implFn(implFn) { + setDebugName(llvm::getTypeName()); + } LogicalResult matchAndRewrite(OpType op, PatternRewriter &rewriter) const override { @@ -979,8 +992,13 @@ template std::enable_if_t::value> addImpl(Args &&... args) { - nativePatterns.emplace_back( - std::make_unique(std::forward(args)...)); + auto pattern = std::make_unique(std::forward(args)...); + + // Pattern can potentially set name in ctor. Preserve old name if present. + if (pattern->getDebugName().empty()) + pattern->setDebugName(llvm::getTypeName()); + + nativePatterns.emplace_back(std::move(pattern)); } template std::enable_if_t::value> diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp --- a/mlir/lib/Rewrite/PatternApplicator.cpp +++ b/mlir/lib/Rewrite/PatternApplicator.cpp @@ -195,7 +195,13 @@ result = success(!onSuccess || succeeded(onSuccess(*bestPattern))); } else { const auto *pattern = static_cast(bestPattern); + + LLVM_DEBUG(llvm::dbgs() + << "Trying to match \"" << pattern->getDebugName() << "\"\n"); result = pattern->matchAndRewrite(op, rewriter); + LLVM_DEBUG(llvm::dbgs() << "\"" << pattern->getDebugName() << "\" result " + << succeeded(result) << "\n"); + if (succeeded(result) && onSuccess && failed(onSuccess(*pattern))) result = failure(); }