diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -768,6 +768,8 @@ outputs.push_back(rewriter.create( genericOp.getLoc(), expandedOutputType, opOperand->get(), reassociation)); + } else { + outputs.push_back(opOperand->get()); } } diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -539,3 +539,37 @@ // CHECK-SAME: ins(%[[RESHAPE0]], %[[RESHAPE1]] : // CHECK-SAME: outs(%[[RESHAPE2]], %[[RESHAPE3]] : // CHECK: return %[[GENERIC]]#0, %[[GENERIC]]#1 + +// ----- + +#map0 = affine_map<(d0, d1) -> (d1)> +#map1 = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @multi_result_op_expansion(%arg0: tensor<512xf32>, %arg1: tensor<512xf32>, + %arg2: tensor<512xf32>, %arg3: tensor<200x512xf32>) -> tensor<25x8x1x512xf32> { + %0:2 = linalg.generic { + indexing_maps = [#map0, #map0, #map0, #map1], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1 : tensor<512xf32>, tensor<512xf32>) + outs(%arg2, %arg3 : tensor<512xf32>, tensor<200x512xf32>) { + ^bb0(%arg4: f32, %arg5: f32, %arg6: f32, %arg7: f32): + %2 = arith.addf %arg4, %arg5 : f32 + linalg.yield %2, %2 : f32, f32 + } -> (tensor<512xf32>, tensor<200x512xf32>) + %1 = tensor.expand_shape %0#1 [[0, 1, 2], [3]] : tensor<200x512xf32> into tensor<25x8x1x512xf32> + return %1 : tensor<25x8x1x512xf32> + } +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: func.func @multi_result_op_expansion( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<512xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<512xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<512xf32> +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: tensor<200x512xf32> +// CHECK: %[[OUTS:.+]] = tensor.expand_shape %[[ARG3]] {{\[}}[0, 1, 2], [3]{{\]}} +// CHECK: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]], #[[MAP1]]] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : +// CHECK-SAME: outs(%[[ARG2]], %[[OUTS]] : +// CHECK: return %[[GENERIC]]#1