diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -108,11 +108,15 @@ bool isValidLoopInterchangePermutation(ArrayRef loops, ArrayRef loopPermMap); -/// Performs a sequence of loop interchanges on perfectly nested 'loops', as -/// specified by permutation 'loopPermMap' (loop 'i' in 'loops' is mapped to -/// location 'j = 'loopPermMap[i]' after the loop interchange). -unsigned interchangeLoops(ArrayRef loops, - ArrayRef loopPermMap); +/// Performs a loop permutation on a perfectly nested loop nest `inputNest` +/// (where the contained loops appear from outer to inner) as specified by the +/// permutation `permMap`: loop 'i' in `inputNest` is mapped to location +/// 'loopPermMap[i]', where positions 0, 1, ... are from the outermost position +/// to inner. Returns the position in `inputNest` of the AffineForOp that +/// becomes the new outermost loop of this nest. This method always succeeds, +/// asserts out on invalid input / specifications. +unsigned permuteLoops(ArrayRef inputNest, + ArrayRef permMap); // Sinks all sequential loops to the innermost levels (while preserving // relative order among them) and moves all parallel loops to the diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -700,19 +700,48 @@ return checkLoopInterchangeDependences(depCompsVec, loops, loopPermMap); } -/// Performs a sequence of loop interchanges of loops in perfectly nested -/// sequence of loops in 'loops', as specified by permutation in 'loopPermMap'. -unsigned mlir::interchangeLoops(ArrayRef loops, - ArrayRef loopPermMap) { +/// Return true if `loops` is a perfect nest. +static bool isPerfectlyNested(ArrayRef loops) { + auto outerLoop = loops.front(); + for (auto loop : loops.drop_front()) { + auto parentForOp = dyn_cast(loop.getParentOp()); + // parentForOp's body should be just this loop and the terminator. + if (parentForOp != outerLoop || + parentForOp.getBody()->getOperations().size() != 2) + return false; + outerLoop = loop; + } + return true; +} + +// input[i] should move from position i -> permMap[i]. Returns the position in +// `input` that becomes the new outermost loop. +unsigned mlir::permuteLoops(ArrayRef input, + ArrayRef permMap) { + assert(input.size() == permMap.size() && "invalid permutation map size"); + // Check whether the permutation spec is valid. This is a small vector - we'll + // just sort and check if it's iota. + SmallVector checkPermMap(permMap.begin(), permMap.end()); + llvm::sort(checkPermMap); + if (llvm::any_of(llvm::enumerate(checkPermMap), + [](const auto &en) { return en.value() != en.index(); })) + assert(false && "invalid permutation map"); + + // Nothing to do. + if (input.size() < 2) + return 0; + + assert(isPerfectlyNested(input) && "input not perfectly nested"); + Optional loopNestRootIndex; - for (int i = loops.size() - 1; i >= 0; --i) { - int permIndex = static_cast(loopPermMap[i]); + for (int i = input.size() - 1; i >= 0; --i) { + int permIndex = static_cast(permMap[i]); // Store the index of the for loop which will be the new loop nest root. if (permIndex == 0) loopNestRootIndex = i; if (permIndex > i) { // Sink loop 'i' by 'permIndex - i' levels deeper into the loop nest. - sinkLoop(loops[i], permIndex - i); + sinkLoop(input[i], permIndex - i); } } assert(loopNestRootIndex.hasValue()); @@ -770,7 +799,7 @@ if (!checkLoopInterchangeDependences(depCompsVec, loops, loopPermMap)) return forOp; // Perform loop interchange according to permutation 'loopPermMap'. - unsigned loopNestRootIndex = interchangeLoops(loops, loopPermMap); + unsigned loopNestRootIndex = permuteLoops(loops, loopPermMap); return loops[loopNestRootIndex]; }