diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1663,6 +1663,163 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// PackOp +//===----------------------------------------------------------------------===// + +class Tensor_RelayoutOp traits = []> : + Tensor_Op, + DestinationStyleOpInterface, + Pure, + DeclareOpInterfaceMethods, + TypesMatchWith<"result type matches type of dest", + "dest", "result", + "$_self">])> { + + code commonExtraClassDeclaration = [{ + int64_t getSourceRank() { return getSource().getType().getRank(); }; + int64_t getDestRank() { return getDest().getType().getRank(); }; + RankedTensorType getSourceType() { + return getSource().getType().cast(); }; + RankedTensorType getDestType() { + return getDest().getType().cast(); }; + + // Return position for init operand. Init operand is `dest`. + std::pair getDpsInitsPositionRange() { + return {1, 2}; // `dest` operand + } + + // Return a mapping from positions `inner_dims_pos` to their + // tile factors. + DenseMap getDimAndTileMapping(); + + // Return the tile sizes as OpFoldResult. + SmallVector getMixedTiles(); + + // Return the tile sizes as `int64_t`. If a tile size is dynamic + // a sentinel `kDynamicSize` is introduced at that position in + // the returned vector. + SmallVector getStaticTiles(); + }]; + + let hasVerifier = 1; +} + +def Tensor_PackOp : Tensor_RelayoutOp<"pack", [ + AttrSizedOperandSegments]> { + let summary = "tensor pack operation"; + let description = [{ + The pack operation converts an `input` into a tiled and packed layout. The + dimensions to be tiled are obtained from `inner_dims_pos` and the size of the + tile is obtained from `inner_tiles`. The dimensions listed in `inner_dims_pos` + do not need to be contiguous in which case the tile will get transposed. We + handle only full tiles if `padding_value` is not set; it is UB if the tile does + not perfectly divide the dimension. If `padding_value` is set, it will pad + along high dimensions, to make full tiles. As optional input, the operation + takes `outer_dims_perm` that allows to permute the tiled loops. + + Example NC_to_NCnc: + + ```mlir + tensor.pack %source inner_dims_pos = [0, 1] + inner_tiles = [8, 32] into %dest : tensor<128x256xf32> -> tensor<16x8x8x32xf32> + ``` + Example CK to KCck + + ```mlir + tensor.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] + inner_tiles = [8, 32] into %dest : tensor<128x256xf32> -> tensor<8x16x8x32xf32> + ``` + + In all cases, dimension at position 0 in the input tensor (128) is tiled + with a factor of 8, while dimension at position 1 (256) is tiled with a factor + of 32. In the second example, the outer loop are interchanged according to + `outer_dims_perm`. + + Example NC_to_NCnc with padding: + + ```mlir + tensor.pack %arg padding_value(%pad : f32) inner_dims_pos = [0, 1] + inner_tiles = [8, 2] into %arg1 : tensor<13x15xf32> -> tensor<2x8x8x2xf32> + ``` + + }]; + let arguments = (ins AnyRankedTensor:$source, + AnyRankedTensor:$dest, + Optional:$padding_value, + DefaultValuedOptionalAttr:$outer_dims_perm, + I64ArrayAttr:$inner_dims_pos, + Variadic:$inner_tiles, + I64ArrayAttr:$static_inner_tiles); + let results = (outs AnyRankedTensor:$result); + let assemblyFormat = [{ + $source + (`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)? + (`outer_dims_perm` `=` $outer_dims_perm^)? + `inner_dims_pos` `=` $inner_dims_pos + `inner_tiles` `=` + custom($inner_tiles, $static_inner_tiles, + "ShapedType::kDynamicSize") + `into` $dest attr-dict `:` type($source) `->` type($dest) + }]; + + let extraClassDeclaration = commonExtraClassDeclaration # [{ + // Method to get the `ShapedType` of the result based on the inner tiles, + // position of the inner tiles (innerDimsPos) and interchange vector of + // outer loops (outerDimsPerm). + static ShapedType getPackedType(ShapedType sourceType, + ArrayRef innerTileSizes, ArrayRef innerDimsPos, + ArrayRef outerDimsPerm = {}); + }]; +} + +//===----------------------------------------------------------------------===// +// UnPackOp +//===----------------------------------------------------------------------===// + +def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> { + let summary = "tensor unapack operation"; + let description = [{ + The unpack operation converts a tiled and packed input to an unpacked + output. See `pack` for more details on `inner_tiles`, `inner_dims_pos` and + `outer_dims_perm`; it is UB if the tile does not perfectly divide the + dimension. Optionally, the operation also supports permuting the tiled loops. + + Example NCnc_to_NC: + + ```mlir + tensor.unpack %source inner_dims_pos = [0, 1] + inner_tiles = [8, 32] into %dest : tensor<16x8x8x32xf32> -> tensor<128x256xf32> + ``` + + Example CK to KCck: + + ```mlir + tensor.unapck %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] + inner_tiles = [8, 32] into %dest : tensor<8x16x8x32xf32> -> tensor<128x256xf32> + ``` + }]; + let arguments = (ins AnyRankedTensor:$source, + AnyRankedTensor:$dest, + DefaultValuedOptionalAttr:$outer_dims_perm, + I64ArrayAttr:$inner_dims_pos, + Variadic:$inner_tiles, + I64ArrayAttr:$static_inner_tiles); + let results = (outs AnyRankedTensor:$result); + let assemblyFormat = [{ + $source + (`outer_dims_perm` `=` $outer_dims_perm^)? + `inner_dims_pos` `=` $inner_dims_pos + `inner_tiles` `=` + custom($inner_tiles, $static_inner_tiles, + "ShapedType::kDynamicSize") + `into` $dest attr-dict `:` type($source) `->` type($dest) + }]; + + let extraClassDeclaration = commonExtraClassDeclaration; +} + //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -18,10 +18,12 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/TypeSwitch.h" #include using namespace mlir; @@ -2944,6 +2946,339 @@ return SplatElementsAttr::get(getType(), {constOperand}); } +//===----------------------------------------------------------------------===// +// PackOp/UnPackOp Common +//===----------------------------------------------------------------------===// + +template +static LogicalResult +reifyResultShapesImpl(OpTy op, OpBuilder &builder, + ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + static_assert(llvm::is_one_of::value, + "applies to only pack or unpack operations"); + int64_t destRank = op.getDestRank(); + reifiedReturnShapes.resize(1, SmallVector(destRank)); + for (auto dim : llvm::seq(0, destRank)) { + reifiedReturnShapes[0][dim] = + builder.createOrFold(op.getLoc(), op.getDest(), dim); + } + return success(); +} + +template +static DenseMap getDimAndTileMappingImpl(OpTy op) { + static_assert(llvm::is_one_of::value, + "applies to only pack or unpack operations"); + DenseMap dimAndTileMapping; + SmallVector dimsToTile = + extractFromI64ArrayAttr(op.getInnerDimsPos()); + SmallVector tiles = op.getMixedTiles(); + assert(tiles.size() == dimsToTile.size() && + "tiles must match indices of dimension to block"); + // bind the dimension `i` with the tile factor. + for (auto i : llvm::seq(0, dimsToTile.size())) + dimAndTileMapping[dimsToTile[i]] = tiles[i]; + return dimAndTileMapping; +} + +template +static SmallVector getMixedTilesImpl(OpTy op) { + static_assert(llvm::is_one_of::value, + "applies to only pack or unpack operations"); + SmallVector mixedInnerTiles; + unsigned dynamicValIndex = 0; + for (Attribute attr : op.getStaticInnerTiles()) { + auto tileAttr = attr.cast(); + if (!ShapedType::isDynamic(tileAttr.getInt())) + mixedInnerTiles.push_back(tileAttr); + else + mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]); + } + return mixedInnerTiles; +} + +template +static SmallVector getStaticTilesImpl(OpTy op) { + static_assert(llvm::is_one_of::value, + "applies to only pack or unpack operations"); + SmallVector dynamicTiles; + SmallVector staticTiles; + dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles, + ShapedType::kDynamicSize); + return staticTiles; +} + +/// Return true if `dimsPos` is invalid. It is invalid when: a) it contains +/// duplicate. b) At least one dimension is out of bound (`dimPos` is >= 0 and < +/// rank). c) the number of elements in `dimsPos` is > than `rank`. +static bool isInvalid(ArrayRef dimsPos, size_t rank) { + // early exit. + size_t dimsPosSize = dimsPos.size(); + if (dimsPosSize > rank) + return true; + DenseSet uniqued; + for (int64_t dim : dimsPos) + uniqued.insert(dim); + if (dimsPosSize != uniqued.size()) + return true; + return llvm::any_of(dimsPos, [rank](int64_t dimPos) { + return dimPos < 0 || dimPos >= static_cast(rank); + }); +} + +/// Returns true if the dimension of `sourceShape` is smaller than the dimension +/// of the `limitShape`. +static bool isSmallerThan(ArrayRef sourceShape, + ArrayRef limitShape) { + assert( + sourceShape.size() == limitShape.size() && + "expected source shape rank, and limit of the shape to have same rank"); + return llvm::all_of( + llvm::zip(sourceShape, limitShape), [](std::tuple it) { + int64_t sourceExtent = std::get<0>(it); + int64_t limit = std::get<1>(it); + return sourceExtent == ShapedType::kDynamicSize || + limit == ShapedType::kDynamicSize || sourceExtent <= limit; + }); +} + +template +static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { + static_assert(llvm::is_one_of::value, + "applies to only pack or unpack operations"); + Operation *op = packOrUnPack.getOperation(); + + // Return true if we have a zero-value tile. + auto hasZeros = [&](ArrayRef tiles) { + return llvm::any_of( + tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); }); + }; + + // Verify tiles. Do not allow zero tiles. + SmallVector mixedTiles = packOrUnPack.getMixedTiles(); + if (hasZeros(mixedTiles)) + return op->emitError("invalid zero tile factor"); + + // Verify inner_dims_pos and outer_dims_perm. + ShapedType unpackedType = (std::is_same::value) + ? packOrUnPack.getSourceType() + : packOrUnPack.getDestType(); + size_t unpackedRank = unpackedType.getRank(); + SmallVector innerDimsPos = + extractFromI64ArrayAttr(packOrUnPack.getInnerDimsPos()); + SmallVector outerDimPerm = + extractFromI64ArrayAttr(packOrUnPack.getOuterDimsPerm()); + if (isInvalid(innerDimsPos, unpackedRank)) + return op->emitError("invalid inner_dims_pos vector"); + if (isInvalid(outerDimPerm, unpackedRank)) + return op->emitError("invalid outer_dims_perm vector"); + + // Tiling factors must be less or equal than the input rank for pack (or + // output rank for unpack), and must match the number of `inner_dims_pos`. + if (mixedTiles.size() > unpackedRank) { + return op->emitError("tiling factors must be less or equal than the " + "input rank for pack or output rank for unpack"); + } + if (mixedTiles.size() != innerDimsPos.size()) { + return op->emitError( + "tiling factors must equal the number of dimensions to tile"); + } + + ShapedType packedType = (std::is_same::value) + ? packOrUnPack.getDestType() + : packOrUnPack.getSourceType(); + size_t packedRank = packedType.getRank(); + // Require output rank to match input rank + number of blocking factors. + if (unpackedRank + mixedTiles.size() != packedRank) { + return op->emitError( + "packed rank must equal unpacked rank + tiling factors"); + } + + // Verify result shape is greater than the minimum expected + // by the pack operation, and that the output shape + // represents full tiles. + ShapedType expectedPackedType = PackOp::getPackedType( + unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm); + if (!isSmallerThan(expectedPackedType.getShape(), packedType.getShape())) { + return op->emitError("the shape of output is not large enough to hold the " + "packed data. Expected at least ") + << expectedPackedType << ", got " << packedType; + } + if (!llvm::all_of( + llvm::zip(packedType.getShape().take_back(mixedTiles.size()), + mixedTiles), + [](std::tuple it) { + Optional constTileSize = + getConstantIntValue(std::get<1>(it)); + int64_t shape = std::get<0>(it); + if (!constTileSize) { + // If specified tile size is dynamic, output shape should + // be dynamic too. + return shape == ShapedType::kDynamicSize; + } else { + if (shape == ShapedType::kDynamicSize) { + // For the shape being dynamic when tile size is + // specified, return true. In canonical form a constant + // tile size should lead to constant shape of the tiled + // dimension, but not needed for verification. + return true; + } + return shape == constTileSize.value(); + } + })) { + return op->emitError("mismatch in inner tile sizes specified and shaped of " + "tiled dimension in the packed type"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// PackOp +//===----------------------------------------------------------------------===// + +void PackOp::getAsmResultNames(function_ref setNameFn) { + setNameFn(getResult(), "pack"); +} + +LogicalResult +PackOp::reifyResultShapes(OpBuilder &builder, + ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + return reifyResultShapesImpl(*this, builder, reifiedReturnShapes); +} + +DenseMap PackOp::getDimAndTileMapping() { + return getDimAndTileMappingImpl(*this); +} + +SmallVector PackOp::getMixedTiles() { + return getMixedTilesImpl(*this); +} + +SmallVector PackOp::getStaticTiles() { + return getStaticTilesImpl(*this); +} + +/// Check if we have enough static information to catch undefined behavior when +/// the tile size does not divide perfectly the dimension of the input tensor. +static bool +areNotFullTiles(ArrayRef inputShape, + DenseMap const &dimAndTileMapping) { + int64_t rank = inputShape.size(); + for (int64_t dim = 0; dim < rank; dim++) { + if (inputShape[dim] == ShapedType::kDynamicSize) + continue; + auto it = dimAndTileMapping.find(dim); + if (it != dimAndTileMapping.end()) { + Optional constantTile = getConstantIntValue(it->second); + if (!constantTile) + continue; + if (inputShape[dim] % (*constantTile) != 0) + return true; + } + } + return false; +} + +LogicalResult PackOp::verify() { + if (failed(commonVerifierPackAndUnPackOp(*this))) + return failure(); + // Verify padding value, and bail out if the tile does not divide the + // dimension fully. In the case of dynamic tile factors or dimensions, having + // a partial tile is undefined behavior. + if (auto paddingValue = getPaddingValue()) { + if (paddingValue.getType() != getSourceType().getElementType()) { + return emitOpError("expected padding_value has ") + << getSourceType().getElementType() + << " but got: " << paddingValue.getType(); + } + } + auto dimAndTileMapping = getDimAndTileMapping(); + if (!getPaddingValue() && + areNotFullTiles(getSourceType().getShape(), dimAndTileMapping)) { + return emitOpError("invalid tile factor provided. Only full tiles are " + "supported when padding_value is not set"); + } + return success(); +} + +/// Get the expected packed type based on source type, tile factors, position of +/// the inner tiles and permutation of the outer tiled loop. +ShapedType PackOp::getPackedType(ShapedType sourceType, + ArrayRef innerTileSizes, + ArrayRef innerDimsPos, + ArrayRef outerDimsPerm) { + // Returns a vector that interchanges `elements` starting at offset `offset` + // based on the indexes in `interchangeVector`. + auto interchange = [](ArrayRef elements, + ArrayRef interchangeVector) { + SmallVector vec = llvm::to_vector(elements); + for (auto en : llvm::enumerate(interchangeVector)) { + vec[en.index()] = elements[en.value()]; + } + return vec; + }; + + SmallVector resultShape = llvm::to_vector(sourceType.getShape()); + for (auto tiledDim : llvm::enumerate(innerDimsPos)) { + if (ShapedType::isDynamic(resultShape[tiledDim.value()])) + continue; + if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) { + resultShape[tiledDim.value()] = ShapedType::kDynamicSize; + continue; + } + resultShape[tiledDim.value()] = ceilDiv(resultShape[tiledDim.value()], + innerTileSizes[tiledDim.index()]); + } + + // Swap tile loops if outer_dims_perm is available. + resultShape = interchange(resultShape, outerDimsPerm); + + // Append the inner tile dimensions. + resultShape.append(innerTileSizes.begin(), innerTileSizes.end()); + return TypeSwitch(sourceType) + .Case([&](auto shapedType) { + return RankedTensorType::get(resultShape, shapedType.getElementType()); + }) + .Case([&](auto shapedType) { + return MemRefType::get(resultShape, shapedType.getElementType()); + }) + .Default([&](Type t) { + assert(false && "unexpected type"); + return nullptr; + }); +} + +//===----------------------------------------------------------------------===// +// UnPackOp +//===----------------------------------------------------------------------===// + +void UnPackOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "unpack"); +} + +LogicalResult +UnPackOp::reifyResultShapes(OpBuilder &builder, + ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + return reifyResultShapesImpl(*this, builder, reifiedReturnShapes); +} + +DenseMap UnPackOp::getDimAndTileMapping() { + return getDimAndTileMappingImpl(*this); +} + +SmallVector UnPackOp::getMixedTiles() { + return getMixedTilesImpl(*this); +} + +SmallVector UnPackOp::getStaticTiles() { + return getStaticTilesImpl(*this); +} + +LogicalResult UnPackOp::verify() { + return commonVerifierPackAndUnPackOp(*this); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -522,3 +522,92 @@ %out = tensor.empty(%sz) : tensor<2x?x?x5xf32> return } + +// ----- + +func.func @pack_invalid_no_padding_no_full_tiles(%input: tensor<256x128xf32>, %output: tensor<8x8x16x33xf32>) -> tensor<8x8x16x33xf32> { + // expected-error@+1 {{invalid tile factor provided. Only full tiles are supported when padding_value is not set}} + %0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 33] into %output : tensor<256x128xf32> -> tensor<8x8x16x33xf32> + return %0 : tensor<8x8x16x33xf32> +} + +// ----- + +func.func @pad_and_pack_invalid_type(%input: tensor<13x15xf32>, %output: tensor<2x8x8x2xf32>, %pad: i32) -> tensor<2x8x8x2xf32> { + // expected-error@+1 {{expected padding_value has 'f32' but got: 'i32'}} + %0 = tensor.pack %input padding_value(%pad: i32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<13x15xf32> -> tensor<2x8x8x2xf32> + return %0 : tensor<2x8x8x2xf32> +} + +// ----- + +func.func @pack_invalid_inner_dims_pos_vector(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> { + // expected-error@+1 {{invalid inner_dims_pos vector}} + %0 = tensor.pack %input inner_dims_pos = [2, 0] inner_tiles = [2, 2] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32> + return %0 : tensor<8x8x32x16xf32> +} + +// ----- + +func.func @pack_invalid_duplicate_element_in_inner_dims(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> { + // expected-error@+1 {{invalid inner_dims_pos vector}} + %0 = tensor.pack %input inner_dims_pos = [1, 1] inner_tiles = [2, 2] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32> + return %0 : tensor<8x8x32x16xf32> +} + +// ----- + +func.func @pack_invalid_duplicate_element_in_outer_perm(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> { + // expected-error@+1 {{invalid outer_dims_perm vector}} + %0 = tensor.pack %input outer_dims_perm = [1, 1] inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32> + return %0 : tensor<8x8x32x16xf32> +} + +// ----- + +func.func @unpack_invalid_out_of_bound_outer_perm(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> { + // expected-error@+1 {{invalid outer_dims_perm vector}} + %0 = tensor.unpack %output outer_dims_perm = [2, 1] inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %input : tensor<8x8x32x16xf32> -> tensor<256x128xf32> + return %0 : tensor<256x128xf32> +} + +// ----- + +func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> { + // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x8x32x16xf32>'}} + %0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32> + return %0 : tensor<8x8x32x16xf32> +} + +// ----- + +func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x32x16xf32>) -> tensor<256x128xf32> { + // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x32x16xf32>'}} + %0 = tensor.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x32x16xf32> -> tensor<256x128xf32> + return %0 : tensor<256x128xf32> +} + +// ----- + +func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> { + // expected-error@+1 {{invalid zero tile factor}} + %0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [0, 2] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32> + return %0 : tensor<8x8x32x16xf32> +} + +// ----- +func.func @pack_mismatch_inner_tile_size_and_output_shape( + %input : tensor, %output : tensor) -> tensor { + // expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}} + %0 = tensor.pack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor -> tensor + return %0 : tensor +} + +// ----- + +func.func @unpack_mismatch_inner_tile_size_and_output_shape( + %input : tensor, %output : tensor) -> tensor { + // expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}} + %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor -> tensor + return %0 : tensor +} diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -293,3 +293,43 @@ (tensor<1x3x4xf32>, tensor<4x5x6xf32>, tensor<1x3x2xi32>) -> tensor<4x5x6xf32> return } + +func.func @pack_nc_to_ncnc(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) -> tensor<128x256xf32> { + %0 = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32> + %1 = tensor.empty() : tensor<128x256xf32> + %2 = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %1 : tensor<4x16x32x16xf32> -> tensor<128x256xf32> + return %2 : tensor<128x256xf32> +} + +// CHECK: func.func @pack_nc_to_ncnc( +// CHECK-SAME: %[[ARG0:.*]]: tensor<128x256xf32>, +// CHECK-SAME: %[[ARG1:.*]]: tensor<4x16x32x16xf32>) -> tensor<128x256xf32> { +// CHECK: %[[PACKED:.*]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %[[ARG1]] : tensor<128x256xf32> -> tensor<4x16x32x16xf32> +// CHECK: %[[BUFF:.*]] = tensor.empty() : tensor<128x256xf32> +// CHECK: %[[UNPACKED:.*]] = tensor.unpack %[[PACKED]] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %[[BUFF]] : tensor<4x16x32x16xf32> -> tensor<128x256xf32> +// CHECK: return %[[UNPACKED]] : tensor<128x256xf32> +// CHECK: } + +func.func @pack_nc_to_ncnc_with_padding(%source: tensor<13x15xf32>, %dest: tensor<2x8x8x2xf32>, %padding: f32) -> tensor<13x15xf32> { + %0 = tensor.pack %source padding_value(%padding : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %dest : tensor<13x15xf32> -> tensor<2x8x8x2xf32> + %1 = tensor.empty() : tensor<13x15xf32> + %2 = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %1 : tensor<2x8x8x2xf32> -> tensor<13x15xf32> + return %2 : tensor<13x15xf32> +} + +// CHECK: func.func @pack_nc_to_ncnc_with_padding( +// CHECK-SAME: %[[ARG0:.*]]: tensor<13x15xf32>, +// CHECK-SAME: %[[ARG1:.*]]: tensor<2x8x8x2xf32>, +// CHECK-SAME: %[[PADDING:.*]]: f32) -> tensor<13x15xf32> { +// CHECK: %[[PACKED:.*]] = tensor.pack %[[ARG0]] padding_value(%[[PADDING]] : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %[[ARG1]] : tensor<13x15xf32> -> tensor<2x8x8x2xf32> +// CHECK: %[[BUFF:.*]] = tensor.empty() : tensor<13x15xf32> +// CHECK: %[[UNPACKED:.*]] = tensor.unpack %[[PACKED]] inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %[[BUFF]] : tensor<2x8x8x2xf32> -> tensor<13x15xf32> +// CHECK: return %[[UNPACKED]] : tensor<13x15xf32> +// CHECK: } + +func.func @pack_ck_to_kcck(%source: tensor<128x256xf32>, %dest: tensor<16x4x32x16xf32>) -> tensor<128x256xf32> { + %0 = tensor.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<16x4x32x16xf32> + %1 = tensor.empty() : tensor<128x256xf32> + %2 = tensor.unpack %0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %1 : tensor<16x4x32x16xf32> -> tensor<128x256xf32> + return %2 : tensor<128x256xf32> +}