diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -206,6 +206,18 @@ auto iv = forOp.getInductionVar(); iv.replaceAllUsesWith(lbCstOp); + // Replace uses of iterArgs with iterOperands. + auto iterOperands = forOp.getIterOperands(); + auto iterArgs = forOp.getRegionIterArgs(); + for (auto e : llvm::zip(iterOperands, iterArgs)) + std::get<1>(e).replaceAllUsesWith(std::get<0>(e)); + + // Replace uses of loop results with the values yielded by the loop. + auto outerResults = forOp.getResults(); + auto innerResults = forOp.getBody()->back().getOperands(); + for (auto e : llvm::zip(outerResults, innerResults)) + std::get<0>(e).replaceAllUsesWith(std::get<1>(e)); + // Move the loop body operations, except for its terminator, to the loop's // containing block. auto *parentBlock = forOp.getOperation()->getBlock(); @@ -649,6 +661,17 @@ std::next(Block::iterator(forOp))); auto epilogueForOp = cast(epilogueBuilder.clone(*forOp)); epilogueForOp.setLowerBound(upperBoundUnrolled); + + // Update uses of loop results. + auto results = forOp.getResults(); + auto epilogueResults = epilogueForOp.getResults(); + auto epilogueIterOperands = epilogueForOp.getIterOperands(); + + for (auto e : llvm::zip(results, epilogueResults, epilogueIterOperands)) { + std::get<0>(e).replaceAllUsesWith(std::get<1>(e)); + epilogueForOp.getOperation()->replaceUsesOfWith(std::get<2>(e), + std::get<0>(e)); + } promoteIfSingleIteration(epilogueForOp); } diff --git a/mlir/test/Transforms/scf-loop-unroll.mlir b/mlir/test/Transforms/scf-loop-unroll.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/scf-loop-unroll.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect --test-loop-unrolling="unroll-factor=3" -split-input-file -canonicalize | FileCheck %s + +!tensor = type tensor<*xf32> + +// CHECK-LABEL: scf_loop_unroll +func @scf_loop_unroll(%arg0 : !tensor, %arg1 : !tensor) -> !tensor { + %from = constant 0 : index + %to = constant 10 : index + %step = constant 1 : index + %sum = scf.for %iv = %from to %to step %step iter_args(%sum_iter = %arg0) -> (!tensor) { + %next = "add"(%sum_iter, %arg1) : (!tensor, !tensor) -> !tensor + scf.yield %next : !tensor + } + // CHECK: %[[SUM:.*]] = scf.for + // CHECK-NEXT: "add" + // CHECK-NEXT: "add" + // CHECK-NEXT: "add" + // CHECK-NEXT: scf.yield + // CHECK-NEXT: } + // CHECK-NEXT: %[[RES:.*]] = "add"(%[[SUM]], + // CHECK-NEXT: return %[[RES]] + return %sum : !tensor +}