diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -147,6 +147,26 @@ return builder.create(loc, sum, divisor); } +/// Helper to replace uses of loop carried values (iter_args) and loop +/// yield values while promoting single iteration affine.for and scf.for ops. +template +static void replaceIterArgsAndYieldResults(AffineOrSCFForOp forOp) { + static_assert( + llvm::is_one_of::value, + "only for affine.for and scf.for ops"); + // Replace uses of iter arguments with iter operands (initial values). + auto iterOperands = forOp.getIterOperands(); + auto iterArgs = forOp.getRegionIterArgs(); + for (auto e : llvm::zip(iterOperands, iterArgs)) + std::get<1>(e).replaceAllUsesWith(std::get<0>(e)); + + // Replace uses of loop results with the values yielded by the loop. + auto outerResults = forOp.getResults(); + auto innerResults = forOp.getBody()->getTerminator()->getOperands(); + for (auto e : llvm::zip(outerResults, innerResults)) + std::get<0>(e).replaceAllUsesWith(std::get<1>(e)); +} + /// Promotes the loop body of a forOp to its containing block if the forOp /// was known to have a single iteration. // TODO: extend this for arbitrary affine bounds. @@ -181,6 +201,9 @@ } } } + + replaceIterArgsAndYieldResults(forOp); + // Move the loop body operations, except for its terminator, to the loop's // containing block. forOp.getBody()->back().erase(); @@ -206,17 +229,7 @@ auto iv = forOp.getInductionVar(); iv.replaceAllUsesWith(lbCstOp); - // Replace uses of iterArgs with iterOperands. - auto iterOperands = forOp.getIterOperands(); - auto iterArgs = forOp.getRegionIterArgs(); - for (auto e : llvm::zip(iterOperands, iterArgs)) - std::get<1>(e).replaceAllUsesWith(std::get<0>(e)); - - // Replace uses of loop results with the values yielded by the loop. - auto outerResults = forOp.getResults(); - auto innerResults = forOp.getBody()->getTerminator()->getOperands(); - for (auto e : llvm::zip(outerResults, innerResults)) - std::get<0>(e).replaceAllUsesWith(std::get<1>(e)); + replaceIterArgsAndYieldResults(forOp); // Move the loop body operations, except for its terminator, to the loop's // containing block. @@ -1127,6 +1140,17 @@ if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) { OpBuilder builder(forOp->getBlock(), std::next(Block::iterator(forOp))); auto cleanupForOp = cast(builder.clone(*forOp)); + + // Update users of loop results. + auto results = forOp.getResults(); + auto cleanupResults = cleanupForOp.getResults(); + auto cleanupIterOperands = cleanupForOp.getIterOperands(); + + for (auto e : llvm::zip(results, cleanupResults, cleanupIterOperands)) { + std::get<0>(e).replaceAllUsesWith(std::get<1>(e)); + cleanupForOp->replaceUsesOfWith(std::get<2>(e), std::get<0>(e)); + } + AffineMap cleanupMap; SmallVector cleanupOperands; getCleanupLoopLowerBound(forOp, unrollFactor, cleanupMap, cleanupOperands); @@ -1142,18 +1166,21 @@ forOp.setUpperBound(cleanupOperands, cleanupMap); } + ValueRange iterArgs(forOp.getRegionIterArgs()); + auto yieldedValues = forOp.getBody()->getTerminator()->getOperands(); + // Scale the step of loop being unrolled by unroll factor. int64_t step = forOp.getStep(); forOp.setStep(step * unrollFactor); - generateUnrolledLoop(forOp.getBody(), forOp.getInductionVar(), unrollFactor, - [&](unsigned i, Value iv, OpBuilder b) { - // iv' = iv + i * step - auto d0 = b.getAffineDimExpr(0); - auto bumpMap = AffineMap::get(1, 0, d0 + i * step); - return b.create(forOp.getLoc(), bumpMap, - iv); - }, - /*iterArgs=*/{}, /*yieldedValues=*/{}); + generateUnrolledLoop( + forOp.getBody(), forOp.getInductionVar(), unrollFactor, + [&](unsigned i, Value iv, OpBuilder b) { + // iv' = iv + i * step + auto d0 = b.getAffineDimExpr(0); + auto bumpMap = AffineMap::get(1, 0, d0 + i * step); + return b.create(forOp.getLoc(), bumpMap, iv); + }, + /*iterArgs=*/iterArgs, /*yieldedValues=*/yieldedValues); // Promote the loop body up if this has turned into a single iteration loop. (void)promoteIfSingleIteration(forOp); diff --git a/mlir/test/Dialect/Affine/unroll.mlir b/mlir/test/Dialect/Affine/unroll.mlir --- a/mlir/test/Dialect/Affine/unroll.mlir +++ b/mlir/test/Dialect/Affine/unroll.mlir @@ -590,3 +590,54 @@ // UNROLL-BY-1-NEXT: %0 = "foo"(%c0) : (index) -> i32 // UNROLL-BY-1-NEXT: return } + +// Test unrolling with affine.for iter_args. + +// UNROLL-BY-4-LABEL: loop_unroll_with_iter_args_and_cleanup +func @loop_unroll_with_iter_args_and_cleanup(%arg0 : f32, %arg1 : f32, %n : index) -> (f32,f32) { + %cf1 = constant 1.0 : f32 + %cf2 = constant 2.0 : f32 + %sum:2 = affine.for %iv = 0 to 10 iter_args(%i0 = %arg0, %i1 = %arg1) -> (f32, f32) { + %sum0 = addf %i0, %cf1 : f32 + %sum1 = addf %i1, %cf2 : f32 + affine.yield %sum0, %sum1 : f32, f32 + } + return %sum#0, %sum#1 : f32, f32 + // UNROLL-BY-4: %[[SUM:.*]]:2 = affine.for {{.*}} = 0 to 8 step 4 iter_args + // UNROLL-BY-4-NEXT: addf + // UNROLL-BY-4-NEXT: addf + // UNROLL-BY-4-NEXT: addf + // UNROLL-BY-4-NEXT: addf + // UNROLL-BY-4-NEXT: addf + // UNROLL-BY-4-NEXT: addf + // UNROLL-BY-4-NEXT: %[[Y1:.*]] = addf + // UNROLL-BY-4-NEXT: %[[Y2:.*]] = addf + // UNROLL-BY-4-NEXT: affine.yield %[[Y1]], %[[Y2]] + // UNROLL-BY-4-NEXT: } + // UNROLL-BY-4-NEXT: %[[SUM1:.*]]:2 = affine.for {{.*}} = 8 to 10 iter_args(%[[V1:.*]] = %[[SUM]]#0, %[[V2:.*]] = %[[SUM]]#1) + // UNROLL-BY-4: } + // UNROLL-BY-4-NEXT: return %[[SUM1]]#0, %[[SUM1]]#1 +} + +// The epilogue being a single iteration loop gets promoted here. + +// UNROLL-BY-4-LABEL: unroll_with_iter_args_and_promotion +func @unroll_with_iter_args_and_promotion(%arg0 : f32, %arg1 : f32) -> f32 { + %from = constant 0 : index + %to = constant 10 : index + %step = constant 1 : index + %sum = affine.for %iv = 0 to 9 iter_args(%sum_iter = %arg0) -> (f32) { + %next = addf %sum_iter, %arg1 : f32 + affine.yield %next : f32 + } + // UNROLL-BY-4: %[[SUM:.*]] = affine.for %{{.*}} = 0 to 8 step 4 iter_args(%[[V0:.*]] = + // UNROLL-BY-4-NEXT: %[[V1:.*]] = addf %[[V0]] + // UNROLL-BY-4-NEXT: %[[V2:.*]] = addf %[[V1]] + // UNROLL-BY-4-NEXT: %[[V3:.*]] = addf %[[V2]] + // UNROLL-BY-4-NEXT: %[[V4:.*]] = addf %[[V3]] + // UNROLL-BY-4-NEXT: affine.yield %[[V4]] + // UNROLL-BY-4-NEXT: } + // UNROLL-BY-4-NEXT: %[[RES:.*]] = addf %[[SUM]], + // UNROLL-BY-4-NEXT: return %[[RES]] + return %sum : f32 +}