diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -454,7 +454,7 @@ loc, collapsedType, transposeOp->getResult(0), packingMetadata.reassociations); - // 6. ExtractSlice + // 6. ExtractSlice. int64_t destRank = destTensorType.getRank(); auto extractSliceOp = rewriter.create( loc, destTensorType, reshapeOp->getResult(0), @@ -462,8 +462,12 @@ tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)), SmallVector(destRank, one)); - // 7. Replace unPackOp by extractSliceOp. - rewriter.replaceOp(unPackOp, extractSliceOp->getResults()); + // 7. Inject a copy to preserve DPS. + auto copyOp = rewriter.create( + loc, extractSliceOp->getResult(0), unPackOp.getDest()); + + // 8. Replace unPackOp by extractSliceOp. + rewriter.replaceOp(unPackOp, copyOp->getResults()); return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp}; } diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir --- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -123,16 +123,18 @@ // CHECK-LABEL: func.func @unpack( func.func @unpack(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> { %cst_0 = arith.constant 0.0 : f32 - - // CHECK: tensor.empty() : tensor<17x8x2x32x16x16xf32> - // CHECK: linalg.transpose - // CHECK-SAME: ins(%{{.*}} : tensor<17x2x16x16x32x8xf32>) - // CHECK-SAME: outs(%{{.*}} : tensor<17x8x2x32x16x16xf32>) + // CHECK-SAME: %[[ARG0:.*]]: tensor<17x2x16x16x32x8xf32>, %[[ARG1:.*]]: tensor<129x47x16x16xf32> + // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<17x8x2x32x16x16xf32> + // CHECK: %[[TRAN:.*]] = linalg.transpose + // CHECK-SAME: ins(%[[ARG0]] : tensor<17x2x16x16x32x8xf32>) + // CHECK-SAME: outs(%[[EMPTY]] : tensor<17x8x2x32x16x16xf32>) // CHECK-SAME: permutation = [0, 5, 1, 4, 2, 3] - // CHECK: tensor.collapse_shape {{.*}}[0, 1], [2, 3], [4], [5]] + // CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3], [4], [5]] // CHECK-SAME: : tensor<17x8x2x32x16x16xf32> into tensor<136x64x16x16xf32> - // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1] + // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1] // CHECK-SAME: : tensor<136x64x16x16xf32> to tensor<129x47x16x16xf32> + // CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>) + // CHECK-SAME: outs(%[[ARG1]] : tensor<129x47x16x16xf32>) %pack = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1 : tensor<17x2x16x16x32x8xf32> -> tensor<129x47x16x16xf32> return %pack : tensor<129x47x16x16xf32>