diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -206,10 +206,22 @@ auto iv = forOp.getInductionVar(); iv.replaceAllUsesWith(lbCstOp); + // Replace uses of iterArgs with iterOperands. + auto iterOperands = forOp.getIterOperands(); + auto iterArgs = forOp.getRegionIterArgs(); + for (auto e : llvm::zip(iterOperands, iterArgs)) + std::get<1>(e).replaceAllUsesWith(std::get<0>(e)); + + // Replace uses of loop results with the values yielded by the loop. + auto outerResults = forOp.getResults(); + auto innerResults = forOp.getBody()->getTerminator()->getOperands(); + for (auto e : llvm::zip(outerResults, innerResults)) + std::get<0>(e).replaceAllUsesWith(std::get<1>(e)); + // Move the loop body operations, except for its terminator, to the loop's // containing block. auto *parentBlock = forOp.getOperation()->getBlock(); - forOp.getBody()->back().erase(); + forOp.getBody()->getTerminator()->erase(); parentBlock->getOperations().splice(Block::iterator(forOp), forOp.getBody()->getOperations()); forOp.erase(); @@ -418,10 +430,559 @@ return success(); } -// Collect perfectly nested loops starting from `rootForOps`. Loops are -// perfectly nested if each loop is the first and only non-terminator operation -// in the parent loop. Collect at most `maxLoops` loops and append them to -// `forOps`. +/// Checks the legality of tiling of a hyper-rectangular loop nest by simply +/// checking if there is a 'negative' dependence in the memrefs present in +/// the loop nest. If yes then tiling is invalid. +static bool +checkTilingLegalityImpl(MutableArrayRef origLoops) { + assert(!origLoops.empty() && "no original loops provided"); + + // We first find out all dependences we intend to check. + SmallVector loadAndStoreOps; + origLoops[0].getOperation()->walk([&](Operation *op) { + if (isa(op)) + loadAndStoreOps.push_back(op); + }); + + unsigned numOps = loadAndStoreOps.size(); + unsigned numLoops = origLoops.size(); + FlatAffineConstraints dependenceConstraints; + for (unsigned d = 1; d <= numLoops + 1; ++d) { + for (unsigned i = 0; i < numOps; ++i) { + Operation *srcOp = loadAndStoreOps[i]; + MemRefAccess srcAccess(srcOp); + for (unsigned j = 0; j < numOps; ++j) { + Operation *dstOp = loadAndStoreOps[j]; + MemRefAccess dstAccess(dstOp); + + SmallVector depComps; + dependenceConstraints.reset(); + DependenceResult result = checkMemrefAccessDependence( + srcAccess, dstAccess, d, &dependenceConstraints, &depComps); + + // Skip if there is no dependence in this case. + if (!hasDependence(result)) + continue; + + // Check whether there is any negative direction vector in the + // dependence components found above, which means that dependence is + // violated by the default hyper-rect tiling method. + LLVM_DEBUG(llvm::dbgs() << "Checking whether tiling legality violated " + "for dependence at depth: " + << Twine(d) << " between:\n";); + LLVM_DEBUG(srcAccess.opInst->dump();); + LLVM_DEBUG(dstAccess.opInst->dump();); + for (unsigned k = 0, e = depComps.size(); k < e; k++) { + DependenceComponent depComp = depComps[k]; + if (depComp.lb.hasValue() && depComp.ub.hasValue() && + depComp.lb.getValue() < depComp.ub.getValue() && + depComp.ub.getValue() < 0) { + LLVM_DEBUG(llvm::dbgs() + << "Dependence component lb = " + << Twine(depComp.lb.getValue()) + << " ub = " << Twine(depComp.ub.getValue()) + << " is negative at depth: " << Twine(d) + << " and thus violates the legality rule.\n"); + return false; + } + } + } + } + } + + return true; +} + +/// Checks whether hyper-rectangular loop tiling of the nest +/// represented by `origLoops` is valid. The validity condition is from Irigoin +/// and Triolet, which states that two tiles cannot depend on each other. We +/// simplify such condition to just checking whether there is any negative +/// dependence direction, since we have the prior knowledge that the tiling +/// results will be hyper-rectangles, which are scheduled in the +/// lexicographically increasing order on the vector of loop indices. This +/// function will return failure when any dependence component is negative along +/// any of `origLoops`. +LogicalResult +checkTilingLegality(MutableArrayRef origLoops) { + return success(checkTilingLegalityImpl(origLoops)); +} + +/// Check if the input data is valid and wheter tiled code will be legal or not. +template +void performPreTilingChecks(MutableArrayRef input, + ArrayRef tileSizes) { + // Check if the supplied for op's are all successively nested. + assert(!input.empty() && "no loops in input band"); + assert(input.size() == tileSizes.size() && "Too few/many tile sizes"); + + assert(isPerfectlyNested(input) && "input loops not perfectly nested"); + + // Perform tiling legality test. + if (failed(checkTilingLegality(input))) + input[0].emitRemark("tiled code is illegal due to dependences"); +} + +/// Move the loop body of AffineForOp 'src' from 'src' into the specified +/// location in destination's body, ignoring the terminator. +static void moveLoopBodyImpl(AffineForOp src, AffineForOp dest, + Block::iterator loc) { + auto &ops = src.getBody()->getOperations(); + dest.getBody()->getOperations().splice(loc, ops, ops.begin(), + std::prev(ops.end())); +} + +/// Move the loop body of AffineForOp 'src' from 'src' to the start of dest +/// body. +void moveLoopBody(AffineForOp src, AffineForOp dest) { + moveLoopBodyImpl(src, dest, dest.getBody()->begin()); +} + +/// Constructs tiled loop nest, without setting the loop bounds and move the +/// body of the original loop nest to the tiled loop nest. +void constructTiledLoopNest(MutableArrayRef origLoops, + AffineForOp rootAffineForOp, unsigned width, + MutableArrayRef tiledLoops) { + Location loc = rootAffineForOp.getLoc(); + + // The outermost among the loops as we add more.. + Operation *topLoop = rootAffineForOp.getOperation(); + AffineForOp innermostPointLoop; + + // Add intra-tile (or point) loops. + for (unsigned i = 0; i < width; i++) { + OpBuilder b(topLoop); + // Loop bounds will be set later. + AffineForOp pointLoop = b.create(loc, 0, 0); + pointLoop.getBody()->getOperations().splice( + pointLoop.getBody()->begin(), topLoop->getBlock()->getOperations(), + topLoop); + tiledLoops[2 * width - 1 - i] = pointLoop; + topLoop = pointLoop.getOperation(); + if (i == 0) + innermostPointLoop = pointLoop; + } + + // Add tile space loops; + for (unsigned i = width; i < 2 * width; i++) { + OpBuilder b(topLoop); + // Loop bounds will be set later. + AffineForOp tileSpaceLoop = b.create(loc, 0, 0); + tileSpaceLoop.getBody()->getOperations().splice( + tileSpaceLoop.getBody()->begin(), topLoop->getBlock()->getOperations(), + topLoop); + tiledLoops[2 * width - i - 1] = tileSpaceLoop; + topLoop = tileSpaceLoop.getOperation(); + } + + // Move the loop body of the original nest to the new one. + moveLoopBody(origLoops.back(), innermostPointLoop); +} + +/// Checks whether a loop nest is hyper-rectangular or not. +LogicalResult checkIfHyperRectangular(MutableArrayRef input, + AffineForOp rootAffineForOp, + unsigned width) { + FlatAffineConstraints cst; + SmallVector ops(input.begin(), input.end()); + getIndexSet(ops, &cst); + if (!cst.isHyperRectangular(0, width)) { + rootAffineForOp.emitError("tiled code generation unimplemented for the " + "non-hyperrectangular case"); + return failure(); + } + return success(); +} + +/// Set lower and upper bounds of intra-tile loops for parametric tiling. +// TODO: Handle non-constant lower bounds. +static void setIntraTileBoundsParametric(OpBuilder &b, AffineForOp origLoop, + AffineForOp newInterTileLoop, + AffineForOp newIntraTileLoop, + Value tileSize) { + // The lower bound for the intra-tile loop is represented by an affine map + // as (%i, %t0)->((%i - %origlb) * %t0 + %origlb). Similarly, the upper bound + // for the intra-tile loop is represented by an affine map as (%i, %t0)->((%i + // - %origlb) * %t0) + (%t0 * %origLoopStep) + %origlb), where %i is loop IV + // of the corresponding inter-tile loop, %t0 is the corresponding tiling + // parameter, %origlb is lower bound and %origLoopStep is the loop step of the + // corresponding inter-tile loop. + + assert(origLoop.hasConstantLowerBound() && + "expected input loops to have constant lower bound."); + + // Get lower bound of original loop as an affine expression. + AffineExpr origLowerBoundExpr; + origLowerBoundExpr = + b.getAffineConstantExpr(origLoop.getConstantLowerBound()); + + // Add dim operands from original lower/upper bound. + SmallVector lbOperands, ubOperands; + AffineBound lb = origLoop.getLowerBound(); + AffineBound ub = origLoop.getUpperBound(); + lbOperands.reserve(lb.getNumOperands() + 2); + ubOperands.reserve(ub.getNumOperands() + 2); + AffineMap origLbMap = lb.getMap(); + AffineMap origUbMap = ub.getMap(); + for (unsigned j = 0, e = origLbMap.getNumDims(); j < e; ++j) + lbOperands.push_back(lb.getOperand(j)); + for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j) + ubOperands.push_back(ub.getOperand(j)); + + // Add a new dim operand in lb/ubOperands corresponding to the origLoop + // IV. + lbOperands.push_back(newInterTileLoop.getInductionVar()); + ubOperands.push_back(newInterTileLoop.getInductionVar()); + + // Get loop IV as an affine expression for lower/upper bound. Size of + // lb/ubOperands is guaranteed to be atleast one. + AffineExpr lbLoopIvExpr = b.getAffineDimExpr(lbOperands.size() - 1); + AffineExpr ubLoopIvExpr = b.getAffineDimExpr(ubOperands.size() - 1); + + // Add symbol operands from original lower/upper bound. + for (unsigned j = 0, e = origLbMap.getNumSymbols(); j < e; ++j) + lbOperands.push_back(lb.getOperand(origLbMap.getNumDims() + j)); + for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j) + ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j)); + + // Add a new symbol operand which is the tile size for this loop. + lbOperands.push_back(tileSize); + ubOperands.push_back(tileSize); + + SmallVector lbBoundExprs; + SmallVector ubBoundExprs; + lbBoundExprs.reserve(origLbMap.getNumResults()); + ubBoundExprs.reserve(origUbMap.getNumResults()); + + // Get tiling parameter as an affine expression for lb/ub. + AffineExpr lbTileParameter = b.getAffineSymbolExpr(origLbMap.getNumSymbols()); + AffineExpr ubTileParameter = b.getAffineSymbolExpr(origUbMap.getNumSymbols()); + + // Insert lb as inter-tile ((loop IV - origlb) * tilingParameter) + origlb. + lbBoundExprs.push_back( + ((lbLoopIvExpr - origLowerBoundExpr) * lbTileParameter) + + origLowerBoundExpr); + + // Get the origLoopStep as an affine expression. + AffineExpr origLoopStep = b.getAffineConstantExpr(origLoop.getStep()); + + // Insert ub as inter-tile ((loop IV - origlb) * tilingParameter) + + // (tilingParameter * origLoopStep) + origlb. + ubBoundExprs.push_back( + ((ubLoopIvExpr - origLowerBoundExpr) * ubTileParameter) + + (ubTileParameter * origLoopStep) + origLowerBoundExpr); + + ubBoundExprs.append(origUbMap.getResults().begin(), + origUbMap.getResults().end()); + + AffineMap lbMap = + AffineMap::get(origLbMap.getNumDims() + 1, origLbMap.getNumSymbols() + 1, + lbBoundExprs, b.getContext()); + newIntraTileLoop.setLowerBound(lbOperands, lbMap); + + AffineMap ubMap = + AffineMap::get(origUbMap.getNumDims() + 1, origUbMap.getNumSymbols() + 1, + ubBoundExprs, b.getContext()); + newIntraTileLoop.setUpperBound(ubOperands, ubMap); + + // Original loop step must be preserved. + newIntraTileLoop.setStep(origLoop.getStep()); +} + +/// Set lower and upper bounds of inter-tile loops for parametric tiling. +// TODO: Handle non-constant lower bounds. +static void setInterTileBoundsParametric(OpBuilder &b, AffineForOp origLoop, + AffineForOp newLoop, Value tileSize) { + OperandRange newLbOperands = origLoop.getLowerBoundOperands(); + + // The lower bounds for inter-tile loops are same as the correspondig lower + // bounds of original loops. + newLoop.setLowerBound(newLbOperands, origLoop.getLowerBoundMap()); + + // The new upper bound map for inter-tile loops, assuming constant lower + // bounds, are now originalLowerBound + ceildiv((orignalUpperBound - + // originalLowerBound), tiling paramter); where tiling parameter is the + // respective tile size for that loop. For e.g. if the original ubmap was + // ()->(1024), the new map will be + // ()[s0]->(ceildiv((1024 -lb) % s0)), where s0 is the tiling parameter. + // Therefore a new symbol operand is inserted in the map and the result + // expression is overwritten. + + assert(origLoop.hasConstantLowerBound() && + "expected input loops to have constant lower bound."); + + // Get lower bound of original loop as an affine expression. + AffineExpr origLowerBoundExpr; + origLowerBoundExpr = + b.getAffineConstantExpr(origLoop.getConstantLowerBound()); + + // Add dim operands from original upper bound. + SmallVector ubOperands; + AffineBound ub = origLoop.getUpperBound(); + ubOperands.reserve(ub.getNumOperands() + 1); + AffineMap origUbMap = ub.getMap(); + for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j) + ubOperands.push_back(ub.getOperand(j)); + + // Add symbol operands from original upper bound. + for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j) + ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j)); + + // Add a new symbol operand which is the tile size for this loop. + ubOperands.push_back(tileSize); + + // Get tiling parameter as an affine expression. + AffineExpr tileParameter = b.getAffineSymbolExpr(origUbMap.getNumSymbols()); + + SmallVector boundExprs; + boundExprs.reserve(origUbMap.getNumResults()); + int64_t origUpperBound; + AffineExpr origUpperBoundExpr; + + // If upper bound for the original loop is constant, then the constant can + // be obtained as an affine expression straight away. + if (origLoop.hasConstantUpperBound()) { + origUpperBound = origLoop.getConstantUpperBound(); + + // Get original constant upper bound as an affine expression. + origUpperBoundExpr = b.getAffineConstantExpr(origUpperBound); + + // Insert the bound as originalLowerBoundceildiv((originalUpperBound - + // originalLowerBound), tilingParameter). + boundExprs.push_back( + origLowerBoundExpr + + (origUpperBoundExpr - origLowerBoundExpr).ceilDiv(tileParameter)); + } else { + // If upper bound for the original loop is not constant then two cases + // are possible, although there handeling is the same, 1.) The result of + // ubmap has only one result expression. For e.g. + // affine.for %i = 5 to %ub + // + // A symbol operand is added which represents the tiling paramater. The + // new loop bounds here will be like ()[s0, s1] -> ((s0 - 5) ceildiv s1 + 5) + // where 's0' is the original upper bound and 's1' is the tiling + // parameter. 2.) When ubMap has more than one result expression. For e.g. + // #map0 = affine_map<()[s0, s1] -> (s0, s1) + // affine.for %i = 5 to min #map0()[%s0, %s1] + // + // A symbol operand is added which represents the tiling parameter. The + // new loop bounds will be like ()[s0, s1, s2] -> ((s0 - 5) ceildiv s2 + 5, + // (s1 -5) ceildiv s2 + 5), where s2 is the tiling parameter. + + // Insert the bounds as originalLowerBound + ceildiv((originalUpperBound - + // originalLowerBound), tilingParameter). + for (AffineExpr origUpperBoundExpr : origUbMap.getResults()) + boundExprs.push_back( + origLowerBoundExpr + + (origUpperBoundExpr - origLowerBoundExpr).ceilDiv(tileParameter)); + } + + AffineMap ubMap = + AffineMap::get(origUbMap.getNumDims(), origUbMap.getNumSymbols() + 1, + boundExprs, b.getContext()); + newLoop.setUpperBound(ubOperands, ubMap); + + // Original loop step must be preserved. + newLoop.setStep(origLoop.getStep()); +} + +/// Constructs and sets new loop bounds after tiling for the case of +/// hyper-rectangular index sets, where the bounds of one dimension do not +/// depend on other dimensions and tiling parameters are captured from SSA +/// values. Bounds of each dimension can thus be treated independently, +/// and deriving the new bounds is much simpler and faster than for the case of +/// tiling arbitrary polyhedral shapes. +static void constructParametricallyTiledIndexSetHyperRect( + MutableArrayRef origLoops, + MutableArrayRef newLoops, ArrayRef tileSizes) { + assert(!origLoops.empty() && "expected atleast one loop in band"); + assert(origLoops.size() == tileSizes.size() && + "expected tiling parameter for each loop in band."); + + OpBuilder b(origLoops[0].getOperation()); + unsigned width = origLoops.size(); + + // Set bounds for tile space loops. + for (unsigned i = 0; i < width; ++i) { + setInterTileBoundsParametric(b, origLoops[i], newLoops[i], tileSizes[i]); + } + + // Set bounds for intra-tile loops. + for (unsigned i = 0; i < width; ++i) { + setIntraTileBoundsParametric(b, origLoops[i], newLoops[i], + newLoops[i + width], tileSizes[i]); + } +} + +/// Constructs and sets new loop bounds after tiling for the case of +/// hyper-rectangular index sets, where the bounds of one dimension do not +/// depend on other dimensions. Bounds of each dimension can thus be treated +/// independently, and deriving the new bounds is much simpler and faster +/// than for the case of tiling arbitrary polyhedral shapes. +static void +constructTiledIndexSetHyperRect(MutableArrayRef origLoops, + MutableArrayRef newLoops, + ArrayRef tileSizes) { + assert(!origLoops.empty()); + assert(origLoops.size() == tileSizes.size()); + + OpBuilder b(origLoops[0].getOperation()); + unsigned width = origLoops.size(); + + // Bounds for tile space loops. + for (unsigned i = 0; i < width; i++) { + OperandRange newLbOperands = origLoops[i].getLowerBoundOperands(); + OperandRange newUbOperands = origLoops[i].getUpperBoundOperands(); + newLoops[i].setLowerBound(newLbOperands, origLoops[i].getLowerBoundMap()); + newLoops[i].setUpperBound(newUbOperands, origLoops[i].getUpperBoundMap()); + newLoops[i].setStep(tileSizes[i]); + } + // Bounds for intra-tile loops. + for (unsigned i = 0; i < width; i++) { + int64_t largestDiv = getLargestDivisorOfTripCount(origLoops[i]); + Optional mayBeConstantCount = getConstantTripCount(origLoops[i]); + // The lower bound is just the tile-space loop. + AffineMap lbMap = b.getDimIdentityMap(); + newLoops[width + i].setLowerBound( + /*operands=*/newLoops[i].getInductionVar(), lbMap); + + // Set the upper bound. + if (mayBeConstantCount && mayBeConstantCount.getValue() < tileSizes[i]) { + // Trip count is less than the tile size: upper bound is lower bound + + // trip count. + AffineMap ubMap = + b.getSingleDimShiftAffineMap(mayBeConstantCount.getValue()); + newLoops[width + i].setUpperBound( + /*operands=*/newLoops[i].getInductionVar(), ubMap); + } else if (largestDiv % tileSizes[i] != 0) { + // Intra-tile loop ii goes from i to min(i + tileSize, ub_i). + // Construct the upper bound map; the operands are the original operands + // with 'i' (tile-space loop) appended to it. The new upper bound map is + // the original one with an additional expression i + tileSize appended. + + // Add dim operands from original upper bound. + SmallVector ubOperands; + AffineBound ub = origLoops[i].getUpperBound(); + ubOperands.reserve(ub.getNumOperands() + 1); + AffineMap origUbMap = ub.getMap(); + for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j) + ubOperands.push_back(ub.getOperand(j)); + + // Add dim operand for new loop upper bound. + ubOperands.push_back(newLoops[i].getInductionVar()); + + // Add symbol operands from original upper bound. + for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j) + ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j)); + + SmallVector boundExprs; + boundExprs.reserve(1 + origUbMap.getNumResults()); + AffineExpr dim = b.getAffineDimExpr(origUbMap.getNumDims()); + // The new upper bound map is the original one with an additional + // expression i + tileSize appended. + boundExprs.push_back(dim + tileSizes[i]); + boundExprs.append(origUbMap.getResults().begin(), + origUbMap.getResults().end()); + AffineMap ubMap = + AffineMap::get(origUbMap.getNumDims() + 1, origUbMap.getNumSymbols(), + boundExprs, b.getContext()); + newLoops[width + i].setUpperBound(/*operands=*/ubOperands, ubMap); + } else { + // No need of the min expression. + AffineExpr dim = b.getAffineDimExpr(0); + AffineMap ubMap = AffineMap::get(1, 0, dim + tileSizes[i]); + newLoops[width + i].setUpperBound(newLoops[i].getInductionVar(), ubMap); + } + } +} + +/// Tiles the specified band of perfectly nested loops creating tile-space loops +/// and intra-tile loops. A band is a contiguous set of loops. +// TODO: handle non hyper-rectangular spaces. +LogicalResult +mlir::tilePerfectlyNested(MutableArrayRef input, + ArrayRef tileSizes, + SmallVectorImpl *tiledNest) { + performPreTilingChecks(input, tileSizes); + + MutableArrayRef origLoops = input; + AffineForOp rootAffineForOp = origLoops[0]; + // Note that width is at least one since band isn't empty. + unsigned width = input.size(); + SmallVector tiledLoops(2 * width); + + // Construct a tiled loop nest without setting their bounds. Bounds are + // set later. + constructTiledLoopNest(origLoops, rootAffineForOp, width, tiledLoops); + + SmallVector origLoopIVs; + extractForInductionVars(input, &origLoopIVs); + + if (failed(checkIfHyperRectangular(input, rootAffineForOp, width))) + return failure(); + + // Set loop bounds for the tiled loop nest. + constructTiledIndexSetHyperRect(origLoops, tiledLoops, tileSizes); + + // Replace original IVs with intra-tile loop IVs. + for (unsigned i = 0; i < width; i++) + origLoopIVs[i].replaceAllUsesWith(tiledLoops[i + width].getInductionVar()); + + // Erase the old loop nest. + rootAffineForOp.erase(); + + if (tiledNest) + *tiledNest = std::move(tiledLoops); + + return success(); +} + +/// Tiles the specified band of perfectly nested loops creating tile-space +/// loops and intra-tile loops, using SSA values as tiling parameters. A band +/// is a contiguous set of loops. +// TODO: handle non hyper-rectangular spaces. +LogicalResult +mlir::tilePerfectlyNestedParametric(MutableArrayRef input, + ArrayRef tileSizes, + SmallVectorImpl *tiledNest) { + performPreTilingChecks(input, tileSizes); + + MutableArrayRef origLoops = input; + AffineForOp rootAffineForOp = origLoops[0]; + // Note that width is at least one since band isn't empty. + unsigned width = input.size(); + SmallVector tiledLoops(2 * width); + + // Construct a tiled loop nest without setting their bounds. Bounds are + // set later. + constructTiledLoopNest(origLoops, rootAffineForOp, width, tiledLoops); + + SmallVector origLoopIVs; + extractForInductionVars(input, &origLoopIVs); + + if (failed(checkIfHyperRectangular(input, rootAffineForOp, width))) + return failure(); + + // Set loop bounds for the tiled loop nest. + constructParametricallyTiledIndexSetHyperRect(origLoops, tiledLoops, + tileSizes); + + // Replace original IVs with intra-tile loop IVs. + for (unsigned i = 0; i < width; i++) + origLoopIVs[i].replaceAllUsesWith(tiledLoops[i + width].getInductionVar()); + + // Erase the old loop nest. + rootAffineForOp.erase(); + + if (tiledNest) + *tiledNest = std::move(tiledLoops); + + return success(); +} + +/// Collect perfectly nested loops starting from `rootForOps`. Loops are +/// perfectly nested if each loop is the first and only non-terminator operation +/// in the parent loop. Collect at most `maxLoops` loops and append them to +/// `forOps`. template static void getPerfectlyNestedLoopsImpl( SmallVectorImpl &forOps, T rootForOp, @@ -452,6 +1013,20 @@ getPerfectlyNestedLoopsImpl(nestedLoops, root); } +/// Identify valid and profitable bands of loops to tile. This is currently just +/// a temporary placeholder to test the mechanics of tiled code generation. +/// Returns all maximal outermost perfect loop nests to tile. +void mlir::getTileableBands(FuncOp f, + std::vector> *bands) { + // Get maximal perfect nest of 'affine.for' insts starting from root + // (inclusive). + for (AffineForOp forOp : f.getOps()) { + SmallVector band; + getPerfectlyNestedLoops(band, forOp); + bands->push_back(band); + } +} + /// Unrolls this loop completely. LogicalResult mlir::loopUnrollFull(AffineForOp forOp) { Optional mayBeConstantTripCount = getConstantTripCount(forOp); @@ -478,9 +1053,10 @@ // Generates unrolled copies of AffineForOp or scf::ForOp 'loopBodyBlock', with // associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap // 'forOpIV' for each unrolled body. -static void generateUnrolledLoop( - Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor, - function_ref ivRemapFn) { +static void +generateUnrolledLoop(Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor, + function_ref ivRemapFn, + ValueRange iterArgs, ValueRange yieldedValues) { // Builder to insert unrolled bodies just before the terminator of the body of // 'forOp'. auto builder = OpBuilder::atBlockTerminator(loopBodyBlock); @@ -490,9 +1066,14 @@ Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2); // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies). + SmallVector lastYielded(yieldedValues); + for (unsigned i = 1; i < unrollFactor; i++) { BlockAndValueMapping operandMap; + // Prepare operand map. + operandMap.map(iterArgs, lastYielded); + // If the induction variable is used, create a remapping to the value for // this unrolled instance. if (!forOpIV.use_empty()) { @@ -503,7 +1084,14 @@ // Clone the original body of 'forOp'. for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) builder.clone(*it, operandMap); + + // Update yielded values. + for (unsigned i = 0; i < lastYielded.size(); i++) + lastYielded[i] = operandMap.lookup(yieldedValues[i]); } + + // Update operands of the yield statement. + loopBodyBlock->getTerminator()->setOperands(lastYielded); } /// Unrolls this loop by the specified factor. Returns success if the loop @@ -564,7 +1152,8 @@ auto bumpMap = AffineMap::get(1, 0, d0 + i * step); return b.create(forOp.getLoc(), bumpMap, iv); - }); + }, + {}, {}); // Promote the loop body up if this has turned into a single iteration loop. promoteIfSingleIteration(forOp); @@ -649,19 +1238,36 @@ std::next(Block::iterator(forOp))); auto epilogueForOp = cast(epilogueBuilder.clone(*forOp)); epilogueForOp.setLowerBound(upperBoundUnrolled); + + // Update uses of loop results. + auto results = forOp.getResults(); + auto epilogueResults = epilogueForOp.getResults(); + auto epilogueIterOperands = epilogueForOp.getIterOperands(); + + for (auto e : llvm::zip(results, epilogueResults, epilogueIterOperands)) { + std::get<0>(e).replaceAllUsesWith(std::get<1>(e)); + epilogueForOp.getOperation()->replaceUsesOfWith(std::get<2>(e), + std::get<0>(e)); + } promoteIfSingleIteration(epilogueForOp); } // Create unrolled loop. forOp.setUpperBound(upperBoundUnrolled); forOp.setStep(stepUnrolled); - generateUnrolledLoop(forOp.getBody(), forOp.getInductionVar(), unrollFactor, - [&](unsigned i, Value iv, OpBuilder b) { - // iv' = iv + step * i; - auto stride = b.create( - loc, step, b.create(loc, i)); - return b.create(loc, iv, stride); - }); + + auto iterArgs = ValueRange(forOp.getRegionIterArgs()); + auto yieldedValues = forOp.getBody()->getTerminator()->getOperands(); + + generateUnrolledLoop( + forOp.getBody(), forOp.getInductionVar(), unrollFactor, + [&](unsigned i, Value iv, OpBuilder b) { + // iv' = iv + step * i; + auto stride = + b.create(loc, step, b.create(loc, i)); + return b.create(loc, iv, stride); + }, + iterArgs, yieldedValues); // Promote the loop body up if this has turned into a single iteration loop. promoteIfSingleIteration(forOp); return success(); diff --git a/mlir/test/Transforms/scf-loop-unroll.mlir b/mlir/test/Transforms/scf-loop-unroll.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/scf-loop-unroll.mlir @@ -0,0 +1,44 @@ +// RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=3" -split-input-file -canonicalize | FileCheck %s + +// CHECK-LABEL: scf_loop_unroll_single +func @scf_loop_unroll_single(%arg0 : f32, %arg1 : f32) -> f32 { + %from = constant 0 : index + %to = constant 10 : index + %step = constant 1 : index + %sum = scf.for %iv = %from to %to step %step iter_args(%sum_iter = %arg0) -> (f32) { + %next = addf %sum_iter, %arg1 : f32 + scf.yield %next : f32 + } + // CHECK: %[[SUM:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[V0:.*]] = + // CHECK-NEXT: %[[V1:.*]] = addf %[[V0]] + // CHECK-NEXT: %[[V2:.*]] = addf %[[V1]] + // CHECK-NEXT: %[[V3:.*]] = addf %[[V2]] + // CHECK-NEXT: scf.yield %[[V3]] + // CHECK-NEXT: } + // CHECK-NEXT: %[[RES:.*]] = addf %[[SUM]], + // CHECK-NEXT: return %[[RES]] + return %sum : f32 +} + +// CHECK-LABEL: scf_loop_unroll_double_symbolic_ub +// CHECK-SAME: (%{{.*}}: f32, %{{.*}}: f32, %[[N:.*]]: index) +func @scf_loop_unroll_double_symbolic_ub(%arg0 : f32, %arg1 : f32, %n : index) -> (f32,f32) { + %from = constant 0 : index + %step = constant 1 : index + %sum:2 = scf.for %iv = %from to %n step %step iter_args(%i0 = %arg0, %i1 = %arg1) -> (f32, f32) { + %sum0 = addf %i0, %arg0 : f32 + %sum1 = addf %i1, %arg1 : f32 + scf.yield %sum0, %sum1 : f32, f32 + } + return %sum#0, %sum#1 : f32, f32 + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK-NEXT: %[[C1:.*]] = constant 1 : index + // CHECK-NEXT: %[[C3:.*]] = constant 3 : index + // CHECK-NEXT: %[[REM:.*]] = remi_signed %[[N]], %[[C3]] + // CHECK-NEXT: %[[UB:.*]] = subi %[[N]], %[[REM]] + // CHECK-NEXT: %[[SUM:.*]]:2 = scf.for {{.*}} = %[[C0]] to %[[UB]] step %[[C3]] iter_args + // CHECK: } + // CHECK-NEXT: %[[SUM1:.*]]:2 = scf.for {{.*}} = %[[UB]] to %[[N]] step %[[C1]] iter_args(%[[V1:.*]] = %[[SUM]]#0, %[[V2:.*]] = %[[SUM]]#1) + // CHECK: } + // CHECK-NEXT: return %[[SUM1]]#0, %[[SUM1]]#1 +}