diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -480,19 +480,37 @@ // associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap // 'forOpIV' for each unrolled body. static void generateUnrolledLoop( - Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor, + Block *loopBodyBlock, Value forOpIV, ArrayRef iterArgs, + uint64_t unrollFactor, function_ref ivRemapFn) { // Builder to insert unrolled bodies just before the terminator of the body of // 'forOp'. auto builder = OpBuilder::atBlockTerminator(loopBodyBlock); - // Keep a pointer to the last non-terminator operation in the original block - // so that we know what to clone (since we are doing this in-place). + // Keep a pointer to the operation in the original block so that we know what + // to clone (since we are doing this in-place). Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2); + Block::iterator terminator = std::next(srcBlockEnd); + + BlockAndValueMapping operandMap; // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies). for (unsigned i = 1; i < unrollFactor; i++) { - BlockAndValueMapping operandMap; + // If there are iterator arguments, map them to the return results from the + // previous iteration. + for (auto iterArg : llvm::enumerate(iterArgs)) { + Value iterArgValue = iterArg.value(); + Value mapped_returned = + operandMap.lookupOrNull(terminator->getOperand(iterArg.index())); + if (!mapped_returned) + mapped_returned = terminator->getOperand(iterArg.index()); + operandMap.map(iterArgValue, mapped_returned); + } + + // TODO: Seems like the operandMap needs to be cleared for nested loops. + // In other words, the nested loop with iteration argument isn't supported. + if (iterArgs.empty()) + operandMap.clear(); // If the induction variable is used, create a remapping to the value for // this unrolled instance. @@ -501,10 +519,15 @@ operandMap.map(forOpIV, ivUnroll); } - // Clone the original body of 'forOp'. + // Clone the original body of 'forOp' without the terminator. for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) builder.clone(*it, operandMap); } + + // Finally, adjust the operands of the terminator to be the unrolled ones. + for (auto operand : llvm::enumerate(terminator->getOperands())) { + terminator->setOperand(operand.index(), operandMap.lookup(operand.value())); + } } /// Unrolls this loop by the specified factor. Returns success if the loop @@ -558,7 +581,8 @@ // Scale the step of loop being unrolled by unroll factor. int64_t step = forOp.getStep(); forOp.setStep(step * unrollFactor); - generateUnrolledLoop(forOp.getBody(), forOp.getInductionVar(), unrollFactor, + generateUnrolledLoop(forOp.getBody(), forOp.getInductionVar(), + /*iterArgs=*/{}, unrollFactor, [&](unsigned i, Value iv, OpBuilder b) { // iv' = iv + i * step auto d0 = b.getAffineDimExpr(0); @@ -656,13 +680,17 @@ // Create unrolled loop. forOp.setUpperBound(upperBoundUnrolled); forOp.setStep(stepUnrolled); - generateUnrolledLoop(forOp.getBody(), forOp.getInductionVar(), unrollFactor, - [&](unsigned i, Value iv, OpBuilder b) { - // iv' = iv + step * i; - auto stride = b.create( - loc, step, b.create(loc, i)); - return b.create(loc, iv, stride); - }); + SmallVector initArgs(forOp.initArgs().begin(), + forOp.initArgs().end()); + generateUnrolledLoop( + forOp.getBody(), forOp.getInductionVar(), forOp.getRegionIterArgs(), + unrollFactor, [&](unsigned i, Value iv, OpBuilder b) { + // iv' = iv + step * i; + auto stride = + b.create(loc, step, b.create(loc, i)); + return b.create(loc, iv, stride); + }); + // Promote the loop body up if this has turned into a single iteration loop. promoteIfSingleIteration(forOp); return success(); diff --git a/mlir/test/Dialect/SCF/loop-unroll.mlir b/mlir/test/Dialect/SCF/loop-unroll.mlir --- a/mlir/test/Dialect/SCF/loop-unroll.mlir +++ b/mlir/test/Dialect/SCF/loop-unroll.mlir @@ -248,3 +248,28 @@ // UNROLL-BY-3-NEXT: } // UNROLL-BY-3-NEXT: store %{{.*}}, %[[MEM]][%[[C9]]] : memref // UNROLL-BY-3-NEXT: return + +func @reduce_loop_unroll(%buffer: memref<1024xf32>, %lb: index, %ub: index, %step: index) -> (f32) { + %sum_0 = constant 0.0 : f32 + %sum = scf.for %iv = %lb to %ub step %step iter_args(%sum_iter = %sum_0) -> (f32) { + %t = load %buffer[%iv] : memref<1024xf32> + %sum_next = addf %sum_iter, %t : f32 + scf.yield %sum_next : f32 + } + return %sum : f32 + +// UNROLL-BY-2-LABEL: func @reduce_loop_unroll +// UNROLL-BY-2: scf.for %[[IV:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[AR:.*]] = %{{.*}}) -> (f32) { +// UNROLL-BY-2: %[[ADD0:.*]] = addf %arg5, %11 : f32 +// UNROLL-BY-2: %[[ADD1:.*]] = addf %[[ADD0]], %15 : f32 +// UNROLL-BY-2: scf.yield %[[ADD1]] : f32 +// UNROLL-BY-2: } + +// UNROLL-BY-3-LABEL: func @reduce_loop_unroll +// UNROLL-BY-3: scf.for %[[IV:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[AR:.*]] = %{{.*}}) -> (f32) { +// UNROLL-BY-3: %[[ADD0:.*]] = addf %arg5, %11 : f32 +// UNROLL-BY-3: %[[ADD1:.*]] = addf %[[ADD0]], %15 : f32 +// UNROLL-BY-3: %[[ADD2:.*]] = addf %[[ADD1]], %19 : f32 +// UNROLL-BY-3: scf.yield %[[ADD2]] : f32 +// UNROLL-BY-3: } +} \ No newline at end of file