diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -182,7 +182,17 @@ // allocated via malloc. Operation::~Operation() { assert(block == nullptr && "operation destroyed but still in a block"); - +#ifndef NDEBUG + if (!getUses().empty()) { + { + InFlightDiagnostic diag = + emitOpError("operation destroyed but still has uses"); + for (Operation *user : getUsers()) + diag.attachNote(user->getLoc()) << "- use: " << *user << "\n"; + } + llvm::report_fatal_error("operation destroyed but still has uses"); + } +#endif // Explicitly run the destructors for the operands. if (hasOperandStorage) getOperandStorage().~OperandStorage(); diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -916,9 +916,13 @@ // In a second pass, erase all of the replaced operations in reverse. This // allows processing nested operations before their parent region is - // destroyed. - for (auto &repl : llvm::reverse(replacements)) + // destroyed. Because we process in reverse order, producers may be deleted + // before their users (a pattern deleting a producer and then the consumer) + // so we first drop all uses explicitly. + for (auto &repl : llvm::reverse(replacements)) { + repl.first->dropAllUses(); repl.first->erase(); + } argConverter.applyRewrites(mapping); @@ -2230,13 +2234,19 @@ // legalized. if (failed(finalize(rewriter))) return rewriterImpl.discardRewrites(), failure(); - // After a successful conversion, apply rewrites if this is not an analysis // conversion. if (mode == OpConversionMode::Analysis) rewriterImpl.discardRewrites(); else rewriterImpl.applyRewrites(); + + // It is possible for a later pattern to erase an op that was originally + // identified as illegal and added to the trackedOps, remove it now after + // replacements have been computed. + for (auto &repl : rewriterImpl.replacements) + trackedOps->erase(repl.first); + return success(); } diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -270,3 +270,17 @@ // expected-remark@+1 {{op 'std.return' is not legalizable}} return } + + +// ----- + + +// Check that a conversion pattern on `test.blackhole` can mark the producer +// for deletion. +// CHECK-LABEL: @blackhole +func @blackhole() { + %input = "test.blackhole_producer"() : () -> (i32) + "test.blackhole"(%input) : (i32) -> () + // expected-remark@+1 {{op 'std.return' is not legalizable}} + return +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1278,6 +1278,11 @@ let assemblyFormat = "$depth attr-dict"; } +// Test legalization pattern: this op will be erase and will also erase the +// producer of its operand. +def BlackHoleOp : TEST_Op<"blackhole">, + Arguments<(ins AnyType)>; + //===----------------------------------------------------------------------===// // Test Type Legalization //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -489,6 +489,19 @@ return success(); }; }; + +// This pattern matches `test.blackhole` and delete this op and its producer. +struct TestReplaceEraseOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BlackHoleOp op, + PatternRewriter &rewriter) const final { + Operation *producer = op.getOperand().getDefiningOp(); + rewriter.eraseOp(op); + rewriter.eraseOp(producer); + return success(); + }; +}; } // namespace namespace { @@ -566,7 +579,8 @@ TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, TestNonRootReplacement, TestBoundedRecursiveRewrite, - TestNestedOpCreationUndoRewrite>(&getContext()); + TestNestedOpCreationUndoRewrite, TestReplaceEraseOp>( + &getContext()); patterns.add(&getContext(), converter); mlir::populateFuncOpTypeConversionPattern(patterns, converter); mlir::populateCallOpTypeConversionPattern(patterns, converter);