diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -518,6 +518,15 @@ paddingValueComputationFunction = std::move(fun); return *this; } + + /// Add label attributes to tiled loops, so that they can be easily found + /// and processed by other patterns. + bool markLoops; + + LinalgTilingOptions &setMarkLoops(bool b) { + markLoops = b; + return *this; + } }; /// Canonicalization patterns relevant to apply after tiling patterns. These are @@ -545,6 +554,9 @@ LogicalResult matchAndRewriteBase(Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const; + // Walk `op` and remove all markers from tiled loops. + static void removeMarkers(Operation *op); + private: /// LinalgTransformMarker handles special attribute manipulations. LinalgTransformationFilter filter; @@ -585,6 +597,17 @@ } }; +template +struct PeelTiledLoopsPattern : public OpRewritePattern { + PeelTiledLoopsPattern(MLIRContext *ctx, int64_t loopId) + : OpRewritePattern(ctx), loopId(loopId) {} + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override; + + const int64_t loopId; +}; + struct LinalgFusionOptions { /// List of operands indices to use for fusion. llvm::SmallSet indicesToFuse = {}; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/SCF/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" @@ -230,6 +231,8 @@ : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), options(options) {} +static const char kTiledLoopMarker[] = "__linalg_internal_tiled_loop__"; + LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { LinalgOp linalgOp = dyn_cast(op); @@ -243,6 +246,12 @@ if (!res) return failure(); + // Mark tiled loops. + if (options.markLoops) + for (auto it : llvm::enumerate(res->loops)) + it.value()->setAttr(kTiledLoopMarker, + rewriter.getI64IntegerAttr(it.index())); + // Setup RAII guard to return properly. LinalgOp tiledOp = res->op; auto guard = llvm::make_scope_exit([&]() { @@ -274,6 +283,35 @@ return failure(); } +void LinalgBaseTilingPattern::removeMarkers(Operation *op) { + op->walk([](Operation *c) { c->removeAttr(kTiledLoopMarker); }); +} + +template <> +LogicalResult PeelTiledLoopsPattern::matchAndRewrite( + scf::ForOp op, PatternRewriter &rewriter) const { + if (!op->hasAttr(kTiledLoopMarker)) + return failure(); + auto markerAttr = op->getAttr(kTiledLoopMarker).dyn_cast(); + assert(markerAttr && "expected integer attribute"); + if (markerAttr.getInt() != loopId) + return failure(); + + scf::IfOp ifOp; + LogicalResult status = scf::peelAndCanonicalizeForLoop(rewriter, op, ifOp); + // Remove marker to avoid loop from getting peeled a second time. + rewriter.updateRootInPlace(op, [&]() { op->removeAttr(kTiledLoopMarker); }); + + if (succeeded(status)) { + // Do not tile loops inside partial iterations. + // TODO: This should be configurable. + LinalgBaseTilingPattern::removeMarkers(ifOp); + return success(); + } + assert(!ifOp && "expected that loop was not peeled"); + return failure(); +} + static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { if (tiledOp.loops.empty()) return tiledOp.op.getOperation()->getResults(); diff --git a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-and-pad-pattern tile-sizes-for-padding=2,3,4" -canonicalize | FileCheck %s -// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-and-pad-pattern tile-sizes-for-padding=2,3" -canonicalize | FileCheck %s -check-prefix=CHECK-1DIM-TILE +// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern pad-tiles tile-sizes=2,3,4" -canonicalize | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern pad-tiles tile-sizes=2,3" -canonicalize | FileCheck %s -check-prefix=CHECK-1DIM-TILE // CHECK-LABEL: func @matmul_tensors( // CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor @@ -33,7 +33,7 @@ // CHECK: scf.yield %[[TD]] : tensor // CHECK: scf.yield %[[TD2]] : tensor // CHECK: scf.yield %[[TD1]] : tensor - %0 = linalg.matmul_i8_i8_i32 {__internal_linalg_transform__ = "tile-and-pad"} + %0 = linalg.matmul_i8_i8_i32 {__internal_linalg_transform__ = "tile"} ins(%arg0, %arg1: tensor, tensor) outs(%arg2: tensor) -> tensor @@ -68,7 +68,7 @@ indexing_maps = [ affine_map<(d0, d1, d2) -> ()>, affine_map<(d0, d1, d2) -> (d0, d1, d2)> ], iterator_types = ["parallel", "parallel", "parallel"]} - {__internal_linalg_transform__ = "tile-and-pad"} + {__internal_linalg_transform__ = "tile"} ins(%arg1 : f32) outs(%arg0: tensor) { ^bb(%0: f32, %1: f32) : @@ -87,7 +87,7 @@ func @matmul_partially_padded_tensors( %arg0: tensor, %arg1: tensor<8x?xi8>, %arg2: tensor) -> tensor { - %0 = linalg.matmul_i8_i8_i32 {__internal_linalg_transform__ = "tile-and-pad"} + %0 = linalg.matmul_i8_i8_i32 {__internal_linalg_transform__ = "tile"} ins(%arg0, %arg1: tensor, tensor<8x?xi8>) outs(%arg2: tensor) -> tensor diff --git a/mlir/test/Dialect/Linalg/tile-and-peel-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-peel-tensors.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile-and-peel-tensors.mlir @@ -0,0 +1,112 @@ +// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern mark-tiled-loops tile-sizes=256,128,512" -canonicalize -for-loop-canonicalization -cse \ +// RUN: -test-linalg-transform-patterns="test-peel-tiled-loops-pattern peeled-loops=0" -canonicalize -split-input-file |\ +// RUN: FileCheck %s -check-prefix=CHECK-PEEL-0 + +// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern mark-tiled-loops tile-sizes=256,128,512" -canonicalize -for-loop-canonicalization -cse \ +// RUN: -test-linalg-transform-patterns="test-peel-tiled-loops-pattern peeled-loops=1,2" -canonicalize -split-input-file | \ +// RUN: FileCheck %s -check-prefix=CHECK-PEEL-12 + +// CHECK-PEEL-0: func @matmul_static_tensor +// CHECK-PEEL-0-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-PEEL-0-DAG: %[[c128:.*]] = constant 128 : index +// CHECK-PEEL-0-DAG: %[[c256:.*]] = constant 256 : index +// CHECK-PEEL-0-DAG: %[[c512:.*]] = constant 512 : index +// CHECK-PEEL-0-DAG: %[[c1280:.*]] = constant 1280 : index +// CHECK-PEEL-0-DAG: %[[c1600:.*]] = constant 1600 : index +// CHECK-PEEL-0-DAG: %[[c1700:.*]] = constant 1700 : index +// CHECK-PEEL-0: scf.for %{{.*}} = %[[c0]] to %[[c1280]] step %[[c256]] {{.*}} { +// CHECK-PEEL-0: scf.for %{{.*}} = %[[c0]] to %[[c1700]] step %[[c128]] {{.*}} { +// CHECK-PEEL-0: scf.for %{{.*}} = %[[c0]] to %[[c1600]] step %[[c512]] {{.*}} { +// CHECK-PEEL-0: linalg.matmul ins({{.*}} : tensor<256x?xf32>, tensor) outs({{.*}} : tensor<256x?xf32>) +// CHECK-PEEL-0: } +// CHECK-PEEL-0: } +// CHECK-PEEL-0: } +// CHECK-PEEL-0: scf.for %{{.*}} = %[[c0]] to %[[c1700]] step %[[c128]] {{.*}} { +// CHECK-PEEL-0: scf.for %{{.*}} = %[[c0]] to %[[c1600]] step %[[c512]] {{.*}} { +// CHECK-PEEL-0: linalg.matmul ins({{.*}} : tensor<220x?xf32>, tensor) outs({{.*}} : tensor<220x?xf32>) +// CHECK-PEEL-0: } +// CHECK-PEEL-0: } + +// CHECK-PEEL-12: func @matmul_static_tensor +// CHECK-PEEL-12-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-PEEL-12-DAG: %[[c128:.*]] = constant 128 : index +// CHECK-PEEL-12-DAG: %[[c256:.*]] = constant 256 : index +// CHECK-PEEL-12-DAG: %[[c512:.*]] = constant 512 : index +// CHECK-PEEL-12-DAG: %[[c1500:.*]] = constant 1500 : index +// CHECK-PEEL-12-DAG: %[[c1536:.*]] = constant 1536 : index +// CHECK-PEEL-12-DAG: %[[c1600:.*]] = constant 1600 : index +// CHECK-PEEL-12-DAG: %[[c1664:.*]] = constant 1664 : index +// CHECK-PEEL-12: scf.for %{{.*}} = %[[c0]] to %[[c1500]] step %[[c256]] {{.*}} { +// CHECK-PEEL-12: scf.for %{{.*}} = %[[c0]] to %[[c1664]] step %[[c128]] {{.*}} { +// CHECK-PEEL-12: scf.for %{{.*}} = %[[c0]] to %[[c1536]] step %[[c512]] {{.*}} { +// CHECK-PEEL-12: linalg.matmul ins({{.*}} : tensor, tensor<512x128xf32>) outs({{.*}} : tensor) +// CHECK-PEEL-12: } +// CHECK-PEEL-12: linalg.matmul ins({{.*}} : tensor, tensor<64x128xf32>) outs({{.*}} : tensor) +// CHECK-PEEL-12: } +// CHECK-PEEL-12: scf.for %{{.*}} = %[[c0]] to %[[c1600]] step %[[c512]] {{.*}} { +// CHECK-PEEL-12: linalg.matmul ins({{.*}} : tensor, tensor) outs({{.*}} : tensor) +// CHECK-PEEL-12: } +// CHECK-PEEL-12: } +func @matmul_static_tensor(%arg0: tensor<1500x1600xf32>, %arg1: tensor<1600x1700xf32>) + -> tensor<1500x1700xf32> { + %out = linalg.init_tensor [1500, 1700] : tensor<1500x1700xf32> + %r = linalg.matmul {__internal_linalg_transform__ = "tile"} + ins(%arg0, %arg1: tensor<1500x1600xf32>, tensor<1600x1700xf32>) + outs(%out: tensor<1500x1700xf32>) -> tensor<1500x1700xf32> + return %r : tensor<1500x1700xf32> +} + +// ----- + +// CHECK-PEEL-0: func @matmul_dynamic_tensor +// CHECK-PEEL-0-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-PEEL-0-DAG: %[[c128:.*]] = constant 128 : index +// CHECK-PEEL-0-DAG: %[[c256:.*]] = constant 256 : index +// CHECK-PEEL-0-DAG: %[[c512:.*]] = constant 512 : index +// CHECK-PEEL-0: scf.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c256]] {{.*}} { +// CHECK-PEEL-0: scf.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c128]] {{.*}} { +// CHECK-PEEL-0: scf.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c512]] {{.*}} { +// CHECK-PEEL-0: linalg.matmul ins({{.*}} : tensor<256x?xf32>, tensor) outs({{.*}} : tensor) +// CHECK-PEEL-0: } +// CHECK-PEEL-0: } +// CHECK-PEEL-0: } +// CHECK-PEEL-0: scf.if %{{.*}} { +// CHECK-PEEL-0: scf.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c128]] {{.*}} { +// CHECK-PEEL-0: scf.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c512]] {{.*}} { +// CHECK-PEEL-0: linalg.matmul ins({{.*}} : tensor, tensor) outs({{.*}} : tensor) +// CHECK-PEEL-0: } +// CHECK-PEEL-0: } +// CHECK-PEEL-0: } + +// CHECK-PEEL-12: func @matmul_dynamic_tensor +// CHECK-PEEL-12-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-PEEL-12-DAG: %[[c128:.*]] = constant 128 : index +// CHECK-PEEL-12-DAG: %[[c256:.*]] = constant 256 : index +// CHECK-PEEL-12-DAG: %[[c512:.*]] = constant 512 : index +// CHECK-PEEL-12: scf.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c256]] {{.*}} { +// CHECK-PEEL-12: scf.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c128]] {{.*}} { +// CHECK-PEEL-12: scf.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c512]] {{.*}} { +// CHECK-PEEL-12: linalg.matmul ins({{.*}} : tensor, tensor) outs({{.*}} : tensor) +// CHECK-PEEL-12: } +// CHECK-PEEL-12: scf.if %{{.*}} { +// CHECK-PEEL-12: linalg.matmul ins({{.*}} : tensor, tensor) outs({{.*}} : tensor) +// CHECK-PEEL-12: } +// CHECK-PEEL-12: } +// CHECK-PEEL-12: scf.if %{{.*}} { +// CHECK-PEEL-12: scf.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c512]] {{.*}} { +// CHECK-PEEL-12: linalg.matmul ins({{.*}} : tensor, tensor) outs({{.*}} : tensor) +// CHECK-PEEL-12: } +// CHECK-PEEL-12: } +// CHECK-PEEL-12: } +func @matmul_dynamic_tensor(%arg0: tensor, %arg1: tensor) + -> tensor { + %c0 = constant 0 : index + %c1 = constant 1 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg1, %c1 : tensor + %out = linalg.init_tensor [%d0, %d1] : tensor + %r = linalg.matmul {__internal_linalg_transform__ = "tile"} + ins(%arg0, %arg1: tensor, tensor) + outs(%out: tensor) -> tensor + return %r : tensor +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -84,9 +84,13 @@ llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " "in vector.contract form"), llvm::cl::init(false)}; - Option testTileAndPadPattern{ - *this, "test-tile-and-pad-pattern", - llvm::cl::desc("Test tile and pad pattern"), llvm::cl::init(false)}; + Option testPeelTiledLoopsPattern{ + *this, "test-peel-tiled-loops-pattern", + llvm::cl::desc("Test peeling of marked tiled loops"), + llvm::cl::init(false)}; + Option testTilePattern{*this, "test-tile-pattern", + llvm::cl::desc("Test tile pattern"), + llvm::cl::init(false)}; Option testHoistPadding{*this, "test-hoist-padding", llvm::cl::desc("Test hoist padding"), llvm::cl::init(0)}; @@ -103,10 +107,20 @@ llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " "pad_tensor(subtensor)"), llvm::cl::init(false)}; - ListOption tileSizesForPadding{ - *this, "tile-sizes-for-padding", - llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore, - llvm::cl::MiscFlags::CommaSeparated}; + Option markTiledLoops{ + *this, "mark-tiled-loops", + llvm::cl::desc("Mark tiled loops during tile+pad"), + llvm::cl::init(false)}; + Option padTiles{*this, "pad-tiles", + llvm::cl::desc("Pad tiles when test-tile-pattern"), + llvm::cl::init(false)}; + ListOption peeledLoops{ + *this, "peeled-loops", llvm::cl::desc("Loops to be peeled"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; + ListOption tileSizes{ + *this, "tile-sizes", + llvm::cl::desc("Linalg tile sizes for test-tile-pattern"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; ListOption testInterchangePattern{ *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated, llvm::cl::desc("Test the interchange pattern.")}; @@ -554,18 +568,34 @@ return b.create(op.getOwner()->getLoc(), t, b.getZeroAttr(t)); } -static void applyTileAndPadPattern(FuncOp funcOp, ArrayRef tileSizes) { +static void applyPeelTiledLoopsPattern(FuncOp funcOp, + ArrayRef peeledLoops) { + MLIRContext *context = funcOp.getContext(); + RewritePatternSet peelingPattern(context); + for (int64_t loop : peeledLoops) + peelingPattern.add>(context, + loop); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(peelingPattern)); + // Remove markers from tiled loops. + LinalgBaseTilingPattern::removeMarkers(funcOp); +} + +static void applyTilePattern(FuncOp funcOp, ArrayRef tileSizes, + bool markTiledLoops, bool padTiles) { MLIRContext *context = funcOp.getContext(); RewritePatternSet tilingPattern(context); auto linalgTilingOptions = - linalg::LinalgTilingOptions() - .setTileSizes(tileSizes) - .setPaddingValueComputationFunction(getNeutralOfLinalgOp); - tilingPattern.add, + linalg::LinalgTilingOptions().setTileSizes(tileSizes).setMarkLoops( + markTiledLoops); + if (padTiles) + linalgTilingOptions.setPaddingValueComputationFunction( + getNeutralOfLinalgOp); + + tilingPattern.add, + linalg::LinalgTilingPattern, linalg::LinalgTilingPattern>( context, linalgTilingOptions, - linalg::LinalgTransformationFilter( - Identifier::get("tile-and-pad", context))); + linalg::LinalgTransformationFilter(Identifier::get("tile", context))); (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); } @@ -678,8 +708,10 @@ return applyExtractSliceOfPadTensorSwapPattern(getFunction()); if (testTiledLoopPeeling.hasValue()) return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling); - if (testTileAndPadPattern) - return applyTileAndPadPattern(getFunction(), tileSizesForPadding); + if (testPeelTiledLoopsPattern) + return applyPeelTiledLoopsPattern(getFunction(), peeledLoops); + if (testTilePattern) + return applyTilePattern(getFunction(), tileSizes, markTiledLoops, padTiles); if (testHoistPadding) { getFunction().walk([&](linalg::PadTensorOp padTensorOp) { (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding);