diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h --- a/mlir/include/mlir/Dialect/Affine/Passes.h +++ b/mlir/include/mlir/Dialect/Affine/Passes.h @@ -61,7 +61,8 @@ /// and no callback is provided, anything passed from the command-line (if at /// all) or the default unroll factor is used (LoopUnroll:kDefaultUnrollFactor). std::unique_ptr> createLoopUnrollPass( - int unrollFactor = -1, bool unrollFull = false, + int unrollFactor = -1, bool unrollUpToFactor = false, + bool unrollFull = false, const std::function &getUnrollFactor = nullptr); /// Creates a loop unroll jam pass to unroll jam by the specified factor. A diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td --- a/mlir/include/mlir/Dialect/Affine/Passes.td +++ b/mlir/include/mlir/Dialect/Affine/Passes.td @@ -71,6 +71,8 @@ let options = [ Option<"unrollFactor", "unroll-factor", "unsigned", /*default=*/"4", "Use this unroll factor for all loops being unrolled">, + Option<"unrollUpToFactor", "unroll-up-to-factor", "bool", /*default=*/"false", + "Allow unroling up to the factor specicied">, Option<"unrollFull", "unroll-full", "bool", /*default=*/"false", "Fully unroll loops">, Option<"numRepetitions", "unroll-num-reps", "unsigned", /*default=*/"1", diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp --- a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp @@ -9,7 +9,6 @@ // This file implements loop unrolling. // //===----------------------------------------------------------------------===// - #include "PassDetail.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -45,11 +44,13 @@ : AffineLoopUnrollBase(other), getUnrollFactor(other.getUnrollFactor) {} explicit LoopUnroll( - Optional unrollFactor = None, bool unrollFull = false, + Optional unrollFactor = None, bool unrollUpToFactor = false, + bool unrollFull = false, const std::function &getUnrollFactor = nullptr) : getUnrollFactor(getUnrollFactor) { if (unrollFactor) this->unrollFactor = *unrollFactor; + this->unrollUpToFactor = unrollUpToFactor; this->unrollFull = unrollFull; } @@ -126,13 +127,16 @@ if (unrollFull) return loopUnrollFull(forOp); // Otherwise, unroll by the given unroll factor. + if (unrollUpToFactor) { + return loopUnrollUpToFactor(forOp, unrollFactor); + } return loopUnrollByFactor(forOp, unrollFactor); } std::unique_ptr> mlir::createLoopUnrollPass( - int unrollFactor, bool unrollFull, + int unrollFactor, bool unrollUpToFactor, bool unrollFull, const std::function &getUnrollFactor) { return std::make_unique( - unrollFactor == -1 ? None : Optional(unrollFactor), unrollFull, - getUnrollFactor); + unrollFactor == -1 ? None : Optional(unrollFactor), + unrollUpToFactor, unrollFull, getUnrollFactor); } diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -469,7 +469,6 @@ LogicalResult mlir::loopUnrollUpToFactor(AffineForOp forOp, uint64_t unrollFactor) { Optional mayBeConstantTripCount = getConstantTripCount(forOp); - if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() < unrollFactor) return loopUnrollByFactor(forOp, mayBeConstantTripCount.getValue()); diff --git a/mlir/test/Dialect/SCF/loop-unroll.mlir b/mlir/test/Dialect/SCF/loop-unroll.mlir --- a/mlir/test/Dialect/SCF/loop-unroll.mlir +++ b/mlir/test/Dialect/SCF/loop-unroll.mlir @@ -2,6 +2,7 @@ // RUN: mlir-opt %s -test-loop-unrolling='unroll-factor=3' | FileCheck %s --check-prefix UNROLL-BY-3 // RUN: mlir-opt %s -test-loop-unrolling='unroll-factor=2 loop-depth=0' | FileCheck %s --check-prefix UNROLL-OUTER-BY-2 // RUN: mlir-opt %s -test-loop-unrolling='unroll-factor=2 loop-depth=1' | FileCheck %s --check-prefix UNROLL-INNER-BY-2 +// RUN: mlir-opt %s --affine-loop-unroll='unroll-factor=6 unroll-up-to-factor=true' | FileCheck %s --check-prefix UNROLL-UP-TO func @dynamic_loop_unroll(%arg0 : index, %arg1 : index, %arg2 : index, %arg3: memref) { @@ -248,3 +249,24 @@ // UNROLL-BY-3-NEXT: } // UNROLL-BY-3-NEXT: store %{{.*}}, %[[MEM]][%[[C9]]] : memref // UNROLL-BY-3-NEXT: return + + +// Test unroll-up-to functionality. +func @static_loop_unroll_up_to_factor(%arg0 : memref) { + %0 = constant 7.0 : f32 + %lb = constant 0 : index + %ub = constant 2 : index + affine.for %i0 = %lb to %ub { + store %0, %arg0[%i0] : memref + } + return +} +// UNROLL-UP-TO-LABEL: func @static_loop_unroll_up_to_factor +// UNROLL-UP-TO-SAME: %[[MEM:.*0]]: memref +// UNROLL-UP-TO-DAG: %[[C0:.*]] = constant 0 : index +// UNROLL-UP-TO-DAG: %[[C2:.*]] = constant 2 : index +// UNROLL-UP-TO-NEXT: %[[V0:.*]] = affine.apply {{.*}} +// UNROLL-UP-TO-NEXT: store %{{.*}}, %[[MEM]][%[[V0]]] : memref +// UNROLL-UP-TO-NEXT: %[[V1:.*]] = affine.apply {{.*}} +// UNROLL-UP-TO-NEXT: tore %{{.*}}, %[[MEM]][%[[V1]]] : memref +// UNROLL-UP-TO-NEXT: return diff --git a/mlir/test/lib/Transforms/TestLoopUnrolling.cpp b/mlir/test/lib/Transforms/TestLoopUnrolling.cpp --- a/mlir/test/lib/Transforms/TestLoopUnrolling.cpp +++ b/mlir/test/lib/Transforms/TestLoopUnrolling.cpp @@ -55,6 +55,9 @@ Option unrollFactor{*this, "unroll-factor", llvm::cl::desc("Loop unroll factor."), llvm::cl::init(1)}; + Option unrollUpToFactor{*this, "unroll-up-to-factor", + llvm::cl::desc("Loop unroll up to factor."), + llvm::cl::init(false)}; Option loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."), llvm::cl::init(0)}; };