diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1663,6 +1663,170 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// PackOp +//===----------------------------------------------------------------------===// + +class Tensor_RelayoutOp traits = []> : + Tensor_Op, + DestinationStyleOpInterface, + ConditionallySpeculatable, NoMemoryEffect, + DeclareOpInterfaceMethods, + TypesMatchWith<"result type matches type of dest", + "dest", "result", + "$_self">])> { + + code commonExtraClassDeclaration = [{ + int64_t getSourceRank() { return getSource().getType().getRank(); }; + int64_t getDestRank() { return getDest().getType().getRank(); }; + RankedTensorType getSourceType() { + return getSource().getType().cast(); }; + RankedTensorType getDestType() { + return getDest().getType().cast(); }; + + /// Return position for init operand. Init operand is `dest`. + std::pair getDpsInitsPositionRange() { + return {1, 2}; // `dest` operand + } + + /// Interface method for ConditionallySpeculatable. + Speculation::Speculatability getSpeculatability(); + + /// Return a mapping from positions `inner_dims_pos` to their + /// tile factors. + DenseMap getDimAndTileMapping(); + + /// Return the tile sizes as OpFoldResult. + SmallVector getMixedTiles(); + + /// Return the tile sizes as `int64_t`. If a tile size is dynamic + /// a sentinel `kDynamic` is introduced at that position in + /// the returned vector. + SmallVector getStaticTiles(); + }]; + + let hasVerifier = 1; +} + +def Tensor_PackOp : Tensor_RelayoutOp<"pack", [ + AttrSizedOperandSegments]> { + let summary = "tensor pack operation"; + let description = [{ + The pack operation converts an input tensor to a higher-dimensional tensor + with a tiled and packed layout. The mandatory `inner_dims_pos` attribute + specifies a permutation for the original dimensions, while `inner_tiles` is the + tiling factor for each dimension. The optional attribute `outer_dims_perm` + specifies the order for the tiled data dimension, while the attribute + `padding_value` specifies a padding value at the boundary on non-perfectly + divisible dimensions. Padding is optional: + - If absent, it is UB if the tile does not perfectly divide the dimension. + - If present, it will pad along high dimensions (high-padding) to make the + tile complete. + + Example NC_to_NCnc: + + ```mlir + tensor.pack %source inner_dims_pos = [0, 1] + inner_tiles = [8, 32] into %dest : tensor<128x256xf32> -> tensor<16x8x8x32xf32> + ``` + Example CK to KCck + + ```mlir + tensor.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] + inner_tiles = [8, 32] into %dest : tensor<128x256xf32> -> tensor<8x16x8x32xf32> + ``` + + In all cases, dimension at position 0 in the input tensor (128) is tiled + with a factor of 8, while dimension at position 1 (256) is tiled with a factor + of 32. In the second example, the outer data dimensions are interchanged + according to `outer_dims_perm`. + + Example NC_to_NCnc with padding: + + ```mlir + tensor.pack %arg padding_value(%pad : f32) inner_dims_pos = [0, 1] + inner_tiles = [8, 2] into %arg1 : tensor<13x15xf32> -> tensor<2x8x8x2xf32> + ``` + + }]; + let arguments = (ins AnyRankedTensor:$source, + AnyRankedTensor:$dest, + Optional:$padding_value, + DefaultValuedOptionalAttr:$outer_dims_perm, + DenseI64ArrayAttr:$inner_dims_pos, + Variadic:$inner_tiles, + I64ArrayAttr:$static_inner_tiles); + let results = (outs AnyRankedTensor:$result); + let assemblyFormat = [{ + $source + (`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)? + (`outer_dims_perm` `=` $outer_dims_perm^)? + `inner_dims_pos` `=` $inner_dims_pos + `inner_tiles` `=` + custom($inner_tiles, $static_inner_tiles, + "ShapedType::kDynamic") + `into` $dest attr-dict `:` type($source) `->` type($dest) + }]; + + let extraClassDeclaration = commonExtraClassDeclaration # [{ + // Method to get the `ShapedType` of the result based on the inner tiles, + // position of the inner tiles (innerDimsPos) and interchange vector of + // outer loops (outerDimsPerm). + static ShapedType inferPackedType(ShapedType sourceType, + ArrayRef innerTileSizes, ArrayRef innerDimsPos, + ArrayRef outerDimsPerm = {}); + }]; +} + +//===----------------------------------------------------------------------===// +// UnPackOp +//===----------------------------------------------------------------------===// + +def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> { + let summary = "tensor unpack operation"; + let description = [{ + The unpack operation converts a tensor with a tiled and packed layout to a + lower-dimensional tensor. Similar to `pack`, the mandatory attributes + `inner_dims_pos` specifies a permutation for the inner data dimensions, while + `inner_tiles` is the tiling factor. The attribute `outer_dims_perm` has the + exact behavior as the one described in `pack`. In `unpack`, it is UB if the + tile does not perfectly divide the dimension. + + Example NCnc_to_NC: + + ```mlir + tensor.unpack %source inner_dims_pos = [0, 1] + inner_tiles = [8, 32] into %dest : tensor<16x8x8x32xf32> -> tensor<128x256xf32> + ``` + + Example CK to KCck: + + ```mlir + tensor.unapck %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] + inner_tiles = [8, 32] into %dest : tensor<8x16x8x32xf32> -> tensor<128x256xf32> + ``` + }]; + let arguments = (ins AnyRankedTensor:$source, + AnyRankedTensor:$dest, + DefaultValuedOptionalAttr:$outer_dims_perm, + DenseI64ArrayAttr:$inner_dims_pos, + Variadic:$inner_tiles, + I64ArrayAttr:$static_inner_tiles); + let results = (outs AnyRankedTensor:$result); + let assemblyFormat = [{ + $source + (`outer_dims_perm` `=` $outer_dims_perm^)? + `inner_dims_pos` `=` $inner_dims_pos + `inner_tiles` `=` + custom($inner_tiles, $static_inner_tiles, + "ShapedType::kDynamic") + `into` $dest attr-dict `:` type($source) `->` type($dest) + }]; + + let extraClassDeclaration = commonExtraClassDeclaration; +} + //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" @@ -2944,6 +2945,369 @@ return SplatElementsAttr::get(getType(), {constOperand}); } +//===----------------------------------------------------------------------===// +// PackOp/UnPackOp Common +//===----------------------------------------------------------------------===// + +template +static LogicalResult +reifyResultShapesImpl(OpTy op, OpBuilder &builder, + ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + static_assert(llvm::is_one_of::value, + "applies to only pack or unpack operations"); + int64_t destRank = op.getDestRank(); + reifiedReturnShapes.resize(1, SmallVector(destRank)); + for (auto dim : llvm::seq(0, destRank)) { + reifiedReturnShapes[0][dim] = + builder.createOrFold(op.getLoc(), op.getDest(), dim); + } + return success(); +} + +template +static DenseMap getDimAndTileMappingImpl(OpTy op) { + static_assert(llvm::is_one_of::value, + "applies to only pack or unpack operations"); + DenseMap dimAndTileMapping; + ArrayRef dimsToTile = op.getInnerDimsPos(); + SmallVector tiles = op.getMixedTiles(); + assert(tiles.size() == dimsToTile.size() && + "tiles must match indices of dimension to block"); + // bind the dimension `i` with the tile factor. + for (auto i : llvm::seq(0, dimsToTile.size())) + dimAndTileMapping[dimsToTile[i]] = tiles[i]; + return dimAndTileMapping; +} + +template +static SmallVector getMixedTilesImpl(OpTy op) { + static_assert(llvm::is_one_of::value, + "applies to only pack or unpack operations"); + SmallVector mixedInnerTiles; + unsigned dynamicValIndex = 0; + for (Attribute attr : op.getStaticInnerTiles()) { + auto tileAttr = attr.cast(); + if (!ShapedType::isDynamic(tileAttr.getInt())) + mixedInnerTiles.push_back(tileAttr); + else + mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]); + } + return mixedInnerTiles; +} + +template +static SmallVector getStaticTilesImpl(OpTy op) { + static_assert(llvm::is_one_of::value, + "applies to only pack or unpack operations"); + SmallVector dynamicTiles; + SmallVector staticTiles; + dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles, + ShapedType::kDynamic); + return staticTiles; +} + +/// Returns true if `dimsPos` is invalid. It is invalid when: +/// a) It contains duplicate. +/// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank). +/// c) The number of elements in `dimsPos` is > than `rank`. +static bool isInvalidPackingPosSpecification(ArrayRef dimsPos, + size_t rank) { + size_t dimsPosSize = dimsPos.size(); + if (dimsPosSize > rank) + return true; + DenseSet uniqued; + for (int64_t dim : dimsPos) + uniqued.insert(dim); + if (dimsPosSize != uniqued.size()) + return true; + return llvm::any_of(dimsPos, [rank](int64_t dimPos) { + return dimPos < 0 || dimPos >= static_cast(rank); + }); +} + +/// Returns true if the dimension of `sourceShape` is smaller than the dimension +/// of the `limitShape`. +static bool areAllInBound(ArrayRef sourceShape, + ArrayRef limitShape) { + assert( + sourceShape.size() == limitShape.size() && + "expected source shape rank, and limit of the shape to have same rank"); + return llvm::all_of( + llvm::zip(sourceShape, limitShape), [](std::tuple it) { + int64_t sourceExtent = std::get<0>(it); + int64_t limit = std::get<1>(it); + return ShapedType::isDynamic(sourceExtent) || + ShapedType::isDynamic(limit) || sourceExtent <= limit; + }); +} + +template +static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { + static_assert(llvm::is_one_of::value, + "applies to only pack or unpack operations"); + Operation *op = packOrUnPack.getOperation(); + + // Return true if we have a zero-value tile. + auto hasZeros = [&](ArrayRef tiles) { + return llvm::any_of( + tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); }); + }; + + // Verify tiles. Do not allow zero tiles. + SmallVector mixedTiles = packOrUnPack.getMixedTiles(); + if (hasZeros(mixedTiles)) + return op->emitError("invalid zero tile factor"); + + // Verify inner_dims_pos and outer_dims_perm. + ShapedType unpackedType = (std::is_same::value) + ? packOrUnPack.getSourceType() + : packOrUnPack.getDestType(); + size_t unpackedRank = unpackedType.getRank(); + ArrayRef innerDimsPos = packOrUnPack.getInnerDimsPos(); + ArrayRef outerDimPerm = packOrUnPack.getOuterDimsPerm(); + if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank)) + return op->emitError("invalid inner_dims_pos vector"); + if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank)) + return op->emitError("invalid outer_dims_perm vector"); + + // Tiling factors must be less than or equal to the input rank for pack (or + // output rank for unpack), and must match the number of `inner_dims_pos`. + if (mixedTiles.size() > unpackedRank) { + return op->emitError("tiling factors must be less than or equal to the " + "input rank for pack or output rank for unpack"); + } + if (mixedTiles.size() != innerDimsPos.size()) { + return op->emitError( + "tiling factors must equal the number of dimensions to tile"); + } + + ShapedType packedType = (std::is_same::value) + ? packOrUnPack.getDestType() + : packOrUnPack.getSourceType(); + size_t packedRank = packedType.getRank(); + // Require output rank to match input rank + number of blocking factors. + if (unpackedRank + mixedTiles.size() != packedRank) { + return op->emitError( + "packed rank must equal unpacked rank + tiling factors"); + } + + // Verify result shape is greater than the minimum expected + // by the pack operation, and that the output shape + // represents full tiles. + ShapedType expectedPackedType = PackOp::inferPackedType( + unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm); + if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) { + return op->emitError("the shape of output is not large enough to hold the " + "packed data. Expected at least ") + << expectedPackedType << ", got " << packedType; + } + if (!llvm::all_of( + llvm::zip(packedType.getShape().take_back(mixedTiles.size()), + mixedTiles), + [](std::tuple it) { + Optional constTileSize = + getConstantIntValue(std::get<1>(it)); + int64_t shape = std::get<0>(it); + if (!constTileSize) { + // If specified tile size is dynamic, output shape should + // be dynamic too. + return ShapedType::isDynamic(shape); + } else { + if (ShapedType::isDynamic(shape)) { + // For the shape being dynamic when tile size is + // specified, return true. In canonical form a constant + // tile size should lead to constant shape of the tiled + // dimension, but not needed for verification. + return true; + } + return shape == constTileSize.value(); + } + })) { + return op->emitError("mismatch in inner tile sizes specified and shaped of " + "tiled dimension in the packed type"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// PackOp +//===----------------------------------------------------------------------===// + +void PackOp::getAsmResultNames(function_ref setNameFn) { + setNameFn(getResult(), "pack"); +} + +LogicalResult +PackOp::reifyResultShapes(OpBuilder &builder, + ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + return reifyResultShapesImpl(*this, builder, reifiedReturnShapes); +} + +DenseMap PackOp::getDimAndTileMapping() { + return getDimAndTileMappingImpl(*this); +} + +SmallVector PackOp::getMixedTiles() { + return getMixedTilesImpl(*this); +} + +SmallVector PackOp::getStaticTiles() { + return getStaticTilesImpl(*this); +} + +/// Check if we have enough static information to catch undefined behavior when +/// the tile size does not divide perfectly the dimension of the input tensor. +static bool +areNotFullTiles(ArrayRef inputShape, + DenseMap const &dimAndTileMapping) { + int64_t rank = inputShape.size(); + for (int64_t dim = 0; dim < rank; dim++) { + if (ShapedType::isDynamic(inputShape[dim])) + continue; + auto it = dimAndTileMapping.find(dim); + if (it == dimAndTileMapping.end()) + continue; + Optional constantTile = getConstantIntValue(it->second); + if (!constantTile) + continue; + if (inputShape[dim] % (*constantTile) != 0) + return true; + } + return false; +} + +LogicalResult PackOp::verify() { + if (failed(commonVerifierPackAndUnPackOp(*this))) + return failure(); + + // Verify padding value, and bail out if the tile does not divide the + // dimension fully. In the case of dynamic tile factors or dimensions, having + // a partial tile is undefined behavior. + auto paddingValue = getPaddingValue(); + if (paddingValue && + paddingValue.getType() != getSourceType().getElementType()) { + return emitOpError("expected padding_value has ") + << getSourceType().getElementType() + << " but got: " << paddingValue.getType(); + } + + auto dimAndTileMapping = getDimAndTileMapping(); + if (!paddingValue && + areNotFullTiles(getSourceType().getShape(), dimAndTileMapping)) { + return emitOpError("invalid tile factor provided. Only full tiles are " + "supported when padding_value is not set"); + } + return success(); +} + +/// Returns a vector that interchanges `elements` starting at offset `offset` +/// based on the indexes in `interchangeVector`. +template +SmallVector interchange(ArrayRef elements, + ArrayRef interchangeVector, + int offset = 0) { + SmallVector vec = llvm::to_vector(elements); + for (auto en : llvm::enumerate(interchangeVector)) + vec[en.index() + offset] = elements[en.value() + offset]; + + return vec; +} + +/// Get the expected packed type based on source type, tile factors, position of +/// the inner tiles and permutation of the outer tiled loop. +ShapedType PackOp::inferPackedType(ShapedType sourceType, + ArrayRef innerTileSizes, + ArrayRef innerDimsPos, + ArrayRef outerDimsPerm) { + SmallVector resultShape = llvm::to_vector(sourceType.getShape()); + for (auto tiledDim : llvm::enumerate(innerDimsPos)) { + if (ShapedType::isDynamic(resultShape[tiledDim.value()])) + continue; + if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) { + resultShape[tiledDim.value()] = ShapedType::kDynamic; + continue; + } + resultShape[tiledDim.value()] = ceilDiv(resultShape[tiledDim.value()], + innerTileSizes[tiledDim.index()]); + } + + resultShape = interchange(resultShape, outerDimsPerm); + + // Append the inner tile dimensions. + resultShape.append(innerTileSizes.begin(), innerTileSizes.end()); + return RankedTensorType::get(resultShape, sourceType.getElementType()); +} + +/// Returns true if the tiles and the tiled dims are constant. +template +bool areTilesAndTiledDimsAllConstant(OpTy op) { + static_assert(llvm::is_one_of::value, + "applies to only pack or unpack operations"); + ShapedType packedType = (std::is_same::value) + ? op.getDestType() + : op.getSourceType(); + SmallVector mixedTiles = op.getMixedTiles(); + for (auto [dimDest, tile] : llvm::zip( + packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) { + Optional constTileSize = getConstantIntValue(tile); + if (!constTileSize || ShapedType::isDynamic(dimDest)) + return false; + } + return true; +} + +Speculation::Speculatability PackOp::getSpeculatability() { + if (auto paddingValue = getPaddingValue()) + return Speculation::Speculatable; + + // The verifier rejects already operations if we can statically prove that the + // sizes of the tiles do not divide perfectly the dimension; thus, check only + // to have constant tiles and tiled inner dimensions. + if (!areTilesAndTiledDimsAllConstant(*this)) + return Speculation::NotSpeculatable; + + return Speculation::Speculatable; +} + +//===----------------------------------------------------------------------===// +// UnPackOp +//===----------------------------------------------------------------------===// + +void UnPackOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "unpack"); +} + +LogicalResult +UnPackOp::reifyResultShapes(OpBuilder &builder, + ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + return reifyResultShapesImpl(*this, builder, reifiedReturnShapes); +} + +DenseMap UnPackOp::getDimAndTileMapping() { + return getDimAndTileMappingImpl(*this); +} + +SmallVector UnPackOp::getMixedTiles() { + return getMixedTilesImpl(*this); +} + +SmallVector UnPackOp::getStaticTiles() { + return getStaticTilesImpl(*this); +} + +LogicalResult UnPackOp::verify() { + return commonVerifierPackAndUnPackOp(*this); +} + +Speculation::Speculatability UnPackOp::getSpeculatability() { + // See PackOp::getSpeculatability. + if (!areTilesAndTiledDimsAllConstant(*this)) + return Speculation::NotSpeculatable; + + return Speculation::Speculatable; +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -522,3 +522,92 @@ %out = tensor.empty(%sz) : tensor<2x?x?x5xf32> return } + +// ----- + +func.func @pack_invalid_no_padding_no_full_tiles(%input: tensor<256x128xf32>, %output: tensor<8x8x16x33xf32>) -> tensor<8x8x16x33xf32> { + // expected-error@+1 {{invalid tile factor provided. Only full tiles are supported when padding_value is not set}} + %0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 33] into %output : tensor<256x128xf32> -> tensor<8x8x16x33xf32> + return %0 : tensor<8x8x16x33xf32> +} + +// ----- + +func.func @pad_and_pack_invalid_type(%input: tensor<13x15xf32>, %output: tensor<2x8x8x2xf32>, %pad: i32) -> tensor<2x8x8x2xf32> { + // expected-error@+1 {{expected padding_value has 'f32' but got: 'i32'}} + %0 = tensor.pack %input padding_value(%pad: i32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<13x15xf32> -> tensor<2x8x8x2xf32> + return %0 : tensor<2x8x8x2xf32> +} + +// ----- + +func.func @pack_invalid_inner_dims_pos_vector(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> { + // expected-error@+1 {{invalid inner_dims_pos vector}} + %0 = tensor.pack %input inner_dims_pos = [2, 0] inner_tiles = [2, 2] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32> + return %0 : tensor<8x8x32x16xf32> +} + +// ----- + +func.func @pack_invalid_duplicate_element_in_inner_dims(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> { + // expected-error@+1 {{invalid inner_dims_pos vector}} + %0 = tensor.pack %input inner_dims_pos = [1, 1] inner_tiles = [2, 2] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32> + return %0 : tensor<8x8x32x16xf32> +} + +// ----- + +func.func @pack_invalid_duplicate_element_in_outer_perm(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> { + // expected-error@+1 {{invalid outer_dims_perm vector}} + %0 = tensor.pack %input outer_dims_perm = [1, 1] inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32> + return %0 : tensor<8x8x32x16xf32> +} + +// ----- + +func.func @unpack_invalid_out_of_bound_outer_perm(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> { + // expected-error@+1 {{invalid outer_dims_perm vector}} + %0 = tensor.unpack %output outer_dims_perm = [2, 1] inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %input : tensor<8x8x32x16xf32> -> tensor<256x128xf32> + return %0 : tensor<256x128xf32> +} + +// ----- + +func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> { + // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x8x32x16xf32>'}} + %0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32> + return %0 : tensor<8x8x32x16xf32> +} + +// ----- + +func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x32x16xf32>) -> tensor<256x128xf32> { + // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x32x16xf32>'}} + %0 = tensor.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x32x16xf32> -> tensor<256x128xf32> + return %0 : tensor<256x128xf32> +} + +// ----- + +func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> { + // expected-error@+1 {{invalid zero tile factor}} + %0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [0, 2] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32> + return %0 : tensor<8x8x32x16xf32> +} + +// ----- +func.func @pack_mismatch_inner_tile_size_and_output_shape( + %input : tensor, %output : tensor) -> tensor { + // expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}} + %0 = tensor.pack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor -> tensor + return %0 : tensor +} + +// ----- + +func.func @unpack_mismatch_inner_tile_size_and_output_shape( + %input : tensor, %output : tensor) -> tensor { + // expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}} + %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor -> tensor + return %0 : tensor +} diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt --split-input-file %s | mlir-opt | FileCheck %s // CHECK-LABEL: func @cast( func.func @cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor) { @@ -13,6 +13,8 @@ return } +// ----- + // CHECK-LABEL: func @empty( // CHECK-SAME: %[[sz:.*]]: index func.func @empty(%sz: index) -> tensor<5x?x6xf32> { @@ -21,6 +23,8 @@ return %0 : tensor<5x?x6xf32> } +// ----- + // CHECK-LABEL: func @empty_with_encoding( // CHECK-SAME: %[[sz:.*]]: index func.func @empty_with_encoding(%sz: index) -> tensor<5x?x6xf32, "foo"> { @@ -29,6 +33,8 @@ return %0 : tensor<5x?x6xf32, "foo"> } +// ----- + // CHECK-LABEL: func @extract( // CHECK-SAME: %[[TENSOR:.*]]: tensor, // CHECK-SAME: %[[INDEX:.*]]: index) { @@ -38,6 +44,8 @@ return } +// ----- + // CHECK-LABEL: func @insert( // CHECK-SAME: %[[SCALAR:.*]]: f32 // CHECK-SAME: %[[INDEX:.*]]: index @@ -48,6 +56,8 @@ return } +// ----- + // CHECK-LABEL: func @tensor.from_elements() { func.func @tensor.from_elements() { %c0 = "arith.constant"() {value = 0: index} : () -> index @@ -74,6 +84,8 @@ return } +// ----- + // CHECK-LABEL: @tensor.generate func.func @tensor.generate(%m : index, %n : index) -> tensor { @@ -85,6 +97,8 @@ return %tnsr : tensor } +// ----- + // CHECK-LABEL: func @tensor_reshape func.func @tensor_reshape(%unranked: tensor<*xf32>, %shape1: tensor<1xi32>, %shape2: tensor<2xi32>, %shape3: tensor) -> tensor<*xf32> { @@ -97,6 +111,8 @@ return %new_unranked : tensor<*xf32> } +// ----- + // CHECK-LABEL: func @slice({{.*}}) { func.func @slice(%t: tensor<8x16x4xf32>, %idx : index) { %c0 = arith.constant 0 : index @@ -120,6 +136,8 @@ return } +// ----- + // CHECK-LABEL: func @insert_slice({{.*}}) { func.func @insert_slice( %t: tensor<8x16x4xf32>, @@ -154,6 +172,8 @@ return } +// ----- + func.func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor) -> (tensor, tensor<1x1xf32>) { %0 = tensor.collapse_shape %arg0 [] : tensor<1x1xf32> into tensor @@ -164,6 +184,8 @@ // CHECK: tensor.collapse_shape %{{.*}} [] : tensor<1x1xf32> into tensor // CHECK: tensor.expand_shape %{{.*}} [] : tensor into tensor<1x1xf32> +// ----- + func.func @legal_collapsing_reshape_dynamic_tensor (%arg0: tensor) -> tensor { @@ -175,6 +197,8 @@ // CHECK: tensor.collapse_shape // CHECK-SAME: [0], [1], [2, 3, 4] +// ----- + func.func @rank(%t : tensor<4x4x?xf32>) { // CHECK: %{{.*}} = tensor.rank %{{.*}} : tensor<4x4x?xf32> %0 = "tensor.rank"(%t) : (tensor<4x4x?xf32>) -> index @@ -184,6 +208,8 @@ return } +// ----- + func.func @pad_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index, %pad_value: f32) -> tensor<6x?x?x?xf32> { %0 = tensor.pad %arg0 low[2, %low, 3, 3] high[3, 3, %high, 2] { @@ -201,6 +227,8 @@ // CHECK-SAME: high[3, 3, %[[HIGH]], 2] // CHECK: : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32> +// ----- + func.func @pad_static(%arg0: tensor<3x4xf32>, %pad_value: f32) -> tensor<6x9xf32> { %0 = tensor.pad %arg0 low[1, 2] high[2, 3] { ^bb0(%arg1 : index, %arg2 : index): @@ -213,6 +241,8 @@ // CHECK: tensor.pad %[[ARG0]] low[1, 2] high[2, 3] // CHECK: : tensor<3x4xf32> to tensor<6x9xf32> +// ----- + func.func @pad_asymmetrical(%arg0: tensor<2x3xf32>, %ub0: index, %ub1: index, %pad_value: f32) -> tensor { %0 = tensor.pad %arg0 low[0, 0] high[%ub0, %ub1] { @@ -230,6 +260,8 @@ // CHECK-SAME: high[%[[UB0]], %[[UB1]]] // CHECK: : tensor<2x3xf32> to tensor +// ----- + func.func @pad_to_static_size(%arg0: tensor, %ub0: index, %ub1: index, %pad_value: f32) -> tensor<2x3xf32> { %0 = tensor.pad %arg0 low[0, 0] high[%ub0, %ub1] { @@ -247,6 +279,8 @@ // CHECK-SAME: high[%[[UB0]], %[[UB1]]] // CHECK: : tensor to tensor<2x3xf32> +// ----- + // CHECK-LABEL: func @test_splat_op // CHECK-SAME: [[S:%arg[0-9]+]]: f32 func.func @test_splat_op(%s : f32) { @@ -258,6 +292,8 @@ return } +// ----- + // CHECK-LABEL: func.func @gather_scatter( // CHECK-SAME: %[[ARG0:.*]]: tensor<4x5x6xf32>, // CHECK-SAME: %[[ARG1:.*]]: tensor<1x3x2xindex>, @@ -281,3 +317,106 @@ (tensor<1x3x4xf32>, tensor<4x5x6xf32>, tensor<1x3x2xi32>) -> tensor<4x5x6xf32> return } + +// ----- + +func.func @pack_nc_to_ncnc(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) -> tensor<128x256xf32> { + %0 = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32> + %1 = tensor.empty() : tensor<128x256xf32> + %2 = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %1 : tensor<4x16x32x16xf32> -> tensor<128x256xf32> + return %2 : tensor<128x256xf32> +} + +// CHECK-LABEL: func.func @pack_nc_to_ncnc( +// CHECK-SAME: %[[SOURCE:.*]]: tensor<128x256xf32>, +// CHECK-SAME: %[[DEST:.*]]: tensor<4x16x32x16xf32>) +// CHECK: %[[PACKED:.*]] = tensor.pack %[[SOURCE]] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %[[DEST]] : tensor<128x256xf32> -> tensor<4x16x32x16xf32> +// CHECK: %[[BUFF:.*]] = tensor.empty() : tensor<128x256xf32> +// CHECK: %{{.*}} = tensor.unpack %[[PACKED]] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %[[BUFF]] : tensor<4x16x32x16xf32> -> tensor<128x256xf32> + +// ----- + +func.func @pack_nc_to_ncnc_with_padding(%source: tensor<13x15xf32>, %dest: tensor<2x8x8x2xf32>, %padding: f32) -> tensor<13x15xf32> { + %0 = tensor.pack %source padding_value(%padding : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %dest : tensor<13x15xf32> -> tensor<2x8x8x2xf32> + %1 = tensor.empty() : tensor<13x15xf32> + %2 = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %1 : tensor<2x8x8x2xf32> -> tensor<13x15xf32> + return %2 : tensor<13x15xf32> +} + +// CHECK-LABEL: func.func @pack_nc_to_ncnc_with_padding( +// CHECK-SAME: %[[SOURCE:.*]]: tensor<13x15xf32>, +// CHECK-SAME: %[[DEST:.*]]: tensor<2x8x8x2xf32>, +// CHECK-SAME: %[[PADDING:.*]]: f32) +// CHECK: %[[PACKED:.*]] = tensor.pack %[[SOURCE]] padding_value(%[[PADDING]] : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %[[DEST]] : tensor<13x15xf32> -> tensor<2x8x8x2xf32> +// CHECK: %[[BUFF:.*]] = tensor.empty() : tensor<13x15xf32> +// CHECK: %{{.*}} = tensor.unpack %[[PACKED]] inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %[[BUFF]] : tensor<2x8x8x2xf32> -> tensor<13x15xf32> + +// ----- + +func.func @pack_ck_to_kcck(%source: tensor<128x256xf32>, %dest: tensor<16x4x32x16xf32>) -> tensor<128x256xf32> { + %0 = tensor.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<16x4x32x16xf32> + %1 = tensor.empty() : tensor<128x256xf32> + %2 = tensor.unpack %0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %1 : tensor<16x4x32x16xf32> -> tensor<128x256xf32> + return %2 : tensor<128x256xf32> +} + +// CHECK-LABEL: func.func @pack_ck_to_kcck( +// CHECK-SAME: %[[SOURCE:.*]]: tensor<128x256xf32>, +// CHECK-SAME: %[[DEST:.*]]: tensor<16x4x32x16xf32>) +// CHECK: %[[PACKED:.*]] = tensor.pack %[[SOURCE]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %[[DEST]] : tensor<128x256xf32> -> tensor<16x4x32x16xf32> +// CHECK: %[[BUFF:.*]] = tensor.empty() : tensor<128x256xf32> +// CHECK: %{{.*}} = tensor.unpack %[[PACKED]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %[[BUFF]] : tensor<16x4x32x16xf32> -> tensor<128x256xf32> + +// ----- + +func.func @pad_and_pack_fully_dynamic(%source: tensor, %dest: tensor, %pad: f32, %tile_n : index, %tile_m : index) -> tensor { + %0 = tensor.pack %source padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%tile_n, %tile_m] into %dest : tensor -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @pad_and_pack_fully_dynamic( +// CHECK-SAME: %[[SOURCE:.*]]: tensor, +// CHECK-SAME: %[[DEST:.*]]: tensor, +// CHECK-SAME: %[[PAD:.*]]: f32, +// CHECK-SAME: %[[TILE_N:.*]]: index, +// CHECK-SAME: %[[TILE_M:.*]]: index) +// CHECK: %{{.*}} = tensor.pack %[[SOURCE]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [%[[TILE_N]], %[[TILE_M]]] into %[[DEST]] : tensor -> tensor + +// ----- + +func.func @pad_and_pack_partially_dynamic(%source: tensor, %dest: tensor, %pad: f32) -> tensor { + %0 = tensor.pack %source padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %dest : tensor -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @pad_and_pack_partially_dynamic( +// CHECK-SAME: %[[SOURCE:.*]]: tensor, +// CHECK-SAME: %[[DEST:.*]]: tensor, +// CHECK-SAME: %[[PAD:.*]]: f32) +// CHECK: %{{.*}} = tensor.pack %[[SOURCE]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %[[DEST]] : tensor -> tensor + +// ----- + +func.func @unpack_fully_dynamic(%source: tensor, %dest: tensor, %tile_n : index, %tile_m : index) -> tensor { + %0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [%tile_n, %tile_m] into %dest : tensor -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @unpack_fully_dynamic( +// CHECK-SAME: %[[SOURCE:.*]]: tensor, +// CHECK-SAME: %[[DEST:.*]]: tensor, +// CHECK-SAME: %[[TILE_N:.*]]: index, +// CHECK-SAME: %[[TILE_M:.*]]: index) +// CHECK: %{{.*}} = tensor.unpack %[[SOURCE]] inner_dims_pos = [0, 1] inner_tiles = [%[[TILE_N]], %[[TILE_M]]] into %[[DEST]] : tensor -> tensor + +// ----- + +func.func @unpack_partially_dynamic(%source: tensor, %dest: tensor) -> tensor { + %0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %dest : tensor -> tensor + return %0: tensor +} + +// CHECK-LABEL: func.func @unpack_partially_dynamic( +// CHECK-SAME: %[[SOURCE:.*]]: tensor, +// CHECK-SAME: %[[DEST:.*]]: tensor) +// CHECK: %{{.*}} = tensor.unpack %[[SOURCE]] inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %[[DEST]] : tensor -> tensor diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir --- a/mlir/test/Transforms/loop-invariant-code-motion.mlir +++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir @@ -874,3 +874,58 @@ return } + +// ----- + +func.func @speculate_static_pack_and_unpack(%source: tensor<128x256xf32>, + %dest: tensor<4x16x32x16xf32>, %lb: index, %ub: index, %step: index) { + + // CHECK: tensor.pack + // CHECK-NEXT: scf.for + scf.for %i = %lb to %ub step %step { + %packed = tensor.pack %source + inner_dims_pos = [0, 1] + inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32> + } + + // CHECK: tensor.unpack + // CHECK-NEXT: scf.for + scf.for %i = %lb to %ub step %step { + %unpacked = tensor.unpack %dest + inner_dims_pos = [0, 1] + inner_tiles = [32, 16] into %source : tensor<4x16x32x16xf32> -> tensor<128x256xf32> + } + return +} + +// ----- + +func.func @speculate_dynamic_pack_and_unpack(%source: tensor, + %dest: tensor, %lb: index, %ub: index, %step: index, + %tile_m: index, %tile_n: index, %pad: f32) { + + // CHECK: scf.for + // CHECK-NEXT: tensor.pack + scf.for %i = %lb to %ub step %step { + %packed = tensor.pack %source + inner_dims_pos = [0, 1] + inner_tiles = [%tile_n, %tile_m] into %dest : tensor -> tensor + } + + // CHECK: scf.for + // CHECK-NEXT: tensor.unpack + scf.for %i = %lb to %ub step %step { + %unpacked = tensor.unpack %dest + inner_dims_pos = [0, 1] + inner_tiles = [%tile_n, %tile_m] into %source : tensor -> tensor + } + + // CHECK: tensor.pack + // CHECK-NEXT: scf.for + scf.for %i = %lb to %ub step %step { + %packed = tensor.pack %source padding_value(%pad : f32) + inner_dims_pos = [0, 1] + inner_tiles = [%tile_n, %tile_m] into %dest : tensor -> tensor + } + return +}