diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -152,6 +152,18 @@ } } +// Insert a tile `source` into the destination tensor `dest`. The position at +// which the tile is inserted (as well as size of tile) is taken from a given +// ExtractSliceOp `sliceOp`. +static Value insertSliceIntoTensor(OpBuilder &b, Location loc, + tensor::ExtractSliceOp sliceOp, Value source, + Value dest) { + return b.create( + loc, sliceOp.source().getType(), source, dest, sliceOp.offsets(), + sliceOp.sizes(), sliceOp.strides(), sliceOp.static_offsets(), + sliceOp.static_sizes(), sliceOp.static_strides()); +} + template static Optional tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes, @@ -259,11 +271,8 @@ // `tiledOperands`. Value outputTensor = tiledOperands[opOperand->getOperandNumber()]; if (auto sliceOp = outputTensor.getDefiningOp()) { - tensorResults.push_back(b.create( - loc, sliceOp.source().getType(), res->getResult(resultIdx), - sliceOp.source(), sliceOp.offsets(), sliceOp.sizes(), - sliceOp.strides(), sliceOp.static_offsets(), sliceOp.static_sizes(), - sliceOp.static_strides())); + tensorResults.push_back(insertSliceIntoTensor( + b, loc, sliceOp, res->getResult(resultIdx), sliceOp.source())); } else { tensorResults.push_back(res->getResult(resultIdx)); } @@ -341,6 +350,84 @@ return llvm::None; } +/// Generate a loop nest around a given PadTensorOp (for tiling). `newPadOp` +/// is an output parameter and returns the new (tiled) PadTensorOp. +static LogicalResult tilePadTensorOp(PatternRewriter &rewriter, PadTensorOp op, + PadTensorOp &newPadOp, + const LinalgTilingOptions &options) { + // Can tile only PadTensorOp that have an output operand. + if (!op.output()) + return failure(); + + Location loc = op.getLoc(); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + + // Clone PadTensorOp so that the existing op can be replaced more easily. + newPadOp = cast(rewriter.clone(*op.getOperation())); + // Get rank and tile sizes. + int64_t rank = op.getResultType().getRank(); + SmallVector tileSizes = + options.tileSizeComputationFunction(rewriter, op); + assert(tileSizes.size() == rank); + // Compute lower and upper bounds of the loop nest. + SmallVector lbs, dims, steps; + for (int64_t i = 0; i < rank; ++i) { + if (!isZero(tileSizes[i])) { + lbs.push_back(rewriter.create(loc, 0)); + dims.push_back(rewriter.create(loc, op.output(), i)); + steps.push_back(tileSizes[i]); + } + } + // Generate loop nest: One loop per dimension. + LoopNest loopNest = mlir::scf::buildLoopNest( + rewriter, loc, lbs, /*ubs=*/dims, steps, ValueRange(op.output()), + [&](OpBuilder &b, Location loc, ValueRange localIvs, + ValueRange iterArgs) -> scf::ValueVector { + // Compute offsets and sizes of ExtractSliceOp. + SmallVector offsets = + computeTileOffsets(b, loc, localIvs, tileSizes); + SmallVector sizes = + computeTileSizes(b, loc, localIvs, tileSizes, dims); + // Create ExtractSliceOp: Extract a tile from the PadTensorOp. + // Note: The PadTensorOp is located outside of the loop nest. It is + // later moved inside by ExtractSliceOfPadTensorSwapPattern. + auto map = AffineMap::getMultiDimIdentityMap(rank, b.getContext()); + Value tiledOutput = makeTiledShape(b, loc, newPadOp->getResult(0), + tileSizes, map, offsets, sizes); + auto sliceOp = tiledOutput.getDefiningOp(); + assert(sliceOp && "expected ExtractSliceOp"); + // Insert the tile into the output tensor. + Value yieldValue = + insertSliceIntoTensor(b, loc, sliceOp, sliceOp, iterArgs[0]); + return scf::ValueVector({yieldValue}); + }); + // Replace all uses of the original PadTensorOp. + rewriter.replaceOp(op, loopNest.getResults()[0]); + return success(); +} + +namespace { +struct PadTensorOpTilingPattern : public OpRewritePattern { + PadTensorOpTilingPattern(MLIRContext *ctx, LinalgTilingOptions opt) + : OpRewritePattern(ctx), options(opt) {} + + LogicalResult matchAndRewrite(PadTensorOp op, + PatternRewriter &rewriter) const override { + if (op->hasAttr(LinalgTransforms::kLinalgTransformMarker)) + return failure(); + PadTensorOp newPadOp; + if (failed(tilePadTensorOp(rewriter, op, newPadOp, options))) + return failure(); + newPadOp->setAttr(LinalgTransforms::kLinalgTransformMarker, + rewriter.getUnitAttr()); + return success(); + } + + LinalgTilingOptions options; +}; +} // namespace + namespace { /// Helper classes for type list expansion. template @@ -408,6 +495,7 @@ memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx); tensor::CastOp::getCanonicalizationPatterns(patterns, ctx); memref::ViewOp::getCanonicalizationPatterns(patterns, ctx); + PadTensorOp::getCanonicalizationPatterns(patterns, ctx); ctx->getLoadedDialect()->getCanonicalizationPatterns(patterns); CanonicalizationPatternList< #define GET_OP_LIST @@ -422,6 +510,8 @@ #define GET_OP_LIST #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" >::insert(patterns, options); + patterns.add(patterns.getContext(), options); + patterns.add(patterns.getContext()); } static void diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" @@ -550,7 +551,7 @@ if (!isTiled(map.getSubMap({r}), tileSizes)) { offsets.push_back(builder.getIndexAttr(0)); Value dim = createOrFoldDimOp(builder, loc, valueToTile, r); - sizes.push_back(dim); + sizes.push_back(getAsOpFoldResult(dim)); strides.push_back(builder.getIndexAttr(1)); LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n"); continue; diff --git a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir @@ -105,14 +105,12 @@ // CHECK-1DIM-TILE: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor) { // CHECK-1DIM-TILE: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor) { // CHECK-1DIM-TILE: %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor to tensor -// CHECK-1DIM-TILE: %[[sTAc:.*]] = tensor.cast %[[sTA]] : tensor to tensor // CHECK-1DIM-TILE: %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<8x?xi8> to tensor<8x?xi8> -// CHECK-1DIM-TILE: %[[sTBc:.*]] = tensor.cast %[[sTB]] : tensor<8x?xi8> to tensor // CHECK-1DIM-TILE: %[[sTC:.*]] = tensor.extract_slice %[[TC1]][{{.*}}] : tensor to tensor -// CHECK-1DIM-TILE: %[[pA:.*]] = linalg.pad_tensor %[[sTAc]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] -// CHECK-1DIM-TILE: : tensor to tensor<2x8xi8> -// CHECK-1DIM-TILE: %[[pB:.*]] = linalg.pad_tensor %[[sTBc]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] -// CHECK-1DIM-TILE: : tensor to tensor<8x3xi8> +// CHECK-1DIM-TILE: %[[pA:.*]] = linalg.pad_tensor %[[sTA]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] +// CHECK-1DIM-TILE: : tensor to tensor<2x8xi8> +// CHECK-1DIM-TILE: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] +// CHECK-1DIM-TILE: : tensor<8x?xi8> to tensor<8x3xi8> // CHECK-1DIM-TILE: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] // CHECK-1DIM-TILE: : tensor to tensor<2x3xi32> // CHECK-1DIM-TILE: %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x8xi8>, tensor<8x3xi8>) diff --git a/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir @@ -0,0 +1,94 @@ +// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3" -cse -split-input-file | \ +// RUN: FileCheck %s -check-prefix=TILE2 +// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,3" -cse -split-input-file | \ +// RUN: FileCheck %s -check-prefix=TILE1 + +// TILE2-LABEL: func @dynamic_pad_tensor( +// TILE2-SAME: %[[IN:.*]]: tensor, %[[OUT:.*]]: tensor +// TILE2-DAG: %[[C0:.*]] = constant 0 : index +// TILE2-DAG: %[[C1:.*]] = constant 1 : index +// TILE2-DAG: %[[C2:.*]] = constant 2 : index +// TILE2-DAG: %[[C3:.*]] = constant 3 : index +// TILE2: %[[DIM0:.*]] = tensor.dim %[[OUT]], %[[C0]] +// TILE2: %[[DIM1:.*]] = tensor.dim %[[OUT]], %[[C1]] +// TILE2: %[[RESULT:.*]] = scf.for {{.*}} = %[[C0]] to %[[DIM0]] step %[[C2]] +// TILE2: scf.for {{.*}} = %[[C0]] to %[[DIM1]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] = +// TILE2: %[[SWAP_RESULT:.*]] = scf.if +// TILE2: tensor.generate +// TILE2: else +// TILE2: %[[SLICE:.*]] = tensor.extract_slice %[[IN]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1] +// TILE2: %[[PAD:.*]] = linalg.pad_tensor %[[SLICE]] +// TILE2: tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1] +// TILE2: return %[[RESULT]] + +// TILE1-LABEL: func @dynamic_pad_tensor( +// TILE1-SAME: %[[IN:.*]]: tensor, %[[OUT:.*]]: tensor +// TILE1-DAG: %[[C0:.*]] = constant 0 : index +// TILE1-DAG: %[[C1:.*]] = constant 1 : index +// TILE1-DAG: %[[C3:.*]] = constant 3 : index +// TILE1: %[[DIM1:.*]] = tensor.dim %[[OUT]], %[[C1]] +// TILE1: %[[RESULT:.*]] = scf.for {{.*}} = %[[C0]] to %[[DIM1]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] = +// TILE1: %[[DIM0:.*]] = tensor.dim %[[OUT]], %[[C0]] +// TILE1: %[[SWAP_RESULT:.*]] = scf.if +// TILE1: tensor.generate +// TILE1: else +// TILE1: %[[SLICE:.*]] = tensor.extract_slice %[[IN]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1] +// TILE1: %[[PAD:.*]] = linalg.pad_tensor %[[SLICE]] low[3, %{{.*}}] high[{{.*}}, {{.*}}] +// TILE1: tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][0, {{.*}}] [%[[DIM0]], {{.*}}] [1, 1] +// TILE1: return %[[RESULT]] + +func @dynamic_pad_tensor(%input_tensor: tensor, + %output_tensor: tensor, + %pad_value: f32) -> tensor { + %0 = linalg.pad_tensor %input_tensor + low[3, 4] high[5, 3] into %output_tensor{ + ^bb0(%arg1: index, %arg2: index): + linalg.yield %pad_value : f32 + } : tensor to tensor + return %0 : tensor +} + +// ----- + +// TILE2-LABEL: func @static_pad_tensor( +// TILE2-SAME: %[[IN:.*]]: tensor<7x9xf32>, %[[OUT:.*]]: tensor<15x16xf32> +// TILE2-DAG: %[[C0:.*]] = constant 0 : index +// TILE2-DAG: %[[C2:.*]] = constant 2 : index +// TILE2-DAG: %[[C3:.*]] = constant 3 : index +// TILE2-DAG: %[[C15:.*]] = constant 15 : index +// TILE2-DAG: %[[C16:.*]] = constant 16 : index +// TILE2: %[[RESULT:.*]] = scf.for {{.*}} = %[[C0]] to %[[C15]] step %[[C2]] +// TILE2: scf.for {{.*}} = %[[C0]] to %[[C16]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] = +// TILE2: %[[SWAP_RESULT:.*]] = scf.if +// TILE2: tensor.generate +// TILE2: else +// TILE2: %[[SLICE:.*]] = tensor.extract_slice %[[IN]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1] +// TILE2: %[[PAD:.*]] = linalg.pad_tensor %[[SLICE]] +// TILE2: tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1] +// TILE2: return %[[RESULT]] + + +// TILE1-LABEL: func @static_pad_tensor( +// TILE1-SAME: %[[IN:.*]]: tensor<7x9xf32>, %[[OUT:.*]]: tensor<15x16xf32> +// TILE1-DAG: %[[C0:.*]] = constant 0 : index +// TILE1-DAG: %[[C3:.*]] = constant 3 : index +// TILE1-DAG: %[[C16:.*]] = constant 16 : index +// TILE1: %[[RESULT:.*]] = scf.for {{.*}} = %[[C0]] to %[[C16]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] = +// TILE1: %[[SWAP_RESULT:.*]] = scf.if +// TILE1: tensor.generate +// TILE1: else +// TILE1: %[[SLICE:.*]] = tensor.extract_slice %[[IN]][0, {{.*}}] [7, {{.*}}] [1, 1] +// TILE1: %[[PAD:.*]] = linalg.pad_tensor %[[SLICE]] low[3, %{{.*}}] high[5, {{.*}}] +// TILE1: tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][0, {{.*}}] [15, {{.*}}] [1, 1] +// TILE1: return %[[RESULT]] + +func @static_pad_tensor(%input_tensor: tensor<7x9xf32>, + %output_tensor: tensor<15x16xf32>, + %pad_value: f32) -> tensor<15x16xf32> { + %0 = linalg.pad_tensor %input_tensor + low[3, 4] high[5, 3] into %output_tensor { + ^bb0(%arg1: index, %arg2: index): + linalg.yield %pad_value : f32 + } : tensor<7x9xf32> to tensor<15x16xf32> + return %0 : tensor<15x16xf32> +}