diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h --- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h +++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h @@ -69,12 +69,23 @@ template void applyPermutationToVector(SmallVector &inVec, ArrayRef permutation) { + assert(inVec.size() == permutation.size()); SmallVector auxVec(inVec.size()); for (const auto &en : enumerate(permutation)) auxVec[en.index()] = inVec[en.value()]; inVec = auxVec; } +template +void undoPermutationToVector(SmallVector &inVec, + ArrayRef permutation) { + assert(inVec.size() == permutation.size()); + SmallVector auxVec = llvm::to_vector(inVec); + for (const auto &en : llvm::enumerate(permutation)) + auxVec[en.value()] = inVec[en.index()]; + inVec = auxVec; +} + /// Helper that returns a subset of `arrayAttr` as a vector of int64_t. SmallVector getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0, unsigned dropBack = 0); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -12,6 +12,8 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Interfaces/TilingInterface.h" using namespace mlir; @@ -68,6 +70,194 @@ } }; +/// Returns the permutation based on `dimsPos` and `rank`. The `dimsPos` is a +/// projected permutation which contains a list of position. The method returns +/// the reverse map of the projected permutation. +/// E.g., dimsPos = [2, 0] and rank = 2, the method returnssss [1, 0]. +static SmallVector +computeInversePermutationForDimsPos(ArrayRef dimsPos, int64_t rank) { + SmallVector inVec; + inVec.reserve(dimsPos.size()); + // First map dims and their position. For example, dimsPos = [2, 0] will map + // to: + // [ + // [ key: 2, value: 0] + // [ key: 0, value: 1] + // ] + // where key is the idx in dims_pos while value its position in dims_pos. + DenseMap dimsAndPosMapping; + for (auto dimsIdx : llvm::seq(0, dimsPos.size())) + dimsAndPosMapping[dimsPos[dimsIdx]] = dimsIdx; + + // Scan the position in order and insert the value in the map + // to compute the interchange vector. + for (auto dimsIdx : llvm::seq(0, rank)) + if (dimsAndPosMapping.count(dimsIdx)) + inVec.push_back(dimsAndPosMapping[dimsIdx]); + + return inVec; +} + +/// Builds an `affine_min` of `v1` and `v2`. +static OpFoldResult buildMin(OpBuilder &b, Location loc, OpFoldResult v1, + OpFoldResult v2) { + return makeComposedFoldedAffineMin( + b, loc, AffineMap::getMultiDimIdentityMap(2, loc.getContext()), {v1, v2}); +} + +/// Builds an `affine_apply` which subtracts `v2` from `v1`, i.e., "v1 - v2". +static OpFoldResult buildSub(OpBuilder &b, Location loc, OpFoldResult v1, + OpFoldResult v2) { + AffineExpr dim0, dim1; + bindDims(loc.getContext(), dim0, dim1); + auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); + return makeComposedFoldedAffineApply(b, loc, subMap, {v1, v2}); +} + +/// Builds an `affine_apply` which takes the product of `v1` and `v2`. Note that +/// v1 is dim expr and `v2` is symbol expr. +static OpFoldResult buildMul(OpBuilder &b, Location loc, OpFoldResult v1, + OpFoldResult v2) { + MLIRContext *ctx = loc.getContext(); + AffineExpr i, tile; + bindDims(ctx, i); + bindSymbols(ctx, tile); + return makeComposedFoldedAffineApply(b, loc, i * tile, + ArrayRef{v1, v2}); +} + +struct PackOpTiling + : public TilingInterface::ExternalModel { + + SmallVector getLoopIteratorTypes(Operation *op) const { + // Note that here we only consider untiled dimensions and outer tiled data + // dimensions, the inner tiled data dimensions are materialized when + // building the body of the operation. + auto packOp = cast(op); + SmallVector iteratorTypes( + packOp.getSourceRank(), utils::IteratorType::parallel); + return iteratorTypes; + } + + SmallVector getIterationDomain(Operation *op, OpBuilder &b) const { + OpBuilder::InsertionGuard guard(b); + auto packOp = cast(op); + Location loc = packOp.getLoc(); + int64_t rank = packOp.getSourceRank(); + Value zero = b.create(loc, 0); + Value one = b.create(loc, 1); + ReifiedRankedShapedTypeDims resultShape; + (void)packOp.reifyResultShapes(b, resultShape); + SmallVector loopRanges(rank); + for (auto dim : llvm::seq(0, rank)) { + loopRanges[dim].offset = zero; + loopRanges[dim].stride = one; + loopRanges[dim].size = resultShape[0][dim]; + } + return loopRanges; + } + + SmallVector + getTiledImplementation(Operation *op, OpBuilder &b, + ArrayRef offsets, + ArrayRef sizes) const { + auto packOp = cast(op); + Location loc = packOp.getLoc(); + + // The tiling is applied on interchanged dimensions. We have to undo the + // interchange to map sizes and offsets to the original input. + int64_t inputRank = packOp.getSourceRank(); + ArrayRef dimsToOuterBlock(packOp.getOuterDimsPerm()); + SmallVector origOffsets(offsets.begin(), offsets.end()); + SmallVector origSizes(sizes.begin(), sizes.end()); + if (!dimsToOuterBlock.empty()) { + SmallVector vec = + computeInversePermutationForDimsPos(dimsToOuterBlock, inputRank); + undoPermutationToVector(origOffsets, vec); + undoPermutationToVector(origSizes, vec); + } + + DenseMap dimAndTileMapping = + packOp.getDimAndTileMapping(); + SmallVector srcDimValues = + tensor::createDimValues(b, loc, packOp.getSource()); + SmallVector inputIndices, inputSizes; + for (auto dim : llvm::seq(0, inputRank)) { + if (dimAndTileMapping.count(dim)) { + // If the data dimension is tiled, the i-th index is the product of + // offset_i and tile_i, and the i-th size is the product of sizes_i and + // tile_i. + inputIndices.push_back( + buildMul(b, loc, origOffsets[dim], dimAndTileMapping[dim])); + inputSizes.push_back( + buildMul(b, loc, origSizes[dim], dimAndTileMapping[dim])); + } else { + inputIndices.push_back(origOffsets[dim]); + inputSizes.push_back(origSizes[dim]); + } + + // Limit the size of the input operand for incomplete tiles. + OpFoldResult dimSize = srcDimValues[dim]; + inputSizes.back() = + buildMin(b, loc, inputSizes.back(), + buildSub(b, loc, dimSize, inputIndices.back())); + } + + auto oneAttr = b.getI64IntegerAttr(1); + SmallVector strides(inputRank, oneAttr); + + SmallVector tiledOperands; + tiledOperands.push_back(b.create( + loc, packOp.getSource(), inputIndices, inputSizes, strides)); + + SmallVector outputOffsets, outputSizes; + if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets, + outputSizes))) + return {}; + + strides.append(packOp.getDestRank() - inputRank, oneAttr); + auto extractSlice = b.create( + loc, packOp.getDest(), outputOffsets, outputSizes, strides); + tiledOperands.push_back(extractSlice); + + if (auto val = packOp.getPaddingValue()) + tiledOperands.push_back(val); + for (auto tile : packOp.getInnerTiles()) + tiledOperands.push_back(tile); + + Operation *tiledPackOp = b.create( + loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs()); + + return {tiledPackOp}; + } + + LogicalResult + getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes, + SmallVector &resultOffsets, + SmallVector &resultSizes) const { + // The iteration domain is over outer dimensions of packed layout. In this + // context, the outer dimensions of `resultOffsets` are `offsets`. The + // inner dimensions of `resultOffsets` are zeros because tiling is not + // applied to them. + auto packOp = cast(op); + int64_t inputRank = packOp.getSourceRank(); + int64_t outputRank = packOp.getDestRank(); + auto zeroAttr = b.getI64IntegerAttr(0); + resultOffsets.assign(offsets.begin(), offsets.end()); + resultOffsets.append(outputRank - inputRank, zeroAttr); + + ReifiedRankedShapedTypeDims outputShape; + (void)packOp.reifyResultShapes(b, outputShape); + resultSizes.assign(sizes.begin(), sizes.end()); + for (auto dataTileDim : llvm::seq(inputRank, outputRank)) + resultSizes.push_back(getAsOpFoldResult(outputShape[0][dataTileDim])); + + return success(); + } +}; + } // namespace Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, @@ -282,5 +472,6 @@ DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { tensor::PadOp::attachInterface(*ctx); + tensor::PackOp::attachInterface(*ctx); }); } diff --git a/mlir/test/Dialect/Tensor/tiling.mlir b/mlir/test/Dialect/Tensor/tiling.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/tiling.mlir @@ -0,0 +1,214 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter -canonicalize -cse -split-input-file | FileCheck %s + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 32)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * -32 + 128, 64)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * -32 + 256, 128)> +// CHECK: func.func @NC_to_NCnc +// CHECK-SAME: %[[IN:.*]]: tensor<128x256xf32>, +// CHECK-SAME: %[[OUT:.*]]: tensor<4x8x32x32xf32>) -> tensor<4x8x32x32xf32> { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[RES0:.*]] = scf.for %[[N:.*]] = %[[C0]] to %[[C4]] step %[[C2]] iter_args(%[[ITER0:.*]] = %[[OUT]]) -> (tensor<4x8x32x32xf32>) { +// CHECK: %[[RES1:.+]] = scf.for %[[C:.*]] = %[[C0]] to %[[C8]] step %[[C4]] iter_args(%[[ITER1:.*]] = %[[ITER0]]) -> (tensor<4x8x32x32xf32>) { +// CHECK-DAG: %[[IN_N:.+]] = affine.apply #[[MAP0]](%[[N]]) +// CHECK-DAG: %[[IN_N_SZ:.*]] = affine.min #[[MAP1]] +// CHECK-DAG: %[[IN_C:.+]] = affine.apply #[[MAP0]](%[[C]]) +// CHECK-DAG: %[[IN_C_SZ:.*]] = affine.min #[[MAP2]] +// CHECK: %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][%[[IN_N]], %[[IN_C]]] [%[[IN_N_SZ]], %[[IN_C_SZ]]] [1, 1] : tensor<128x256xf32> to tensor +// CHECK: %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][%[[N]], %[[C]], 0, 0] [2, 4, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<2x4x32x32xf32> +// CHECK: %[[CAST_OUT:.*]] = tensor.cast %[[SUB_OUT]] +// CHECK: %[[SUB_RES:.*]] = tensor.pack +// CHECK-SAME: %[[SUB_IN]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[CAST_OUT]] +// CHECK: %[[CAST_RES:.*]] = tensor.cast %[[SUB_RES]] +// CHECK: %[[INSERT:.*]] = tensor.insert_slice %[[CAST_RES]] into %[[ITER1]] +// CHECK: scf.yield %[[INSERT]] : tensor<4x8x32x32xf32> +// CHECK: } +// CHECK: scf.yield %[[RES1:.*]] : tensor<4x8x32x32xf32> +// CHECK: } +// CHECK: return %[[RES0:.*]] : tensor<4x8x32x32xf32> +// CHECK: } +func.func @NC_to_NCnc(%arg0: tensor<128x256xf32>, %arg1: tensor<4x8x32x32xf32>) -> tensor<4x8x32x32xf32> { + %0 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg1 : tensor<128x256xf32> -> tensor<4x8x32x32xf32> + return %0 : tensor<4x8x32x32xf32> +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 + %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4] +} + +// ----- + +// CHECK: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 8)> +// CHECK: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * -8 + 256, 16)> +// CHECK: func.func @KC_to_CKkc +// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]: +// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index +// CHECK: scf.for %[[C:.+]] = %[[C0]] to %[[C32]] step %[[C2]] +// CHECK-DAG: %[[IN_C:.+]] = affine.apply #[[MAP0]](%[[C]]) +// CHECK-DAG: %[[IN_C_SZ:.+]] = affine.min #[[MAP1]](%[[C]]) +// CHECK: %[[INPUT_SLICE:.+]] = tensor.extract_slice %[[IN]] +// CHECK-SAME: [0, %[[IN_C]]] [128, %[[IN_C_SZ]]] +// CHECK: %[[CAST_IN:.+]] = tensor.cast %[[INPUT_SLICE]] +// CHECK: %[[OUTPUT_SLICE:.+]] = tensor.extract_slice %{{.+}}[%[[C]], 0, 0, 0] [2, 4, 32, 8] +// CHECK: %[[CAST_OUT:.+]] = tensor.cast %[[OUTPUT_SLICE]] +// CHECK: tensor.pack +// CHECK-SAME: %[[CAST_IN]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] +// CHECK-SAME: into %[[CAST_OUT]] +func.func @KC_to_CKkc(%arg0: tensor<128x256xf32>, %arg1: tensor<32x4x32x8xf32>) -> tensor<32x4x32x8xf32> { + %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<128x256xf32> -> tensor<32x4x32x8xf32> + return %0 : tensor<32x4x32x8xf32> +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 + %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4] +} + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * -2 + 15, 8)> +// CHECK: func.func @pad_and_pack_static( +// CHECK-SAME: %[[IN:.*]]: tensor<13x15xf32>, +// CHECK-SAME: %[[OUT:.*]]: tensor<2x8x8x2xf32>, +// CHECK-SAME: %[[PAD:.*]]: f32) -> tensor<2x8x8x2xf32> { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[RES0:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[C8]] step %[[C4]] iter_args(%[[ITER1:.*]] = %[[OUT]]) -> (tensor<2x8x8x2xf32>) { +// CHECK-DAG: %[[IN_J:.*]] = affine.apply #[[MAP0]](%[[J]]) +// CHECK-DAG: %[[IN_J_SZ:.*]] = affine.min #[[MAP1]](%[[J]]) +// CHECK: %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][0, %[[IN_J]]] [13, %[[IN_J_SZ]]] [1, 1] +// CHECK: %[[CAST_IN:.*]] = tensor.cast %[[SUB_IN]] +// CHECK: %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][0, %[[J]], 0, 0] [2, 4, 8, 2] [1, 1, 1, 1] +// CHECK: %[[CAST_OUT:.*]] = tensor.cast %[[SUB_OUT]] +// CHECK: %[[SUB_RES:.*]] = tensor.pack +// CHECK-SAME: %[[CAST_IN]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] +// CHECK-SAME: into %[[CAST_OUT]] +// CHECK: %[[CAST_RES:.*]] = tensor.cast %[[SUB_RES]] +// CHECK: %[[INSERT:.*]] = tensor.insert_slice %[[CAST_RES]] into %[[ITER1]] +// CHECK: scf.yield %[[INSERT]] : tensor<2x8x8x2xf32> +// CHECK: } +// CHECK: return %[[RES0:.*]] : tensor<2x8x8x2xf32> +// CHECK: } +func.func @pad_and_pack_static(%input: tensor<13x15xf32>, %output: tensor<2x8x8x2xf32>, %pad: f32) -> tensor<2x8x8x2xf32> { + %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<13x15xf32> -> tensor<2x8x8x2xf32> + return %0 : tensor<2x8x8x2xf32> +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 + %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4] +} + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 8)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1)[s0] -> (d1 * -8 + s0, d0 * 8)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0) -> (d0 * 2)> +// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1)[s0] -> (d1 * -2 + s0, d0 * 2)> +// CHECK: func.func @pad_and_pack_partially_dynamic( +// CHECK-SAME: %[[IN:.*]]: tensor, +// CHECK-SAME: %[[OUT:.*]]: tensor, +// CHECK-SAME: %[[PAD:.*]]: f32) -> tensor { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[OUT_D0:.*]] = tensor.dim %[[OUT]], %[[C0]] : tensor +// CHECK-DAG: %[[OUT_D1:.*]] = tensor.dim %[[OUT]], %[[C1]] : tensor +// CHECK: %[[RES0:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[OUT_D0]] step %[[C2]] iter_args(%[[ITER0:.*]] = %[[OUT]]) -> (tensor) { +// CHECK-DAG: %[[OUT_I_SZ:.*]] = affine.min #[[MAP0]](%[[I]])[%[[OUT_D0]]] +// CHECK: %[[RES1:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[OUT_D1]] step %[[C4]] iter_args(%[[ITER1:.*]] = %[[ITER0]]) -> (tensor) { +// CHECK-DAG: %[[OUT_J_SZ:.*]] = affine.min #[[MAP1]](%[[J]])[%[[OUT_D1]]] +// CHECK-DAG: %[[IN_I:.*]] = affine.apply #[[MAP2]](%[[I]]) +// CHECK-DAG: %[[IN_I_SZ:.*]] = affine.min #[[MAP3]] +// CHECK-DAG: %[[IN_J:.*]] = affine.apply #[[MAP4]](%[[J]]) +// CHECK-DAG: %[[IN_J_SZ:.*]] = affine.min #[[MAP5]] +// CHECK: %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][%[[IN_I]], %[[IN_J]]] [%[[IN_I_SZ]], %[[IN_J_SZ]]] [1, 1] : tensor to tensor +// CHECK: %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][%[[I]], %[[J]], 0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]], 8, 2] [1, 1, 1, 1] : tensor to tensor +// CHECK: %[[SUB_RES:.*]] = tensor.pack +// CHECK-SAME: %[[SUB_IN]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] +// CHECK-SAME: into %[[SUB_OUT]] +// CHECK: %[[INSERT:.*]] = tensor.insert_slice %[[SUB_RES]] into %[[ITER1]] +// CHECK: scf.yield %[[INSERT]] : tensor +// CHECK: } +// CHECK: scf.yield %[[RES1:.*]] : tensor +// CHECK: } +// CHECK: return %[[VAL_34:.*]] : tensor +// CHECK: } +func.func @pad_and_pack_partially_dynamic(%input: tensor, %output: tensor, %pad: f32) -> tensor { + %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor -> tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 + %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4] +} + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (d0 * s0)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0, -(d1 * s0) + s1)> +// CHECK: func.func @pad_and_pack_fully_dynamic( +// CHECK-SAME: %[[IN:.*]]: tensor, +// CHECK-SAME: %[[OUT:.*]]: tensor, +// CHECK-SAME: %[[PAD:.*]]: f32, +// CHECK-SAME: %[[TILE_0:.*]]: index, +// CHECK-SAME: %[[TILE_1:.*]]: index) -> tensor { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[OUT_D0:.*]] = tensor.dim %[[OUT]], %[[C0]] : tensor +// CHECK-DAG: %[[OUT_D1:.*]] = tensor.dim %[[OUT]], %[[C1]] : tensor +// CHECK: %[[RES0:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[OUT_D0]] step %[[C2]] iter_args(%[[ITER0:.*]] = %[[OUT]]) -> (tensor) { +// CHECK: %[[OUT_I_SZ:.*]] = affine.min #[[MAP0]](%[[I]])[%[[OUT_D0]]] +// CHECK: %[[RES1:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[OUT_D1]] step %[[C4]] iter_args(%[[ITER1:.*]] = %[[ITER0]]) -> (tensor) { +// CHECK: %[[OUT_J_SZ:.*]] = affine.min #[[MAP1]](%[[J]])[%[[OUT_D1]]] +// CHECK: %[[IN_D0:.*]] = tensor.dim %[[IN]], %[[C0]] +// CHECK: %[[IN_D1:.*]] = tensor.dim %[[IN]], %[[C1]] +// CHECK: %[[IN_I:.*]] = affine.apply #[[MAP2]](%[[I]])[%[[TILE_0]]] +// CHECK: %[[IN_I_SZ:.*]] = affine.min #[[MAP3]](%[[OUT_I_SZ]], %[[I]])[%[[TILE_0]], %[[IN_D0]]] +// CHECK: %[[IN_J:.*]] = affine.apply #[[MAP2]](%[[J]])[%[[TILE_1]]] +// CHECK: %[[IN_J_SZ:.*]] = affine.min #[[MAP3]](%[[OUT_J_SZ]], %[[J]])[%[[TILE_1]], %[[IN_D1]]] +// CHECK: %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][%[[IN_I]], %[[IN_J]]] [%[[IN_I_SZ]], %[[IN_J_SZ]]] [1, 1] : tensor to tensor +// CHECK: %[[OUT_D2:.+]] = tensor.dim %[[OUT]], %[[C2]] +// CHECK: %[[OUT_D3:.+]] = tensor.dim %[[OUT]], %[[C3]] +// CHECK: %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][%[[I]], %[[J]], 0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]], %[[OUT_D2]], %[[OUT_D3]]] [1, 1, 1, 1] : tensor to tensor +// CHECK: %[[PACK:.*]] = tensor.pack +// CHECK-SAME: %[[SUB_IN]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [%[[TILE_0]], %[[TILE_1]]] +// CHECK-SAME: into %[[SUB_OUT]] +// CHECK: %[[INSERT:.*]] = tensor.insert_slice %[[PACK]] into %[[ITER1]] +// CHECK: scf.yield %[[INSERT]] : tensor +// CHECK: } +// CHECK: scf.yield %[[RES1:.*]] : tensor +// CHECK: } +// CHECK: return %[[RES0:.*]] : tensor +// CHECK: } +func.func @pad_and_pack_fully_dynamic(%source: tensor, %dest: tensor, %pad: f32, %tile_n : index, %tile_m : index) -> tensor { + %0 = tensor.pack %source padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%tile_n, %tile_m] into %dest : tensor -> tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 + %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4] +}