diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -663,6 +663,22 @@ ignoredOps.size(), rootUpdates.size()); } +/// Detach any operations nested in the given operation from their parent +/// blocks, and erase the given operation. This can be used when the nested +/// operations are scheduled for erasure themselves, so deleting the regions of +/// the given operation together with their content would result in double-free. +/// This happens, for example, when rolling back op creation in the reverse +/// order and if the nested ops were created before the parent op. This function +/// does not need to collect nested ops recursively because it is expected to +/// also be called for each nested op when it is about to be deleted. +static void detachNestedAndErase(Operation *op) { + for (Region ®ion : op->getRegions()) + for (Block &block : region.getBlocks()) + while (!block.getOperations().empty()) + block.getOperations().remove(block.getOperations().begin()); + op->erase(); +} + void ConversionPatternRewriterImpl::resetState(RewriterState state) { // Reset any operations that were updated in place. for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i) @@ -686,7 +702,7 @@ // Pop all of the newly created operations. while (createdOps.size() != state.numCreatedOps) { - createdOps.back()->erase(); + detachNestedAndErase(createdOps.back()); createdOps.pop_back(); } @@ -746,7 +762,7 @@ // Remove any newly created ops. for (auto *op : llvm::reverse(createdOps)) - op->erase(); + detachNestedAndErase(op); } void ConversionPatternRewriterImpl::applyRewrites() { diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -233,3 +233,18 @@ // expected-remark@+1 {{op 'std.return' is not legalizable}} return } + +// ----- + +// The op in this function is attempted to be rewritten to another illegal op +// with an attached region containing an invalid terminator. The terminator is +// created before the parent op. The deletion should not crash when deleting +// created ops in the inverse order, i.e. deleting the parent op and then the +// child op. +// CHECK-LABEL: @undo_child_created_before_parent +func @undo_child_created_before_parent() { + // expected-remark@+1 {{is not legalizable}} + "test.illegal_op_with_region_anchor"() : () -> () + // expected-remark@+1 {{op 'std.return' is not legalizable}} + return +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1080,6 +1080,22 @@ Arguments<(ins Test_LegalizerEnum:$status)>, Results<(outs I32)>; def LegalOpB : TEST_Op<"legal_op_b">, Results<(outs I32)>; +// Check that the conversion infrastructure can properly undo the creation of +// operations where an operation was created before its parent, in this case, +// in the parent's builder. +def IllegalOpTerminator : TEST_Op<"illegal_op_terminator", [Terminator]>; +def IllegalOpWithRegion : TEST_Op<"illegal_op_with_region"> { + let skipDefaultBuilders = 1; + let builders = [OpBuilder<"OpBuilder &builder, OperationState &state", + [{ Region *bodyRegion = state.addRegion(); + OpBuilder::InsertionGuard g(builder); + Block *body = builder.createBlock(bodyRegion); + builder.setInsertionPointToEnd(body); + builder.create(state.location); + }]>]; +} +def IllegalOpWithRegionAnchor : TEST_Op<"illegal_op_with_region_anchor">; + // Check that smaller pattern depths are chosen, i.e. prioritize more direct // mappings. def : Pat<(ILLegalOpA), (LegalOpA Test_LegalizerEnum_Success)>; diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -443,6 +443,18 @@ /// The conversion target handles bounding the recursion of this pattern. bool hasBoundedRewriteRecursion() const final { return true; } }; + +struct TestNestedOpCreationUndoRewrite + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, + PatternRewriter &rewriter) const final { + // rewriter.replaceOpWithNewOp(op); + rewriter.replaceOpWithNewOp(op); + return success(); + }; +}; } // namespace namespace { @@ -498,8 +510,8 @@ TestSplitReturnType, TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, - TestNonRootReplacement, TestBoundedRecursiveRewrite>( - &getContext()); + TestNonRootReplacement, TestBoundedRecursiveRewrite, + TestNestedOpCreationUndoRewrite>(&getContext()); patterns.insert(&getContext(), converter); mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), converter);