diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -1160,8 +1160,14 @@ function_ref annotateFn) { assert(unrollFactor > 0 && "unroll factor should be positive"); - if (unrollFactor == 1) - return promoteIfSingleIteration(forOp); + Optional mayBeConstantTripCount = getConstantTripCount(forOp); + if (unrollFactor == 1) { + if (mayBeConstantTripCount.hasValue() && + mayBeConstantTripCount.getValue() == 1 && + failed(promoteIfSingleIteration(forOp))) + return failure(); + return success(); + } // Nothing in the loop body other than the terminator. if (llvm::hasSingleElement(forOp.getBody()->getOperations())) @@ -1169,7 +1175,6 @@ // If the trip count is lower than the unroll factor, no unrolled body. // TODO: option to specify cleanup loop unrolling. - Optional mayBeConstantTripCount = getConstantTripCount(forOp); if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() < unrollFactor) return failure(); @@ -1215,8 +1220,6 @@ scf::ForOp forOp, uint64_t unrollFactor, function_ref annotateFn) { assert(unrollFactor > 0 && "expected positive unroll factor"); - if (unrollFactor == 1) - return promoteIfSingleIteration(forOp); // Return if the loop body is empty. if (llvm::hasSingleElement(forOp.getBody()->getOperations())) @@ -1242,6 +1245,13 @@ assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 && "expected positive loop bounds and step"); int64_t tripCount = mlir::ceilDiv(ubCst - lbCst, stepCst); + + if (unrollFactor == 1) { + if (tripCount == 1 && failed(promoteIfSingleIteration(forOp))) + return failure(); + return success(); + } + int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor); int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst; assert(upperBoundUnrolledCst <= ubCst); @@ -1383,14 +1393,19 @@ uint64_t unrollJamFactor) { assert(unrollJamFactor > 0 && "unroll jam factor should be positive"); - if (unrollJamFactor == 1) - return promoteIfSingleIteration(forOp); + Optional mayBeConstantTripCount = getConstantTripCount(forOp); + if (unrollJamFactor == 1) { + if (mayBeConstantTripCount.hasValue() && + mayBeConstantTripCount.getValue() == 1 && + failed(promoteIfSingleIteration(forOp))) + return failure(); + return success(); + } // Nothing in the loop body other than the terminator. if (llvm::hasSingleElement(forOp.getBody()->getOperations())) return success(); - Optional mayBeConstantTripCount = getConstantTripCount(forOp); // If the trip count is lower than the unroll jam factor, no unroll jam. if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() < unrollJamFactor) { diff --git a/mlir/test/Dialect/Affine/unroll.mlir b/mlir/test/Dialect/Affine/unroll.mlir --- a/mlir/test/Dialect/Affine/unroll.mlir +++ b/mlir/test/Dialect/Affine/unroll.mlir @@ -631,6 +631,17 @@ // UNROLL-BY-1-NEXT: return } +// UNROLL-BY-1-LABEL: func @unroll_by_one_should_succeed_if_no_promotion() +func @unroll_by_one_should_succeed_if_no_promotion() { + affine.for %i = 0 to 2 { + %x = "foo"(%i) : (index) -> i32 + } + return +// UNROLL-BY-1-NEXT: affine.for %[[IV:.*]] = 0 to 2 { +// UNROLL-BY-1-NEXT: %{{.*}} = "foo"(%[[IV]]) : (index) -> i32 +// UNROLL-BY-1-NEXT: } +} + // Test unrolling with affine.for iter_args. // UNROLL-BY-4-LABEL: loop_unroll_with_iter_args_and_cleanup diff --git a/mlir/test/Transforms/scf-loop-unroll.mlir b/mlir/test/Transforms/scf-loop-unroll.mlir --- a/mlir/test/Transforms/scf-loop-unroll.mlir +++ b/mlir/test/Transforms/scf-loop-unroll.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=3" -split-input-file -canonicalize | FileCheck %s +// RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=1" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-BY-1 // CHECK-LABEL: scf_loop_unroll_single func @scf_loop_unroll_single(%arg0 : f32, %arg1 : f32) -> f32 { @@ -42,3 +43,34 @@ // CHECK: } // CHECK-NEXT: return %[[SUM1]]#0, %[[SUM1]]#1 } + +// UNROLL-BY-1-LABEL: scf_loop_unroll_factor_1_promote +func @scf_loop_unroll_factor_1_promote() -> () { + %step = arith.constant 1 : index + %lo = arith.constant 0 : index + %hi = arith.constant 1 : index + scf.for %i = %lo to %hi step %step { + %x = "test.foo"(%i) : (index) -> i32 + } + return + // UNROLL-BY-1-NEXT: %[[C0:.*]] = arith.constant 0 : index + // UNROLL-BY-1-NEXT: %{{.*}} = "test.foo"(%[[C0]]) : (index) -> i32 +} + +// UNROLL-BY-1-LABEL: scf_loop_unroll_factor_1_no_promotion +func @scf_loop_unroll_factor_1_no_promotion() -> () { + %step = arith.constant 1 : index + %lo = arith.constant 0 : index + %hi = arith.constant 2 : index + scf.for %i = %lo to %hi step %step { + %x = "test.foo"(%i) : (index) -> i32 + } + return + // UNROLL-BY-1-DAG: %[[STEP:.*]] = arith.constant 1 : index + // UNROLL-BY-1-DAG: %[[LO:.*]] = arith.constant 0 : index + // UNROLL-BY-1-DAG: %[[HI:.*]] = arith.constant 2 : index + // UNROLL-BY-1-NEXT: scf.for %[[IV:.*]] = %[[LO]] to %[[HI]] step %[[STEP]] { + // UNROLL-BY-1-NEXT: %{{.*}} = "test.foo"(%[[IV]]) : (index) -> i32 + // UNROLL-BY-1-NEXT: } + // UNROLL-BY-1-NEXT: return +}