diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -434,14 +434,14 @@ getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, genericOp, opOperand); - // We'll replace the init operand with the destination of pack op if the init - // operand has not users in the body of the linalg.generic (pure elementwise). - // If it has users we need to pack the init operand too and replace the init - // with the packing result. - Value dest = (genericOp.getRegionOutputArgs()[0].use_empty()) - ? packOpDest - : packedOutOperand; - + // If the dps init operand of the generic is a tensor.empty forward the pack + // op destination. + Value dest = packedOutOperand; + if (auto initTensor = genericOp.getDpsInitOperand(0) + ->get() + .getDefiningOp()) { + dest = packOpDest; + } return packElementWiseOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo); } diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -726,9 +726,6 @@ // CHECK: func.func @reduction_pack_transpose_inner_dims // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] -// CHECK: %[[ORIG_INIT:.+]] = tensor.empty() : tensor<128x256xi32> -// CHECK: %[[INIT_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32> -// CHECK: %[[PACK_INIT:.+]] = tensor.pack %[[ORIG_INIT]] // CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x32x16x32xi32> // CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] // CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32] @@ -737,7 +734,7 @@ // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel"] // CHECK-SAME: ins(%[[PACK_ARG0]] -// CHECK-SAME: outs(%[[PACK_INIT]] +// CHECK-SAME: outs(%[[DEST]] // CHECK: return %[[RED]] : tensor<4x16x16x32xi32> // ----- @@ -776,11 +773,7 @@ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<100x128x256xi32> // CHECK: %[[INIT_EMPTY:.+]] = tensor.empty() : tensor<4x16x100x16x32xi32> -// CHECK: %[[PACKED_INIT:.+]] = tensor.pack %[[INIT]] -// CHECK-SAME: outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [16, 32] -// CHECK-SAME: into %[[INIT_EMPTY]] // CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x200x100x16x32xi32> // CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]] // CHECK-SAME: outer_dims_perm = [1, 3, 2, 0] inner_dims_pos = [3, 1] inner_tiles = [16, 32] @@ -792,7 +785,7 @@ // CHECK: %[[RES:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]], #[[MAP3]]] // CHECK-SAME: ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]] -// CHECK-SAME: outs(%[[PACKED_INIT]] +// CHECK-SAME: outs(%[[INIT_EMPTY]] // -----