diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -16,6 +16,8 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/ScopedPrinter.h" using namespace mlir; using namespace mlir::detail; @@ -67,6 +69,34 @@ return success(); } +#ifndef NDEBUG +/// A utility function to log a successful result for the given reason. +template +static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, + Args &&... args) { + LLVM_DEBUG({ + os.unindent(); + os.startLine() << "} -> SUCCESS"; + if (!fmt.empty()) + os.getOStream() << " : " + << llvm::formatv(fmt.data(), std::forward(args)...); + os.getOStream() << "\n"; + }); +} + +/// A utility function to log a failure result for the given reason. +template +static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, + Args &&... args) { + LLVM_DEBUG({ + os.unindent(); + os.startLine() << "} -> FAILURE : " + << llvm::formatv(fmt.data(), std::forward(args)...) + << "\n"; + }); +} +#endif + //===----------------------------------------------------------------------===// // Multi-Level Value Mapper //===----------------------------------------------------------------------===// @@ -612,6 +642,9 @@ /// strictly necessary, and is thus only active during debug builds for extra /// verification. SmallPtrSet pendingRootUpdates; + + /// A logger used to emit diagnostics during the conversion process. + llvm::ScopedPrinter logger{llvm::dbgs()}; #endif }; } // end namespace detail @@ -836,10 +869,11 @@ ConversionPatternRewriter::~ConversionPatternRewriter() {} /// PatternRewriter hook for replacing the results of an operation. -void ConversionPatternRewriter::replaceOp(Operation *op, - ValueRange newValues) { - LLVM_DEBUG(llvm::dbgs() << "** Replacing operation : " << op->getName() - << "\n"); +void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { + LLVM_DEBUG({ + impl->logger.startLine() + << "** Replace : '" << op->getName() << "'(" << op << ")\n"; + }); impl->replaceOp(op, newValues); } @@ -847,8 +881,10 @@ /// operation *must* be made dead by the end of the conversion process, /// otherwise an assert will be issued. void ConversionPatternRewriter::eraseOp(Operation *op) { - LLVM_DEBUG(llvm::dbgs() << "** Erasing operation : " << op->getName() - << "\n"); + LLVM_DEBUG({ + impl->logger.startLine() + << "** Erase : '" << op->getName() << "'(" << op << ")\n"; + }); SmallVector nullRepls(op->getNumResults(), nullptr); impl->replaceOp(op, nullRepls); } @@ -915,8 +951,10 @@ /// PatternRewriter hook for creating a new operation. Operation *ConversionPatternRewriter::insert(Operation *op) { - LLVM_DEBUG(llvm::dbgs() << "** Inserting operation : " << op->getName() - << "\n"); + LLVM_DEBUG({ + impl->logger.startLine() + << "** Insert : '" << op->getName() << "'(" << op << ")\n"; + }); impl->createdOps.push_back(op); return OpBuilder::insert(op); } @@ -1073,27 +1111,40 @@ LogicalResult OperationLegalizer::legalize(Operation *op, ConversionPatternRewriter &rewriter) { - LLVM_DEBUG(llvm::dbgs() << "Legalizing operation : " << op->getName() - << "\n"); + const char *logLineComment = + "//===-------------------------------------------===//\n"; + + auto &rewriterImpl = rewriter.getImpl(); + LLVM_DEBUG({ + auto &os = rewriterImpl.logger; + os.getOStream() << "\n"; + os.startLine() << logLineComment; + os.startLine() << "Legalizing operation : '" << op->getName() << "'(" << op + << ") {\n"; + os.indent(); + }); // Check if this operation is legal on the target. if (auto legalityInfo = target.isLegal(op)) { - LLVM_DEBUG(llvm::dbgs() - << "-- Success : Operation marked legal by the target\n"); + logSuccess( + rewriterImpl.logger, "operation marked legal by the target{0}", + legalityInfo->isRecursivelyLegal + ? "; NOTE: operation is recursively legal; skipping internals" + : ""); + LLVM_DEBUG(rewriterImpl.logger.startLine() << logLineComment); + // If this operation is recursively legal, mark its children as ignored so // that we don't consider them for legalization. - if (legalityInfo->isRecursivelyLegal) { - LLVM_DEBUG(llvm::dbgs() << "-- Success : Operation is recursively legal; " - "Skipping internals\n"); + if (legalityInfo->isRecursivelyLegal) rewriter.getImpl().markNestedOpsIgnored(op); - } return success(); } // Check to see if the operation is ignored and doesn't need to be converted. if (rewriter.getImpl().isOpIgnored(op)) { - LLVM_DEBUG(llvm::dbgs() - << "-- Success : Operation marked ignored during conversion\n"); + logSuccess(rewriterImpl.logger, + "operation marked 'ignored' during conversion"); + LLVM_DEBUG(rewriterImpl.logger.startLine() << logLineComment); return success(); } @@ -1101,23 +1152,30 @@ // TODO(riverriddle) Should we always try to do this, even if the op is // already legal? if (succeeded(legalizeWithFold(op, rewriter))) { - LLVM_DEBUG(llvm::dbgs() << "-- Success : Operation was folded\n"); + logSuccess(rewriterImpl.logger, "operation was folded"); + LLVM_DEBUG(rewriterImpl.logger.startLine() << logLineComment); return success(); } // Otherwise, we need to apply a legalization pattern to this operation. auto it = legalizerPatterns.find(op->getName()); if (it == legalizerPatterns.end()) { - LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no known legalization path.\n"); + logFailure(rewriterImpl.logger, "no known legalization path"); + LLVM_DEBUG(rewriterImpl.logger.startLine() << logLineComment); return failure(); } // The patterns are sorted by expected benefit, so try to apply each in-order. - for (auto *pattern : it->second) - if (succeeded(legalizePattern(op, pattern, rewriter))) + for (auto *pattern : it->second) { + if (succeeded(legalizePattern(op, pattern, rewriter))) { + logSuccess(rewriterImpl.logger, ""); + LLVM_DEBUG(rewriterImpl.logger.startLine() << logLineComment); return success(); + } + } - LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no matched legalization pattern.\n"); + logFailure(rewriterImpl.logger, "no matched legalization pattern"); + LLVM_DEBUG(rewriterImpl.logger.startLine() << logLineComment); return failure(); } @@ -1127,11 +1185,18 @@ auto &rewriterImpl = rewriter.getImpl(); RewriterState curState = rewriterImpl.getCurrentState(); + LLVM_DEBUG({ + rewriterImpl.logger.startLine() << "* Fold {\n"; + rewriterImpl.logger.indent(); + }); + // Try to fold the operation. SmallVector replacementValues; rewriter.setInsertionPoint(op); - if (failed(rewriter.tryFold(op, replacementValues))) + if (failed(rewriter.tryFold(op, replacementValues))) { + logFailure(rewriterImpl.logger, "unable to fold"); return failure(); + } // Insert a replacement for 'op' with the folded replacement values. rewriter.replaceOp(op, replacementValues); @@ -1141,22 +1206,28 @@ i != e; ++i) { Operation *cstOp = rewriterImpl.createdOps[i]; if (failed(legalize(cstOp, rewriter))) { - LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated folding constant '" - << cstOp->getName() << "' was illegal.\n"); + logFailure(rewriterImpl.logger, "generated constant '{0}' was illegal", + cstOp->getName()); rewriterImpl.resetState(curState); return failure(); } } + + logSuccess(rewriterImpl.logger, ""); return success(); } LogicalResult OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern, ConversionPatternRewriter &rewriter) { + auto &rewriterImpl = rewriter.getImpl(); LLVM_DEBUG({ - llvm::dbgs() << "-* Applying rewrite pattern '" << op->getName() << " -> ("; + auto &os = rewriterImpl.logger; + os.getOStream() << "\n"; + os.startLine() << "* Pattern : '" << pattern->getRootKind() << " -> ("; interleaveComma(pattern->getGeneratedOps(), llvm::dbgs()); - llvm::dbgs() << ")'.\n"; + os.getOStream() << ")' {\n"; + os.indent(); }); // Ensure that we don't cycle by not allowing the same pattern to be @@ -1164,11 +1235,10 @@ // TODO(riverriddle) We could eventually converge, but that requires more // complicated analysis. if (!appliedPatterns.insert(pattern).second) { - LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern was already applied.\n"); + logFailure(rewriterImpl.logger, "pattern was already applied"); return failure(); } - auto &rewriterImpl = rewriter.getImpl(); RewriterState curState = rewriterImpl.getCurrentState(); auto cleanupFailure = [&] { // Reset the rewriter state and pop this pattern. @@ -1185,7 +1255,7 @@ #endif if (!matchedPattern) { - LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern failed to match.\n"); + logFailure(rewriterImpl.logger, "pattern failed to match"); return cleanupFailure(); } @@ -1202,8 +1272,7 @@ // Convert the block signature. if (failed(rewriterImpl.convertBlockSignature(action.block))) { - LLVM_DEBUG(llvm::dbgs() - << "-- FAIL: failed to convert types of moved block.\n"); + logFailure(rewriterImpl.logger, "failed to convert types of moved block"); return cleanupFailure(); } } @@ -1239,8 +1308,8 @@ i != e; ++i) { auto &state = rewriterImpl.rootUpdates[i]; if (failed(legalize(state.getOperation(), rewriter))) { - LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Operation updated in-place '" - << op->getName() << "' was illegal.\n"); + logFailure(rewriterImpl.logger, + "operation updated in-place '{0}' was illegal", op->getName()); return cleanupFailure(); } } @@ -1250,12 +1319,14 @@ i != e; ++i) { Operation *op = rewriterImpl.createdOps[i]; if (failed(legalize(op, rewriter))) { - LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated operation '" - << op->getName() << "' was illegal.\n"); + logFailure(rewriterImpl.logger, + "generated operation '{0}'({1}) was illegal", op->getName(), + op); return cleanupFailure(); } } + logSuccess(rewriterImpl.logger, "pattern applied successfully"); appliedPatterns.erase(pattern); return success(); }