diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -364,11 +364,6 @@ // If the argument is still used, replace it with the generated cast. if (!origArg.use_empty()) origArg.replaceAllUsesWith(mapping.lookupOrDefault(castValue)); - - // If all users of the cast were removed, we can drop it. Otherwise, keep - // the operation alive and let the user handle any remaining usages. - if (castValue.use_empty() && castValue.getDefiningOp()) - castValue.getDefiningOp()->erase(); } } } diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -36,8 +36,9 @@ // CHECK-LABEL: func @remap_input_1_to_N({{.*}}f16, {{.*}}f16) func @remap_input_1_to_N(%arg0: f32) -> f32 { - // CHECK-NEXT: "test.return"{{.*}} : (f16, f16) -> () - "test.return"(%arg0) : (f32) -> () + // CHECK-NEXT: [[CAST:%.*]] = "test.cast"(%arg0, %arg1) : (f16, f16) -> f32 + // CHECK-NEXT: "test.return"{{.*}} : (f16, f16) -> () + "test.return"(%arg0) : (f32) -> () } // CHECK-LABEL: func @remap_input_1_to_N_remaining_use(%arg0: f16, %arg1: f16)