diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -682,14 +682,6 @@ // Process the remapping for each of the original arguments. for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) { - // FIXME: We should run the below checks even if a type converter wasn't - // provided, but a lot of existing lowering rely on the block argument - // being blindly replaced. We should rework argument materialization to be - // more robust for temporary source materializations, update existing - // patterns, and remove these checks. - if (!blockInfo.converter && blockInfo.argInfo[i]) - continue; - // If the type of this argument changed and the argument is still live, we // need to materialize a conversion. BlockArgument origArg = origBlock->getArgument(i); diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir --- a/mlir/test/Transforms/test-legalize-type-conversion.mlir +++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir @@ -98,3 +98,17 @@ }) : () -> () return } + +// ----- + +// Make sure argument type changes aren't implicitly forwarded. +func @test_signature_conversion_no_converter() { + "test.signature_conversion_no_converter"() ({ + // expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion}} + ^bb0(%arg0: f32): + // expected-note@below {{see existing live user here}} + "test.type_consumer"(%arg0) : (f32) -> () + "test.return"(%arg0) : (f32) -> () + }) : () -> () + return +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1520,6 +1520,11 @@ let regions = (region AnyRegion); } +def TestSignatureConversionNoConverterOp + : TEST_Op<"signature_conversion_no_converter"> { + let regions = (region AnyRegion); +} + //===----------------------------------------------------------------------===// // Test parser. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -950,6 +950,34 @@ } }; +/// Call signature conversion without providing a type converter to handle +/// materializations. +struct TestTestSignatureConversionNoConverter + : public OpConversionPattern { + TestTestSignatureConversionNoConverter(TypeConverter &converter, + MLIRContext *context) + : OpConversionPattern(context), + converter(converter) {} + + LogicalResult + matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + Region ®ion = op->getRegion(0); + Block *entry = ®ion.front(); + + // Convert the original entry arguments. + TypeConverter::SignatureConversion result(entry->getNumArguments()); + if (failed( + converter.convertSignatureArgs(entry->getArgumentTypes(), result))) + return failure(); + rewriter.updateRootInPlace( + op, [&] { rewriter.applySignatureConversion(®ion, result); }); + return success(); + } + + TypeConverter &converter; +}; + /// Just forward the operands to the root op. This is essentially a no-op /// pattern that is used to trigger target materialization. struct TestTypeConsumerForward @@ -1041,11 +1069,17 @@ // Allow casts from F64 to F32. return (*op.operand_type_begin()).isF64() && op.getType().isF32(); }); + target.addDynamicallyLegalOp( + [&](TestSignatureConversionNoConverterOp op) { + return converter.isLegal(op.getRegion().front().getArgumentTypes()); + }); // Initialize the set of rewrite patterns. RewritePatternSet patterns(&getContext()); patterns.add(converter, &getContext()); + TestSignatureConversionUndo, + TestTestSignatureConversionNoConverter>(converter, + &getContext()); patterns.add(&getContext()); mlir::populateFuncOpTypeConversionPattern(patterns, converter);