diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2086,7 +2086,8 @@ auto tensorCastOp = output->get().getDefiningOp(); bool fold = canFoldIntoConsumerOp(tensorCastOp); newOperands.push_back(fold ? tensorCastOp.getOperand() : output->get()); - newResultTypes.push_back(newOperands.back().getType()); + if (!newOperands.back().getType().isa()) + newResultTypes.push_back(newOperands.back().getType()); } // Clone op. Operation *newOp = diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -870,3 +870,27 @@ } -> tensor<4xf32> return } + +// ----- + +#map = affine_map<(d0) -> (d0)> +func.func @cast_producer_mixed(%arg0 : tensor<5xf32>, %arg1: memref) { + %0 = tensor.cast %arg0 : tensor<5xf32> to tensor + linalg.generic { + indexing_maps = [#map, #map], + iterator_types = ["parallel"] + } ins(%0 : tensor) + outs(%arg1 : memref) { + ^bb0(%arg2 : f32, %arg3 : f32): + linalg.yield %arg2 : f32 + } + return +} + +// CHECK-LABEL: func @cast_producer_mixed +// CHECK-SAME: (%[[ARG1:.*]]: tensor<5xf32>, %[[ARG2:.*]]: memref) +// CHECK: linalg.generic { +// CHECK-SAME: indexing_maps = [#map, #map], +// CHECK-SAME: iterator_types = ["parallel"] +// CHECK-SAME: } ins(%[[ARG1]] : tensor<5xf32>) +// CHECK-SAME: outs(%[[ARG2]] : memref) {