diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -479,20 +479,42 @@ // Generates unrolled copies of AffineForOp or scf::ForOp 'loopBodyBlock', with // associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap // 'forOpIV' for each unrolled body. -static void generateUnrolledLoop( - Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor, - function_ref ivRemapFn) { +static void +generateUnrolledLoop(Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor, + function_ref ivRemapFn, + ArrayRef iterArgs, + ArrayRef iterInits) { // Builder to insert unrolled bodies just before the terminator of the body of // 'forOp'. auto builder = OpBuilder::atBlockTerminator(loopBodyBlock); - // Keep a pointer to the last non-terminator operation in the original block - // so that we know what to clone (since we are doing this in-place). - Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2); + // Keep a pointer to the operation in the original block so that we know what + // to clone (since we are doing this in-place). + Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 1); + + BlockAndValueMapping operandMap; + + // If there are iterator arguments, create a remapping to the initial values. + for (auto iterArg : llvm::zip(iterArgs, iterInits)) { + Value iterArgValue = std::get<0>(iterArg); + operandMap.map(iterArgValue, std::get<1>(iterArg)); + } // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies). + Operation *last_cloned = nullptr; for (unsigned i = 1; i < unrollFactor; i++) { - BlockAndValueMapping operandMap; + // `last_cloned` points to the terminator op of the last iteration + // instance. Remap the iterator arguments to results of that iteration. + if (last_cloned && last_cloned->isKnownTerminator()) { + for (auto iterArg : llvm::enumerate(iterArgs)) { + Value iterArgValue = iterArg.value(); + auto mapped_returned = + operandMap.lookup(last_cloned->getOperand(iterArg.index())); + operandMap.map(iterArgValue, mapped_returned); + } + // Remove the terminator op of the last iteration instance. + last_cloned->erase(); + } // If the induction variable is used, create a remapping to the value for // this unrolled instance. @@ -501,9 +523,10 @@ operandMap.map(forOpIV, ivUnroll); } - // Clone the original body of 'forOp'. + // Clone the original body of 'forOp', including the terminator op. The + // terminator op will be erased if there is anther unrolling iteration. for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) - builder.clone(*it, operandMap); + last_cloned = builder.clone(*it, operandMap); } } @@ -565,7 +588,8 @@ auto bumpMap = AffineMap::get(1, 0, d0 + i * step); return b.create(forOp.getLoc(), bumpMap, iv); - }); + }, + /*iterArgs=*/{}, /*iterInits=*/{}); // Promote the loop body up if this has turned into a single iteration loop. promoteIfSingleIteration(forOp); @@ -656,13 +680,18 @@ // Create unrolled loop. forOp.setUpperBound(upperBoundUnrolled); forOp.setStep(stepUnrolled); - generateUnrolledLoop(forOp.getBody(), forOp.getInductionVar(), unrollFactor, - [&](unsigned i, Value iv, OpBuilder b) { - // iv' = iv + step * i; - auto stride = b.create( - loc, step, b.create(loc, i)); - return b.create(loc, iv, stride); - }); + SmallVector initArgs(forOp.initArgs().begin(), + forOp.initArgs().end()); + generateUnrolledLoop( + forOp.getBody(), forOp.getInductionVar(), unrollFactor, + [&](unsigned i, Value iv, OpBuilder b) { + // iv' = iv + step * i; + auto stride = + b.create(loc, step, b.create(loc, i)); + return b.create(loc, iv, stride); + }, + forOp.getRegionIterArgs(), initArgs); + // Promote the loop body up if this has turned into a single iteration loop. promoteIfSingleIteration(forOp); return success(); diff --git a/mlir/test/Dialect/SCF/loop-unroll.mlir b/mlir/test/Dialect/SCF/loop-unroll.mlir --- a/mlir/test/Dialect/SCF/loop-unroll.mlir +++ b/mlir/test/Dialect/SCF/loop-unroll.mlir @@ -3,6 +3,16 @@ // RUN: mlir-opt %s -test-loop-unrolling='unroll-factor=2 loop-depth=0' | FileCheck %s --check-prefix UNROLL-OUTER-BY-2 // RUN: mlir-opt %s -test-loop-unrolling='unroll-factor=2 loop-depth=1' | FileCheck %s --check-prefix UNROLL-INNER-BY-2 +func @reduce_loop_unroll(%buffer: memref<1024xf32>, %lb: index, %ub: index, %step: index) -> (f32) { + %sum_0 = constant 0.0 : f32 + %sum = scf.for %iv = %lb to %ub step %step iter_args(%sum_iter = %sum_0) -> (f32) { + %t = load %buffer[%iv] : memref<1024xf32> + %sum_next = addf %sum_iter, %t : f32 + scf.yield %sum_next : f32 + } + return %sum : f32 +} + func @dynamic_loop_unroll(%arg0 : index, %arg1 : index, %arg2 : index, %arg3: memref) { %0 = constant 7.0 : f32