diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -563,6 +563,18 @@ /// Each region must be an ancestor of the following regions in this list. /// These are also the keys for "mappings". SmallVector regionStack; + + /// This cache stores operation names for operations that are tracked in the + /// transform dialect state. It is used to detect missing memory side effects + /// and op tracking. + /// + /// All tracked ops are added to this cache before a transform op is applied. + /// After the application of the transform op, the names of all tracked ops + /// are compared with the names in the cache. If there is a mismatch (or a + /// crash), op tracking is missing somewhere. This is typically a missing + /// "consumesHandle" side effect or a pattern that removes an op without + /// notifying a TrackingListener. + DenseMap cachedNames; #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS }; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1818,7 +1818,8 @@ transform::TransformState &state) { tensor::PadOp hoistedPadOp; SmallVector transposeOps; - IRRewriter rewriter(target->getContext()); + TrackingListener listener(state); + IRRewriter rewriter(target->getContext(), &listener); FailureOr result = hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(), hoistedPadOp, transposeOps); diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -345,6 +345,15 @@ } #endif // NDEBUG +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + if (options.getExpensiveChecksEnabled()) { + auto it = cachedNames.find(op); + assert(it != cachedNames.end() && "entry not found"); + assert(it->second == op->getName() && "operation name mismatch"); + cachedNames.erase(it); + } +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + // TODO: consider invalidating the handles to nested objects here. // If replacing with null, that is erasing the mapping, drop the mapping @@ -357,6 +366,16 @@ return success(); } +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + if (options.getExpensiveChecksEnabled()) { + auto insertion = cachedNames.insert({replacement, replacement->getName()}); + if (!insertion.second) { + assert(insertion.first->second == replacement->getName() && + "operation is already cached with a different name"); + } + } +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + // Otherwise, replace the pointed-to object of all handles while preserving // their relative order. First, replace the mapped operation if present. for (Value handle : opHandles) { @@ -722,6 +741,28 @@ FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n"); } } + +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + // Cache Operation* -> OperationName mappings. These will be checked after + // the transform has been applied to detect incorrect memory side effects + // and missing op tracking. + for (Mappings &mapping : llvm::make_second_range(mappings)) { + for (Operation *op : llvm::make_first_range(mapping.reverse)) { + auto insertion = cachedNames.insert({op, op->getName()}); + if (!insertion.second) { + if (insertion.first->second != op->getName()) { + // Operation is already in the cache, but with a different name. + DiagnosedDefiniteFailure diag = + emitDefiniteFailure(transform->getLoc()) + << "expensive checks failure: operation mismatch, expected " + << insertion.first->second; + diag.attachNote(op->getLoc()) << "payload op: " << op->getName(); + return diag; + } + } + } + } +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS } // Find which operands are consumed. @@ -748,11 +789,23 @@ // IR after that. SmallVector origOpFlatResults; SmallVector origAssociatedOps; +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + DenseSet consumedPayloadOps; +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS for (unsigned index : consumedOperands) { Value operand = transform->getOperand(index); if (operand.getType().isa()) { - for (Operation *payloadOp : getPayloadOps(operand)) + for (Operation *payloadOp : getPayloadOps(operand)) { llvm::append_range(origOpFlatResults, payloadOp->getResults()); +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + if (options.getExpensiveChecksEnabled()) { + // Store all consumed payload ops (and their nested ops) in a set for + // extra error checking. + payloadOp->walk( + [&](Operation *op) { consumedPayloadOps.insert(op); }); + } +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + } continue; } if (operand.getType().isa()) { @@ -812,6 +865,61 @@ } } +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + if (options.getExpensiveChecksEnabled()) { + // Remove erased ops from the transform state. + for (Operation *op : consumedPayloadOps) { + // This payload op was consumed but it may still be mapped to one or + // multiple handles. Forget all handles that are mapped to the op, so that + // there are no dangling pointers in the transform dialect state. This is + // necessary so that the `cachedNames`-based checks work correctly. + // + // Note: Dangling pointers to erased payload ops are allowed if the + // corresponding handles are not used anymore. There is another + // "expensive-check" that looks for future uses of dangling payload op + // pointers (through arbitrary handles). Removing handles to erased ops + // does not interfere with the other expensive checks: handle invalidation + // happens earlier and keeps track of invalidated handles with + // pre-generated error messages, so we do not need the association to + // still be there when the invalidated handle is accessed. + SmallVector handles; + (void)getHandlesForPayloadOp(op, handles); + for (Value handle : handles) + forgetMapping(handle, /*origOpFlatResults=*/ValueRange()); + cachedNames.erase(op); + } + + // Check cached operation names. + for (Mappings &mapping : llvm::make_second_range(mappings)) { + for (Operation *op : llvm::make_first_range(mapping.reverse)) { + // Make sure that the name of the op has not changed. If it has changed, + // the op was removed and a new op was allocated at the same memory + // location. This means that we are missing op tracking somewhere. + auto cacheIt = cachedNames.find(op); + if (cacheIt == cachedNames.end()) { + DiagnosedDefiniteFailure diag = + emitDefiniteFailure(transform->getLoc()) + << "expensive checks failure: operation not found in cache"; + diag.attachNote(op->getLoc()) << "payload op"; + return diag; + } + // If the `getName` call (or the above `attachNote`) is crashing, we + // have a dangling pointer. This usually means that an op was erased but + // the transform dialect was not made aware of that; e.g., missing + // "consumesHandle" or rewriter usage. + if (cacheIt->second != op->getName()) { + DiagnosedDefiniteFailure diag = + emitDefiniteFailure(transform->getLoc()) + << "expensive checks failure: operation mismatch, expected " + << cacheIt->second; + diag.attachNote(op->getLoc()) << "payload op: " << op->getName(); + return diag; + } + } + } + } +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + for (OpResult result : transform->getResults()) { assert(result.getDefiningOp() == transform.getOperation() && "payload IR association for a value other than the result of the " diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -349,7 +349,8 @@ if (!failed) { // We will be using the clones, so cancel their scheduled deletion. deleteClones.release(); - IRRewriter rewriter(getContext()); + TrackingListener listener(state); + IRRewriter rewriter(getContext(), &listener); for (const auto &kvp : llvm::zip(originals, clones)) { Operation *original = std::get<0>(kvp); Operation *clone = std::get<1>(kvp); diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -352,6 +352,12 @@ return DiagnosedSilenceableFailure::success(); } +void mlir::test::TestEmitRemarkAndEraseOperandOp::getEffects( + SmallVectorImpl &effects) { + transform::consumesHandle(getTarget(), effects); + transform::modifiesPayload(effects); +} + DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne( Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -203,7 +203,8 @@ def TestEmitRemarkAndEraseOperandOp : Op, - MemoryEffectsOpInterface, FunctionalStyleTransformOpTrait]> { + DeclareOpInterfaceMethods, + FunctionalStyleTransformOpTrait]> { let arguments = (ins PDL_Operation:$target, StrAttr:$remark, UnitAttr:$fail_after_erase); let assemblyFormat = "$target `,` $remark attr-dict";