diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -555,14 +555,13 @@ // Compute the number of iterations the loop executes: ceildiv(ub - lb, step) // assuming the step is strictly positive. Update the bounds and the step // of the loop to go from 0 to the number of iterations, if necessary. - // TODO: introduce support for negative steps or emit dynamic asserts - // on step positivity, whatever gets implemented first. if (isZeroBased && isStepOne) return {/*lowerBound=*/lowerBound, /*upperBound=*/upperBound, /*step=*/step}; Value diff = boundsBuilder.create(loc, upperBound, lowerBound); - Value newUpperBound = ceilDivPositive(boundsBuilder, loc, diff, step); + Value newUpperBound = + boundsBuilder.create(loc, diff, step); Value newLowerBound = isZeroBased ? lowerBound diff --git a/mlir/test/Dialect/Affine/loop-coalescing.mlir b/mlir/test/Dialect/Affine/loop-coalescing.mlir --- a/mlir/test/Dialect/Affine/loop-coalescing.mlir +++ b/mlir/test/Dialect/Affine/loop-coalescing.mlir @@ -88,10 +88,7 @@ // Number of iterations in the outer scf. // CHECK: %[[diff_i:.*]] = arith.subi %[[orig_ub_i]], %[[orig_lb_i]] - // CHECK: %[[c1:.*]] = arith.constant 1 - // CHECK: %[[step_minus_c1:.*]] = arith.subi %[[orig_step_i]], %[[c1]] - // CHECK: %[[dividend:.*]] = arith.addi %[[diff_i]], %[[step_minus_c1]] - // CHECK: %[[numiter_i:.*]] = arith.divui %[[dividend]], %[[orig_step_i]] + // CHECK: %[[numiter_i:.*]] = arith.ceildivsi %[[diff_i]], %[[orig_step_i]] // Normalized lower bound and step for the outer scf. // CHECK: %[[lb_i:.*]] = arith.constant 0 @@ -99,7 +96,7 @@ // Number of iterations in the inner loop, the pattern is the same as above, // only capture the final result. - // CHECK: %[[numiter_j:.*]] = arith.divui {{.*}}, %[[orig_step_j]] + // CHECK: %[[numiter_j:.*]] = arith.ceildivsi {{.*}}, %[[orig_step_j]] // New bounds of the outer scf. // CHECK: %[[range:.*]] = arith.muli %[[numiter_i]], %[[numiter_j]] @@ -135,13 +132,9 @@ // Compute the number of iterations for each of the loops and the total // number of iterations. // CHECK: %[[range1:.*]] = arith.subi %[[orig_ub1]], %[[orig_lb1]] - // CHECK: %[[orig_step1_minus_1:.*]] = arith.subi %[[orig_step1]], %c1 - // CHECK: %[[dividend1:.*]] = arith.addi %[[range1]], %[[orig_step1_minus_1]] - // CHECK: %[[numiter1:.*]] = arith.divui %[[dividend1]], %[[orig_step1]] + // CHECK: %[[numiter1:.*]] = arith.ceildivsi %[[range1]], %[[orig_step1]] // CHECK: %[[range2:.*]] = arith.subi %[[orig_ub2]], %[[orig_lb2]] - // CHECK: %[[orig_step2_minus_1:.*]] = arith.subi %arg5, %c1 - // CHECK: %[[dividend2:.*]] = arith.addi %[[range2]], %[[orig_step2_minus_1]] - // CHECK: %[[numiter2:.*]] = arith.divui %[[dividend2]], %[[orig_step2]] + // CHECK: %[[numiter2:.*]] = arith.ceildivsi %[[range2]], %[[orig_step2]] // CHECK: %[[range:.*]] = arith.muli %[[numiter1]], %[[numiter2]] : index // Check that the outer loop is updated.