diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp --- a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp @@ -50,6 +50,10 @@ llvm::cl::cat(clOptionsCategory)); namespace { + +// TODO: this is really a test pass and should be moved out of dialect +// transforms. + /// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a /// full unroll threshold was specified, in which case, fully unrolls all loops /// with trip count less than the specified threshold. The latter is for testing @@ -76,44 +80,32 @@ }; } // end anonymous namespace -void LoopUnroll::runOnFunction() { - // Gathers all innermost loops through a post order pruned walk. - struct InnermostLoopGatherer { - // Store innermost loops as we walk. - std::vector loops; - - void walkPostOrder(FuncOp f) { - for (auto &b : f) - walkPostOrder(b.begin(), b.end()); - } - - bool walkPostOrder(Block::iterator Start, Block::iterator End) { - bool hasInnerLoops = false; - // We need to walk all elements since all innermost loops need to be - // gathered as opposed to determining whether this list has any inner - // loops or not. - while (Start != End) - hasInnerLoops |= walkPostOrder(&(*Start++)); - return hasInnerLoops; - } - bool walkPostOrder(Operation *opInst) { - bool hasInnerLoops = false; - for (auto ®ion : opInst->getRegions()) - for (auto &block : region) - hasInnerLoops |= walkPostOrder(block.begin(), block.end()); - if (isa(opInst)) { - if (!hasInnerLoops) - loops.push_back(cast(opInst)); - return true; - } - return hasInnerLoops; - } - }; +/// Returns true if no other affine.for ops are nested within. +static bool isInnermostAffineForOp(AffineForOp forOp) { + // Only for the innermost affine.for op's. + bool isInnermost = true; + forOp.walk([&](AffineForOp thisForOp) { + // Since this is a post order walk, we are able to conclude here. + isInnermost = (thisForOp == forOp); + return WalkResult::interrupt(); + }); + return isInnermost; +} +/// Gathers loops that have no affine.for's nested within. +static void gatherInnermostLoops(FuncOp f, + SmallVectorImpl &loops) { + f.walk([&](AffineForOp forOp) { + if (isInnermostAffineForOp(forOp)) + loops.push_back(forOp); + }); +} + +void LoopUnroll::runOnFunction() { if (clUnrollFull.getNumOccurrences() > 0 && clUnrollFullThreshold.getNumOccurrences() > 0) { // Store short loops as we walk. - std::vector loops; + SmallVector loops; // Gathers all loops with trip count <= minTripCount. Do a post order walk // so that loops are gathered from innermost to outermost (or else unrolling @@ -133,10 +125,10 @@ : 1; // If the call back is provided, we will recurse until no loops are found. FuncOp func = getFunction(); + SmallVector loops; for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) { - InnermostLoopGatherer ilg; - ilg.walkPostOrder(func); - auto &loops = ilg.loops; + loops.clear(); + gatherInnermostLoops(func, loops); if (loops.empty()) break; bool unrolled = false;