diff --git a/mlir/test/Dialect/Linalg/tiled-loop-peeling.mlir b/mlir/test/Dialect/Linalg/tiled-loop-peeling.mlir --- a/mlir/test/Dialect/Linalg/tiled-loop-peeling.mlir +++ b/mlir/test/Dialect/Linalg/tiled-loop-peeling.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt %s -allow-unregistered-dialect -test-linalg-transform-patterns=test-tiled-loop-peeling=2 -split-input-file | FileCheck %s -check-prefix=CHECK-TILE-2 // RUN: mlir-opt %s -allow-unregistered-dialect -test-linalg-transform-patterns=test-tiled-loop-peeling=0,1,2 -split-input-file | FileCheck %s -check-prefix=CHECK-TILE-012 +// RUN: mlir-opt %s -allow-unregistered-dialect -test-linalg-transform-patterns="test-tiled-loop-peeling=0,1,2 skip-partial" -split-input-file | FileCheck %s -check-prefix=CHECK-TILE-012-SKIP-PARTIAL // CHECK-TILE-2-LABEL: func @tiled_loop_3d_tensor( // CHECK-TILE-2-SAME: %[[input:.*]]: tensor, %[[s0:.*]]: index, %[[s1:.*]]: index, %[[s2:.*]]: index @@ -63,6 +64,21 @@ // CHECK-TILE-012: } // CHECK-TILE-012-NOT: linalg.tiled_loop +// CHECK-TILE-012-SKIP-PARTIAL: func @tiled_loop_3d_tensor( +// CHECK-TILE-012-SKIP-PARTIAL-SAME: %[[input:.*]]: tensor +// CHECK-TILE-012-SKIP-PARTIAL-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-TILE-012-SKIP-PARTIAL-DAG: %[[c1:.*]] = constant 1 : index +// CHECK-TILE-012-SKIP-PARTIAL-DAG: %[[c2:.*]] = constant 2 : index +// CHECK-TILE-012-SKIP-PARTIAL-DAG: %[[dim0:.*]] = tensor.dim %[[input]], %[[c0]] +// CHECK-TILE-012-SKIP-PARTIAL-DAG: %[[dim1:.*]] = tensor.dim %[[input]], %[[c1]] +// CHECK-TILE-012-SKIP-PARTIAL-DAG: %[[dim2:.*]] = tensor.dim %[[input]], %[[c2]] +// CHECK-TILE-012-SKIP-PARTIAL: %[[p0:.*]] = affine.apply #{{.*}}()[%[[dim0]] +// CHECK-TILE-012-SKIP-PARTIAL: %[[p1:.*]] = affine.apply #{{.*}}()[%[[dim1]] +// CHECK-TILE-012-SKIP-PARTIAL: %[[p2:.*]] = affine.apply #{{.*}}()[%[[dim2]] +// CHECK-TILE-012-SKIP-PARTIAL: linalg.tiled_loop {{.*}} = (%[[c0]], %[[c0]], %[[c0]]) to (%[[p0]], %[[p1]], %[[p2]]) +// CHECK-TILE-012-SKIP-PARTIAL: linalg.tiled_loop {{.*}} = (%[[c0]], %[[c0]], %[[p2]]) to (%[[p0]], %[[p1]], %[[dim2]]) +// CHECK-TILE-012-SKIP-PARTIAL: linalg.tiled_loop {{.*}} = (%[[c0]], %[[p1]], %[[c0]]) to (%[[p0]], %[[dim1]], %[[dim2]]) +// CHECK-TILE-012-SKIP-PARTIAL: linalg.tiled_loop {{.*}} = (%[[p0]], %[[c0]], %[[c0]]) to (%[[dim0]], %[[dim1]], %[[dim2]]) func @tiled_loop_3d_tensor(%arg0: tensor, %s0: index, %s1: index, %s2: index) -> tensor { %cst = constant 0.000000e+00 : f32 diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -114,6 +114,10 @@ *this, "test-tiled-loop-peeling", llvm::cl::desc("Test peeling of linalg.tiled_loop ops"), llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated}; + Option skipPartial{ + *this, "skip-partial", + llvm::cl::desc("Skip loops inside partial iterations during peeling"), + llvm::cl::init(false)}; }; } // end anonymous namespace @@ -581,14 +585,16 @@ } static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__"; +static constexpr char kPartialIterationLabel[] = "__partial_iteration__"; namespace { /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the /// `idx`-th loop contains only "full" iterations and a second loop for the /// remaining partial iteration (if any). struct TiledLoopPeelingPattern : public OpRewritePattern { - TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx) - : OpRewritePattern(ctx), idx(idx) {} + TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial) + : OpRewritePattern(ctx), idx(idx), skipPartial(skipPartial) { + } LogicalResult matchAndRewrite(TiledLoopOp loopOp, PatternRewriter &rewriter) const override { @@ -603,6 +609,9 @@ if (llvm::find(peeledLoops, idx) != peeledLoops.end()) return failure(); } + if (skipPartial && loopOp->hasAttr(kPartialIterationLabel)) + // No peeling of loop nests with a partial iteration. + return failure(); if (static_cast(loopOp.iterator_types().size()) <= idx) return failure(); @@ -612,31 +621,40 @@ if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx, result))) return failure(); - peeledLoops.push_back(idx); // Apply label, so that the same loop is not rewritten a second time. + peeledLoops.push_back(idx); rewriter.updateRootInPlace(loopOp, [&]() { loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); }); result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); + result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr()); + return success(); } /// Index of loop to peel. int64_t idx; + + /// If set to true, do not peel TiledLoopOps with a partial iteration. + bool skipPartial; }; } // namespace static void applyTiledLoopPeelingPattern(FuncOp funcOp, - ArrayRef loops) { + ArrayRef loops, + bool skipPartial) { MLIRContext *ctx = funcOp.getContext(); RewritePatternSet patterns(ctx); for (unsigned idx : loops) - patterns.add(ctx, idx); + patterns.add(ctx, idx, skipPartial); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); - // Drop the marker. - funcOp.walk([](TiledLoopOp op) { op->removeAttr(kPeeledLoopsLabel); }); + // Drop the markers. + funcOp.walk([](TiledLoopOp op) { + op->removeAttr(kPeeledLoopsLabel); + op->removeAttr(kPartialIterationLabel); + }); } /// Apply transformations specified as patterns. @@ -677,7 +695,8 @@ if (testSwapSubTensorPadTensor) return applyExtractSliceOfPadTensorSwapPattern(getFunction()); if (testTiledLoopPeeling.hasValue()) - return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling); + return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling, + skipPartial); if (testTileAndPadPattern) return applyTileAndPadPattern(getFunction(), tileSizesForPadding); if (testHoistPadding) {