diff --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td --- a/mlir/include/mlir/Dialect/SCF/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Passes.td @@ -21,6 +21,12 @@ : FunctionPass<"for-loop-peeling"> { let summary = "Peel `for` loops at their upper bounds."; let constructor = "mlir::createForLoopPeelingPass()"; + let options = [ + Option<"skipPartial", "skip-partial", "bool", + /*default=*/"true", + "Do not peel loops inside of the last, partial iteration of another " + "already peeled loop."> + ]; let dependentDialects = ["AffineDialect"]; } diff --git a/mlir/include/mlir/Dialect/SCF/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms.h --- a/mlir/include/mlir/Dialect/SCF/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms.h @@ -41,8 +41,8 @@ /// Rewrite a for loop with bounds/step that potentially do not divide evenly /// into a for loop where the step divides the iteration space evenly, followed -/// by an scf.if for the last (partial) iteration (if any). This transformation -/// is called "loop peeling". +/// by an scf.if for the last (partial) iteration (if any; returned via `ifOp`). +/// This transformation is called "loop peeling". /// /// This transformation is beneficial for a wide range of transformations such /// as vectorization or loop tiling: It enables additional canonicalizations @@ -81,7 +81,8 @@ /// Note: This function rewrites the given scf.for loop in-place and creates a /// new scf.if operation for the last iteration. It replaces all uses of the /// unpeeled loop with the results of the newly generated scf.if. -LogicalResult peelAndCanonicalizeForLoop(RewriterBase &rewriter, ForOp forOp); +LogicalResult peelAndCanonicalizeForLoop(RewriterBase &rewriter, ForOp forOp, + scf::IfOp &ifOp); /// Tile a parallel loop of the form /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp @@ -362,9 +362,9 @@ } LogicalResult mlir::scf::peelAndCanonicalizeForLoop(RewriterBase &rewriter, - ForOp forOp) { + ForOp forOp, + scf::IfOp &ifOp) { Value ub = forOp.upperBound(); - scf::IfOp ifOp; Value splitBound; if (failed(peelForLoop(rewriter, forOp, ifOp, splitBound))) return failure(); @@ -383,23 +383,45 @@ } static constexpr char kPeeledLoopLabel[] = "__peeled_loop__"; +static constexpr char kPartialIterationLabel[] = "__partial_iteration__"; namespace { struct ForLoopPeelingPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + ForLoopPeelingPattern(MLIRContext *ctx, bool skipPartial) + : OpRewritePattern(ctx), skipPartial(skipPartial) {} LogicalResult matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const override { + // Do not peel already peeled loops. if (forOp->hasAttr(kPeeledLoopLabel)) return failure(); - if (failed(peelAndCanonicalizeForLoop(rewriter, forOp))) + if (skipPartial) { + // No peeling of loops inside the partial iteration (scf.if) of another + // peeled loop. + Operation *op = forOp.getOperation(); + while ((op = op->getParentOfType())) { + if (op->hasAttr(kPartialIterationLabel)) + return failure(); + } + } + // Apply loop peeling. + scf::IfOp ifOp; + if (failed(peelAndCanonicalizeForLoop(rewriter, forOp, ifOp))) return failure(); // Apply label, so that the same loop is not rewritten a second time. rewriter.updateRootInPlace(forOp, [&]() { forOp->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr()); }); + ifOp->setAttr(kPartialIterationLabel, rewriter.getUnitAttr()); return success(); } + + /// If set to true, loops inside partial iterations of another peeled loop + /// are not peeled. This reduces the size of the generated code. Partial + /// iterations are not usually performance critical. + /// Note: Takes into account the entire chain of parent operations, not just + /// the direct parent. + bool skipPartial; }; } // namespace @@ -424,11 +446,14 @@ FuncOp funcOp = getFunction(); MLIRContext *ctx = funcOp.getContext(); RewritePatternSet patterns(ctx); - patterns.add(ctx); + patterns.add(ctx, skipPartial); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); - // Drop the marker. - funcOp.walk([](ForOp op) { op->removeAttr(kPeeledLoopLabel); }); + // Drop the markers. + funcOp.walk([](Operation *op) { + op->removeAttr(kPeeledLoopLabel); + op->removeAttr(kPartialIterationLabel); + }); } }; } // namespace diff --git a/mlir/test/Dialect/SCF/for-loop-peeling.mlir b/mlir/test/Dialect/SCF/for-loop-peeling.mlir --- a/mlir/test/Dialect/SCF/for-loop-peeling.mlir +++ b/mlir/test/Dialect/SCF/for-loop-peeling.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -for-loop-peeling -canonicalize -split-input-file | FileCheck %s +// RUN: mlir-opt %s -for-loop-peeling=skip-partial=false -canonicalize -split-input-file | FileCheck %s -check-prefix=CHECK-NO-SKIP // CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s1 - (s1 - s0) mod s2)> // CHECK-DAG: #[[MAP1:.*]] = affine_map<()[s0, s1, s2] -> (-(s0 - (s0 - s1) mod s2) + s0)> @@ -223,3 +224,48 @@ } return } + +// ----- + +// CHECK: func @nested_loops +// CHECK: scf.for {{.*}} { +// CHECK: scf.for {{.*}} { +// CHECK: } +// CHECK: scf.if {{.*}} { +// CHECK: } +// CHECK: } +// CHECK: scf.if {{.*}} { +// CHECK: scf.for {{.*}} { +// CHECK: } +// CHECK-NOT: scf.if +// CHECK: } + +// CHECK-NO-SKIP: func @nested_loops +// CHECK-NO-SKIP: scf.for {{.*}} { +// CHECK-NO-SKIP: scf.for {{.*}} { +// CHECK-NO-SKIP: } +// CHECK-NO-SKIP: scf.if {{.*}} { +// CHECK-NO-SKIP: } +// CHECK-NO-SKIP: } +// CHECK-NO-SKIP: scf.if {{.*}} { +// CHECK-NO-SKIP: scf.for {{.*}} { +// CHECK-NO-SKIP: } +// CHECK-NO-SKIP: scf.if {{.*}} { +// CHECK-NO-SKIP: } +// CHECK-NO-SKIP: } +#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)> +func @nested_loops(%lb0: index, %lb1 : index, %ub0: index, %ub1: index, + %step: index) -> i32 { + %c0 = constant 0 : i32 + %r0 = scf.for %iv0 = %lb0 to %ub0 step %step iter_args(%arg0 = %c0) -> i32 { + %r1 = scf.for %iv1 = %lb1 to %ub1 step %step iter_args(%arg1 = %arg0) -> i32 { + %s = affine.min #map(%ub1, %iv1)[%step] + %casted = index_cast %s : index to i32 + %0 = addi %arg1, %casted : i32 + scf.yield %0 : i32 + } + %1 = addi %arg0, %r1 : i32 + scf.yield %1 : i32 + } + return %r0 : i32 +}