diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2588,6 +2588,11 @@ return !necessaryMaterializations.count(matIt->second); return rewriterImpl.isOpIgnored(user); }; + // This value may be replacing another value that has a live user. + for (Value inv : inverseMapping.lookup(value)) + if (llvm::find_if_not(inv.getUsers(), findFn) != inv.user_end()) + return true; + // Or have live users itself. return llvm::find_if_not(value.getUsers(), findFn) != value.user_end(); }; diff --git a/mlir/test/Transforms/test-legalize-target-materialization-no-uses.mlir b/mlir/test/Transforms/test-legalize-target-materialization-no-uses.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/test-legalize-target-materialization-no-uses.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt -test-target-materialization-with-no-uses %s | FileCheck %s + +// The conversion is set up as follows: +// - type_changer ops are illegal; +// - type_changer ops are replaced with their operands; +// - i16 types are converted to i64 by the type conversion; +// - the rest of the types are legal. +// The first type_changer is replaced with its operand. For the pattern to +// apply to the second type_changer, the conversion infra creates a dummy +// cast operation to cast from the i32 to i64 because the original op takes an +// (illegal) i16 that became i64. This dummy operation should be replaced by +// the one produced by the target materialization hook. At the moment when the +// materialization decision is taken, the i64 replacement of the first type +// change (the result of the dummy cast) has no uses, but the value it replaces +// does, so the infra must call the materialization rather than assume the +// dummy cast to be dead. + +// CHECK-LABEL: @foo +func @foo() { + %0 = "test.type_producer"() : () -> i32 + // CHECK: test.cast + // CHECK-NOT: test.type_changer + %1 = "test.type_changer"(%0) : (i32) -> i16 + %2 = "test.type_changer"(%1) : (i16) -> i64 + "test.type_consumer"(%2) : (i64) -> () + return +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1603,6 +1603,8 @@ Results<(outs AnyType)>; def TestTypeConsumerOp : TEST_Op<"type_consumer">, Arguments<(ins AnyType)>; +def TestTypeChangerOp : TEST_Op<"type_changer">, + Arguments<(ins AnyType)>, Results<(outs AnyType)>; def TestValidOp : TEST_Op<"valid", [Terminator]>, Arguments<(ins Variadic)>; diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1135,6 +1135,58 @@ }; } // namespace +//===----------------------------------------------------------------------===// +// Test Target Materialization With No Uses +//===----------------------------------------------------------------------===// + +namespace { +struct ForwardOperandPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + rewriter.replaceOp(op, adaptor.getOperands()); + return success(); + } +}; + +struct TestTargetMaterializationWithNoUses + : public PassWrapper> { + StringRef getArgument() const final { + return "test-target-materialization-with-no-uses"; + } + StringRef getDescription() const final { + return "Test a special case of target materialization in DialectConversion"; + } + + void runOnOperation() override { + TypeConverter converter; + converter.addConversion([](Type t) { return t; }); + converter.addConversion([](IntegerType intTy) -> Type { + if (intTy.getWidth() == 16) + return IntegerType::get(intTy.getContext(), 64); + return intTy; + }); + converter.addTargetMaterialization( + [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { + return builder.create(loc, type, inputs).getResult(); + }); + + ConversionTarget target(getContext()); + target.addIllegalOp(); + + RewritePatternSet patterns(&getContext()); + patterns.add(converter, &getContext()); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + //===----------------------------------------------------------------------===// // Test Block Merging //===----------------------------------------------------------------------===// @@ -1317,6 +1369,7 @@ PassRegistration(); PassRegistration(); + PassRegistration(); PassRegistration(); PassRegistration();