diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h --- a/mlir/include/mlir/Dialect/Affine/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Utils.h @@ -359,6 +359,10 @@ OpFoldResult mul(AffineValueExpr lhs, AffineValueExpr rhs) { return makeComposedFoldedAffineApply(b, loc, {lhs.e * rhs.e}, {lhs, rhs}); } + OpFoldResult floor(AffineValueExpr lhs, AffineValueExpr rhs) { + return makeComposedFoldedAffineApply(b, loc, {lhs.e.floorDiv(rhs.e)}, + {lhs, rhs}); + } OpFoldResult ceil(AffineValueExpr lhs, AffineValueExpr rhs) { return makeComposedFoldedAffineApply(b, loc, {lhs.e.ceilDiv(rhs.e)}, {lhs, rhs}); diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt @@ -57,9 +57,11 @@ LINK_LIBS PUBLIC MLIRAffineDialect + MLIRAffineUtils MLIRDialectUtils MLIRIR MLIRLinalgDialect + MLIRLinalgUtils MLIRSCFDialect MLIRSupport MLIRTensorDialect diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" @@ -71,6 +72,38 @@ } }; +template +static SmallVector getPackUnPackIterationDomain(OpTy op, + OpBuilder &builder) { + static_assert(llvm::is_one_of::value, + "applies to only pack or unpack operations"); + OpBuilder::InsertionGuard g(builder); + Location loc = op.getLoc(); + int64_t rank = (std::is_same::value) ? op.getSourceRank() + : op.getDestRank(); + Value zero = builder.create(loc, 0); + Value one = builder.create(loc, 1); + ReifiedRankedShapedTypeDims resultShape; + (void)op.reifyResultShapes(builder, resultShape); + SmallVector loopBounds(rank); + for (auto dim : llvm::seq(0, rank)) { + loopBounds[dim].offset = zero; + loopBounds[dim].stride = one; + loopBounds[dim].size = resultShape[0][dim]; + } + return loopBounds; +} + +static void applyInversePermToRange(SmallVector &offsets, + SmallVector &sizes, + ArrayRef permutation) { + if (permutation.empty()) + return; + SmallVector inversedPerm = invertPermutationVector(permutation); + applyPermutationToVector(offsets, inversedPerm); + applyPermutationToVector(sizes, inversedPerm); +} + struct PackOpTiling : public TilingInterface::ExternalModel { @@ -85,21 +118,7 @@ } SmallVector getIterationDomain(Operation *op, OpBuilder &b) const { - OpBuilder::InsertionGuard guard(b); - auto packOp = cast(op); - Location loc = packOp.getLoc(); - int64_t rank = packOp.getSourceRank(); - Value zero = b.create(loc, 0); - Value one = b.create(loc, 1); - ReifiedRankedShapedTypeDims resultShape; - (void)packOp.reifyResultShapes(b, resultShape); - SmallVector loopRanges(rank); - for (auto dim : llvm::seq(0, rank)) { - loopRanges[dim].offset = zero; - loopRanges[dim].stride = one; - loopRanges[dim].size = resultShape[0][dim]; - } - return loopRanges; + return getPackUnPackIterationDomain(cast(op), b); } SmallVector @@ -112,15 +131,9 @@ // The tiling is applied on interchanged dimensions. We have to undo the // interchange to map sizes and offsets to the original input. int64_t inputRank = packOp.getSourceRank(); - ArrayRef dimsToOuterBlock(packOp.getOuterDimsPerm()); SmallVector origOffsets(offsets.begin(), offsets.end()); SmallVector origSizes(sizes.begin(), sizes.end()); - if (!dimsToOuterBlock.empty()) { - SmallVector inversedPerm = - invertPermutationVector(dimsToOuterBlock); - applyPermutationToVector(origOffsets, inversedPerm); - applyPermutationToVector(origSizes, inversedPerm); - } + applyInversePermToRange(origOffsets, origSizes, packOp.getOuterDimsPerm()); DenseMap dimAndTileMapping = packOp.getDimAndTileMapping(); @@ -210,6 +223,192 @@ } }; +struct UnpackTileDimInfo { + bool isAlignedToInnerTileSize; + OpFoldResult sourceOffset; + OpFoldResult sourceSize; + OpFoldResult resultOffset; + OpFoldResult destExpandedSize; +}; + +static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp, + int64_t tileDim, + OpFoldResult tileOffset, + OpFoldResult tileSize) { + UnpackTileDimInfo info; + Attribute zeroAttr = b.getIndexAttr(0); + Attribute oneAttr = b.getIndexAttr(1); + DenseMap dimAndTileMapping = + unpackOp.getDimAndTileMapping(); + // The dimension is not one of packed data dimension. + if (!dimAndTileMapping.count(tileDim)) { + info.isAlignedToInnerTileSize = true; + info.sourceOffset = tileOffset; + info.sourceSize = tileSize; + info.resultOffset = zeroAttr; + info.destExpandedSize = tileSize; + return info; + } + + Location loc = unpackOp.getLoc(); + using AV = AffineValueExpr; + AffineBuilder ab(b, loc); + AffineExpr dim0, dim1, sym0; + bindDims(b.getContext(), dim0, dim1); + bindSymbols(b.getContext(), sym0); + + OpFoldResult innerTileSize = dimAndTileMapping[tileDim]; + + info.isAlignedToInnerTileSize = false; + FailureOr cstSize = linalg::getConstantUpperBoundForIndex( + getValueOrCreateConstantIndexOp(b, loc, tileSize)); + Optional cstInnerSize = getConstantIntValue(innerTileSize); + if (!failed(cstSize) && cstInnerSize) { + if (cstSize.value() % cstInnerSize.value() == 0) + info.isAlignedToInnerTileSize = true; + + // If the tiling size equals to the inner tiling size, the outer dims are + // always 1. + if (cstInnerSize.value() == cstSize.value()) { + auto lhs = AV(dim0).bind(tileOffset); + auto rhs = AV(dim1).bind(innerTileSize); + info.sourceOffset = ab.floor(lhs, rhs); + info.sourceSize = oneAttr; + info.resultOffset = zeroAttr; + info.destExpandedSize = tileSize; + return info; + } + } + + if (info.isAlignedToInnerTileSize) { + info.sourceOffset = + ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize)); + info.resultOffset = zeroAttr; + info.destExpandedSize = tileSize; + + // The ceilDiv is needed here because there could be incomplete tile even + // it is perfect tiling cases. E.g., + // %0 = unpack tensor<33x2xf32> into tensor<64xf32> + // If the tiling size is 32, there will be three tiles. Two of them have + // size=32; one of them have size=2. The size is represented using + // affine_min op; we need ceilDiv. + info.sourceSize = + ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize)); + return info; + } + + DivModValue firstCoord = + getDivMod(b, loc, getValueOrCreateConstantIndexOp(b, loc, tileOffset), + getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); + OpFoldResult tileExclusiveBound = + ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize)); + DivModValue lastCoord = getDivMod( + b, loc, + getValueOrCreateConstantIndexOp( + b, loc, + ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))), + getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); + + OpFoldResult lengthMinusOne = ab.sub(AV(dim0).bind(lastCoord.quotient), + AV(dim1).bind(firstCoord.quotient)); + info.sourceSize = + ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr)); + info.sourceOffset = firstCoord.quotient; + info.resultOffset = firstCoord.remainder; + info.destExpandedSize = + ab.mul(AV(dim0).bind(info.sourceSize), AV(sym0).bind(innerTileSize)); + return info; +} + +struct UnPackOpTiling + : public TilingInterface::ExternalModel { + + SmallVector getLoopIteratorTypes(Operation *op) const { + auto unpackOp = cast(op); + SmallVector iteratorTypes( + unpackOp.getDestRank(), utils::IteratorType::parallel); + return iteratorTypes; + } + + SmallVector getIterationDomain(Operation *op, OpBuilder &b) const { + return getPackUnPackIterationDomain(cast(op), b); + } + + SmallVector + getTiledImplementation(Operation *op, OpBuilder &b, + ArrayRef offsets, + ArrayRef sizes) const { + auto unpackOp = cast(op); + int64_t srcRank = unpackOp.getSourceRank(); + int64_t destRank = unpackOp.getDestRank(); + int64_t numInnerTiles = srcRank - destRank; + Location loc = unpackOp.getLoc(); + + // The perfect tiling case indicates that the tiling sizes is are multiple + // of inner_tile_size. In this context, The indices of input slice are all + // aligned to head. No extra data is needed when representing the tiled + // unpack op. + bool isPerfectTilingCase = true; + Attribute oneAttr = b.getIndexAttr(1); + SmallVector sliceSrcStrides(destRank, oneAttr); + SmallVector sliceSrcIndices, sliceSrcSizes; + SmallVector destExpandedSizes, resultOffsetsFromDest; + for (auto dim : llvm::seq(0, destRank)) { + UnpackTileDimInfo info = + getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]); + if (!info.isAlignedToInnerTileSize) + isPerfectTilingCase = false; + sliceSrcIndices.push_back(info.sourceOffset); + sliceSrcSizes.push_back(info.sourceSize); + destExpandedSizes.push_back(info.destExpandedSize); + resultOffsetsFromDest.push_back(info.resultOffset); + } + + applyInversePermToRange(sliceSrcIndices, sliceSrcSizes, + unpackOp.getOuterDimsPerm()); + Attribute zeroAttr = b.getIndexAttr(0); + sliceSrcIndices.append(numInnerTiles, zeroAttr); + sliceSrcSizes.append(unpackOp.getMixedTiles()); + sliceSrcStrides.append(numInnerTiles, oneAttr); + Value sliceSource = + b.create(loc, unpackOp.getSource(), sliceSrcIndices, + sliceSrcSizes, sliceSrcStrides); + + SmallVector destStrides(destRank, oneAttr); + Value sliceDest; + if (isPerfectTilingCase) { + sliceDest = b.create(loc, unpackOp.getDest(), offsets, + sizes, destStrides); + } else { + sliceDest = b.create(loc, destExpandedSizes, + unpackOp.getDestType().getElementType()); + } + + Operation *tiledUnpackOp = + b.create(loc, TypeRange{sliceDest.getType()}, + ValueRange{sliceSource, sliceDest}, op->getAttrs()); + + if (isPerfectTilingCase) + return {tiledUnpackOp}; + + Operation *extractSlice = + b.create(loc, tiledUnpackOp->getResult(0), + resultOffsetsFromDest, sizes, destStrides); + return {tiledUnpackOp, extractSlice}; + } + + LogicalResult + getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes, + SmallVector &resultOffsets, + SmallVector &resultSizes) const { + resultOffsets = llvm::to_vector(offsets); + resultSizes = llvm::to_vector(sizes); + return success(); + } +}; + } // namespace Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, @@ -425,5 +624,6 @@ registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { tensor::PadOp::attachInterface(*ctx); tensor::PackOp::attachInterface(*ctx); + tensor::UnPackOp::attachInterface(*ctx); }); } diff --git a/mlir/test/Dialect/Tensor/tiling.mlir b/mlir/test/Dialect/Tensor/tiling.mlir --- a/mlir/test/Dialect/Tensor/tiling.mlir +++ b/mlir/test/Dialect/Tensor/tiling.mlir @@ -212,3 +212,181 @@ %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4] } + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 mod 32)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> ((d0 + 1) floordiv 32 - d0 floordiv 32 + 1)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (((d0 + 1) floordiv 32) * 32 - (d0 floordiv 32) * 32 + 32)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0) -> (d0 floordiv 16)> +// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0) -> (d0 mod 16)> +// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0) -> ((d0 + 3) floordiv 16 - d0 floordiv 16 + 1)> +// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0) -> (((d0 + 3) floordiv 16) * 16 - (d0 floordiv 16) * 16 + 16)> +// CHECK: func.func @NCnc_to_NC +// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]: +// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index +// CHECK-DAG: %[[C256:.*]] = arith.constant 256 : index +// CHECK: %{{.+}} = scf.for %[[I:.+]] = %[[C0]] to %[[C256]] step %[[C2]] +// CHECK: %{{.+}} = scf.for %[[J:.+]] = %[[C0]] to %[[C128]] step %[[C4]] +// CHECK-DAG: %[[IN_I:.+]] = affine.apply #[[MAP0]](%[[I]]) +// CHECK-DAG: %[[OFFSET_I:.+]] = affine.apply #[[MAP1]](%[[I]]) +// CHECK-DAG: %[[IN_I_SZ:.+]] = affine.apply #[[MAP2]](%[[I]]) +// CHECK-DAG: %[[IN_J:.+]] = affine.apply #[[MAP4]](%[[J]]) +// CHECK-DAG: %[[OFFSET_J:.+]] = affine.apply #[[MAP5]](%[[J]]) +// CHECK-DAG: %[[IN_J_SZ:.+]] = affine.apply #[[MAP6]](%[[J]]) +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[IN]] +// CHECK-SAME: [%[[IN_I]], %[[IN_J]], 0, 0] [%[[IN_I_SZ]], %[[IN_J_SZ]], 32, 16] +// CHECK-SAME: : tensor<8x8x32x16xf32> to tensor +// CHECK: %[[EMPTY:.+]] = tensor.empty +// CHECK: %[[UNPACK:.+]] = tensor.unpack +// CHECK-SAME: %[[SLICE]] inner_dims_pos = [0, 1] inner_tiles = [32, 16] +// CHECK-SAME: into %[[EMPTY]] +// CHECK: %[[UNPACK_SLICE:.+]] = tensor.extract_slice %[[UNPACK]] +// CHECK-SAME: [%[[OFFSET_I]], %[[OFFSET_J]]] [2, 4] +// CHECK: %[[RES:.+]] = tensor.insert_slice %[[UNPACK_SLICE]] +// CHECK-SAME: into %{{.+}}[%[[I]], %[[J]]] [2, 4] +// CHECK: scf.yield %[[RES]] +func.func @NCnc_to_NC(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> { + %0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32> + return %0 : tensor<256x128xf32> +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 + %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4] +} + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 mod 32)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> ((d0 + 1) floordiv 32 - d0 floordiv 32 + 1)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (((d0 + 1) floordiv 32) * 32 - (d0 floordiv 32) * 32 + 32)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0) -> (d0 floordiv 8)> +// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0) -> (d0 mod 8)> +// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0) -> ((d0 + 3) floordiv 8 - d0 floordiv 8 + 1)> +// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0) -> (((d0 + 3) floordiv 8) * 8 - (d0 floordiv 8) * 8 + 8)> +// CHECK: func.func @CKkc_to_KC +// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]: +// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index +// CHECK-DAG: %[[C256:.*]] = arith.constant 256 : index +// CHECK: %{{.+}} = scf.for %[[K:.+]] = %[[C0]] to %[[C128]] step %[[C2]] +// CHECK: %{{.+}} = scf.for %[[C:.+]] = %[[C0]] to %[[C256]] step %[[C4]] +// CHECK-DAG: %[[IN_K:.+]] = affine.apply #[[MAP0]](%[[K]]) +// CHECK-DAG: %[[OFFSET_K:.+]] = affine.apply #[[MAP1]](%[[K]]) +// CHECK-DAG: %[[IN_K_SZ:.+]] = affine.apply #[[MAP2]](%[[K]]) +// CHECK-DAG: %[[IN_C:.+]] = affine.apply #[[MAP4]](%[[C]]) +// CHECK-DAG: %[[OFFSET_C:.+]] = affine.apply #[[MAP5]](%[[C]]) +// CHECK-DAG: %[[IN_C_SZ:.+]] = affine.apply #[[MAP6]](%[[C]]) +// CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %[[IN]] +// CHECK: [%[[IN_C]], %[[IN_K]], 0, 0] [%[[IN_C_SZ]], %[[IN_K_SZ]], 32, 8] +// CHECK: %[[EMPTY:.+]] = tensor.empty +// CHECK: %[[UNPACK:.+]] = tensor.unpack +// CHECK-SAME: %[[IN_SLICE]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] +// CHECK-SAME: into %[[EMPTY]] +// CHECK: %[[UNPACK_SLICE:.+]] = tensor.extract_slice %[[UNPACK]] +// CHECK-SAME: [%[[OFFSET_K]], %[[OFFSET_C]]] [2, 4] +// CHECK: %[[RES:.+]] = tensor.insert_slice %[[UNPACK_SLICE]] +// CHECK-SAME: into %{{.+}}[%[[K]], %[[C]]] [2, 4] +// CHECK: scf.yield %[[RES]] +func.func @CKkc_to_KC(%source: tensor<32x4x32x8xf32>, %dest: tensor<128x256xf32>) -> tensor<128x256xf32> { + %0 = tensor.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %dest : tensor<32x4x32x8xf32> -> tensor<128x256xf32> + return %0 : tensor<128x256xf32> +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 + %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4] +} + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 4)> +// CHECK: func.func @perfect_CKkc_to_KC +// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]: +// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index +// CHECK: %{{.+}} = scf.for %[[K:.+]] = %[[C0]] to %[[C8]] step %[[C2]] +// CHECK: %{{.+}} = scf.for %[[C:.+]] = %[[C0]] to %[[C128]] step %[[C4]] +// CHECK-DAG: %[[IN_K:.+]] = affine.apply #[[MAP0]](%[[K]]) +// CHECK-DAG: %[[IN_C:.+]] = affine.apply #[[MAP1]](%[[C]]) +// CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %[[IN]] +// CHECK: [%[[IN_C]], %[[IN_K]], 0, 0] [1, 1, 2, 4] +// CHECK: %[[ITER_SLICE:.+]] = tensor.extract_slice %{{.+}}[%[[K]], %[[C]]] [2, 4] +// TODO: Add FoldTensorCastOp patterns for unpack op, then we do not need +// tensor.cast here. +// CHECK: %[[ITER_CAST:.+]] = tensor.cast %[[ITER_SLICE]] +// CHECK: %[[UNPACK:.+]] = tensor.unpack +// CHECK-SAME: %[[IN_SLICE]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 4] +// CHECK-SAME: into %[[ITER_CAST]] +// CHECK: %[[UNPACK_CAST:.+]] = tensor.cast %[[UNPACK]] +// CHECK: %[[RES:.+]] = tensor.insert_slice %[[UNPACK_CAST]] +// CHECK-SAME: into %{{.+}}[%[[K]], %[[C]]] [2, 4] +// CHECK: scf.yield %[[RES]] +func.func @perfect_CKkc_to_KC(%source: tensor<32x4x2x4xf32>, %dest: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = tensor.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 4] into %dest : tensor<32x4x2x4xf32> -> tensor<8x128xf32> + return %0 : tensor<8x128xf32> +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 + %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4] +} + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 floordiv 2)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 ceildiv 2)> +// CHECK: func.func @dynamic_perfect_CKkc_to_KC +// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]: +// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[OUT]], %[[C0]] +// CHECK-DAG: %[[DIM_1:.+]] = tensor.dim %[[OUT]], %[[C1]] +// CHECK: %{{.+}} = scf.for %[[K:.+]] = %[[C0]] to %[[DIM_0]] step %[[C2]] +// CHECK-DAG: %[[OUT_K_SZ:.+]] = affine.min #[[MAP0]](%[[K]])[%[[DIM_0]]] +// CHECK: %{{.+}} = scf.for %[[C:.+]] = %[[C0]] to %[[DIM_1]] step %[[C4]] +// CHECK-DAG: %[[OUT_C_SZ:.+]] = affine.min #[[MAP1]](%[[C]])[%[[DIM_1]]] +// CHECK-DAG: %[[IN_K:.+]] = affine.apply #[[MAP2]](%[[K]]) +// CHECK-DAG: %[[IN_C:.+]] = affine.apply #[[MAP2]](%[[C]]) +// CHECK-DAG: %[[IN_C_SZ:.+]] = affine.apply #[[MAP3]](%[[OUT_C_SZ]]) +// CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %[[IN]] +// CHECK: [%[[IN_C]], %[[IN_K]], 0, 0] [%[[IN_C_SZ]], 1, 2, 2] +// CHECK: %[[ITER_SLICE:.+]] = tensor.extract_slice %{{.+}}[%[[K]], %[[C]]] [%[[OUT_K_SZ]], %[[OUT_C_SZ]]] +// CHECK: %[[UNPACK:.+]] = tensor.unpack +// CHECK-SAME: %[[IN_SLICE]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 2] +// CHECK-SAME: into %[[ITER_SLICE]] +// CHECK: %[[RES:.+]] = tensor.insert_slice %[[UNPACK]] +// CHECK-SAME: into %{{.+}}[%[[K]], %[[C]]] [%[[OUT_K_SZ]], %[[OUT_C_SZ]]] +// CHECK: scf.yield %[[RES]] +func.func @dynamic_perfect_CKkc_to_KC(%source: tensor, %dest: tensor) -> tensor { + %0 = tensor.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %dest : tensor -> tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 + %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4] +}