diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -991,7 +991,7 @@ Type oldType = oldResult.getType(); replacements.push_back( (newType != oldType) - ? rewriter.create(loc, newType, newResult) + ? rewriter.create(loc, oldType, newResult) : newResult); } rewriter.replaceOp(genericOp, replacements); diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -780,3 +780,27 @@ // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>) // CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>) } + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @cast_dest +// CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor<1x?x?xf32>, +func @cast_dest(%arg0: tensor, %arg1: tensor<1x?x?xf32>, %arg2: index, %arg3: index, %arg4: index) -> tensor { + %0 = linalg.init_tensor [%arg2, %arg3, %arg4] : tensor + %1 = tensor.cast %arg1 : tensor<1x?x?xf32> to tensor + %2 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel", "parallel"] + } ins(%arg0, %arg1 : tensor, tensor<1x?x?xf32>) + outs(%0 : tensor) { + ^bb0(%arg5: f32, %arg6: f32, %arg7: f32): + %3 = arith.subf %arg5, %arg6 : f32 + linalg.yield %3 : f32 + } -> tensor + return %2 : tensor +// CHECK: %[[GENERIC_OP:.*]] = linalg.generic +// CHECK-SAME: ins(%{{.*}}, %[[ARG1]] : tensor<1x?x?xf32>, tensor<1x?x?xf32>) +// CHECK-SAME: outs(%{{.*}} : tensor<1x?x?xf32>) +// CHECK: tensor.cast %[[GENERIC_OP]] : tensor<1x?x?xf32> to tensor +}