diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -866,6 +866,16 @@ const SmallVector &dynSizes) const; }; +/// Rewrites a tensor::PackOp into a sequence of tensor.pad + linalg.transpose + +/// tensor.insert_slice ops, where the tensor::PackOp has outer dims being all +/// 1s. +struct GeneralizeOuterUnitDimsPackOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::PackOp packOp, + PatternRewriter &rewriter) const override; +}; + /// Populates `patterns` with patterns that vectorize tensor.pad. /// These patterns are meant to apply in a complementary fashion. Benefits /// are used to encode a certain ordering of pattern application. To avoid diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -22,6 +22,8 @@ #include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -473,6 +475,118 @@ return success(); } +/// Returns a tensor.pad op if padding value is set. Otherwise, returns the +/// source directly. The method assumes that the `packOp` has static shapes. +static Value getPackOpSourceOrPaddedSource(OpBuilder &builder, + tensor::PackOp packOp) { + Value input = packOp.getSource(); + if (!packOp.getPaddingValue()) { + return input; + } + + Location loc = packOp.getLoc(); + ShapedType inputType = packOp.getSourceType(); + int64_t inputRank = inputType.getRank(); + assert(llvm::all_of(packOp.getDestType().getShape().take_front(inputRank), + [](int64_t val) { return val == 1; })); + + SmallVector paddedShape; + DenseMap tileAndPosMapping = + packOp.getDimAndTileMapping(); + for (int64_t dim = 0; dim < inputRank; ++dim) { + int64_t size = inputType.getDimSize(dim); + if (!tileAndPosMapping.count(dim)) { + paddedShape.push_back(size); + continue; + } + + // The size is less than or equal to tileSize because outer dims are all 1s. + Optional tileSize = + getConstantIntValue(tileAndPosMapping.lookup(dim)); + assert(tileSize.has_value() && "dynamic inner tile size is not supported"); + paddedShape.push_back(tileSize.value()); + } + auto resultType = + RankedTensorType::get(paddedShape, inputType.getElementType()); + return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(), + /*nofold=*/false, loc, builder); +} + +LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite( + tensor::PackOp packOp, PatternRewriter &rewriter) const { + // TODO: support the case that outer dimensions are not all 1s A + // tensor.expand_shape will be generated in this case. + int64_t srcRank = packOp.getSourceRank(); + if (llvm::any_of(packOp.getDestType().getShape().take_front(srcRank), + [](int64_t val) { return val != 1; })) { + return rewriter.notifyMatchFailure( + packOp, "require the outer dimension of the result are all 1s"); + } + + if (llvm::any_of(packOp.getMixedTiles(), + [](OpFoldResult tile) { return tile.is(); })) { + return rewriter.notifyMatchFailure(packOp, + "require inner tile sizes being static"); + } + + // 1. Use rank-reduced tensor.extract_slice op to extrat the tile. + Location loc = packOp.getLoc(); + Attribute zeroIdxAttr = rewriter.getIndexAttr(0); + Attribute oneIdxAttr = rewriter.getIndexAttr(1); + SmallVector readOffsets(srcRank, zeroIdxAttr); + SmallVector readStrides(srcRank, oneIdxAttr); + SmallVector readSizes; + SmallVector readShape; + DenseMap dimAndTileMapping = + packOp.getDimAndTileMapping(); + for (auto i : llvm::seq(0, srcRank)) { + if (!dimAndTileMapping.count(i)) { + readSizes.push_back(oneIdxAttr); + continue; + } + readSizes.push_back(dimAndTileMapping[i]); + readShape.push_back(getConstantIntValue(dimAndTileMapping[i]) + .value_or(ShapedType::kDynamic)); + } + Type elemType = packOp.getSourceType().getElementType(); + auto readType = RankedTensorType::get(readShape, elemType); + + Value input = getPackOpSourceOrPaddedSource(rewriter, packOp); + Value tile = rewriter.create( + loc, readType, input, readOffsets, readSizes, readStrides); + + // 2. Transpose the tile to match the inner tile order. + constexpr int64_t kNonTiledMarker = -1; + ArrayRef innerDimsPos = packOp.getInnerDimsPos(); + SmallVector vec(srcRank, kNonTiledMarker); + for (auto [index, value] : llvm::enumerate(innerDimsPos)) + vec[value] = index; + SmallVector perm = llvm::to_vector(llvm::make_filter_range( + vec, [](int64_t v) { return v != kNonTiledMarker; })); + + SmallVector transpShape = readShape; + applyPermutationToVector(transpShape, perm); + + Value empty = rewriter.create(loc, transpShape, elemType); + auto transposedOp = + rewriter.create(loc, tile, empty, perm); + + // 3. Insert the inner tile to the destination. + int64_t destRank = packOp.getDestRank(); + SmallVector writeStrides(destRank, oneIdxAttr); + SmallVector writeOffsets(destRank, zeroIdxAttr); + SmallVector writeSizes(srcRank, oneIdxAttr); + for (auto size : transpShape) + writeSizes.push_back(rewriter.getIndexAttr(size)); + + auto insert = rewriter.create( + loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets, + writeSizes, writeStrides); + rewriter.replaceOp(packOp, insert.getResult()); + + return success(); +} + // The following are patterns for downscaling convolution ops with size-1 // window dimensions. // diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir @@ -0,0 +1,166 @@ +// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-generalize-tensor-pack" %s | FileCheck %s + +func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x1x1x8x32xf32>) -> tensor<1x1x1x1x8x32xf32> { + %0 = tensor.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x32x8xf32> -> tensor<1x1x1x1x8x32xf32> + return %0 : tensor<1x1x1x1x8x32xf32> +} +// CHECK-LABEL: func.func @simple_KCRS_to_KCRSsr +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x32xf32> +// CHECK: %[[TRANSP:.+]] = linalg.transpose +// CHECK-SAME: ins(%[[TILE]] : tensor<32x8xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x32xf32>) +// CHECK-SAME: permutation = [1, 0] +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] +// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] +// CHECK: return %[[INSERT]] + +// ----- + +func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2xf32>, %pad: f32) -> tensor<1x1x8x2xf32> { + %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<5x1xf32> -> tensor<1x1x8x2xf32> + return %0 : tensor<1x1x8x2xf32> +} +// CHECK-LABEL: func.func @simple_pad_and_pack +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK: %[[PAD:.+]] = tensor.pad %[[SRC]] low[%[[C0]], %[[C0]]] high[%[[C3]], %[[C1]]] +// CHECK: tensor.yield %[[PAD_VAL]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32> +// CHECK: %[[TRANSP:.+]] = linalg.transpose +// CHECK-SAME: ins(%[[PAD]] : tensor<8x2xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x2xf32>) +// CHECK-SAME: permutation = [0, 1] +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] +// CHECK-SAME: [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1] +// CHECK: return %[[INSERT]] + +// ----- + +func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32>{ + %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<32x8xf32> -> tensor<1x1x32x8xf32> + return %0 : tensor<1x1x32x8xf32> +} +// CHECK-LABEL: func.func @simple_NC_to_CNnc +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32> +// CHECK: %[[TRANSP:.+]] = linalg.transpose +// CHECK-SAME: ins(%[[SRC]] : tensor<32x8xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<32x8xf32>) +// CHECK-SAME: permutation = [0, 1] +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] +// CHECK-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] +// CHECK: return %[[INSERT]] + +// ----- + +// RUN: mlir-opt -split-input-file --test-transform-dialect-interpreter --canonicalize --test-linalg-transform-patterns="test-generalize-tensor-pack" %s | FileCheck %s --check-prefix=CHECK-TRANS + +func.func @KCRS_to_KCRSsr(%arg0: tensor<1x1x128x64xf32>, %arg1: tensor<1x1x4x8x8x32xf32>) -> tensor<1x1x4x8x8x32xf32> { + %0 = tensor.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x128x64xf32> -> tensor<1x1x4x8x8x32xf32> + return %0 : tensor<1x1x4x8x8x32xf32> +} +// CHECK-TRANS-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 32)> +// CHECK-TRANS-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * -32 + 128, 32)> +// CHECK-TRANS-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 8)> +// CHECK-TRANS-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 * -8 + 64, 8)> +// CHECK-TRANS: func.func @KCRS_to_KCRSsr +// CHECK-TRANS-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-TRANS-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK-TRANS: %{{.+}} = scf.for %[[R:[a-zA-Z0-9]+]] = +// CHECK-TRANS: %{{.+}} = scf.for %[[S:[a-zA-Z0-9]+]] = +// CHECK-TRANS: %[[IN_R:.+]] = affine.apply #[[MAP0]](%[[R]]) +// CHECK-TRANS: %[[IN_R_SZ:.+]] = affine.min #[[MAP1]](%[[R]]) +// CHECK-TRANS: %[[IN_S:.+]] = affine.apply #[[MAP2]](%[[S]]) +// CHECK-TRANS: %[[IN_S_SZ:.+]] = affine.min #[[MAP3]](%[[S]]) +// CHECK-TRANS: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]] +// CHECK-TRANS-SAME: [0, 0, %[[IN_R]], %[[IN_S]]] [1, 1, %[[IN_R_SZ]], %[[IN_S_SZ]]] [1, 1, 1, 1] +// CHECK-TRANS: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]] +// CHECK-TRANS-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x?x?xf32> to tensor<32x8xf32> +// CHECK-TRANS: %[[EMPTY:.+]] = tensor.empty() : tensor<8x32xf32> +// CHECK-TRANS: %[[TRANSP:.+]] = linalg.transpose +// CHECK-TRANS-SAME: ins(%[[TILE]] +// CHECK-TRANS-SAME: outs(%[[EMPTY]] +// CHECK-TRANS-SAME: permutation = [1, 0] +// CHECK-TRANS: %{{.+}} = tensor.insert_slice %[[TRANSP]] into %{{.+}} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 + %1, %loops:4 = transform.structured.tile_to_scf_for %0 [1, 1, 1, 1] +} + +// ----- + +func.func @pad_and_pack(%arg0: tensor<13x15xf32>, %arg1: tensor<2x8x8x2xf32>, %arg2: f32) -> tensor<2x8x8x2xf32> { + %0 = tensor.pack %arg0 padding_value(%arg2 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %arg1 : tensor<13x15xf32> -> tensor<2x8x8x2xf32> + return %0 : tensor<2x8x8x2xf32> +} +// CHECK-TRANS: func.func @pad_and_pack +// CHECK-TRANS-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-TRANS-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK-TRANS-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]] +// CHECK-TRANS: scf.for +// CHECK-TRANS: scf.for +// CHECK-TRANS: %[[SRC_SLICE]] = tensor.extract_slice %[[SRC]] +// CHECK-TRANS: %[[PAD:.+]] = tensor.pad %[[SRC_SLICE]] +// CHECK-TRANS: tensor.yield %[[PAD_VAL]] +// CHECK-TRANS: } : tensor to tensor<8x2xf32> +// CHECK-TRANS: %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32> +// CHECK-TRANS: %[[TRANSP:.+]] = linalg.transpose +// CHECK-TRANS-SAME: ins(%[[PAD]] : tensor<8x2xf32>) +// CHECK-TRANS-SAME: outs(%[[EMPTY]] : tensor<8x2xf32>) +// CHECK-TRANS-SAME: permutation = [0, 1] +// CHECK-TRANS: %{{.+}} = tensor.insert_slice %[[TRANSP]] into %{{.+}} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 + %1, %loops:2 = transform.structured.tile_to_scf_for %0 [1, 1] +} + +// ----- + + +func.func @KC_to_CKkc(%arg0: tensor<128x256xf32>, %arg1: tensor<32x4x32x8xf32>) -> tensor<32x4x32x8xf32> { + %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<128x256xf32> -> tensor<32x4x32x8xf32> + return %0 : tensor<32x4x32x8xf32> +} +// CHECK-TRANS-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 32)> +// CHECK-TRANS-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * -32 + 128, 32)> +// CHECK-TRANS-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 8)> +// CHECK-TRANS-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 * -8 + 256, 8)> +// CHECK-TRANS: func.func @KC_to_CKkc +// CHECK-TRANS-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-TRANS-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK-TRANS: %{{.+}} = scf.for %[[C:[a-zA-Z0-9]+]] = +// CHECK-TRANS: %{{.+}} = scf.for %[[K:[a-zA-Z0-9]+]] = +// CHECK-TRANS-DAG: %[[IN_K:.+]] = affine.apply #[[MAP0]](%[[K]]) +// CHECK-TRANS-DAG: %[[IN_K_SZ:.+]] = affine.min #[[MAP1]](%[[K]]) +// CHECK-TRANS-DAG: %[[IN_C:.+]] = affine.apply #[[MAP2]](%[[C]]) +// CHECK-TRANS-DAG: %[[IN_C_SZ:.+]] = affine.min #[[MAP3]](%[[C]]) +// CHECK-TRANS: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]] +// CHECK-TRANS-SAME: [%[[IN_K]], %[[IN_C]]] [%[[IN_K_SZ]], %[[IN_C_SZ]]] [1, 1] +// CHECK-TRANS: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]] +// CHECK-TRANS-SAME: [0, 0] [32, 8] [1, 1] : tensor to tensor<32x8xf32> +// CHECK-TRANS: %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32> +// CHECK-TRANS: %[[TRANSP:.+]] = linalg.transpose +// CHECK-TRANS-SAME: ins(%[[TILE]] +// CHECK-TRANS-SAME: outs(%[[EMPTY]] +// CHECK-TRANS-SAME: permutation = [0, 1] +// CHECK-TRANS: %[[SUB_ITER:.+]] = tensor.insert_slice %[[TRANSP]] into %{{[a-zA-Z0-9]+}} +// CHECK-TRANS-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<32x8xf32> into tensor<1x1x32x8xf32> +// CHECK-TRANS: %{{.+}} = tensor.insert_slice %[[SUB_ITER]] into %{{[a-zA-Z0-9]+}} +// CHECK-TRANS-SAME: [%[[C]], %[[K]], 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> into tensor<32x4x32x8xf32> +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 + %1, %loops:2 = transform.structured.tile_to_scf_for %0 [1, 1] +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -79,6 +79,11 @@ *this, "test-generalize-pad-tensor", llvm::cl::desc("Test transform pad tensor by copying with generic ops"), llvm::cl::init(false)}; + Option testGeneralizeTensorPackOp{ + *this, "test-generalize-tensor-pack", + llvm::cl::desc("Test transform that generalize pack ops into a sequence " + "of tensor and Linalg ops"), + llvm::cl::init(false)}; Option testSwapSubTensorPadTensor{ *this, "test-swap-subtensor-padtensor", llvm::cl::desc("Test rewrite of subtensor(tensor.pad) into " @@ -165,6 +170,12 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } +static void applyGeneralizeTensorPackPatterns(func::FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + patterns.add(funcOp.getContext()); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add(funcOp.getContext()); @@ -207,6 +218,8 @@ return applyPadTensorToGenericPatterns(getOperation()); if (testGeneralizePadTensor) return applyGeneralizePadTensorPatterns(getOperation()); + if (testGeneralizeTensorPackOp) + return applyGeneralizeTensorPackPatterns(getOperation()); if (testSwapSubTensorPadTensor) return applyExtractSliceOfPadTensorSwapPattern(getOperation()); if (testBubbleUpExtractSliceOpPattern)