diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1056,6 +1056,11 @@ // Drop all of the unresolved materialization operations created during // conversion. for (auto &mat : unresolvedMaterializations) { + // In some cases, such as when a materialization is removed via folding, + // the materialization op will already have a replacement, and we need + // to avoid a double-drop. + if (replacements.count(mat.getOp())) + continue; mat.getOp()->dropAllUses(); mat.getOp()->erase(); } @@ -3395,7 +3400,8 @@ // Full Conversion LogicalResult -mlir::applyFullConversion(ArrayRef ops, const ConversionTarget &target, +mlir::applyFullConversion(ArrayRef ops, + const ConversionTarget &target, const FrozenRewritePatternSet &patterns) { OperationConverter opConverter(target, patterns, OpConversionMode::Full); return opConverter.convertOperations(ops); diff --git a/mlir/test/IR/test-dialect-conversion-folded-materialization.mlir b/mlir/test/IR/test-dialect-conversion-folded-materialization.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/test-dialect-conversion-folded-materialization.mlir @@ -0,0 +1,23 @@ +// A regression test for https://github.com/llvm/llvm-project/issues/64665 in +// which the dialect conversion framework tried to remove inserted +// unrealized_conversion_cast ops twice. + +// RUN: mlir-opt %s -test-drop-dozing > %t +// RUN: FileCheck %s < %t + +// CHECK-LABEL: example_fn +// CHECK-NOT: fall_asleep +func.func @example_fn( + %s1 : !test.dozing, + %s2 : !test.dozing) -> !test.dozing { + func.return %s1 : !test.dozing +} + +func.func @test_convert_call() { + %0 = arith.constant 7 : i32 + %1 = arith.constant 8 : i32 + %2 = test.fall_asleep %0 : i32 -> !test.dozing + %3 = test.fall_asleep %1 : i32 -> !test.dozing + %4 = func.call @example_fn(%2, %3) : (!test.dozing, !test.dozing) -> !test.dozing + func.return +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2884,4 +2884,19 @@ let assemblyFormat = "attr-dict"; } +//===----------------------------------------------------------------------===// +// Test ops for dialect conversion materialization bug regression +// https://github.com/llvm/llvm-project/issues/64665 +//===----------------------------------------------------------------------===// + +def TestDozing_FallAsleepOp : TEST_Op<"fall_asleep", [Pure]> { + let summary = "Rock the input value gently to sleep."; + let arguments = (ins AnyType:$input); + let results = (outs TestDozing:$output); + let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)"; + let builders = [ + OpBuilder<(ins "::mlir::Value":$input)> + ]; +} + #endif // TEST_OPS diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1696,6 +1696,76 @@ }; } // namespace +//===----------------------------------------------------------------------===// +// Test Dialect Conversion Materialization Folding +//===----------------------------------------------------------------------===// + +namespace { + +class DropDozingTypeConverter : public TypeConverter { +public: + DropDozingTypeConverter() { + addConversion([](Type type) { return type; }); + addConversion( + [](TestDozingType type) -> Type { return type.getValueType(); }); + } +}; + +struct ConvertFallAsleep : public OpConversionPattern { + ConvertFallAsleep(MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(FallAsleepOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceAllUsesWith(op.getResult(), adaptor.getInput()); + rewriter.eraseOp(op); + return success(); + } +}; + +struct TestDialectConversionMaterializationFolding + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestDialectConversionMaterializationFolding) + + StringRef getArgument() const final { return "test-drop-dozing"; } + StringRef getDescription() const final { + return "Test dropping dozing types"; + } + void runOnOperation() override { + MLIRContext *context = &getContext(); + Operation *op = getOperation(); + RewritePatternSet patterns(context); + DropDozingTypeConverter typeConverter; + ConversionTarget target(*context); + + target.addIllegalOp(); + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + target.addDynamicallyLegalOp( + [&](func::CallOp op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalOp( + [&](func::ReturnOp op) { return typeConverter.isLegal(op); }); + + patterns.add(typeConverter, context); + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + + if (failed(applyPartialConversion(op, target, std::move(patterns)))) { + signalPassFailure(); + } + } +}; +} // namespace + //===----------------------------------------------------------------------===// // PassRegistration //===----------------------------------------------------------------------===// @@ -1725,6 +1795,7 @@ PassRegistration(); PassRegistration(); + PassRegistration(); } } // namespace test } // namespace mlir diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -373,4 +373,12 @@ let mnemonic = "i32"; } +def TestDozing : Test_Type<"TestDozing"> { + let mnemonic = "dozing"; + let summary = "An MLIR type, but very tired."; + let parameters = (ins "Type":$valueType); + let assemblyFormat = "`<` $valueType `>`"; +} + + #endif // TEST_TYPEDEFS