diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -564,6 +564,20 @@ /// These are also the keys for "mappings". SmallVector regionStack; #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + +#ifndef NDEBUG + /// This cache stores operation names for operations that are tracked in the + /// transform dialect state. It is used to detect missing memory side effects + /// and op tracking. + /// + /// All tracked ops are added to this cache before a transform op is applied. + /// After the application of the transform op, the names of all tracked ops + /// are compared with the names in the cache. If there is a mismatch (or a + /// crash), op tracking is missing somewhere. This is typically a missing + /// "consumesHandle" side effect or a pattern that removes an op without + /// notifying a TrackingListener. + DenseMap cachedNames; +#endif // NDEBUG }; /// Local mapping between values defined by a specific op implementing the diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1764,7 +1764,8 @@ transform::TransformState &state) { tensor::PadOp hoistedPadOp; SmallVector transposeOps; - IRRewriter rewriter(target->getContext()); + TrackingListener listener(state); + IRRewriter rewriter(target->getContext(), &listener); FailureOr result = hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(), hoistedPadOp, transposeOps); diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -77,7 +77,8 @@ /// 8. There is no enclosing scf::ForOp that indexes the padded data. /// Other cases succeed and will trigger hoisting of the pad op. struct HoistingAnalysis { - HoistingAnalysis(tensor::PadOp padOp, int numLoops); + HoistingAnalysis(tensor::PadOp padOp, int numLoops, + OpBuilder::Listener *listener); bool isValid() { return valid; } @@ -179,7 +180,8 @@ /*inclusive=*/true); } -HoistingAnalysis::HoistingAnalysis(tensor::PadOp padOp, int numLoops) { +HoistingAnalysis::HoistingAnalysis(tensor::PadOp padOp, int numLoops, + OpBuilder::Listener *listener) { valid = false; // Get at most `numLoops` of immediately enclosing loops. @@ -215,7 +217,7 @@ // If the padded data is not yet available before entering the outermost // enclosing loop, try to apply hoisting on this outermost loop. // TODO: we may want finer-grained hoisting of only that particular `sliceOp`. - IRRewriter rewriter(outermostEnclosingForOp->getContext()); + IRRewriter rewriter(outermostEnclosingForOp->getContext(), listener); if (!outermostEnclosingForOp.isDefinedOutsideOfLoop(sliceOp.getSource())) { outermostEnclosingForOp = hoistRedundantSubsetExtractInsert(rewriter, outermostEnclosingForOp); @@ -657,7 +659,7 @@ SmallVectorImpl &transposeOps) { LLVM_DEBUG(DBGS() << "\n"; DBGS() << " Try to hoist " << *(opToHoist) << "\n"; DBGS() << " by " << numLoops << " loops\n"); - HoistingAnalysis analysis(opToHoist, numLoops); + HoistingAnalysis analysis(opToHoist, numLoops, rewriter.getListener()); if (!analysis.isValid()) { LLVM_DEBUG(DBGS() << "--Analysis failed -> Skip\n"); return failure(); diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -343,6 +343,13 @@ (void)getHandlesForPayloadValue(opResult, valueHandles); assert(valueHandles.empty() && "expected no mapping to old results"); } + + if (options.getExpensiveChecksEnabled()) { + auto it = cachedNames.find(op); + assert(it != cachedNames.end() && "entry not found"); + assert(it->second == op->getName() && "operation name mismatch"); + cachedNames.erase(it); + } #endif // NDEBUG // TODO: consider invalidating the handles to nested objects here. @@ -357,6 +364,16 @@ return success(); } +#ifndef NDEBUG + if (options.getExpensiveChecksEnabled()) { + auto insertion = cachedNames.insert({replacement, replacement->getName()}); + if (!insertion.second) { + assert(insertion.first->second == replacement->getName() && + "operation is already cached with a different name"); + } + } +#endif // NDEBUG + // Otherwise, replace the pointed-to object of all handles while preserving // their relative order. First, replace the mapped operation if present. for (Value handle : opHandles) { @@ -669,6 +686,8 @@ DiagnosedSilenceableFailure transform::TransformState::applyTransform(TransformOpInterface transform) { + llvm::errs() << "apply transform: " << transform << "\n"; + LLVM_DEBUG(DBGS() << "\n"; DBGS() << "applying: " << transform << "\n"); LLVM_DEBUG(DBGS() << "On top-level payload:\n" << *getTopLevel();); auto printOnFailureRAII = llvm::make_scope_exit([this] { @@ -676,6 +695,7 @@ LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print( llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm());); }); + if (options.getExpensiveChecksEnabled()) { FULL_LDBG("ExpensiveChecksEnabled\n"); if (failed(checkAndRecordHandleInvalidation(transform))) @@ -716,6 +736,23 @@ FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n"); } } + +#ifndef NDEBUG + // Cache Operation* -> OperationName mappings. These will be checked after + // the transform has been applied to detect incorrect memory side effects + // and missing op tracking. + for (auto &it : mappings) { + Mappings &mapping = it.second; + for (auto &it : mapping.reverse) { + Operation *op = it.first; + auto insertion = cachedNames.insert({op, op->getName()}); + if (!insertion.second) { + assert(insertion.first->second == op->getName() && + "operation is already cached with a different name"); + } + } + } +#endif // NDEBUG } // Find which operands are consumed. @@ -742,11 +779,19 @@ // IR after that. SmallVector origOpFlatResults; SmallVector origAssociatedOps; + DenseSet consumedPayloadOps; for (unsigned index : consumedOperands) { Value operand = transform->getOperand(index); if (operand.getType().isa()) { - for (Operation *payloadOp : getPayloadOps(operand)) + for (Operation *payloadOp : getPayloadOps(operand)) { llvm::append_range(origOpFlatResults, payloadOp->getResults()); + if (options.getExpensiveChecksEnabled()) { + // Store all consumed payload ops (and their nested ops) in a set for + // extra error checking. + payloadOp->walk( + [&](Operation *op) { consumedPayloadOps.insert(op); }); + } + } continue; } if (operand.getType().isa()) { @@ -806,6 +851,39 @@ } } +#ifndef NDEBUG + if (options.getExpensiveChecksEnabled()) { + // Check cached operation names. + DominanceInfo domInfo; + for (auto &it : mappings) { + Mappings &mapping = it.second; + for (auto &it : mapping.reverse) { + Operation *op = it.first; + if (consumedPayloadOps.contains(op)) { + // This payload op was consumed but it is still mapped to one or + // multiple handles. Erase the op from all mappings, so that there are + // no dangling pointers in the transform dialect state. + for (Value handle : it.second) { + auto it = llvm::find(mapping.direct[handle], op); + assert(it != mapping.direct[handle].end() && + "inconsistent mapping state"); + mapping.direct[handle].erase(it); + } + mapping.reverse.erase(op); + cachedNames.erase(op); + continue; + } + // Make sure that the name of the op has not changed. If it has changed, + // the op was removed and a new op was allocated at the same memory + // location. This means that we are missing op tracking somewhere. + auto cacheIt = cachedNames.find(op); + assert(cacheIt != cachedNames.end() && "operation not in cache"); + assert(cacheIt->second == op->getName() && "operation name mismatch"); + } + } + } +#endif // NDEBUG + for (OpResult result : transform->getResults()) { assert(result.getDefiningOp() == transform.getOperation() && "payload IR association for a value other than the result of the " diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -339,7 +339,8 @@ if (!failed) { // We will be using the clones, so cancel their scheduled deletion. deleteClones.release(); - IRRewriter rewriter(getContext()); + TrackingListener listener(state); + IRRewriter rewriter(getContext(), &listener); for (const auto &kvp : llvm::zip(originals, clones)) { Operation *original = std::get<0>(kvp); Operation *clone = std::get<1>(kvp); diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -352,6 +352,12 @@ return DiagnosedSilenceableFailure::success(); } +void mlir::test::TestEmitRemarkAndEraseOperandOp::getEffects( + SmallVectorImpl &effects) { + transform::consumesHandle(getTarget(), effects); + transform::modifiesPayload(effects); +} + DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne( Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -203,7 +203,8 @@ def TestEmitRemarkAndEraseOperandOp : Op, - MemoryEffectsOpInterface, FunctionalStyleTransformOpTrait]> { + DeclareOpInterfaceMethods, + FunctionalStyleTransformOpTrait]> { let arguments = (ins PDL_Operation:$target, StrAttr:$remark, UnitAttr:$fail_after_erase); let assemblyFormat = "$target `,` $remark attr-dict";