diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1817,6 +1817,13 @@ Location loc, ArrayRef innerPermutation, ArrayRef outerPermutation); + + /// Check if this PackOp is like a simple pad operation. + /// In other words, this operation: + /// 1. adds useless dimensions (dimension of size 1), + /// 2. pads the other ones, and + /// 3. doesn't shuffle the dimensions + bool isLikePad(); }]; let hasCanonicalizeMethod = 1; @@ -1892,6 +1899,12 @@ Value transposedSource, ArrayRef innerPermutation, ArrayRef outerPermutation); + + /// Check if this UnPackOp is like a simple unpad operation. + /// In other words, this operation: + /// 1. drops useless dimensions (dimension of size 1), and + /// 2. reduces dimensions in place (i.e., no tranpose.) + bool isLikeUnPad(); }]; let hasCanonicalizeMethod = 1; diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h --- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h +++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h @@ -198,6 +198,15 @@ /// Method to check if an interchange vector is a permutation. bool isPermutationVector(ArrayRef interchange); +/// Return a permutation vector of size permSize that would result in moving +/// positions into desiredPositions. +/// +/// For example, permSize == 5, positions = {2, 4}, desiredPositions = {1, 0} +/// would result in a {4, 2, 0, 1, 3} permutation vector. +SmallVector +computePermutationVector(int64_t permSize, ArrayRef positions, + ArrayRef desiredPositions); + /// Helper to return a subset of `arrayAttr` as a vector of int64_t. // TODO: Port everything relevant to DenseArrayAttr and drop this util. SmallVector getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0, diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -520,6 +520,19 @@ RankedTensorType sourceType, ArrayRef reassociationIndices); +struct PackingMetadata { + SmallVector insertPositions; + SmallVector reassociations; +}; + +/// Given a vector of `positions` indices representing desired packing insertion +/// points into a target vector (i.e. pack/unpack.inner_dim_pos), compute the +/// final positions in the target shape as well as the reshape reassociations. +// Note: This should not be called with a large positions array (or the +// implementation needs to be updated to use an N.log N sort instead of +// repeated N^2 counts). +PackingMetadata computePackingMetadata(int64_t packedRank, + ArrayRef innerDimPos); } // namespace mlir #endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -139,81 +139,6 @@ return DiagnosedSilenceableFailure::success(); } -/// Return a permutation vector of size permSize that would result in moving -/// positions into desiredPositions. -/// -/// For example, permSize == 5, positions = {2, 4}, desiredPositions = {1, 0} -/// would result in a {4, 2, 0, 1, 3} permutation vector. -static SmallVector -computePermutationVector(int64_t permSize, ArrayRef positions, - ArrayRef desiredPositions) { - SmallVector res(permSize, -1); - DenseSet seen; - for (auto [pos, desiredPos] : llvm::zip_equal(positions, desiredPositions)) { - res[desiredPos] = pos; - seen.insert(pos); - } - int64_t nextPos = 0; - for (int64_t &entry : res) { - if (entry != -1) - continue; - while (seen.contains(nextPos)) - ++nextPos; - entry = nextPos; - ++nextPos; - } - return res; -} - -struct PackingMetadata { - SmallVector insertPositions; - SmallVector reassociations; -}; -/// Given a vector of `positions` indices representing desired packing insertion -/// points into a target vector (i.e. pack/unpack.inner_dim_pos), compute the -/// final positions in the target shape as well as the reshape reassociations. -// Note: This should not be called with a large positions array (or the -// implementation needs to be updated to use an N.log N sort instead of -// repeated N^2 counts). -static PackingMetadata computePackingMetadata(int64_t packedRank, - ArrayRef innerDimPos) { - PackingMetadata res; - res.insertPositions.reserve(innerDimPos.size()); - // The pack insert position is the position + the number of previously - // inserted positions + offset. - // The offset controls whether the packing dimension is the first or last. - // - // Example - // ======= - // Consider packing from a hypothetical ABCD layout to ABCDba whose - // pack.inner_dims is [1, 0]. The first step consists in undoing the - // permutation and producing AaBbCD. This is achieved purely by computing the - // insert positions of `b` and `a` into `ABCD`, starting from [1, 0]. One - // possibility, is to produce insert positions [2, 0], this would result in an - // aAbBCD layout (i.e. offset 0). The other possibility, is to produce insert - // positions [3, 1], this would result in an AaBbCD layout (i.e. offset 1). - // The latter is what we expect from packing. - int64_t offset = 1; - for (int64_t pos : innerDimPos) { - int64_t numInsertedBefore = llvm::count_if( - innerDimPos, [&pos](int64_t pos2) { return pos > pos2; }); - res.insertPositions.push_back(pos + numInsertedBefore + offset); - } - - DenseSet posSet(res.insertPositions.begin(), - res.insertPositions.end()); - res.reassociations.reserve(packedRank); - for (int64_t i = 1; i <= packedRank; ++i) { - if (!posSet.contains(i)) { - res.reassociations.push_back(ReassociationIndices{i - 1}); - continue; - } - res.reassociations.push_back(ReassociationIndices{i - 1, i}); - ++i; - } - return res; -} - //===----------------------------------------------------------------------===// // BufferizeToAllocationOp //===----------------------------------------------------------------------===// @@ -888,27 +813,59 @@ llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: "); DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL();); - // 5. Expand from the padded result to the stripMinedShape. - auto reshapeOp = rewriter.create( - loc, - RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape), - padOp.getResult(), packingMetadata.reassociations); - - // 6. Transpose stripMinedShape to packedShape. - SmallVector insertPositionsToLastDimsPerm = computePermutationVector( - packedRank, packingMetadata.insertPositions, lastDims); - auto transposeOp = rewriter.create( - loc, reshapeOp.getResult(), packOp.getDest(), - insertPositionsToLastDimsPerm); - - LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); - DBGS() << "reshape op: " << reshapeOp; DBGSNL(); - llvm::interleaveComma(insertPositionsToLastDimsPerm, - DBGS() << "insertPositionsToLastDimsPerm: "); - DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL();); - - // 7. Replace packOp by transposeOp. - rewriter.replaceOp(packOp, transposeOp->getResults()); + Operation *replacementOp = nullptr; + tensor::ExpandShapeOp reshapeOp = nullptr; + linalg::TransposeOp transposeOp = nullptr; + if (packOp.isLikePad()) { + // This pack is just a plain pad. + // Just insert the pad in the higher ranked tensor. + ArrayRef origShape = collapsed.getShape(); + auto emptyOp = + rewriter.create(loc, packedTensorType, ValueRange{}); + // offsets. + SmallVector zeros(packedRank, rewriter.getIndexAttr(0)); + // Strides. + OpFoldResult one = rewriter.getIndexAttr(1); + SmallVector ones(packedRank, one); + // The inner dimensions stay the same, but the outer ones are additional 1s. + SmallVector sizes(packedRank - origShape.size(), one); + for (int64_t dstSize : origShape) + sizes.push_back(rewriter.getIndexAttr(dstSize)); + + replacementOp = rewriter.create( + loc, /*source=*/padOp, /*dest=*/emptyOp, + /*offsets=*/zeros, sizes, + /*strides=*/ones); + LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); + DBGS() << "insert_slice op: " << replacementOp; DBGSNL();); + } else { + // 5. Expand from the padded result to the stripMinedShape. + reshapeOp = rewriter.create( + loc, + RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape), + padOp.getResult(), packingMetadata.reassociations); + + // 6. Transpose stripMinedShape to packedShape. + SmallVector insertPositionsToLastDimsPerm = + computePermutationVector(packedRank, packingMetadata.insertPositions, + lastDims); + transposeOp = rewriter.create( + loc, reshapeOp.getResult(), packOp.getDest(), + insertPositionsToLastDimsPerm); + + LLVM_DEBUG( + DBGSNL(); DBGSNL(); DBGSNL(); DBGS() << "reshape op: " << reshapeOp; + DBGSNL(); + llvm::interleaveComma(insertPositionsToLastDimsPerm, + DBGS() << "insertPositionsToLastDimsPerm: "); + DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL();); + replacementOp = transposeOp; + } + assert(padOp && replacementOp && + (packOp.isLikePad() || (reshapeOp && transposeOp)) && + "If pack is not a pad, all intermediates steps should happen"); + // Replace packOp by the final replacement op. + rewriter.replaceOp(packOp, replacementOp->getResults()); return LowerPackResult{padOp, reshapeOp, transposeOp}; } @@ -958,65 +915,90 @@ OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(unPackOp); - // 2. Compute the permutation vector to move the last `numPackedDims` into the - // `innerPosDims` of a shape of rank `packedRank`. - int64_t numPackedDims = unPackOp.getInnerDimsPos().size(); int64_t packedRank = packedTensorType.getRank(); - auto lastDims = llvm::to_vector( - llvm::seq(packedRank - numPackedDims, packedRank)); - PackingMetadata packingMetadata = - computePackingMetadata(packedRank, unPackOp.getInnerDimsPos()); - SmallVector lastDimsToInsertPositionsPerm = computePermutationVector( - packedRank, lastDims, packingMetadata.insertPositions); - // 3. Compute the stripMinedShape: this is the packed shape without outer and - // inner permutations. - SmallVector stripMinedShape(packedTensorType.getShape()); - applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm); + tensor::EmptyOp emptyOp = nullptr; + linalg::TransposeOp transposeOp = nullptr; + tensor::CollapseShapeOp reshapeOp = nullptr; + tensor::ExtractSliceOp extractSliceOp = nullptr; - // 4. Transpose packedShape to stripMinedShape. - RankedTensorType stripMinedTensorType = - RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); - RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType( - stripMinedTensorType, packingMetadata.reassociations); - auto emptyOp = - rewriter.create(loc, stripMinedTensorType, ValueRange{}); - auto transposeOp = rewriter.create( - loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm); - - LLVM_DEBUG( - DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions, - DBGS() << "insertPositions: "); - DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(), - DBGS() << "packedShape: "); - DBGSNL(); - llvm::interleaveComma(lastDimsToInsertPositionsPerm, - DBGS() << "lastDimsToInsertPositionsPerm: "); - DBGSNL(); llvm::interleaveComma( - packingMetadata.reassociations, DBGS() << "reassociations: ", - [&](ReassociationIndices ri) { - llvm::interleaveComma(ri, llvm::dbgs() << "|"); - }); - DBGSNL(); - llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: "); - DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL();); - - // 5. Collapse from the stripMinedShape to the padded result. - auto reshapeOp = rewriter.create( - loc, collapsedType, transposeOp->getResult(0), - packingMetadata.reassociations); - - // 6. ExtractSlice - auto destTensorType = unPackOp.getDest().getType().cast(); - int64_t destRank = destTensorType.getRank(); OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1); - auto extractSliceOp = rewriter.create( - loc, destTensorType, reshapeOp->getResult(0), - SmallVector(destRank, zero), - tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)), - SmallVector(destRank, one)); - - // 7. Replace unPackOp by transposeOp. + auto destTensorType = unPackOp.getDest().getType().cast(); + if (unPackOp.isLikeUnPad()) { + // This unpack is just a plain pad. + // Just extract the pad from the higher ranked tensor. + ArrayRef destShape = destTensorType.getShape(); + SmallVector sizes(packedRank - destShape.size(), one); + for (int64_t dstSize : destShape) { + sizes.push_back(rewriter.getIndexAttr(dstSize)); + } + extractSliceOp = rewriter.create( + loc, destTensorType, unPackOp.getSource(), + SmallVector(packedRank, zero), sizes, + SmallVector(packedRank, one)); + } else { + // 2. Compute the permutation vector to move the last `numPackedDims` into + // the `innerPosDims` of a shape of rank `packedRank`. + int64_t numPackedDims = unPackOp.getInnerDimsPos().size(); + auto lastDims = llvm::to_vector( + llvm::seq(packedRank - numPackedDims, packedRank)); + PackingMetadata packingMetadata = + computePackingMetadata(packedRank, unPackOp.getInnerDimsPos()); + SmallVector lastDimsToInsertPositionsPerm = + computePermutationVector(packedRank, lastDims, + packingMetadata.insertPositions); + + // 3. Compute the stripMinedShape: this is the packed shape without outer + // and inner permutations. + SmallVector stripMinedShape(packedTensorType.getShape()); + applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm); + + // 4. Transpose packedShape to stripMinedShape. + RankedTensorType stripMinedTensorType = + RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); + RankedTensorType collapsedType = + tensor::CollapseShapeOp::inferCollapsedType( + stripMinedTensorType, packingMetadata.reassociations); + emptyOp = rewriter.create(loc, stripMinedTensorType, + ValueRange{}); + transposeOp = rewriter.create( + loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm); + + LLVM_DEBUG( + DBGSNL(); DBGSNL(); llvm::interleaveComma( + packingMetadata.insertPositions, DBGS() << "insertPositions: "); + DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(), + DBGS() << "packedShape: "); + DBGSNL(); + llvm::interleaveComma(lastDimsToInsertPositionsPerm, + DBGS() << "lastDimsToInsertPositionsPerm: "); + DBGSNL(); llvm::interleaveComma( + packingMetadata.reassociations, DBGS() << "reassociations: ", + [&](ReassociationIndices ri) { + llvm::interleaveComma(ri, llvm::dbgs() << "|"); + }); + DBGSNL(); + llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: "); + DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL();); + + // 5. Collapse from the stripMinedShape to the padded result. + reshapeOp = rewriter.create( + loc, collapsedType, transposeOp->getResult(0), + packingMetadata.reassociations); + + // 6. ExtractSlice + int64_t destRank = destTensorType.getRank(); + extractSliceOp = rewriter.create( + loc, destTensorType, reshapeOp->getResult(0), + SmallVector(destRank, zero), + tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)), + SmallVector(destRank, one)); + } + + assert(extractSliceOp && + (unPackOp.isLikeUnPad() || (emptyOp && transposeOp && reshapeOp)) && + "If unPackOp is not a pad, all intermediate steps should happen"); + // Replace unPackOp by extractSliceOp. rewriter.replaceOp(unPackOp, extractSliceOp->getResults()); return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp}; diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3709,6 +3709,48 @@ return success(); } +template +static bool isLikePadUnPad(PackOrUnpackOp packOp, + RankedTensorType packedTensorType) { + static_assert(std::is_same::value || + std::is_same::value, + "Function meant for pack/unpack"); + // This is a pad if packing only adds ones and we don't transpose dimensions. + + // Check that we are not transposing any dimensions. + int64_t expectedDimIdx = 0; + ArrayRef innerDimsPos = packOp.getInnerDimsPos(); + for (int64_t dimIdx : innerDimsPos) { + if (dimIdx != expectedDimIdx++) { + // This dimension doesn't happen in order. + return false; + } + } + int64_t numPackedDims = innerDimsPos.size(); + ArrayRef packedShape = packedTensorType.getShape(); + int64_t packedRank = packedTensorType.getRank(); + // At this point we know that we are taking numPackedDims outer + // dimensions and pushing them all the way as the inner most dimensions. + // What's left on the outer most dimensions is in this order: + // - the factor of the packed dimensions, then + // - the untouched dimensions + // This shifting inward of dimensions is a no-op (as opposed to a transpose) + // if all the dimensions that bubble outerward are ones. + // Therefore check that all the dimensions but the numPackedDims inner most + // ones are ones. + for (int i = 0; i != packedRank - numPackedDims; ++i) { + if (packedShape[i] != 1) + return false; + } + return true; +} + +bool PackOp::isLikePad() { + auto packedTensorType = + (*this)->getResultTypes().front().cast(); + return isLikePadUnPad(*this, packedTensorType); +} + //===----------------------------------------------------------------------===// // UnPackOp //===----------------------------------------------------------------------===// @@ -3822,6 +3864,10 @@ return success(); } +bool UnPackOp::isLikeUnPad() { + RankedTensorType packedTensorType = getSourceType(); + return isLikePadUnPad(*this, packedTensorType); +} //===----------------------------------------------------------------------===// // Common Canonicalizers and Folders. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp --- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp +++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp @@ -213,6 +213,27 @@ return seenVals.size() == interchange.size(); } +SmallVector +mlir::computePermutationVector(int64_t permSize, ArrayRef positions, + ArrayRef desiredPositions) { + SmallVector res(permSize, -1); + DenseSet seen; + for (auto [pos, desiredPos] : llvm::zip_equal(positions, desiredPositions)) { + res[desiredPos] = pos; + seen.insert(pos); + } + int64_t nextPos = 0; + for (int64_t &entry : res) { + if (entry != -1) + continue; + while (seen.contains(nextPos)) + ++nextPos; + entry = nextPos; + ++nextPos; + } + return res; +} + SmallVector mlir::getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront, unsigned dropBack) { diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -450,3 +450,42 @@ return CollapseShapeRankReducingSliceSimplificationInfo{ sliceType, newReassociationIndices}; } + +PackingMetadata mlir::computePackingMetadata(int64_t packedRank, + ArrayRef innerDimPos) { + PackingMetadata res; + res.insertPositions.reserve(innerDimPos.size()); + // The pack insert position is the position + the number of previously + // inserted positions + offset. + // The offset controls whether the packing dimension is the first or last. + // + // Example + // ======= + // Consider packing from a hypothetical ABCD layout to ABCDba whose + // pack.inner_dims is [1, 0]. The first step consists in undoing the + // permutation and producing AaBbCD. This is achieved purely by computing the + // insert positions of `b` and `a` into `ABCD`, starting from [1, 0]. One + // possibility, is to produce insert positions [2, 0], this would result in an + // aAbBCD layout (i.e. offset 0). The other possibility, is to produce insert + // positions [3, 1], this would result in an AaBbCD layout (i.e. offset 1). + // The latter is what we expect from packing. + int64_t offset = 1; + for (int64_t pos : innerDimPos) { + int64_t numInsertedBefore = llvm::count_if( + innerDimPos, [&pos](int64_t pos2) { return pos > pos2; }); + res.insertPositions.push_back(pos + numInsertedBefore + offset); + } + + DenseSet posSet(res.insertPositions.begin(), + res.insertPositions.end()); + res.reassociations.reserve(packedRank); + for (int64_t i = 1; i <= packedRank; ++i) { + if (!posSet.contains(i)) { + res.reassociations.push_back(ReassociationIndices{i - 1}); + continue; + } + res.reassociations.push_back(ReassociationIndices{i - 1, i}); + ++i; + } + return res; +} diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir --- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -21,14 +21,79 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !pdl.operation): - %pack = transform.structured.match ops{["tensor.pack"]} in %module_op + %pack = transform.structured.match ops{["tensor.pack"]} in %module_op : (!pdl.operation) -> !transform.op<"tensor.pack"> - transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">) + transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">) -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">) } // ----- +// CHECK-LABEL: func.func @pack_as_pad( +func.func @pack_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> { + %cst_0 = arith.constant 0.0 : f32 + + // tensor.pack is lowered to tensor.pad + tensor.insert_slice + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] + // CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32> + // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32> + // CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]] + // offsets. + // CHECK-SAME: [0, 0, 0, 0, 0, 0, 0, 0] + // sizes. + // CHECK-SAME: [1, 1, 1, 1, 136, 64, 16, 16] + // strides multipliers. + // CHECK-SAME: [1, 1, 1, 1, 1, 1, 1, 1] + // CHECK-SAME: : tensor<136x64x16x16xf32> into tensor<1x1x1x1x136x64x16x16xf32> + // CHECK: return %[[RES]] + %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1 + : tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32> + return %pack : tensor<1x1x1x1x136x64x16x16xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %pack = transform.structured.match ops{["tensor.pack"]} in %module_op + : (!pdl.operation) -> !transform.op<"tensor.pack"> + transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">) + -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">) +} + +// ----- + +// Check that we don't lower the following pack as a pad. +// Although all the outer most dimensions in the resulting shape are 1s, +// some of the original dimensions are not part of the inner_dims_pos, hence +// some transpose needs to happen. +// CHECK-LABEL: func.func @pack_not_a_pad( +func.func @pack_not_a_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x16x16x136x64xf32>) -> tensor<1x1x16x16x136x64xf32> { + %cst_0 = arith.constant 0.0 : f32 + + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] + // CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32> + // CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3], [4], [5]] + // CHECK-SAME: : tensor<136x64x16x16xf32> into tensor<1x136x1x64x16x16xf32> + // CHECK: linalg.transpose + // CHECK-SAME: ins(%{{.*}} : tensor<1x136x1x64x16x16xf32>) + // CHECK-SAME: outs(%{{.*}} : tensor<1x1x16x16x136x64xf32>) + // CHECK-SAME: permutation = [0, 2, 4, 5, 1, 3] + + %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1] inner_tiles = [136, 64] into %arg1 + : tensor<129x47x16x16xf32> -> tensor<1x1x16x16x136x64xf32> + return %pack : tensor<1x1x16x16x136x64xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %pack = transform.structured.match ops{["tensor.pack"]} in %module_op + : (!pdl.operation) -> !transform.op<"tensor.pack"> + transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">) + -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">) +} + +// ----- // CHECK-LABEL: func.func @unpack( func.func @unpack(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> { %cst_0 = arith.constant 0.0 : f32 @@ -38,9 +103,9 @@ // CHECK-SAME: ins(%{{.*}} : tensor<17x2x16x16x32x8xf32>) // CHECK-SAME: outs(%{{.*}} : tensor<17x8x2x32x16x16xf32>) // CHECK-SAME: permutation = [0, 5, 1, 4, 2, 3] - // CHECK: tensor.collapse_shape {{.*}}[0, 1], [2, 3], [4], [5]] + // CHECK: tensor.collapse_shape {{.*}}[0, 1], [2, 3], [4], [5]] // CHECK-SAME: : tensor<17x8x2x32x16x16xf32> into tensor<136x64x16x16xf32> - // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1] + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1] // CHECK-SAME: : tensor<136x64x16x16xf32> to tensor<129x47x16x16xf32> %pack = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1 : tensor<17x2x16x16x32x8xf32> -> tensor<129x47x16x16xf32> @@ -49,10 +114,41 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !pdl.operation): - %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op + %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op + : (!pdl.operation) -> !transform.op<"tensor.unpack"> + transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">) + -> (!transform.op<"tensor.empty">, + !transform.op<"linalg.transpose">, + !transform.op<"tensor.collapse_shape">, + !transform.op<"tensor.extract_slice">) +} + +// ----- +// When an unpack is a plain 'unpad', lower it to a simple extract_slice. +// CHECK-LABEL: func.func @unpack_as_pad( +func.func @unpack_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> { + %cst_0 = arith.constant 0.0 : f32 + + // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32> + // CHECK: %[[RES:.*]] = tensor.extract_slice %[[ARG0]] + // offsets. + // CHECK-SAME: [0, 0, 0, 0, 0, 0, 0, 0] + // sizes. + // CHECK-SAME: [1, 1, 1, 1, 129, 47, 16, 16] + // strides multiplers. + // CHECK-SAME: [1, 1, 1, 1, 1, 1, 1, 1] + // CHECK-SAME: : tensor<1x1x1x1x136x64x16x16xf32> to tensor<129x47x16x16xf32> + %pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1 + : tensor<1x1x1x1x136x64x16x16xf32> -> tensor<129x47x16x16xf32> + return %pack : tensor<129x47x16x16xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op : (!pdl.operation) -> !transform.op<"tensor.unpack"> - transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">) - -> (!transform.op<"tensor.empty">, + transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">) + -> (!transform.op<"tensor.empty">, !transform.op<"linalg.transpose">, !transform.op<"tensor.collapse_shape">, !transform.op<"tensor.extract_slice">)