diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/Dominance.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" #include @@ -298,10 +299,7 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter, tensor::PackOp packOp) { auto genericOp = packOp.getSource().getDefiningOp(); - if (!genericOp) - return failure(); - - if (!isElementwise(genericOp)) + if (!genericOp || !isElementwise(genericOp)) return failure(); // TODO: Relax the restriction. We are able to bubble up the pack op through @@ -309,6 +307,34 @@ if (genericOp.getNumResults() != 1) return failure(); + // Bail-out if the result of the generic has multiple uses, as bubbling up + // creates recomputation if the generic has multiple users. + if (!genericOp->getResult(0).hasOneUse()) + return failure(); + + // We want to move the pack not the generic. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(genericOp); + + // We need to handle two cases: + // 1) The tensor.pack destination is a tensor.empty. If this is the case, we + // create a new tensor.empty to avoid breaking dominance, as we are moving the + // tensor.pack above the linalg.generic. + // 2) The destination is not a tensor.empty. In this case we can replace only + // if the destination of the tensor.pack dominates the linalg.generic. + Value packOpDest = packOp.getDest(); + if (!packOpDest.hasOneUse()) + return failure(); + if (auto emptyOp = packOpDest.getDefiningOp()) { + packOpDest = rewriter.create( + genericOp->getLoc(), emptyOp.getMixedSizes(), + emptyOp.getType().getElementType()); + } else { + DominanceInfo dom(genericOp); + if (!dom.properlyDominates(packOpDest, genericOp)) + return failure(); + } + // TODO: Add an option for allowing padding values. It could introduce // undefined behavior if we unconditionally propagate pack op through all // the ops. E.g., if the padding value is zero and there are division ops in @@ -330,7 +356,7 @@ // If it has users we need to pack the init operand too and replace the init // with the packing result. Value dest = (genericOp.getRegionOutputArgs()[0].use_empty()) - ? packOp.getDest() + ? packOpDest : packedOutOperand; return packElementWiseOp(rewriter, genericOp, dest, packedOutIndexingMap, diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -621,3 +621,34 @@ // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] // CHECK-SAME: into %[[EMPTY]] : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32> // CHECK: %[[PADDED:.+]] = tensor.pad %[[UNPACK]] low[0, 1, 1, 1] high[0, 1, 1, 1] + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func.func @would_break_dominance(%arg0: tensor<128x256xi32>) -> tensor<4x16x16x32xi32>{ + %init = tensor.empty() : tensor<128x256xi32> + %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor<128x256xi32>) + outs(%init : tensor<128x256xi32>) { + ^bb0(%arg3: i32, %arg4: i32): + %4 = arith.addi %arg3, %arg3 : i32 + linalg.yield %4 : i32 + } -> tensor<128x256xi32> + %dest = bufferization.alloc_tensor() : tensor<4x16x16x32xi32> + %pack = tensor.pack %elem + inner_dims_pos = [1, 0] + inner_tiles = [16, 32] + into %dest : tensor<128x256xi32> -> tensor<4x16x16x32xi32> + return %pack : tensor<4x16x16x32xi32> +} + +// CHECK: func.func @would_break_dominance( +// CHECK-SAME: %[[ARG0:.+]]: tensor<128x256xi32>) +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<128x256xi32> +// CHECK-NEXT: %[[GEN:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]] +// CHECK-SAME: outs(%[[EMPTY]] +// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor() : tensor<4x16x16x32xi32> +// CHECK-NEXT: %{{.+}} = tensor.pack %[[GEN]] +// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32] +// CHECK-SAME: into %[[ALLOC]]