diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -121,6 +121,12 @@ /// The low-level pattern applicator. PatternApplicator matcher; + +#ifndef NDEBUG + DenseMap fingerprints; + + Operation *fingerprintTopLevel = nullptr; +#endif // NDEBUG }; } // namespace @@ -250,12 +256,36 @@ return success(); }; + // Compute finger print to detect faulty rewrite patterns. + auto clearFingerprints = + llvm::make_scope_exit([&]() { fingerprints.clear(); }); + fingerprintTopLevel = config.scope ? config.scope->getParentOp() : op; + fingerprintTopLevel->walk( + [&](Operation *op) { fingerprints.try_emplace(op, op); }); + OperationFingerPrint &beforeFingerPrint = + fingerprints.find(fingerprintTopLevel)->second; + LogicalResult matchResult = matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess); - if (succeeded(matchResult)) - LLVM_DEBUG(logResultWithLine("success", "pattern matched")); - else - LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match")); + OperationFingerPrint afterFingerPrint(fingerprintTopLevel); + if (succeeded(matchResult)) { + LLVM_DEBUG(logResultWithLine("success", "pattern matched")); + assert(beforeFingerPrint != afterFingerPrint && + "pattern returned success but IR did not change"); + fingerprintTopLevel->walk([&](Operation *op) { + if (op == fingerprintTopLevel) + return; + auto it = fingerprints.find(op); + if (it != fingerprints.end()) { + assert(it->second == OperationFingerPrint(op) && + "operation finger print changed"); + } + }); + } else { + LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match")); + assert(beforeFingerPrint == afterFingerPrint && + "pattern returned failure but IR did change"); + } #else LogicalResult matchResult = matcher.matchAndRewrite(op, *this); #endif @@ -270,6 +300,7 @@ } void GreedyPatternRewriteDriver::addToWorklist(Operation *op) { + assert(op && "expected valid op"); // Gather potential ancestors while looking for a "scope" parent region. SmallVector ancestors; Region *region = nullptr; @@ -328,6 +359,13 @@ logger.startLine() << "** Insert : '" << op->getName() << "'(" << op << ")\n"; }); +#ifndef NDEBUG + Operation *fpOp = op->getParentOp(); + while (fpOp && fpOp != fingerprintTopLevel) { + fingerprints.erase(fpOp); + fpOp = fpOp->getParentOp(); + } +#endif // NDEBUG if (config.listener) config.listener->notifyOperationInserted(op); if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps) @@ -340,6 +378,13 @@ logger.startLine() << "** Modified: '" << op->getName() << "'(" << op << ")\n"; }); +#ifndef NDEBUG + Operation *fpOp = op; + while (fpOp && fpOp != fingerprintTopLevel) { + fingerprints.erase(fpOp); + fpOp = fpOp->getParentOp(); + } +#endif // NDEBUG addToWorklist(op); } @@ -369,6 +414,12 @@ op->walk([this](Operation *operation) { removeFromWorklist(operation); folder.notifyRemoval(operation); +#ifndef NDEBUG + while (operation && operation != fingerprintTopLevel) { + fingerprints.erase(operation); + operation = operation->getParentOp(); + } +#endif // NDEBUG }); if (config.strictMode != GreedyRewriteStrictness::AnyOp)