diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -206,10 +206,22 @@ auto iv = forOp.getInductionVar(); iv.replaceAllUsesWith(lbCstOp); + // Replace uses of iterArgs with iterOperands. + auto iterOperands = forOp.getIterOperands(); + auto iterArgs = forOp.getRegionIterArgs(); + for (auto e : llvm::zip(iterOperands, iterArgs)) + std::get<1>(e).replaceAllUsesWith(std::get<0>(e)); + + // Replace uses of loop results with the values yielded by the loop. + auto outerResults = forOp.getResults(); + auto innerResults = forOp.getBody()->getTerminator()->getOperands(); + for (auto e : llvm::zip(outerResults, innerResults)) + std::get<0>(e).replaceAllUsesWith(std::get<1>(e)); + // Move the loop body operations, except for its terminator, to the loop's // containing block. auto *parentBlock = forOp.getOperation()->getBlock(); - forOp.getBody()->back().erase(); + forOp.getBody()->getTerminator()->erase(); parentBlock->getOperations().splice(Block::iterator(forOp), forOp.getBody()->getOperations()); forOp.erase(); @@ -418,10 +430,10 @@ return success(); } -// Collect perfectly nested loops starting from `rootForOps`. Loops are -// perfectly nested if each loop is the first and only non-terminator operation -// in the parent loop. Collect at most `maxLoops` loops and append them to -// `forOps`. +/// Collect perfectly nested loops starting from `rootForOps`. Loops are +/// perfectly nested if each loop is the first and only non-terminator operation +/// in the parent loop. Collect at most `maxLoops` loops and append them to +/// `forOps`. template static void getPerfectlyNestedLoopsImpl( SmallVectorImpl &forOps, T rootForOp, @@ -478,9 +490,10 @@ // Generates unrolled copies of AffineForOp or scf::ForOp 'loopBodyBlock', with // associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap // 'forOpIV' for each unrolled body. -static void generateUnrolledLoop( - Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor, - function_ref ivRemapFn) { +static void +generateUnrolledLoop(Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor, + function_ref ivRemapFn, + ValueRange iterArgs, ValueRange yieldedValues) { // Builder to insert unrolled bodies just before the terminator of the body of // 'forOp'. auto builder = OpBuilder::atBlockTerminator(loopBodyBlock); @@ -490,9 +503,14 @@ Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2); // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies). + SmallVector lastYielded(yieldedValues); + for (unsigned i = 1; i < unrollFactor; i++) { BlockAndValueMapping operandMap; + // Prepare operand map. + operandMap.map(iterArgs, lastYielded); + // If the induction variable is used, create a remapping to the value for // this unrolled instance. if (!forOpIV.use_empty()) { @@ -503,7 +521,14 @@ // Clone the original body of 'forOp'. for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) builder.clone(*it, operandMap); + + // Update yielded values. + for (unsigned i = 0; i < lastYielded.size(); i++) + lastYielded[i] = operandMap.lookup(yieldedValues[i]); } + + // Update operands of the yield statement. + loopBodyBlock->getTerminator()->setOperands(lastYielded); } /// Unrolls this loop by the specified factor. Returns success if the loop @@ -564,7 +589,8 @@ auto bumpMap = AffineMap::get(1, 0, d0 + i * step); return b.create(forOp.getLoc(), bumpMap, iv); - }); + }, + {}, {}); // Promote the loop body up if this has turned into a single iteration loop. promoteIfSingleIteration(forOp); @@ -649,19 +675,36 @@ std::next(Block::iterator(forOp))); auto epilogueForOp = cast(epilogueBuilder.clone(*forOp)); epilogueForOp.setLowerBound(upperBoundUnrolled); + + // Update uses of loop results. + auto results = forOp.getResults(); + auto epilogueResults = epilogueForOp.getResults(); + auto epilogueIterOperands = epilogueForOp.getIterOperands(); + + for (auto e : llvm::zip(results, epilogueResults, epilogueIterOperands)) { + std::get<0>(e).replaceAllUsesWith(std::get<1>(e)); + epilogueForOp.getOperation()->replaceUsesOfWith(std::get<2>(e), + std::get<0>(e)); + } promoteIfSingleIteration(epilogueForOp); } // Create unrolled loop. forOp.setUpperBound(upperBoundUnrolled); forOp.setStep(stepUnrolled); - generateUnrolledLoop(forOp.getBody(), forOp.getInductionVar(), unrollFactor, - [&](unsigned i, Value iv, OpBuilder b) { - // iv' = iv + step * i; - auto stride = b.create( - loc, step, b.create(loc, i)); - return b.create(loc, iv, stride); - }); + + auto iterArgs = ValueRange(forOp.getRegionIterArgs()); + auto yieldedValues = forOp.getBody()->getTerminator()->getOperands(); + + generateUnrolledLoop( + forOp.getBody(), forOp.getInductionVar(), unrollFactor, + [&](unsigned i, Value iv, OpBuilder b) { + // iv' = iv + step * i; + auto stride = + b.create(loc, step, b.create(loc, i)); + return b.create(loc, iv, stride); + }, + iterArgs, yieldedValues); // Promote the loop body up if this has turned into a single iteration loop. promoteIfSingleIteration(forOp); return success(); diff --git a/mlir/test/Transforms/scf-loop-unroll.mlir b/mlir/test/Transforms/scf-loop-unroll.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/scf-loop-unroll.mlir @@ -0,0 +1,44 @@ +// RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=3" -split-input-file -canonicalize | FileCheck %s + +// CHECK-LABEL: scf_loop_unroll_single +func @scf_loop_unroll_single(%arg0 : f32, %arg1 : f32) -> f32 { + %from = constant 0 : index + %to = constant 10 : index + %step = constant 1 : index + %sum = scf.for %iv = %from to %to step %step iter_args(%sum_iter = %arg0) -> (f32) { + %next = addf %sum_iter, %arg1 : f32 + scf.yield %next : f32 + } + // CHECK: %[[SUM:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[V0:.*]] = + // CHECK-NEXT: %[[V1:.*]] = addf %[[V0]] + // CHECK-NEXT: %[[V2:.*]] = addf %[[V1]] + // CHECK-NEXT: %[[V3:.*]] = addf %[[V2]] + // CHECK-NEXT: scf.yield %[[V3]] + // CHECK-NEXT: } + // CHECK-NEXT: %[[RES:.*]] = addf %[[SUM]], + // CHECK-NEXT: return %[[RES]] + return %sum : f32 +} + +// CHECK-LABEL: scf_loop_unroll_double_symbolic_ub +// CHECK-SAME: (%{{.*}}: f32, %{{.*}}: f32, %[[N:.*]]: index) +func @scf_loop_unroll_double_symbolic_ub(%arg0 : f32, %arg1 : f32, %n : index) -> (f32,f32) { + %from = constant 0 : index + %step = constant 1 : index + %sum:2 = scf.for %iv = %from to %n step %step iter_args(%i0 = %arg0, %i1 = %arg1) -> (f32, f32) { + %sum0 = addf %i0, %arg0 : f32 + %sum1 = addf %i1, %arg1 : f32 + scf.yield %sum0, %sum1 : f32, f32 + } + return %sum#0, %sum#1 : f32, f32 + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK-NEXT: %[[C1:.*]] = constant 1 : index + // CHECK-NEXT: %[[C3:.*]] = constant 3 : index + // CHECK-NEXT: %[[REM:.*]] = remi_signed %[[N]], %[[C3]] + // CHECK-NEXT: %[[UB:.*]] = subi %[[N]], %[[REM]] + // CHECK-NEXT: %[[SUM:.*]]:2 = scf.for {{.*}} = %[[C0]] to %[[UB]] step %[[C3]] iter_args + // CHECK: } + // CHECK-NEXT: %[[SUM1:.*]]:2 = scf.for {{.*}} = %[[UB]] to %[[N]] step %[[C1]] iter_args(%[[V1:.*]] = %[[SUM]]#0, %[[V2:.*]] = %[[SUM]]#1) + // CHECK: } + // CHECK-NEXT: return %[[SUM1]]#0, %[[SUM1]]#1 +}