diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -256,8 +256,10 @@ /// Attempt to convert the signature of the given block, if successful a new /// block is returned containing the new arguments. Returns `block` if it did /// not require conversion. - FailureOr convertSignature(Block *block, TypeConverter &converter, - ConversionValueMapping &mapping); + FailureOr + convertSignature(Block *block, TypeConverter &converter, + ConversionValueMapping &mapping, + SmallVectorImpl &argReplacements); /// Apply the given signature conversion on the given block. The new block /// containing the updated signature is returned. If no conversions were @@ -268,7 +270,8 @@ Block *applySignatureConversion( Block *block, TypeConverter &converter, TypeConverter::SignatureConversion &signatureConversion, - ConversionValueMapping &mapping); + ConversionValueMapping &mapping, + SmallVectorImpl &argReplacements); /// Insert a new conversion into the cache. void insertConversion(Block *newBlock, ConvertedBlockInfo &&info); @@ -425,9 +428,9 @@ //===----------------------------------------------------------------------===// // Conversion -FailureOr -ArgConverter::convertSignature(Block *block, TypeConverter &converter, - ConversionValueMapping &mapping) { +FailureOr ArgConverter::convertSignature( + Block *block, TypeConverter &converter, ConversionValueMapping &mapping, + SmallVectorImpl &argReplacements) { // Check if the block was already converted. If the block is detached, // conservatively assume it is going to be deleted. if (hasBeenConverted(block) || !block->getParent()) @@ -435,14 +438,16 @@ // Try to convert the signature for the block with the provided converter. if (auto conversion = converter.convertBlockSignature(block)) - return applySignatureConversion(block, converter, *conversion, mapping); + return applySignatureConversion(block, converter, *conversion, mapping, + argReplacements); return failure(); } Block *ArgConverter::applySignatureConversion( Block *block, TypeConverter &converter, TypeConverter::SignatureConversion &signatureConversion, - ConversionValueMapping &mapping) { + ConversionValueMapping &mapping, + SmallVectorImpl &argReplacements) { // If no arguments are being changed or added, there is nothing to do. unsigned origArgCount = block->getNumArguments(); auto convertedTypes = signatureConversion.getConvertedTypes(); @@ -477,6 +482,7 @@ "invalid to provide a replacement value when the argument isn't " "dropped"); mapping.map(origArg, inputMap->replacementValue); + argReplacements.push_back(origArg); continue; } @@ -492,6 +498,7 @@ newArg = replArgs.front(); } mapping.map(origArg, newArg); + argReplacements.push_back(origArg); info.argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg); } @@ -1113,9 +1120,10 @@ Block *block, TypeConverter &converter, TypeConverter::SignatureConversion *conversion) { FailureOr result = - conversion ? argConverter.applySignatureConversion(block, converter, - *conversion, mapping) - : argConverter.convertSignature(block, converter, mapping); + conversion ? argConverter.applySignatureConversion( + block, converter, *conversion, mapping, argReplacements) + : argConverter.convertSignature(block, converter, mapping, + argReplacements); if (Block *newBlock = result.getValue()) { if (newBlock != block) blockActions.push_back(BlockAction::getTypeConversion(newBlock)); diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir --- a/mlir/test/Transforms/test-legalize-type-conversion.mlir +++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir @@ -62,3 +62,18 @@ %result = "test.type_producer"() : () -> f32 "foo.return"(%result) : (f32) -> () } + +// ----- + +// Should not segfault here but gracefully fail. +// CHECK-LABEL: func @test_signature_conversion_undo +func @test_signature_conversion_undo() { + // CHECK: test.signature_conversion_undo + "test.signature_conversion_undo"() ({ + // CHECK: ^{{.*}}(%{{.*}}: f32): + ^bb0(%arg0: f32): + "test.type_consumer"(%arg0) : (f32) -> () + "test.return"(%arg0) : (f32) -> () + }) : () -> () + return +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1289,6 +1289,10 @@ let results = (outs Variadic:$result); } +def TestSignatureConversionUndoOp : TEST_Op<"signature_conversion_undo"> { + let regions = (region AnyRegion); +} + //===----------------------------------------------------------------------===// // Test parser. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -774,6 +774,34 @@ } }; +/// Call signature conversion and then fail the rewrite to trigger the undo +/// mechanism. +struct TestSignatureConversionUndo + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(TestSignatureConversionUndoOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); + return failure(); + } +}; + +/// Just forward the operands to the root op. This is essentially a no-op +/// pattern that is used to trigger target materialization. +struct TestTypeConsumerForward + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(TestTypeConsumerOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + rewriter.updateRootInPlace(op, [&] { op->setOperands(operands); }); + return success(); + } +}; + struct TestTypeConversionDriver : public PassWrapper> { void getDependentDialects(DialectRegistry ®istry) const override { @@ -836,7 +864,8 @@ // Initialize the set of rewrite patterns. OwningRewritePatternList patterns; - patterns.insert(converter, &getContext()); + patterns.insert(converter, &getContext()); mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), converter);