diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -465,10 +465,69 @@ } }; +/// Propagate a tensor.unpack operation through a tensor.pad. The idea is to +/// add as many zero padding dimensions in `high` and `low` based on the number +/// of point loops. +struct PushDownUnPackThroughPadOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const override { + tensor::UnPackOp unpackOp = + padOp.getSource().getDefiningOp(); + if (!unpackOp) + return failure(); + + Location loc = padOp.getLoc(); + // Bail out if one of the padded dimension is a tiled one. + llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); + ArrayRef innerDimsPos = unpackOp.getInnerDimsPos(); + llvm::SmallBitVector innerDims(paddedDims.size()); + for (int64_t dim : innerDimsPos) + innerDims.flip(dim); + if (paddedDims.anyCommon(innerDims)) + return failure(); + + Value paddingVal = padOp.getConstantPaddingValue(); + if (!paddingVal) + return failure(); + + // If we have `outer_dims_perms` we need to adjust the padded dimensions. + ArrayRef outerDimsPerm = unpackOp.getOuterDimsPerm(); + SmallVector lowPad = padOp.getMixedLowPad(); + SmallVector highPad = padOp.getMixedHighPad(); + if (!outerDimsPerm.empty()) { + applyPermutationToVector(lowPad, outerDimsPerm); + applyPermutationToVector(highPad, outerDimsPerm); + } + // Add zero padding for the point loops. + size_t pointLoopsSize = innerDimsPos.size(); + lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); + highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); + + auto newPadOp = rewriter.create( + loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad, + paddingVal, padOp.getNofold()); + + // Inject the tensor.unpack right after the packed padOp. + Value outputUnPack = rewriter.create( + loc, padOp.getResultType().getShape(), + padOp.getResultType().getElementType()); + + Value replacement = rewriter.create( + loc, newPadOp.getResult(), outputUnPack, innerDimsPos, + unpackOp.getMixedTiles(), outerDimsPerm); + rewriter.replaceOp(padOp, replacement); + return success(); + } +}; + } // namespace void mlir::linalg::populateDataLayoutPropagationPatterns( RewritePatternSet &patterns) { - patterns.insert(patterns.getContext()); + patterns + .insert( + patterns.getContext()); } diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -471,4 +471,70 @@ // CHECK-SAME: outs(%[[DEST]] // CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]] // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] -// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]] +// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]] + +// ----- + +func.func @pad_valid_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x58x58x64xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<1x56x56x64xf32> + %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32> + %padded = tensor.pad %1 low[0, 1, 1, 0] high[0, 1, 1, 0] { + ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): + tensor.yield %cst : f32 + } : tensor<1x56x56x64xf32> to tensor<1x58x58x64xf32> + return %padded : tensor<1x58x58x64xf32> +} + +// CHECK: func.func @pad_valid_propagation( +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>) +// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x58x58x64xf32> +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK-SAME: into %[[EMPTY]] : tensor<1x2x58x58x32xf32> -> tensor<1x58x58x64xf32> + +// ----- + +func.func @pad_valid_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<2x58x58x64xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<1x56x56x64xf32> + %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32> + %padded = tensor.pad %1 low[1, 1, 1, 0] high[0, 1, 1, 0] { + ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): + tensor.yield %cst : f32 + } : tensor<1x56x56x64xf32> to tensor<2x58x58x64xf32> + return %padded : tensor<2x58x58x64xf32> +} + +// CHECK: func.func @pad_valid_propagation( +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>) +// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[1, 0, 1, 1, 0] high[0, 0, 1, 1, 0] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x58x58x64xf32> +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK-SAME: into %[[EMPTY]] : tensor<2x2x58x58x32xf32> -> tensor<2x58x58x64xf32> + +// ----- + +func.func @pad_along_unpacked_dim(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x58x58x66xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<1x56x56x64xf32> + %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32> + %padded = tensor.pad %1 low[0, 1, 1, 1] high[0, 1, 1, 1] { + ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): + tensor.yield %cst : f32 + } : tensor<1x56x56x64xf32> to tensor<1x58x58x66xf32> + return %padded : tensor<1x58x58x66xf32> +} + +// CHECK: func.func @pad_along_unpacked_dim( +// CHECK: %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>) +// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x56x56x64xf32> +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] +// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] +// CHECK-SAME: into %[[EMPTY]] : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32> +// CHECK: %[[PADDED:.+]] = tensor.pad %[[UNPACK]] low[0, 1, 1, 1] high[0, 1, 1, 1]