diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h --- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h +++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h @@ -75,6 +75,15 @@ inVec = auxVec; } +template +SmallVector undoPermutationToVector(ArrayRef inVec, + ArrayRef permutation) { + SmallVector vec = llvm::to_vector(inVec); + for (auto en : llvm::enumerate(permutation)) + vec[en.value()] = inVec[en.index()]; + return vec; +} + /// Helper that returns a subset of `arrayAttr` as a vector of int64_t. SmallVector getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0, unsigned dropBack = 0); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -12,6 +12,8 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Interfaces/TilingInterface.h" using namespace mlir; @@ -68,6 +70,190 @@ } }; +/// Returns the permutation based on `dimsPos`. +SmallVector computeInterchangeFromDimPos(ArrayRef dimsPos, + int64_t inputRank) { + SmallVector interchangeVector; + interchangeVector.reserve(dimsPos.size()); + // First map dims and their position. For example, dims_pos = [2, 0] will map + // to: + // [ + // [ key: 2, value: 0] + // [ key: 0, value: 1] + // ] + // where key is the idx in dims_pos while value its position in dims_pos. + DenseMap dimsAndPosMapping; + for (int64_t dimsIdx = 0, end = dimsPos.size(); dimsIdx < end; dimsIdx++) + dimsAndPosMapping[dimsPos[dimsIdx]] = dimsIdx; + + // Scan the position in order and insert the value in the map + // to compute the interchange vector. + for (int64_t dimsIdx = 0; dimsIdx < inputRank; dimsIdx++) + if (dimsAndPosMapping.count(dimsIdx)) + interchangeVector.push_back(dimsAndPosMapping[dimsIdx]); + + return interchangeVector; +} + +struct PackOpTiling + : public TilingInterface::ExternalModel { + + SmallVector getLoopIteratorTypes(Operation *op) const { + // Note that here we only consider untiled dimensions and outer tiled data + // dimensions, the inner tiled data dimensions are materialized when + // building the body of the operation. + auto packOp = cast(op); + SmallVector iteratorTypes( + packOp.getSourceRank(), utils::IteratorType::parallel); + return iteratorTypes; + } + + SmallVector getIterationDomain(Operation *op, OpBuilder &b) const { + OpBuilder::InsertionGuard guard(b); + auto packOp = cast(op); + Location loc = packOp.getLoc(); + int64_t rank = packOp.getSourceRank(); + Value zero = b.create(loc, 0); + Value one = b.create(loc, 1); + ReifiedRankedShapedTypeDims resultShape; + (void)packOp.reifyResultShapes(b, resultShape); + SmallVector loopRanges(rank); + for (auto dim : llvm::seq(0, rank)) { + loopRanges[dim].offset = zero; + loopRanges[dim].stride = one; + loopRanges[dim].size = resultShape[0][dim]; + } + return loopRanges; + } + + SmallVector + getTiledImplementation(Operation *op, OpBuilder &b, + ArrayRef offsets, + ArrayRef sizes) const { + auto packOp = cast(op); + Location loc = packOp.getLoc(); + auto ctx = b.getContext(); + + // Take the minimum of two integers. + auto idMap = AffineMap::getMultiDimIdentityMap(2, ctx); + auto min = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult { + return makeComposedFoldedAffineMin(b, loc, idMap, {v1, v2}); + }; + // Subtract two integers. + AffineExpr dim0, dim1; + bindDims(ctx, dim0, dim1); + auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); + auto sub = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult { + return makeComposedFoldedAffineApply(b, loc, subMap, {v1, v2}); + }; + + // The tiling is applied on interchanged dimensions. We have to undo the + // interchange to map sizes and offsets to the original input. + int64_t inputRank = packOp.getSourceRank(); + SmallVector dimsToOuterBlock(packOp.getOuterDimsPerm()); + SmallVector origOffsets(offsets.begin(), offsets.end()); + SmallVector origSizes(sizes.begin(), sizes.end()); + if (!dimsToOuterBlock.empty()) { + SmallVector vec = + computeInterchangeFromDimPos(dimsToOuterBlock, inputRank); + origOffsets = undoPermutationToVector(origOffsets, vec); + origSizes = undoPermutationToVector(origSizes, vec); + } + + DenseMap dimAndTileMapping = + packOp.getDimAndTileMapping(); + SmallVector srcDimValues = + tensor::createDimValues(b, loc, packOp.getSource()); + SmallVector inputIndices, inputSizes; + for (auto dim : llvm::seq(0, inputRank)) { + if (dimAndTileMapping.count(dim)) { + // If the data dimension is tiled, the i-th index is the product of + // offset_i and tile_i, and the i-th size is the product of sizes_i and + // tile_i. + AffineExpr i, tile; + bindDims(ctx, i); + bindSymbols(ctx, tile); + OpFoldResult inputIndex = makeComposedFoldedAffineApply( + b, loc, i * tile, + ArrayRef{origOffsets[dim], dimAndTileMapping[dim]}); + inputIndices.push_back(inputIndex); + + OpFoldResult inputSize = makeComposedFoldedAffineApply( + b, loc, i * tile, + ArrayRef{origSizes[dim], dimAndTileMapping[dim]}); + inputSizes.push_back(inputSize); + } else { + inputIndices.push_back(origOffsets[dim]); + inputSizes.push_back(origSizes[dim]); + } + + // Limit the size of the input operand for incomplete tiles. + OpFoldResult dimSize = srcDimValues[dim]; + inputSizes.back() = + min(inputSizes.back(), sub(dimSize, inputIndices.back())); + } + + auto oneAttr = b.getI64IntegerAttr(1); + SmallVector strides(inputRank, oneAttr); + + SmallVector tiledOperands; + tiledOperands.push_back(b.create( + loc, packOp.getSource(), inputIndices, inputSizes, strides)); + + SmallVector outputOffsets, outputSizes; + if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets, + outputSizes))) { + return {}; + } + strides.append(packOp.getDestRank() - inputRank, oneAttr); + auto extractSlice = b.create( + loc, packOp.getDest(), outputOffsets, outputSizes, strides); + tiledOperands.push_back(extractSlice); + + if (auto val = packOp.getPaddingValue()) + tiledOperands.push_back(val); + for (auto tile : packOp.getInnerTiles()) + tiledOperands.push_back(tile); + + Operation *tiledPackOp = b.create( + loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs()); + + return {tiledPackOp}; + } + + LogicalResult + getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes, + SmallVector &resultOffsets, + SmallVector &resultSizes) const { + // The tiling is applied on outer dimensions. In this context, the outer + // dimensions of result tile position is the same. The inner offsets are + // zeros because tiling is not applied to them. + auto packOp = cast(op); + int64_t inputRank = packOp.getSourceRank(); + int64_t outputRank = packOp.getDestRank(); + auto zeroAttr = b.getI64IntegerAttr(0); + resultOffsets.assign(offsets.begin(), offsets.end()); + resultOffsets.append(outputRank - inputRank, zeroAttr); + + ReifiedRankedShapedTypeDims outputShape; + if (failed(packOp.reifyResultShapes(b, outputShape))) + return op->emitOpError("failed to reify result shape"); + if (outputShape.size() != 1 || + static_cast(outputShape[0].size()) != packOp.getDestRank()) { + return op->emitOpError("expected shape of one result value of rank") + << outputRank; + } + + resultSizes.assign(sizes.begin(), sizes.end()); + for (auto dataTileDim : llvm::seq(inputRank, outputRank)) + resultSizes.push_back(getAsOpFoldResult(outputShape[0][dataTileDim])); + + return success(); + } +}; + } // namespace Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, @@ -282,5 +468,6 @@ DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { tensor::PadOp::attachInterface(*ctx); + tensor::PackOp::attachInterface(*ctx); }); } diff --git a/mlir/test/Dialect/Tensor/tiling.mlir b/mlir/test/Dialect/Tensor/tiling.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/tiling.mlir @@ -0,0 +1,214 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter -canonicalize -cse -split-input-file | FileCheck %s + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 32)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * -32 + 128, 64)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * -32 + 256, 128)> +// CHECK: func.func @NC_to_NCnc +// CHECK-SAME: %[[IN:.*]]: tensor<128x256xf32>, +// CHECK-SAME: %[[OUT:.*]]: tensor<4x8x32x32xf32>) -> tensor<4x8x32x32xf32> { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[RES0:.*]] = scf.for %[[N:.*]] = %[[C0]] to %[[C4]] step %[[C2]] iter_args(%[[ITER0:.*]] = %[[OUT]]) -> (tensor<4x8x32x32xf32>) { +// CHECK: %[[RES1:.+]] = scf.for %[[C:.*]] = %[[C0]] to %[[C8]] step %[[C4]] iter_args(%[[ITER1:.*]] = %[[ITER0]]) -> (tensor<4x8x32x32xf32>) { +// CHECK-DAG: %[[IN_N:.+]] = affine.apply #[[MAP0]](%[[N]]) +// CHECK-DAG: %[[IN_N_SZ:.*]] = affine.min #[[MAP1]] +// CHECK-DAG: %[[IN_C:.+]] = affine.apply #[[MAP0]](%[[C]]) +// CHECK-DAG: %[[IN_C_SZ:.*]] = affine.min #[[MAP2]] +// CHECK: %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][%[[IN_N]], %[[IN_C]]] [%[[IN_N_SZ]], %[[IN_C_SZ]]] [1, 1] : tensor<128x256xf32> to tensor +// CHECK: %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][%[[N]], %[[C]], 0, 0] [2, 4, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<2x4x32x32xf32> +// CHECK: %[[CAST_OUT:.*]] = tensor.cast %[[SUB_OUT]] +// CHECK: %[[SUB_RES:.*]] = tensor.pack +// CHECK-SAME: %[[SUB_IN]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[CAST_OUT]] +// CHECK: %[[CAST_RES:.*]] = tensor.cast %[[SUB_RES]] +// CHECK: %[[INSERT:.*]] = tensor.insert_slice %[[CAST_RES]] into %[[ITER1]] +// CHECK: scf.yield %[[INSERT]] : tensor<4x8x32x32xf32> +// CHECK: } +// CHECK: scf.yield %[[RES1:.*]] : tensor<4x8x32x32xf32> +// CHECK: } +// CHECK: return %[[RES0:.*]] : tensor<4x8x32x32xf32> +// CHECK: } +func.func @NC_to_NCnc(%arg0: tensor<128x256xf32>, %arg1: tensor<4x8x32x32xf32>) -> tensor<4x8x32x32xf32> { + %0 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg1 : tensor<128x256xf32> -> tensor<4x8x32x32xf32> + return %0 : tensor<4x8x32x32xf32> +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 + %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4] +} + +// ----- + +// CHECK: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 8)> +// CHECK: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * -8 + 256, 16)> +// CHECK: func.func @KC_to_CKkc +// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]: +// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index +// CHECK: scf.for %[[C:.+]] = %[[C0]] to %[[C32]] step %[[C2]] +// CHECK-DAG: %[[IN_C:.+]] = affine.apply #[[MAP0]](%[[C]]) +// CHECK-DAG: %[[IN_C_SZ:.+]] = affine.min #[[MAP1]](%[[C]]) +// CHECK: %[[INPUT_SLICE:.+]] = tensor.extract_slice %[[IN]] +// CHECK-SAME: [0, %[[IN_C]]] [128, %[[IN_C_SZ]]] +// CHECK: %[[CAST_IN:.+]] = tensor.cast %[[INPUT_SLICE]] +// CHECK: %[[OUTPUT_SLICE:.+]] = tensor.extract_slice %{{.+}}[%[[C]], 0, 0, 0] [2, 4, 32, 8] +// CHECK: %[[CAST_OUT:.+]] = tensor.cast %[[OUTPUT_SLICE]] +// CHECK: tensor.pack +// CHECK-SAME: %[[CAST_IN]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] +// CHECK-SAME: into %[[CAST_OUT]] +func.func @KC_to_CKkc(%arg0: tensor<128x256xf32>, %arg1: tensor<32x4x32x8xf32>) -> tensor<32x4x32x8xf32> { + %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<128x256xf32> -> tensor<32x4x32x8xf32> + return %0 : tensor<32x4x32x8xf32> +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 + %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4] +} + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * -2 + 15, 8)> +// CHECK: func.func @pad_and_pack_static( +// CHECK-SAME: %[[IN:.*]]: tensor<13x15xf32>, +// CHECK-SAME: %[[OUT:.*]]: tensor<2x8x8x2xf32>, +// CHECK-SAME: %[[PAD:.*]]: f32) -> tensor<2x8x8x2xf32> { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[RES0:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[C8]] step %[[C4]] iter_args(%[[ITER1:.*]] = %[[OUT]]) -> (tensor<2x8x8x2xf32>) { +// CHECK-DAG: %[[IN_J:.*]] = affine.apply #[[MAP0]](%[[J]]) +// CHECK-DAG: %[[IN_J_SZ:.*]] = affine.min #[[MAP1]](%[[J]]) +// CHECK: %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][0, %[[IN_J]]] [13, %[[IN_J_SZ]]] [1, 1] +// CHECK: %[[CAST_IN:.*]] = tensor.cast %[[SUB_IN]] +// CHECK: %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][0, %[[J]], 0, 0] [2, 4, 8, 2] [1, 1, 1, 1] +// CHECK: %[[CAST_OUT:.*]] = tensor.cast %[[SUB_OUT]] +// CHECK: %[[SUB_RES:.*]] = tensor.pack +// CHECK-SAME: %[[CAST_IN]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] +// CHECK-SAME: into %[[CAST_OUT]] +// CHECK: %[[CAST_RES:.*]] = tensor.cast %[[SUB_RES]] +// CHECK: %[[INSERT:.*]] = tensor.insert_slice %[[CAST_RES]] into %[[ITER1]] +// CHECK: scf.yield %[[INSERT]] : tensor<2x8x8x2xf32> +// CHECK: } +// CHECK: return %[[RES0:.*]] : tensor<2x8x8x2xf32> +// CHECK: } +func.func @pad_and_pack_static(%input: tensor<13x15xf32>, %output: tensor<2x8x8x2xf32>, %pad: f32) -> tensor<2x8x8x2xf32> { + %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<13x15xf32> -> tensor<2x8x8x2xf32> + return %0 : tensor<2x8x8x2xf32> +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 + %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4] +} + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 8)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1)[s0] -> (d1 * -8 + s0, d0 * 8)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0) -> (d0 * 2)> +// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1)[s0] -> (d1 * -2 + s0, d0 * 2)> +// CHECK: func.func @pad_and_pack_partially_dynamic( +// CHECK-SAME: %[[IN:.*]]: tensor, +// CHECK-SAME: %[[OUT:.*]]: tensor, +// CHECK-SAME: %[[PAD:.*]]: f32) -> tensor { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[OUT_D0:.*]] = tensor.dim %[[OUT]], %[[C0]] : tensor +// CHECK-DAG: %[[OUT_D1:.*]] = tensor.dim %[[OUT]], %[[C1]] : tensor +// CHECK: %[[RES0:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[OUT_D0]] step %[[C2]] iter_args(%[[ITER0:.*]] = %[[OUT]]) -> (tensor) { +// CHECK-DAG: %[[OUT_I_SZ:.*]] = affine.min #[[MAP0]](%[[I]])[%[[OUT_D0]]] +// CHECK: %[[RES1:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[OUT_D1]] step %[[C4]] iter_args(%[[ITER1:.*]] = %[[ITER0]]) -> (tensor) { +// CHECK-DAG: %[[OUT_J_SZ:.*]] = affine.min #[[MAP1]](%[[J]])[%[[OUT_D1]]] +// CHECK-DAG: %[[IN_I:.*]] = affine.apply #[[MAP2]](%[[I]]) +// CHECK-DAG: %[[IN_I_SZ:.*]] = affine.min #[[MAP3]] +// CHECK-DAG: %[[IN_J:.*]] = affine.apply #[[MAP4]](%[[J]]) +// CHECK-DAG: %[[IN_J_SZ:.*]] = affine.min #[[MAP5]] +// CHECK: %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][%[[IN_I]], %[[IN_J]]] [%[[IN_I_SZ]], %[[IN_J_SZ]]] [1, 1] : tensor to tensor +// CHECK: %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][%[[I]], %[[J]], 0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]], 8, 2] [1, 1, 1, 1] : tensor to tensor +// CHECK: %[[SUB_RES:.*]] = tensor.pack +// CHECK-SAME: %[[SUB_IN]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] +// CHECK-SAME: into %[[SUB_OUT]] +// CHECK: %[[INSERT:.*]] = tensor.insert_slice %[[SUB_RES]] into %[[ITER1]] +// CHECK: scf.yield %[[INSERT]] : tensor +// CHECK: } +// CHECK: scf.yield %[[RES1:.*]] : tensor +// CHECK: } +// CHECK: return %[[VAL_34:.*]] : tensor +// CHECK: } +func.func @pad_and_pack_partially_dynamic(%input: tensor, %output: tensor, %pad: f32) -> tensor { + %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor -> tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 + %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4] +} + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (d0 * s0)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0, -(d1 * s0) + s1)> +// CHECK: func.func @pad_and_pack_fully_dynamic( +// CHECK-SAME: %[[IN:.*]]: tensor, +// CHECK-SAME: %[[OUT:.*]]: tensor, +// CHECK-SAME: %[[PAD:.*]]: f32, +// CHECK-SAME: %[[TILE_0:.*]]: index, +// CHECK-SAME: %[[TILE_1:.*]]: index) -> tensor { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[OUT_D0:.*]] = tensor.dim %[[OUT]], %[[C0]] : tensor +// CHECK-DAG: %[[OUT_D1:.*]] = tensor.dim %[[OUT]], %[[C1]] : tensor +// CHECK: %[[RES0:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[OUT_D0]] step %[[C2]] iter_args(%[[ITER0:.*]] = %[[OUT]]) -> (tensor) { +// CHECK: %[[OUT_I_SZ:.*]] = affine.min #[[MAP0]](%[[I]])[%[[OUT_D0]]] +// CHECK: %[[RES1:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[OUT_D1]] step %[[C4]] iter_args(%[[ITER1:.*]] = %[[ITER0]]) -> (tensor) { +// CHECK: %[[OUT_J_SZ:.*]] = affine.min #[[MAP1]](%[[J]])[%[[OUT_D1]]] +// CHECK: %[[IN_D0:.*]] = tensor.dim %[[IN]], %[[C0]] +// CHECK: %[[IN_D1:.*]] = tensor.dim %[[IN]], %[[C1]] +// CHECK: %[[IN_I:.*]] = affine.apply #[[MAP2]](%[[I]])[%[[TILE_0]]] +// CHECK: %[[IN_I_SZ:.*]] = affine.min #[[MAP3]](%[[OUT_I_SZ]], %[[I]])[%[[TILE_0]], %[[IN_D0]]] +// CHECK: %[[IN_J:.*]] = affine.apply #[[MAP2]](%[[J]])[%[[TILE_1]]] +// CHECK: %[[IN_J_SZ:.*]] = affine.min #[[MAP3]](%[[OUT_J_SZ]], %[[J]])[%[[TILE_1]], %[[IN_D1]]] +// CHECK: %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][%[[IN_I]], %[[IN_J]]] [%[[IN_I_SZ]], %[[IN_J_SZ]]] [1, 1] : tensor to tensor +// CHECK: %[[OUT_D2:.+]] = tensor.dim %[[OUT]], %[[C2]] +// CHECK: %[[OUT_D3:.+]] = tensor.dim %[[OUT]], %[[C3]] +// CHECK: %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][%[[I]], %[[J]], 0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]], %[[OUT_D2]], %[[OUT_D3]]] [1, 1, 1, 1] : tensor to tensor +// CHECK: %[[PACK:.*]] = tensor.pack +// CHECK-SAME: %[[SUB_IN]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [%[[TILE_0]], %[[TILE_1]]] +// CHECK-SAME: into %[[SUB_OUT]] +// CHECK: %[[INSERT:.*]] = tensor.insert_slice %[[PACK]] into %[[ITER1]] +// CHECK: scf.yield %[[INSERT]] : tensor +// CHECK: } +// CHECK: scf.yield %[[RES1:.*]] : tensor +// CHECK: } +// CHECK: return %[[RES0:.*]] : tensor +// CHECK: } +func.func @pad_and_pack_fully_dynamic(%source: tensor, %dest: tensor, %pad: f32, %tile_n : index, %tile_m : index) -> tensor { + %0 = tensor.pack %source padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%tile_n, %tile_m] into %dest : tensor -> tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 + %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4] +}