diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp --- a/mlir/lib/Rewrite/PatternApplicator.cpp +++ b/mlir/lib/Rewrite/PatternApplicator.cpp @@ -15,7 +15,7 @@ #include "ByteCode.h" #include "llvm/Support/Debug.h" -#define DEBUG_TYPE "pattern-match" +#define DEBUG_TYPE "pattern-application" using namespace mlir; using namespace mlir::detail; @@ -30,15 +30,32 @@ } PatternApplicator::~PatternApplicator() {} +#ifndef NDEBUG /// Log a message for a pattern that is impossible to match. static void logImpossibleToMatch(const Pattern &pattern) { - LLVM_DEBUG({ llvm::dbgs() << "Ignoring pattern '" << pattern.getRootKind() << "' because it is impossible to match or cannot lead " "to legal IR (by cost model)\n"; - }); } +/// Log IR after pattern application. +static Optional tryToGetModuleOp(Operation *op) { + if (!op->getContext()->isMultithreadingEnabled()) + return op->getParentOfType(); + return llvm::None; +} +static void logSucessfulPatternApplication(Optional op) { + llvm::dbgs() << "// *** IR Dump After Pattern Application ***\n"; + if (op) { + op.getValue().dump(); + } else { + llvm::dbgs() + << "Failed to dump module IR. Multithreading must be disabled."; + } + llvm::dbgs() << "\n\n"; +} +#endif + void PatternApplicator::applyCostModel(CostModel model) { // Apply the cost model to the bytecode patterns first, and then the native // patterns. @@ -53,7 +70,7 @@ for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) { for (const RewritePattern *pattern : it.second) { if (pattern->getBenefit().isImpossibleToMatch()) - logImpossibleToMatch(*pattern); + LLVM_DEBUG(logImpossibleToMatch(*pattern)); else patterns[it.first].push_back(pattern); } @@ -62,7 +79,7 @@ for (const RewritePattern &pattern : frozenPatternList.getMatchAnyOpNativePatterns()) { if (pattern.getBenefit().isImpossibleToMatch()) - logImpossibleToMatch(pattern); + LLVM_DEBUG(logImpossibleToMatch(pattern)); else anyOpPatterns.push_back(&pattern); } @@ -76,7 +93,7 @@ // Special case for one pattern in the list, which is the most common case. if (list.size() == 1) { if (model(*list.front()).isImpossibleToMatch()) { - logImpossibleToMatch(*list.front()); + LLVM_DEBUG(logImpossibleToMatch(*list.front())); list.clear(); } return; @@ -90,8 +107,10 @@ // Sort patterns with highest benefit first, and remove those that are // impossible to match. std::stable_sort(list.begin(), list.end(), cmp); - while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) - logImpossibleToMatch(*list.pop_back_val()); + while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) { + LLVM_DEBUG(logImpossibleToMatch(*list.back())); + list.pop_back(); + } }; for (auto &it : patterns) processPatternList(it.second); @@ -174,18 +193,22 @@ // benefit, so if we match we can immediately rewrite. For PDL patterns, the // match has already been performed, we just need to rewrite. rewriter.setInsertionPoint(op); +#ifndef NDEBUG + Optional moduleOp = tryToGetModuleOp(op); +#endif if (pdlMatch) { bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState); result = success(!onSuccess || succeeded(onSuccess(*bestPattern))); - } else { const auto *pattern = static_cast(bestPattern); result = pattern->matchAndRewrite(op, rewriter); if (succeeded(result) && onSuccess && failed(onSuccess(*pattern))) result = failure(); } - if (succeeded(result)) + if (succeeded(result)) { + LLVM_DEBUG(logSucessfulPatternApplication(moduleOp)); break; + } // Perform any necessary cleanups. if (onFailure)