diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h --- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h +++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h @@ -47,7 +47,8 @@ /// annotates the Ops in each unrolled iteration by applying `annotateFn`. LogicalResult loopUnrollByFactor( AffineForOp forOp, uint64_t unrollFactor, - function_ref annotateFn = nullptr); + function_ref annotateFn = nullptr, + bool cleanUpUnroll = false); /// Unrolls this loop by the specified unroll factor or its trip count, /// whichever is lower. diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td --- a/mlir/include/mlir/Dialect/Affine/Passes.td +++ b/mlir/include/mlir/Dialect/Affine/Passes.td @@ -212,6 +212,8 @@ Option<"unrollFullThreshold", "unroll-full-threshold", "unsigned", /*default=*/"1", "Unroll all loops with trip count less than or equal to this">, + Option<"cleanUpUnroll", "cleanup-unroll", "bool", /*default=*/"false", + "Force to do cleanup loop unrolling even if the trip count is smaller than unroll factor.">, ]; } diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp --- a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp @@ -122,14 +122,15 @@ LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { // Use the function callback if one was provided. if (getUnrollFactor) - return loopUnrollByFactor(forOp, getUnrollFactor(forOp)); + return loopUnrollByFactor(forOp, getUnrollFactor(forOp), nullptr, + cleanUpUnroll); // Unroll completely if full loop unroll was specified. if (unrollFull) return loopUnrollFull(forOp); // Otherwise, unroll by the given unroll factor. if (unrollUpToFactor) return loopUnrollUpToFactor(forOp, unrollFactor); - return loopUnrollByFactor(forOp, unrollFactor); + return loopUnrollByFactor(forOp, unrollFactor, nullptr, cleanUpUnroll); } std::unique_ptr> mlir::createLoopUnrollPass( diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -1052,6 +1052,34 @@ loopBodyBlock->getTerminator()->setOperands(lastYielded); } +/// Helper to generates unrolled copies of AffineForOp by 'unrollFactor'. +/// If specified, annotates the Ops in each unrolled iteration using annotateFn. +/// Promote the loop to a single iteration if possible. +static void generateUnrolledLoopByFactor( + AffineForOp forOp, + function_ref annotateFn, + unsigned unrollFactor) { + ValueRange iterArgs(forOp.getRegionIterArgs()); + auto yieldedValues = forOp.getBody()->getTerminator()->getOperands(); + + // Scale the step of loop being unrolled by unroll factor. + int64_t step = forOp.getStep(); + forOp.setStep(step * unrollFactor); + generateUnrolledLoop( + forOp.getBody(), forOp.getInductionVar(), unrollFactor, + [&](unsigned i, Value iv, OpBuilder b) { + // iv' = iv + i * step + auto d0 = b.getAffineDimExpr(0); + auto bumpMap = AffineMap::get(1, 0, d0 + i * step); + return b.create(forOp.getLoc(), bumpMap, iv); + }, + /*annotateFn=*/annotateFn, + /*iterArgs=*/iterArgs, /*yieldedValues=*/yieldedValues); + + // Promote the loop body up if this has turned into a single iteration loop. + (void)promoteIfSingleIteration(forOp); +} + /// Helper to generate cleanup loop for unroll or unroll-and-jam when the trip /// count is not a multiple of `unrollFactor`. static LogicalResult generateCleanupLoopForUnroll(AffineForOp forOp, @@ -1091,7 +1119,8 @@ /// is successfully unrolled. LogicalResult mlir::loopUnrollByFactor( AffineForOp forOp, uint64_t unrollFactor, - function_ref annotateFn) { + function_ref annotateFn, + bool cleanUpUnroll) { assert(unrollFactor > 0 && "unroll factor should be positive"); Optional mayBeConstantTripCount = getConstantTripCount(forOp); @@ -1107,9 +1136,15 @@ return success(); // If the trip count is lower than the unroll factor, no unrolled body. - // TODO: option to specify cleanup loop unrolling. - if (mayBeConstantTripCount && *mayBeConstantTripCount < unrollFactor) + if (mayBeConstantTripCount && *mayBeConstantTripCount < unrollFactor) { + if (cleanUpUnroll) { + // Force unroll the cleanup loop if cleanUpUnroll is specified. + generateUnrolledLoopByFactor(forOp, annotateFn, *mayBeConstantTripCount); + return success(); + } + return failure(); + } // Generate the cleanup loop if trip count isn't a multiple of unrollFactor. if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) { @@ -1125,25 +1160,8 @@ "and upper bound maps can always be determined"); } - ValueRange iterArgs(forOp.getRegionIterArgs()); - auto yieldedValues = forOp.getBody()->getTerminator()->getOperands(); - - // Scale the step of loop being unrolled by unroll factor. - int64_t step = forOp.getStep(); - forOp.setStep(step * unrollFactor); - generateUnrolledLoop( - forOp.getBody(), forOp.getInductionVar(), unrollFactor, - [&](unsigned i, Value iv, OpBuilder b) { - // iv' = iv + i * step - auto d0 = b.getAffineDimExpr(0); - auto bumpMap = AffineMap::get(1, 0, d0 + i * step); - return b.create(forOp.getLoc(), bumpMap, iv); - }, - /*annotateFn=*/annotateFn, - /*iterArgs=*/iterArgs, /*yieldedValues=*/yieldedValues); + generateUnrolledLoopByFactor(forOp, annotateFn, unrollFactor); - // Promote the loop body up if this has turned into a single iteration loop. - (void)promoteIfSingleIteration(forOp); return success(); } diff --git a/mlir/test/Dialect/SCF/loop-unroll.mlir b/mlir/test/Dialect/SCF/loop-unroll.mlir --- a/mlir/test/Dialect/SCF/loop-unroll.mlir +++ b/mlir/test/Dialect/SCF/loop-unroll.mlir @@ -4,6 +4,7 @@ // RUN: mlir-opt %s -test-loop-unrolling='unroll-factor=2 loop-depth=1' | FileCheck %s --check-prefix UNROLL-INNER-BY-2 // RUN: mlir-opt %s -test-loop-unrolling='unroll-factor=2 annotate=true' | FileCheck %s --check-prefix UNROLL-BY-2-ANNOTATE // RUN: mlir-opt %s --affine-loop-unroll='unroll-factor=6 unroll-up-to-factor=true' | FileCheck %s --check-prefix UNROLL-UP-TO +// RUN: mlir-opt %s --affine-loop-unroll='unroll-factor=5 cleanup-unroll=true' | FileCheck %s --check-prefix CLEANUP-UNROLL-BY-5 func.func @dynamic_loop_unroll(%arg0 : index, %arg1 : index, %arg2 : index, %arg3: memref) { @@ -314,3 +315,28 @@ // UNROLL-BY-3-NEXT: scf.yield %[[EADD]], %[[EMUL]] : f32, f32 // UNROLL-BY-3-NEXT: } // UNROLL-BY-3-NEXT: return %[[EFOR]]#0, %[[EFOR]]#1 : f32, f32 + +// Test that epilogue clean up loop is generated (trip count is less +// than an unroll factor). +func.func @static_loop_unroll_by_5_with_cleanup(%arg0 : memref) { + %0 = arith.constant 7.0 : f32 + %lb = arith.constant 0 : index + %ub = arith.constant 3 : index + affine.for %i0 = %lb to %ub { + memref.store %0, %arg0[%i0] : memref + } + return +} + +// CLEANUP-UNROLL-BY-5-LABEL: func @static_loop_unroll_by_5_with_cleanup +// CLEANUP-UNROLL-BY-5-SAME: %[[MEM:.*0]]: memref +// +// CLEANUP-UNROLL-BY-5-DAG: %[[C0:.*]] = arith.constant 0 : index +// CLEANUP-UNROLL-BY-5-DAG: %[[C3:.*]] = arith.constant 3 : index +// CLEANUP-UNROLL-BY-5-NEXT: %[[V0:.*]] = affine.apply {{.*}} +// CLEANUP-UNROLL-BY-5-NEXT: memref.store %{{.*}}, %[[MEM]][%[[V0]]] : memref +// CLEANUP-UNROLL-BY-5-NEXT: %[[V1:.*]] = affine.apply {{.*}} +// CLEANUP-UNROLL-BY-5-NEXT: memref.store %{{.*}}, %[[MEM]][%[[V1]]] : memref +// CLEANUP-UNROLL-BY-5-NEXT: %[[V2:.*]] = affine.apply {{.*}} +// CLEANUP-UNROLL-BY-5-NEXT: memref.store %{{.*}}, %[[MEM]][%[[V2]]] : memref +// CLEANUP-UNROLL-BY-5-NEXT: return \ No newline at end of file