diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -1182,8 +1182,14 @@ function_ref annotateFn) { assert(unrollFactor > 0 && "unroll factor should be positive"); - if (unrollFactor == 1) - return promoteIfSingleIteration(forOp); + Optional mayBeConstantTripCount = getConstantTripCount(forOp); + if (unrollFactor == 1) { + if (mayBeConstantTripCount.hasValue() && + mayBeConstantTripCount.getValue() == 1 && + failed(promoteIfSingleIteration(forOp))) + return failure(); + return success(); + } // Nothing in the loop body other than the terminator. if (llvm::hasSingleElement(forOp.getBody()->getOperations())) @@ -1191,7 +1197,6 @@ // If the trip count is lower than the unroll factor, no unrolled body. // TODO: option to specify cleanup loop unrolling. - Optional mayBeConstantTripCount = getConstantTripCount(forOp); if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() < unrollFactor) return failure(); @@ -1237,8 +1242,6 @@ scf::ForOp forOp, uint64_t unrollFactor, function_ref annotateFn) { assert(unrollFactor > 0 && "expected positive unroll factor"); - if (unrollFactor == 1) - return promoteIfSingleIteration(forOp); // Return if the loop body is empty. if (llvm::hasSingleElement(forOp.getBody()->getOperations())) @@ -1264,6 +1267,13 @@ assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 && "expected positive loop bounds and step"); int64_t tripCount = mlir::ceilDiv(ubCst - lbCst, stepCst); + + if (unrollFactor == 1) { + if (tripCount == 1 && failed(promoteIfSingleIteration(forOp))) + return failure(); + return success(); + } + int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor); int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst; assert(upperBoundUnrolledCst <= ubCst); @@ -1403,14 +1413,19 @@ uint64_t unrollJamFactor) { assert(unrollJamFactor > 0 && "unroll jam factor should be positive"); - if (unrollJamFactor == 1) - return promoteIfSingleIteration(forOp); + Optional mayBeConstantTripCount = getConstantTripCount(forOp); + if (unrollJamFactor == 1) { + if (mayBeConstantTripCount.hasValue() && + mayBeConstantTripCount.getValue() == 1 && + failed(promoteIfSingleIteration(forOp))) + return failure(); + return success(); + } // Nothing in the loop body other than the terminator. if (llvm::hasSingleElement(forOp.getBody()->getOperations())) return success(); - Optional mayBeConstantTripCount = getConstantTripCount(forOp); // If the trip count is lower than the unroll jam factor, no unroll jam. if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() < unrollJamFactor) { diff --git a/mlir/test/Transforms/scf-loop-unroll.mlir b/mlir/test/Transforms/scf-loop-unroll.mlir --- a/mlir/test/Transforms/scf-loop-unroll.mlir +++ b/mlir/test/Transforms/scf-loop-unroll.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=3" -split-input-file -canonicalize | FileCheck %s +// RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=1" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-BY-1 // CHECK-LABEL: scf_loop_unroll_single func @scf_loop_unroll_single(%arg0 : f32, %arg1 : f32) -> f32 { @@ -42,3 +43,16 @@ // CHECK: } // CHECK-NEXT: return %[[SUM1]]#0, %[[SUM1]]#1 } + +// UNROLL-BY-1-LABEL: scf_loop_unroll_factor_1_promote +func @scf_loop_unroll_factor_1_promote() -> () { + %step = arith.constant 1 : index + %lo = arith.constant 0 : index + %hi = arith.constant 1 : index + scf.for %i = %lo to %hi step %step { + %x = "test.foo"(%i) : (index) -> i32 + } + return + // UNROLL-BY-1-NEXT: %[[C0:.*]] = arith.constant 0 : index + // UNROLL-BY-1-NEXT: %{{.*}} = "test.foo"(%[[C0]]) : (index) -> i32 +}