diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -44,6 +44,11 @@ /// tensor.[extract_slice|cast|expand_shape|collapse_shape]. void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns); +/// Populates `patterns` with patterns that fold operations like `tensor.pad` +/// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations +/// respectively. +void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns); + } // namespace tensor } // namespace mlir diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ Bufferize.cpp EmptyOpPatterns.cpp ExtractSliceFromReshapeUtils.cpp + FoldIntoPackAndUnpackPatterns.cpp MergeConsecutiveInsertExtractSlicePatterns.cpp ReshapePatterns.cpp SplitPaddingPatterns.cpp diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp @@ -0,0 +1,87 @@ +//===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace tensor { +namespace { + +static bool areAllConstantIntValue(ArrayRef ofrs, int64_t value) { + return llvm::all_of( + ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); }); +} + +/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and +/// the pad op has zero low paddings. +struct FoldPadWithPackOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PackOp packOp, + PatternRewriter &rewriter) const override { + auto padOp = packOp.getSource().getDefiningOp(); + + if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad()) + return failure(); + + Value constantPaddingValue = padOp.getConstantPaddingValue(); + if (!constantPaddingValue) + return failure(); + + if (auto paddingValue = packOp.getPaddingValue()) + if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue)) + return failure(); + + rewriter.replaceOpWithNewOp( + packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(), + packOp.getMixedTiles(), constantPaddingValue, + packOp.getOuterDimsPerm()); + return success(); + } +}; + +/// Fold a `unpack` -> `extract_slice` into the `unpack` since it already +/// has extract_slice semantics. +struct FoldUnpackWithExtractSliceOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, + PatternRewriter &rewriter) const override { + auto unpackOp = sliceOp.getSource().getDefiningOp(); + if (!unpackOp) + return failure(); + + // Check all offsets are zeros, and all strides are 1. + if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) || + !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) { + return rewriter.notifyMatchFailure( + sliceOp, "expects offsets to be 0s and strides to be 1s"); + } + + // Create a new empty output tensor. + Type elementType = unpackOp.getDestType().getElementType(); + Value output = rewriter.create( + sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType); + rewriter.replaceOpWithNewOp( + sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(), + unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm()); + return success(); + } +}; +} // namespace + +void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) { + patterns.insert( + patterns.getContext()); +} + +} // namespace tensor +} // namespace mlir diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir @@ -0,0 +1,103 @@ +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-into-pack-and-unpack %s | FileCheck %s + +func.func @fold_unpack_slice(%arg0 : tensor, %arg1 : tensor, + %arg2 : index, %arg3 : index) -> tensor { + %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1 + : tensor -> tensor + %1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [1, 1] : tensor to tensor + return %1 : tensor +} +// CHECK: func @fold_unpack_slice( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index +// CHECK: %[[INIT:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]]) : tensor +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 4] +// CHECK-SAME: into %[[INIT]] +// CHECK: return %[[UNPACK]] + +// ----- + +func.func @nofold_unpack_slice_non_zero_offset(%arg0 : tensor, %arg1 : tensor, + %arg2 : index, %arg3 : index, %arg4 : index) -> tensor { + %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1 + : tensor -> tensor + %1 = tensor.extract_slice %0[0, %arg4] [%arg2, %arg3] [1, 1] : tensor to tensor + return %1 : tensor +} +// CHECK-LABEL: func @nofold_unpack_slice_non_zero_offset( +// CHECK: %[[UNPACK:.+]] = tensor.unpack +// CHECK: tensor.extract_slice %[[UNPACK]] + +// ----- + +func.func @nofold_unpack_slice_non_unit_stride(%arg0 : tensor, %arg1 : tensor, + %arg2 : index, %arg3 : index, %arg4 : index) -> tensor { + %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1 + : tensor -> tensor + %1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [%arg4, 1] : tensor to tensor + return %1 : tensor +} +// CHECK-LABEL: func @nofold_unpack_slice_non_unit_stride( +// CHECK: %[[UNPACK:.+]] = tensor.unpack +// CHECK: tensor.extract_slice %[[UNPACK]] + +// ----- + +func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %padded = tensor.pad %src low[0, 0] high[15, 0] { + ^bb0(%arg0: index, %arg1: index): + tensor.yield %cst : f32 + } : tensor<16641x16xf32> to tensor<16656x16xf32> + %empty = tensor.empty() : tensor<2082x1x8x32xf32> + %pack = tensor.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty + : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32> + return %pack : tensor<2082x1x8x32xf32> +} +// CHECK-LABEL: func.func @pad_pack +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK: %[[PAD_VAL:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<2082x1x8x32xf32> +// CHECK: %[[PACK:.+]] = tensor.pack %[[SRC]] +// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32) +// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[DEST]] + +// ----- + +func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %padded = tensor.pad %src nofold low[0, 0] high[15, 0] { + ^bb0(%arg0: index, %arg1: index): + tensor.yield %cst : f32 + } : tensor<16641x16xf32> to tensor<16656x16xf32> + %empty = tensor.empty() : tensor<2082x1x8x32xf32> + %pack = tensor.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty + : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32> + return %pack : tensor<2082x1x8x32xf32> +} +// CHECK-LABEL: func.func @nofold_pad_pack +// CHECK: tensor.pad +// CHECK: tensor.pack + +// ----- + +func.func @pad_pack_different_padding_value(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> { + %c0 = arith.constant 0 : index + %cst0 = arith.constant 0.000000e+00 : f32 + %cst1 = arith.constant 1.000000e+00 : f32 + %padded = tensor.pad %src low[0, 0] high[15, 0] { + ^bb0(%arg0: index, %arg1: index): + tensor.yield %cst0 : f32 + } : tensor<16641x16xf32> to tensor<16656x16xf32> + %empty = tensor.empty() : tensor<2082x1x8x32xf32> + %pack = tensor.pack %padded padding_value(%cst1 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty + : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32> + return %pack : tensor<2082x1x8x32xf32> +} +// CHECK-LABEL: func.func @pad_pack_different_padding_value +// CHECK: tensor.pad +// CHECK: tensor.pack diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -74,6 +74,11 @@ *this, "test-empty-op-folding", llvm::cl::desc("Test folding of tensor.empty"), llvm::cl::init(false)}; + Option testFoldIntoPackAndUnpack{ + *this, "test-fold-into-pack-and-unpack", + llvm::cl::desc("Test folding ops into tensor.pack and tensor.unpack"), + llvm::cl::init(false)}; + Option useForeach{ *this, "use-foreach", llvm::cl::desc( @@ -95,6 +100,12 @@ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); } +static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) { + RewritePatternSet patterns(rootOp->getContext()); + tensor::populateFoldIntoPackAndUnpackPatterns(patterns); + (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); +} + static void applySplitPaddingPatterns(Operation *rootOp) { RewritePatternSet patterns(rootOp->getContext()); tensor::populateSplitPaddingPatterns(patterns); @@ -276,6 +287,8 @@ applyReassociativeReshapeFoldingPatterns(rootOp); if (testEmptyOpFolding) applyEmptyOpFoldingPatterns(rootOp); + if (testFoldIntoPackAndUnpack) + applyFoldIntoPackAndUnpackPatterns(rootOp); if (testRewriteExtractSliceWithTiledCollapseShape) { if (failed( applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))