diff --git a/mlir/include/mlir/IR/BlockAndValueMapping.h b/mlir/include/mlir/IR/BlockAndValueMapping.h --- a/mlir/include/mlir/IR/BlockAndValueMapping.h +++ b/mlir/include/mlir/IR/BlockAndValueMapping.h @@ -76,6 +76,14 @@ /// Clears all mappings held by the mapper. void clear() { valueMap.clear(); } + /// Returns a new mapper containing the inverse mapping. + BlockAndValueMapping getInverse() const { + BlockAndValueMapping result; + for (const auto &pair : valueMap) + result.valueMap.try_emplace(pair.second, pair.first); + return result; + } + private: /// Utility lookupOrValue that looks up an existing key or returns the /// provided value. diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -122,6 +122,9 @@ /// Drop the last mapping for the given value. void erase(Value value) { mapping.erase(value); } + /// Returns the inverse raw value mapping (without recursive query support). + BlockAndValueMapping getInverse() const { return mapping.getInverse(); } + private: /// Current value mappings. BlockAndValueMapping mapping; @@ -2131,7 +2134,8 @@ legalizeChangedResultType(Operation *op, OpResult result, Value newValue, TypeConverter *replConverter, ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &rewriterImpl); + ConversionPatternRewriterImpl &rewriterImpl, + const BlockAndValueMapping &inverseMapping); /// The legalizer to use when converting operations. OperationLegalizer opLegalizer; @@ -2221,6 +2225,11 @@ if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl))) return failure(); + if (rewriterImpl.operationsWithChangedResults.empty()) + return success(); + + Optional inverseMapping; + // Process requested operation replacements. for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size(); i != e; ++i) { @@ -2241,11 +2250,15 @@ if (result.getType() == newValue.getType()) continue; + // Compute the inverse mapping only if it is really needed. + if (!inverseMapping) + inverseMapping = rewriterImpl.mapping.getInverse(); + // Legalize this result. rewriter.setInsertionPoint(repl.first); if (failed(legalizeChangedResultType(repl.first, result, newValue, repl.second.converter, rewriter, - rewriterImpl))) + rewriterImpl, *inverseMapping))) return failure(); // Update the end iterator for this loop in the case it was updated @@ -2305,16 +2318,32 @@ return success(); } +/// Finds a user of the given value, or of any other value that the given value +/// replaced, that was not replaced in the conversion process. +static Operation * +findLiveUserOfReplaced(Value value, ConversionPatternRewriterImpl &rewriterImpl, + const BlockAndValueMapping &inverseMapping) { + do { + // Walk the users of this value to see if there are any live users that + // weren't replaced during conversion. + auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) { + return rewriterImpl.isOpIgnored(user); + }); + if (liveUserIt != value.user_end()) + return *liveUserIt; + value = inverseMapping.lookupOrNull(value); + } while (value != nullptr); + return nullptr; +} + LogicalResult OperationConverter::legalizeChangedResultType( Operation *op, OpResult result, Value newValue, TypeConverter *replConverter, ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &rewriterImpl) { - // Walk the users of this value to see if there are any live users that - // weren't replaced during conversion. - auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) { - return rewriterImpl.isOpIgnored(user); - }); - if (liveUserIt == result.user_end()) + ConversionPatternRewriterImpl &rewriterImpl, + const BlockAndValueMapping &inverseMapping) { + Operation *liveUser = + findLiveUserOfReplaced(result, rewriterImpl, inverseMapping); + if (!liveUser) return success(); // If the replacement has a type converter, attempt to materialize a @@ -2340,8 +2369,8 @@ << result.getResultNumber() << " of operation '" << op->getName() << "' that remained live after conversion"; - diag.attachNote(liveUserIt->getLoc()) - << "see existing live user here: " << *liveUserIt; + diag.attachNote(liveUser->getLoc()) + << "see existing live user here: " << *liveUser; return failure(); } diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir --- a/mlir/test/Transforms/test-legalize-type-conversion.mlir +++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir @@ -45,6 +45,26 @@ // ----- +// CHECK-LABEL: @test_transitive_use_materialization +func @test_transitive_use_materialization() { + // CHECK: %[[V:.*]] = "test.type_producer"() : () -> f64 + // CHECK: %[[C:.*]] = "test.cast"(%[[V]]) : (f64) -> f32 + %result = "test.another_type_producer"() : () -> f32 + // CHECK: "foo.return"(%[[C]]) + "foo.return"(%result) : (f32) -> () +} + +// ----- + +func @test_transitive_use_invalid_materialization() { + // expected-error@below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}} + %result = "test.another_type_producer"() : () -> f16 + // expected-note@below {{see existing live user here}} + "foo.return"(%result) : (f16) -> () +} + +// ----- + func @test_invalid_result_legalization() { // expected-error@below {{failed to legalize conversion operation generated for result #0 of operation 'test.type_producer' that remained live after conversion}} %result = "test.type_producer"() : () -> i16 diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1281,6 +1281,8 @@ Arguments<(ins Variadic)>; def TestTypeProducerOp : TEST_Op<"type_producer">, Results<(outs AnyType)>; +def TestAnotherTypeProducerOp : TEST_Op<"another_type_producer">, + Results<(outs AnyType)>; def TestTypeConsumerOp : TEST_Op<"type_consumer">, Arguments<(ins AnyType)>; def TestValidOp : TEST_Op<"valid", [Terminator]>, diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -801,6 +801,17 @@ } }; +struct TestTypeConversionAnotherProducer + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, + PatternRewriter &rewriter) const final { + rewriter.replaceOpWithNewOp(op, op.getType()); + return success(); + } +}; + struct TestTypeConversionDriver : public PassWrapper> { void getDependentDialects(DialectRegistry ®istry) const override { @@ -865,6 +876,7 @@ OwningRewritePatternList patterns; patterns.insert(converter, &getContext()); + patterns.insert(&getContext()); mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), converter);