diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md --- a/mlir/docs/PatternRewriter.md +++ b/mlir/docs/PatternRewriter.md @@ -342,6 +342,39 @@ Note: This driver is the one used by the [canonicalization](Canonicalization.md) [pass](Passes.md/#-canonicalize-canonicalize-operations) in MLIR. +### Debugging + +To debug the execution of the greedy pattern rewrite driver, +`-debug-only=greedy-rewriter` may be used. This command line flag activates +LLVM's debug logging infrastructure solely for the greedy pattern rewriter. The +output is formatted as a tree structure, mirroring the structure of the pattern +application process. This output contains all of the actions performed by the +rewriter, how operations get processed and patterns are applied, and why they +fail. + +Example output is shown below: + +``` +//===-------------------------------------------===// +Processing operation : 'std.cond_br'(0x60f000001120) { + "std.cond_br"(%arg0)[^bb2, ^bb2] {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (i1) -> () + + * Pattern SimplifyConstCondBranchPred : 'std.cond_br -> ()' { + } -> failure : pattern failed to match + + * Pattern SimplifyCondBranchIdenticalSuccessors : 'std.cond_br -> ()' { + ** Insert : 'std.br'(0x60b000003690) + ** Replace : 'std.cond_br'(0x60f000001120) + } -> success : pattern applied successfully +} -> success : pattern matched +//===-------------------------------------------===// +``` + +This output is describing the processing of a `std.cond_br` operation. We first +try to apply the `SimplifyConstCondBranchPred`, which fails. From there, another +pattern (`SimplifyCondBranchIdenticalSuccessors`) is applied that matches the +`std.cond_br` and replaces it with a `std.br`. + ## Debugging ### Pattern Filtering diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -18,11 +18,12 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ScopedPrinter.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; -#define DEBUG_TYPE "pattern-matcher" +#define DEBUG_TYPE "greedy-rewriter" //===----------------------------------------------------------------------===// // GreedyPatternRewriteDriver @@ -70,6 +71,14 @@ // before the root is changed. void notifyRootReplaced(Operation *op) override; + /// PatternRewriter hook for erasing a dead operation. + void eraseOp(Operation *op) override; + + /// PatternRewriter hook for notifying match failure reasons. + LogicalResult + notifyMatchFailure(Operation *op, + function_ref reasonCallback) override; + /// The low-level pattern applicator. PatternApplicator matcher; @@ -86,6 +95,11 @@ private: /// Configuration information for how to simplify. GreedyRewriteConfig config; + +#ifndef NDEBUG + /// A logger used to emit information during the application process. + llvm::ScopedPrinter logger{llvm::dbgs()}; +#endif }; } // end anonymous namespace @@ -100,6 +114,24 @@ } bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions) { +#ifndef NDEBUG + const char *logLineComment = + "//===-------------------------------------------===//\n"; + + /// A utility function to log a process result for the given reason. + auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) { + logger.unindent(); + logger.startLine() << "} -> " << result; + if (!msg.isTriviallyEmpty()) + logger.getOStream() << " : " << msg; + logger.getOStream() << "\n"; + }; + auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) { + logResult(result, msg); + logger.startLine() << logLineComment; + }; +#endif + bool changed = false; unsigned iteration = 0; do { @@ -135,11 +167,29 @@ if (op == nullptr) continue; + LLVM_DEBUG({ + logger.getOStream() << "\n"; + logger.startLine() << logLineComment; + logger.startLine() << "Processing operation : '" << op->getName() + << "'(" << op << ") {\n"; + logger.indent(); + + // If the operation has no regions, just print it here. + if (op->getNumRegions() == 0) { + op->print( + logger.startLine(), + OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs()); + logger.getOStream() << "\n\n"; + } + }); + // If the operation is trivially dead - remove it. if (isOpTriviallyDead(op)) { notifyOperationRemoved(op); op->erase(); changed = true; + + LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead")); continue; } @@ -166,6 +216,8 @@ bool inPlaceUpdate; if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction, &inPlaceUpdate)))) { + LLVM_DEBUG(logResultWithLine("success", "operation was folded")); + changed = true; if (!inPlaceUpdate) continue; @@ -174,7 +226,41 @@ // Try to match one of the patterns. The rewriter is automatically // notified of any necessary changes, so there is nothing else to do // here. - changed |= succeeded(matcher.matchAndRewrite(op, *this)); +#ifndef NDEBUG + auto canApply = [&](const Pattern &pattern) { + LLVM_DEBUG({ + logger.getOStream() << "\n"; + logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '" + << op->getName() << " -> ("; + llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream()); + logger.getOStream() << ")' {\n"; + logger.indent(); + }); + return true; + }; + auto onFailure = [&](const Pattern &pattern) { + LLVM_DEBUG(logResult("failure", "pattern failed to match")); + }; + auto onSuccess = [&](const Pattern &pattern) { + LLVM_DEBUG(logResult("success", "pattern applied successfully")); + return success(); + }; + + LogicalResult matchResult = + matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess); + if (succeeded(matchResult)) + LLVM_DEBUG(logResultWithLine("success", "pattern matched")); + else + LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match")); +#else + LogicalResult matchResult = matcher.matchAndRewrite(op, *this); +#endif + + +#ifndef NDEBUG +#endif + + changed |= succeeded(matchResult); } // After applying patterns, make sure that the CFG of each of the regions @@ -218,6 +304,10 @@ } void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) { + LLVM_DEBUG({ + logger.startLine() << "** Insert : '" << op->getName() << "'(" << op + << ")\n"; + }); addToWorklist(op); } @@ -245,11 +335,33 @@ } void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op) { + LLVM_DEBUG({ + logger.startLine() << "** Replace : '" << op->getName() << "'(" << op + << ")\n"; + }); for (auto result : op->getResults()) for (auto *user : result.getUsers()) addToWorklist(user); } +void GreedyPatternRewriteDriver::eraseOp(Operation *op) { + LLVM_DEBUG({ + logger.startLine() << "** Erase : '" << op->getName() << "'(" << op + << ")\n"; + }); + PatternRewriter::eraseOp(op); +} + +LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure( + Operation *op, function_ref reasonCallback) { + LLVM_DEBUG({ + Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark); + reasonCallback(diag); + logger.startLine() << "** Failure : " << diag.str() << "\n"; + }); + return failure(); +} + /// Rewrite the regions of the specified operation, which must be isolated from /// above, by repeatedly applying the highest benefit patterns in a greedy /// work-list driven manner. Return success if no more patterns can be matched