diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -221,6 +221,8 @@ // Check if the result was an SSA value. if (auto repl = foldResults[i].dyn_cast()) { + if (repl.getType() != op->getResult(i).getType()) + return failure(); results.emplace_back(repl); continue; } diff --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir --- a/mlir/test/Transforms/test-canonicalize.mlir +++ b/mlir/test/Transforms/test-canonicalize.mlir @@ -1,10 +1,10 @@ -// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='func(canonicalize)' | FileCheck %s +// RUN: mlir-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s // CHECK-LABEL: func @remove_op_with_inner_ops_pattern func @remove_op_with_inner_ops_pattern() { // CHECK-NEXT: return "test.op_with_region_pattern"() ({ - "foo.op_with_region_terminator"() : () -> () + "test.op_with_region_terminator"() : () -> () }) : () -> () return } @@ -13,7 +13,7 @@ func @remove_op_with_inner_ops_fold_no_side_effect() { // CHECK-NEXT: return "test.op_with_region_fold_no_side_effect"() ({ - "foo.op_with_region_terminator"() : () -> () + "test.op_with_region_terminator"() : () -> () }) : () -> () return } @@ -23,7 +23,7 @@ func @remove_op_with_inner_ops_fold(%arg0 : i32) -> (i32) { // CHECK-NEXT: return %[[ARG_0]] %0 = "test.op_with_region_fold"(%arg0) ({ - "foo.op_with_region_terminator"() : () -> () + "test.op_with_region_terminator"() : () -> () }) : (i32) -> (i32) return %0 : i32 } @@ -51,3 +51,14 @@ // CHECK-NEXT: return %[[O0]], %[[O1]] return %y, %z: i32, i32 } + +func @typemismatch() -> i32 { + %c42 = constant 42.0 : f32 + + // The "passthrough_fold" folder will naively return its operand, but we don't + // want to fold here because of the type mismatch. + + // CHECK: "test.passthrough_fold" + %0 = "test.passthrough_fold"(%c42) : (f32) -> (i32) + return %0 : i32 +} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -628,6 +628,10 @@ return {}; } +OpFoldResult TestPassthroughFold::fold(ArrayRef operands) { + return getOperand(); +} + LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( MLIRContext *, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -900,6 +900,13 @@ let hasFolder = 1; } +// An op that always fold itself. +def TestPassthroughFold : TEST_Op<"passthrough_fold"> { + let arguments = (ins AnyType:$op); + let results = (outs AnyType); + let hasFolder = 1; +} + //===----------------------------------------------------------------------===// // Test Patterns (Symbol Binding)