diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -524,6 +524,15 @@ paddingValueComputationFunction = std::move(fun); return *this; } + + /// Peel the specified loops. + SmallVector peeledLoops; + + LinalgTilingOptions &setPeeledLoops(ArrayRef loops) { + peeledLoops.clear(); + peeledLoops.append(loops.begin(), loops.end()); + return *this; + } }; /// Canonicalization patterns relevant to apply after tiling patterns. These are diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/SCF/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" @@ -26,6 +27,7 @@ #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include @@ -230,6 +232,21 @@ : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), options(options) {} +/// Try to peel a loop `op` and return the new result. +// TODO: Only scf.for loops are supported at the moment. +static SmallVector peelLoop(RewriterBase &rewriter, Operation *op) { + return llvm::TypeSwitch>(op) + .Case([&](scf::ForOp forOp) { + scf::ForOp partialIteration; + if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp, + partialIteration))) + return partialIteration->getResults(); + assert(!partialIteration && "expected that loop was not peeled"); + return forOp->getResults(); + }) + .Default([&](Operation *op) { return op->getResults(); }); +} + LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { LinalgOp linalgOp = dyn_cast(op); @@ -242,31 +259,38 @@ if (!res) return failure(); - - // Setup RAII guard to return properly. - LinalgOp paddedOp; - LinalgOp tiledOp = res->op; - auto guard = llvm::make_scope_exit([&]() { - // Return relevant information to derived pattern. - result = *res; - // Update filter. - if (paddedOp) - filter.replaceLinalgTransformationFilter(rewriter, paddedOp); - else - filter.replaceLinalgTransformationFilter(rewriter, tiledOp); - }); + // Clear filter to stop recursive pattern application. + filter.replaceLinalgTransformationFilter(rewriter, res->op); + + // Peel loops. + for (int64_t loop : options.peeledLoops) { + assert(loop < static_cast(res->loops.size()) && + "requested peeling of non-existing loop"); + Operation *loopOp = res->loops[loop]; + SmallVector loopResults = peelLoop(rewriter, loopOp); + // The result of the loop nest may change with peeling. + if (res->tensorResults.size() == loopOp->getNumResults() && + std::equal(res->tensorResults.begin(), res->tensorResults.end(), + loopOp->getResults().begin())) + res->tensorResults = loopResults; + } // Consider padding on the fly only if the op has tensor semantics. if (!options.paddingValueComputationFunction || - !linalgOp.hasTensorSemantics()) + !linalgOp.hasTensorSemantics()) { + result = *res; return success(); + } // Try to pad on the fly by rewriting res->op as a padded op. If successful, // `res.op` is rewritten in static form with padded operands. + LinalgOp paddedOp; if (succeeded(rewriteAsPaddedOp(rewriter, res->op, options.paddingValueComputationFunction, paddedOp))) { + filter.replaceLinalgTransformationFilter(rewriter, paddedOp); res->op = paddedOp; + result = *res; // Do not perform replacement of `linalgOp`, let the derived patterns // do this as they see fit, from the resulting TiledLinalgOp. return success(); diff --git a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-and-pad-pattern tile-sizes-for-padding=2,3,4" -canonicalize | FileCheck %s -// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-and-pad-pattern tile-sizes-for-padding=2,3" -canonicalize | FileCheck %s -check-prefix=CHECK-1DIM-TILE +// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern pad-tiles tile-sizes=2,3,4" -canonicalize | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern pad-tiles tile-sizes=2,3" -canonicalize | FileCheck %s -check-prefix=CHECK-1DIM-TILE // CHECK-LABEL: func @matmul_tensors( // CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor @@ -33,7 +33,7 @@ // CHECK: scf.yield %[[TD]] : tensor // CHECK: scf.yield %[[TD2]] : tensor // CHECK: scf.yield %[[TD1]] : tensor - %0 = linalg.matmul_i8_i8_i32 {__internal_linalg_transform__ = "tile-and-pad"} + %0 = linalg.matmul_i8_i8_i32 {__internal_linalg_transform__ = "tile"} ins(%arg0, %arg1: tensor, tensor) outs(%arg2: tensor) -> tensor @@ -68,7 +68,7 @@ indexing_maps = [ affine_map<(d0, d1, d2) -> ()>, affine_map<(d0, d1, d2) -> (d0, d1, d2)> ], iterator_types = ["parallel", "parallel", "parallel"]} - {__internal_linalg_transform__ = "tile-and-pad"} + {__internal_linalg_transform__ = "tile"} ins(%arg1 : f32) outs(%arg0: tensor) { ^bb(%0: f32, %1: f32) : @@ -87,7 +87,7 @@ func @matmul_partially_padded_tensors( %arg0: tensor, %arg1: tensor<8x?xi8>, %arg2: tensor) -> tensor { - %0 = linalg.matmul_i8_i8_i32 {__internal_linalg_transform__ = "tile-and-pad"} + %0 = linalg.matmul_i8_i8_i32 {__internal_linalg_transform__ = "tile"} ins(%arg0, %arg1: tensor, tensor<8x?xi8>) outs(%arg2: tensor) -> tensor diff --git a/mlir/test/Dialect/Linalg/tile-and-peel-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-peel-tensors.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile-and-peel-tensors.mlir @@ -0,0 +1,110 @@ +// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern tile-sizes=256,128,512 peeled-loops=0" -canonicalize | \ +// RUN: FileCheck %s -check-prefix=CHECK-PEEL-0 + +// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern tile-sizes=256,128,512 peeled-loops=1,2" -canonicalize | \ +// RUN: FileCheck %s -check-prefix=CHECK-PEEL-12 + +// CHECK-PEEL-0: func @matmul_static_tensor +// CHECK-PEEL-0-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-PEEL-0-DAG: %[[c128:.*]] = constant 128 : index +// CHECK-PEEL-0-DAG: %[[c256:.*]] = constant 256 : index +// CHECK-PEEL-0-DAG: %[[c512:.*]] = constant 512 : index +// CHECK-PEEL-0-DAG: %[[c1280:.*]] = constant 1280 : index +// CHECK-PEEL-0-DAG: %[[c1600:.*]] = constant 1600 : index +// CHECK-PEEL-0-DAG: %[[c1700:.*]] = constant 1700 : index +// CHECK-PEEL-0: scf.for %{{.*}} = %[[c0]] to %[[c1280]] step %[[c256]] {{.*}} { +// CHECK-PEEL-0: scf.for %{{.*}} = %[[c0]] to %[[c1700]] step %[[c128]] {{.*}} { +// CHECK-PEEL-0: scf.for %{{.*}} = %[[c0]] to %[[c1600]] step %[[c512]] {{.*}} { +// CHECK-PEEL-0: linalg.matmul ins({{.*}} : tensor<256x?xf32>, tensor) outs({{.*}} : tensor<256x?xf32>) +// CHECK-PEEL-0: } +// CHECK-PEEL-0: } +// CHECK-PEEL-0: } +// CHECK-PEEL-0: scf.for %{{.*}} = %[[c0]] to %[[c1700]] step %[[c128]] {{.*}} { +// CHECK-PEEL-0: scf.for %{{.*}} = %[[c0]] to %[[c1600]] step %[[c512]] {{.*}} { +// CHECK-PEEL-0: linalg.matmul ins({{.*}} : tensor<220x?xf32>, tensor) outs({{.*}} : tensor<220x?xf32>) +// CHECK-PEEL-0: } +// CHECK-PEEL-0: } + +// CHECK-PEEL-12: func @matmul_static_tensor +// CHECK-PEEL-12-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-PEEL-12-DAG: %[[c128:.*]] = constant 128 : index +// CHECK-PEEL-12-DAG: %[[c256:.*]] = constant 256 : index +// CHECK-PEEL-12-DAG: %[[c512:.*]] = constant 512 : index +// CHECK-PEEL-12-DAG: %[[c1500:.*]] = constant 1500 : index +// CHECK-PEEL-12-DAG: %[[c1536:.*]] = constant 1536 : index +// CHECK-PEEL-12-DAG: %[[c1600:.*]] = constant 1600 : index +// CHECK-PEEL-12-DAG: %[[c1664:.*]] = constant 1664 : index +// CHECK-PEEL-12: scf.for %{{.*}} = %[[c0]] to %[[c1500]] step %[[c256]] {{.*}} { +// CHECK-PEEL-12: scf.for %{{.*}} = %[[c0]] to %[[c1664]] step %[[c128]] {{.*}} { +// CHECK-PEEL-12: scf.for %{{.*}} = %[[c0]] to %[[c1536]] step %[[c512]] {{.*}} { +// CHECK-PEEL-12: linalg.matmul ins({{.*}} : tensor, tensor<512x128xf32>) outs({{.*}} : tensor) +// CHECK-PEEL-12: } +// CHECK-PEEL-12: linalg.matmul ins({{.*}} : tensor, tensor<64x128xf32>) outs({{.*}} : tensor) +// CHECK-PEEL-12: } +// CHECK-PEEL-12: scf.for %{{.*}} = %[[c0]] to %[[c1600]] step %[[c512]] {{.*}} { +// CHECK-PEEL-12: linalg.matmul ins({{.*}} : tensor, tensor) outs({{.*}} : tensor) +// CHECK-PEEL-12: } +// CHECK-PEEL-12: } +func @matmul_static_tensor(%arg0: tensor<1500x1600xf32>, %arg1: tensor<1600x1700xf32>) + -> tensor<1500x1700xf32> { + %out = linalg.init_tensor [1500, 1700] : tensor<1500x1700xf32> + %r = linalg.matmul {__internal_linalg_transform__ = "tile"} + ins(%arg0, %arg1: tensor<1500x1600xf32>, tensor<1600x1700xf32>) + outs(%out: tensor<1500x1700xf32>) -> tensor<1500x1700xf32> + return %r : tensor<1500x1700xf32> +} + +// ----- + +// CHECK-PEEL-0: func @matmul_dynamic_tensor +// CHECK-PEEL-0-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-PEEL-0-DAG: %[[c128:.*]] = constant 128 : index +// CHECK-PEEL-0-DAG: %[[c256:.*]] = constant 256 : index +// CHECK-PEEL-0-DAG: %[[c512:.*]] = constant 512 : index +// CHECK-PEEL-0: scf.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c256]] {{.*}} { +// CHECK-PEEL-0: scf.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c128]] {{.*}} { +// CHECK-PEEL-0: scf.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c512]] {{.*}} { +// CHECK-PEEL-0: linalg.matmul ins({{.*}} : tensor<256x?xf32>, tensor) outs({{.*}} : tensor<256x?xf32>) +// CHECK-PEEL-0: } +// CHECK-PEEL-0: } +// CHECK-PEEL-0: } +// CHECK-PEEL-0: scf.for %{{.*}} { +// CHECK-PEEL-0: scf.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c128]] {{.*}} { +// CHECK-PEEL-0: scf.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c512]] {{.*}} { +// CHECK-PEEL-0: linalg.matmul ins({{.*}} : tensor, tensor) outs({{.*}} : tensor) +// CHECK-PEEL-0: } +// CHECK-PEEL-0: } +// CHECK-PEEL-0: } + +// CHECK-PEEL-12: func @matmul_dynamic_tensor +// CHECK-PEEL-12-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-PEEL-12-DAG: %[[c128:.*]] = constant 128 : index +// CHECK-PEEL-12-DAG: %[[c256:.*]] = constant 256 : index +// CHECK-PEEL-12-DAG: %[[c512:.*]] = constant 512 : index +// CHECK-PEEL-12: scf.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c256]] {{.*}} { +// CHECK-PEEL-12: scf.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c128]] {{.*}} { +// CHECK-PEEL-12: scf.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c512]] {{.*}} { +// CHECK-PEEL-12: linalg.matmul ins({{.*}} : tensor, tensor<512x128xf32>) outs({{.*}} : tensor) +// CHECK-PEEL-12: } +// CHECK-PEEL-12: scf.for %{{.*}} { +// CHECK-PEEL-12: linalg.matmul ins({{.*}} : tensor, tensor) outs({{.*}} : tensor) +// CHECK-PEEL-12: } +// CHECK-PEEL-12: } +// CHECK-PEEL-12: scf.for %{{.*}} { +// CHECK-PEEL-12: scf.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c512]] {{.*}} { +// CHECK-PEEL-12: linalg.matmul ins({{.*}} : tensor, tensor) outs({{.*}} : tensor) +// CHECK-PEEL-12: } +// CHECK-PEEL-12: } +// CHECK-PEEL-12: } +func @matmul_dynamic_tensor(%arg0: tensor, %arg1: tensor) + -> tensor { + %c0 = constant 0 : index + %c1 = constant 1 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg1, %c1 : tensor + %out = linalg.init_tensor [%d0, %d1] : tensor + %r = linalg.matmul {__internal_linalg_transform__ = "tile"} + ins(%arg0, %arg1: tensor, tensor) + outs(%out: tensor) -> tensor + return %r : tensor +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -84,9 +84,9 @@ llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " "in vector.contract form"), llvm::cl::init(false)}; - Option testTileAndPadPattern{ - *this, "test-tile-and-pad-pattern", - llvm::cl::desc("Test tile and pad pattern"), llvm::cl::init(false)}; + Option testTilePattern{*this, "test-tile-pattern", + llvm::cl::desc("Test tile pattern"), + llvm::cl::init(false)}; Option testHoistPadding{*this, "test-hoist-padding", llvm::cl::desc("Test hoist padding"), llvm::cl::init(0)}; @@ -103,10 +103,17 @@ llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " "pad_tensor(subtensor)"), llvm::cl::init(false)}; - ListOption tileSizesForPadding{ - *this, "tile-sizes-for-padding", - llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore, - llvm::cl::MiscFlags::CommaSeparated}; + Option padTiles{*this, "pad-tiles", + llvm::cl::desc("Pad tiles when test-tile-pattern"), + llvm::cl::init(false)}; + ListOption peeledLoops{ + *this, "peeled-loops", + llvm::cl::desc("Loops to be peeled when test-tile-pattern"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; + ListOption tileSizes{ + *this, "tile-sizes", + llvm::cl::desc("Linalg tile sizes for test-tile-pattern"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; ListOption testInterchangePattern{ *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated, llvm::cl::desc("Test the interchange pattern.")}; @@ -558,18 +565,22 @@ return b.create(op.getOwner()->getLoc(), t, b.getZeroAttr(t)); } -static void applyTileAndPadPattern(FuncOp funcOp, ArrayRef tileSizes) { +static void applyTilePattern(FuncOp funcOp, ArrayRef tileSizes, + bool padTiles, ArrayRef peeledLoops) { MLIRContext *context = funcOp.getContext(); RewritePatternSet tilingPattern(context); auto linalgTilingOptions = - linalg::LinalgTilingOptions() - .setTileSizes(tileSizes) - .setPaddingValueComputationFunction(getNeutralOfLinalgOp); - tilingPattern.add, + linalg::LinalgTilingOptions().setTileSizes(tileSizes).setPeeledLoops( + peeledLoops); + if (padTiles) + linalgTilingOptions.setPaddingValueComputationFunction( + getNeutralOfLinalgOp); + + tilingPattern.add, + linalg::LinalgTilingPattern, linalg::LinalgTilingPattern>( context, linalgTilingOptions, - linalg::LinalgTransformationFilter( - Identifier::get("tile-and-pad", context))); + linalg::LinalgTransformationFilter(Identifier::get("tile", context))); (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); } @@ -697,8 +708,8 @@ if (testTiledLoopPeeling.hasValue()) return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling, skipPartial); - if (testTileAndPadPattern) - return applyTileAndPadPattern(getFunction(), tileSizesForPadding); + if (testTilePattern) + return applyTilePattern(getFunction(), tileSizes, padTiles, peeledLoops); if (testHoistPadding) { getFunction().walk([&](linalg::PadTensorOp padTensorOp) { (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding);