diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopUnrollAndJam.cpp --- a/mlir/lib/Dialect/Affine/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopUnrollAndJam.cpp @@ -101,134 +101,5 @@ return loopUnrollJamByFactor(forOp, kDefaultUnrollJamFactor); } -LogicalResult mlir::loopUnrollJamUpToFactor(AffineForOp forOp, - uint64_t unrollJamFactor) { - Optional mayBeConstantTripCount = getConstantTripCount(forOp); - - if (mayBeConstantTripCount.hasValue() && - mayBeConstantTripCount.getValue() < unrollJamFactor) - return loopUnrollJamByFactor(forOp, mayBeConstantTripCount.getValue()); - return loopUnrollJamByFactor(forOp, unrollJamFactor); -} - -/// Unrolls and jams this loop by the specified factor. -LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, - uint64_t unrollJamFactor) { - // Gathers all maximal sub-blocks of operations that do not themselves - // include a for op (a operation could have a descendant for op though - // in its tree). Ignore the block terminators. - struct JamBlockGatherer { - // Store iterators to the first and last op of each sub-block found. - std::vector> subBlocks; - - // This is a linear time walk. - void walk(Operation *op) { - for (auto ®ion : op->getRegions()) - for (auto &block : region) - walk(block); - } - void walk(Block &block) { - for (auto it = block.begin(), e = std::prev(block.end()); it != e;) { - auto subBlockStart = it; - while (it != e && !isa(&*it)) - ++it; - if (it != subBlockStart) - subBlocks.push_back({subBlockStart, std::prev(it)}); - // Process all for insts that appear next. - while (it != e && isa(&*it)) - walk(&*it++); - } - } - }; - - assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1"); - - if (unrollJamFactor == 1) - return promoteIfSingleIteration(forOp); - - if (forOp.getBody()->empty() || - forOp.getBody()->begin() == std::prev(forOp.getBody()->end())) - return failure(); - - // Loops where both lower and upper bounds are multi-result maps won't be - // unrolled (since the trip can't be expressed as an affine function in - // general). - // TODO(mlir-team): this may not be common, but we could support the case - // where the lower bound is a multi-result map and the ub is a single result - // one. - if (forOp.getLowerBoundMap().getNumResults() != 1) - return failure(); - - Optional mayBeConstantTripCount = getConstantTripCount(forOp); - // If the trip count is lower than the unroll jam factor, no unroll jam. - if (mayBeConstantTripCount.hasValue() && - mayBeConstantTripCount.getValue() < unrollJamFactor) - return failure(); - - auto *forInst = forOp.getOperation(); - - // Gather all sub-blocks to jam upon the loop being unrolled. - JamBlockGatherer jbg; - jbg.walk(forInst); - auto &subBlocks = jbg.subBlocks; - - // Generate the cleanup loop if trip count isn't a multiple of - // unrollJamFactor. - if (getLargestDivisorOfTripCount(forOp) % unrollJamFactor != 0) { - // Insert the cleanup loop right after 'forOp'. - OpBuilder builder(forInst->getBlock(), std::next(Block::iterator(forInst))); - auto cleanupAffineForOp = cast(builder.clone(*forInst)); - // Adjust the lower bound of the cleanup loop; its upper bound is the same - // as the original loop's upper bound. - AffineMap cleanupMap; - SmallVector cleanupOperands; - getCleanupLoopLowerBound(forOp, unrollJamFactor, &cleanupMap, - &cleanupOperands, builder); - cleanupAffineForOp.setLowerBound(cleanupOperands, cleanupMap); - - // Promote the cleanup loop if it has turned into a single iteration loop. - promoteIfSingleIteration(cleanupAffineForOp); - - // Adjust the upper bound of the original loop - it will be the same as the - // cleanup loop's lower bound. Its lower bound remains unchanged. - forOp.setUpperBound(cleanupOperands, cleanupMap); - } - - // Scale the step of loop being unroll-jammed by the unroll-jam factor. - int64_t step = forOp.getStep(); - forOp.setStep(step * unrollJamFactor); - - auto forOpIV = forOp.getInductionVar(); - // Unroll and jam (appends unrollJamFactor - 1 additional copies). - for (unsigned i = unrollJamFactor - 1; i >= 1; --i) { - // Operand map persists across all sub-blocks. - BlockAndValueMapping operandMapping; - for (auto &subBlock : subBlocks) { - // Builder to insert unroll-jammed bodies. Insert right at the end of - // sub-block. - OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second)); - - // If the induction variable is used, create a remapping to the value for - // this unrolled instance. - if (!forOpIV.use_empty()) { - // iv' = iv + i, i = 1 to unrollJamFactor-1. - auto d0 = builder.getAffineDimExpr(0); - auto bumpMap = AffineMap::get(1, 0, {d0 + i * step}); - auto ivUnroll = - builder.create(forInst->getLoc(), bumpMap, forOpIV); - operandMapping.map(forOpIV, ivUnroll); - } - // Clone the sub-block being unroll-jammed. - for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) { - builder.clone(*it, operandMapping); - } - } - } - - // Promote the loop body up if this has turned into a single iteration loop. - promoteIfSingleIteration(forOp); - return success(); -} - static PassRegistration pass("affine-loop-unroll-jam", "Unroll and jam loops"); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -486,6 +486,135 @@ return success(); } +LogicalResult mlir::loopUnrollJamUpToFactor(AffineForOp forOp, + uint64_t unrollJamFactor) { + Optional mayBeConstantTripCount = getConstantTripCount(forOp); + + if (mayBeConstantTripCount.hasValue() && + mayBeConstantTripCount.getValue() < unrollJamFactor) + return loopUnrollJamByFactor(forOp, mayBeConstantTripCount.getValue()); + return loopUnrollJamByFactor(forOp, unrollJamFactor); +} + +/// Unrolls and jams this loop by the specified factor. +LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, + uint64_t unrollJamFactor) { + // Gathers all maximal sub-blocks of operations that do not themselves + // include a for op (a operation could have a descendant for op though + // in its tree). Ignore the block terminators. + struct JamBlockGatherer { + // Store iterators to the first and last op of each sub-block found. + std::vector> subBlocks; + + // This is a linear time walk. + void walk(Operation *op) { + for (auto ®ion : op->getRegions()) + for (auto &block : region) + walk(block); + } + void walk(Block &block) { + for (auto it = block.begin(), e = std::prev(block.end()); it != e;) { + auto subBlockStart = it; + while (it != e && !isa(&*it)) + ++it; + if (it != subBlockStart) + subBlocks.push_back({subBlockStart, std::prev(it)}); + // Process all for insts that appear next. + while (it != e && isa(&*it)) + walk(&*it++); + } + } + }; + + assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1"); + + if (unrollJamFactor == 1) + return promoteIfSingleIteration(forOp); + + if (forOp.getBody()->empty() || + forOp.getBody()->begin() == std::prev(forOp.getBody()->end())) + return failure(); + + // Loops where both lower and upper bounds are multi-result maps won't be + // unrolled (since the trip can't be expressed as an affine function in + // general). + // TODO(mlir-team): this may not be common, but we could support the case + // where the lower bound is a multi-result map and the ub is a single result + // one. + if (forOp.getLowerBoundMap().getNumResults() != 1) + return failure(); + + Optional mayBeConstantTripCount = getConstantTripCount(forOp); + // If the trip count is lower than the unroll jam factor, no unroll jam. + if (mayBeConstantTripCount.hasValue() && + mayBeConstantTripCount.getValue() < unrollJamFactor) + return failure(); + + auto *forInst = forOp.getOperation(); + + // Gather all sub-blocks to jam upon the loop being unrolled. + JamBlockGatherer jbg; + jbg.walk(forInst); + auto &subBlocks = jbg.subBlocks; + + // Generate the cleanup loop if trip count isn't a multiple of + // unrollJamFactor. + if (getLargestDivisorOfTripCount(forOp) % unrollJamFactor != 0) { + // Insert the cleanup loop right after 'forOp'. + OpBuilder builder(forInst->getBlock(), std::next(Block::iterator(forInst))); + auto cleanupAffineForOp = cast(builder.clone(*forInst)); + // Adjust the lower bound of the cleanup loop; its upper bound is the same + // as the original loop's upper bound. + AffineMap cleanupMap; + SmallVector cleanupOperands; + getCleanupLoopLowerBound(forOp, unrollJamFactor, &cleanupMap, + &cleanupOperands, builder); + cleanupAffineForOp.setLowerBound(cleanupOperands, cleanupMap); + + // Promote the cleanup loop if it has turned into a single iteration loop. + promoteIfSingleIteration(cleanupAffineForOp); + + // Adjust the upper bound of the original loop - it will be the same as the + // cleanup loop's lower bound. Its lower bound remains unchanged. + forOp.setUpperBound(cleanupOperands, cleanupMap); + } + + // Scale the step of loop being unroll-jammed by the unroll-jam factor. + int64_t step = forOp.getStep(); + forOp.setStep(step * unrollJamFactor); + + auto forOpIV = forOp.getInductionVar(); + // Unroll and jam (appends unrollJamFactor - 1 additional copies). + for (unsigned i = unrollJamFactor - 1; i >= 1; --i) { + // Operand map persists across all sub-blocks. + BlockAndValueMapping operandMapping; + for (auto &subBlock : subBlocks) { + // Builder to insert unroll-jammed bodies. Insert right at the end of + // sub-block. + OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second)); + + // If the induction variable is used, create a remapping to the value for + // this unrolled instance. + if (!forOpIV.use_empty()) { + // iv' = iv + i, i = 1 to unrollJamFactor-1. + auto d0 = builder.getAffineDimExpr(0); + auto bumpMap = AffineMap::get(1, 0, {d0 + i * step}); + auto ivUnroll = + builder.create(forInst->getLoc(), bumpMap, forOpIV); + operandMapping.map(forOpIV, ivUnroll); + } + // Clone the sub-block being unroll-jammed. + for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) { + builder.clone(*it, operandMapping); + } + } + } + + // Promote the loop body up if this has turned into a single iteration loop. + promoteIfSingleIteration(forOp); + return success(); +} + /// Performs loop interchange on 'forOpA' and 'forOpB', where 'forOpB' is /// nested within 'forOpA' as the only non-terminator operation in its block. void mlir::interchangeLoops(AffineForOp forOpA, AffineForOp forOpB) {