diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -111,56 +111,66 @@ bool replaceIterOperandsUsesInLoop) { if (loopNest.empty()) return {}; - SmallVector newLoopNest(loopNest.size()); - - newLoopNest.back() = replaceLoopWithNewYields( - builder, loopNest.back(), newIterOperands, newYieldValueFn); - - for (unsigned loopDepth : - llvm::reverse(llvm::seq(0, loopNest.size() - 1))) { - NewYieldValueFn fn = [&](OpBuilder &innerBuilder, Location loc, - ArrayRef innerNewBBArgs) { - SmallVector newYields( - newLoopNest[loopDepth + 1]->getResults().take_back( - newIterOperands.size())); - return newYields; - }; - newLoopNest[loopDepth] = - replaceLoopWithNewYields(builder, loopNest[loopDepth], newIterOperands, - fn, replaceIterOperandsUsesInLoop); - if (!replaceIterOperandsUsesInLoop) { - /// The yield is expected to producer the following structure - /// ``` - /// %0 = scf.for ... iter_args(%arg0 = %init) { - /// %1 = scf.for ... iter_args(%arg1 = %arg0) { - /// scf.yield %yield - /// } - /// } - /// ``` - /// - /// since the yield is propagated from inside out, after the inner - /// loop is processed the IR is in this form - /// - /// ``` - /// scf.for ... iter_args { - /// %1 = scf.for ... iter_args(%arg1 = %init) { - /// scf.yield %yield - /// } - /// ``` - /// - /// If `replaceIterOperandUsesInLoops` is true, there is nothing to do. - /// `%init` will be replaced with `%arg0` when it is created for the - /// outer loop. But without that this has to be done explicitly. - unsigned subLen = newIterOperands.size(); - unsigned subStart = - newLoopNest[loopDepth + 1].getNumIterOperands() - subLen; - auto resetOperands = - newLoopNest[loopDepth + 1].getInitArgsMutable().slice(subStart, - subLen); - resetOperands.assign( - newLoopNest[loopDepth].getRegionIterArgs().take_back(subLen)); - } + // This method is recursive (to make it more readable). Adding an + // assertion here to limit the recursion. (See + // https://discourse.llvm.org/t/rfc-update-to-mlir-developer-policy-on-recursion/62235) + assert(loopNest.size() <= 6 && + "exceeded recursion limit when yielding value from loop nest"); + + // To yield a value from a perfectly nested loop nest, the following + // pattern needs to be created, i.e. starting with + // + // ```mlir + // scf.for .. { + // scf.for .. { + // scf.for .. { + // %value = ... + // } + // } + // } + // ``` + // + // needs to be modified to + // + // ```mlir + // %0 = scf.for .. iter_args(%arg0 = %init) { + // %1 = scf.for .. iter_args(%arg1 = %arg0) { + // %2 = scf.for .. iter_args(%arg2 = %arg1) { + // %value = ... + // scf.yield %value + // } + // scf.yield %2 + // } + // scf.yield %1 + // } + // ``` + // + // The inner most loop is handled using the `replaceLoopWithNewYields` + // that works on a single loop. + if (loopNest.size() == 1) { + auto innerMostLoop = replaceLoopWithNewYields( + builder, loopNest.back(), newIterOperands, newYieldValueFn, + replaceIterOperandsUsesInLoop); + return {innerMostLoop}; } + // The outer loops are modified by calling this method recursively + // - The return value of the inner loop is the value yielded by this loop. + // - The region iter args of this loop are the init_args for the inner loop. + SmallVector newLoopNest; + NewYieldValueFn fn = + [&](OpBuilder &innerBuilder, Location loc, + ArrayRef innerNewBBArgs) -> SmallVector { + newLoopNest = replaceLoopNestWithNewYields(builder, loopNest.drop_front(), + innerNewBBArgs, newYieldValueFn, + replaceIterOperandsUsesInLoop); + return llvm::to_vector(llvm::map_range( + newLoopNest.front().getResults().take_back(innerNewBBArgs.size()), + [](OpResult r) -> Value { return r; })); + }; + scf::ForOp outerMostLoop = + replaceLoopWithNewYields(builder, loopNest.front(), newIterOperands, fn, + replaceIterOperandsUsesInLoop); + newLoopNest.insert(newLoopNest.begin(), outerMostLoop); return newLoopNest; }