diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -45,6 +45,10 @@ /// whichever is lower. LogicalResult loopUnrollUpToFactor(AffineForOp forOp, uint64_t unrollFactor); +/// Returns true if `loops` is a perfectly nested loop nest, where loops appear +/// in it from outermost to innermost. +bool LLVM_ATTRIBUTE_UNUSED isPerfectlyNested(ArrayRef loops); + /// Get perfectly nested sequence of loops starting at root of loop nest /// (the first op being another AffineFor, and the second op - a terminator). /// A loop is perfectly nested iff: the first op in the loop's body is another @@ -84,10 +88,12 @@ /// Tiles the specified band of perfectly nested loops creating tile-space loops /// and intra-tile loops. A band is a contiguous set of loops. `tiledNest` when /// non-null is set to the loops of the tiled nest from outermost to innermost. +/// Loops in `input` are erased when the tiling is successful. LLVM_NODISCARD -LogicalResult tileCodeGen(MutableArrayRef band, - ArrayRef tileSizes, - SmallVectorImpl *tiledNest = nullptr); +LogicalResult +tilePerfectlyNested(MutableArrayRef input, + ArrayRef tileSizes, + SmallVectorImpl *tiledNest = nullptr); /// Performs loop interchange on 'forOpA' and 'forOpB'. Requires that 'forOpA' /// and 'forOpB' are part of a perfectly nested sequence of loops. diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp --- a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp @@ -93,10 +93,8 @@ // Bounds for tile space loops. for (unsigned i = 0; i < width; i++) { - auto lbOperands = origLoops[i].getLowerBoundOperands(); - auto ubOperands = origLoops[i].getUpperBoundOperands(); - SmallVector newLbOperands(lbOperands); - SmallVector newUbOperands(ubOperands); + OperandRange newLbOperands = origLoops[i].getLowerBoundOperands(); + OperandRange newUbOperands = origLoops[i].getUpperBoundOperands(); newLoops[i].setLowerBound(newLbOperands, origLoops[i].getLowerBoundMap()); newLoops[i].setUpperBound(newUbOperands, origLoops[i].getUpperBoundMap()); newLoops[i].setStep(tileSizes[i]); @@ -111,8 +109,7 @@ /*operands=*/newLoops[i].getInductionVar(), lbMap); // Set the upper bound. - if (mayBeConstantCount.hasValue() && - mayBeConstantCount.getValue() < tileSizes[i]) { + if (mayBeConstantCount && mayBeConstantCount.getValue() < tileSizes[i]) { // Trip count is less than tile size; upper bound is the trip count. auto ubMap = b.getConstantAffineMap(mayBeConstantCount.getValue()); newLoops[width + i].setUpperBoundMap(ubMap); @@ -121,20 +118,22 @@ // Construct the upper bound map; the operands are the original operands // with 'i' (tile-space loop) appended to it. The new upper bound map is // the original one with an additional expression i + tileSize appended. - auto ub = origLoops[i].getUpperBound(); + + // Add dim operands from original upper bound. SmallVector ubOperands; + auto ub = origLoops[i].getUpperBound(); ubOperands.reserve(ub.getNumOperands() + 1); auto origUbMap = ub.getMap(); - // Add dim operands from original upper bound. - for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j) { + for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j) ubOperands.push_back(ub.getOperand(j)); - } + // Add dim operand for new loop upper bound. ubOperands.push_back(newLoops[i].getInductionVar()); + // Add symbol operands from original upper bound. - for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j) { + for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j) ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j)); - } + SmallVector boundExprs; boundExprs.reserve(1 + origUbMap.getNumResults()); auto dim = b.getAffineDimExpr(origUbMap.getNumDims()); @@ -159,22 +158,22 @@ /// Tiles the specified band of perfectly nested loops creating tile-space loops /// and intra-tile loops. A band is a contiguous set of loops. // TODO(bondhugula): handle non hyper-rectangular spaces. -LogicalResult mlir::tileCodeGen(MutableArrayRef band, - ArrayRef tileSizes, - SmallVectorImpl *tiledNest) { +LogicalResult +mlir::tilePerfectlyNested(MutableArrayRef input, + ArrayRef tileSizes, + SmallVectorImpl *tiledNest) { // Check if the supplied for op's are all successively nested. - assert(!band.empty() && "no loops in band"); - assert(band.size() == tileSizes.size() && "Too few/many tile sizes"); + assert(!input.empty() && "no loops in input band"); + assert(input.size() == tileSizes.size() && "Too few/many tile sizes"); - for (unsigned i = 1, e = band.size(); i < e; i++) - assert(band[i].getParentOp() == band[i - 1] && "not a perfect nest / band"); + assert(isPerfectlyNested(input) && "input loops not perfectly nested"); - auto origLoops = band; + auto origLoops = input; AffineForOp rootAffineForOp = origLoops[0]; auto loc = rootAffineForOp.getLoc(); // Note that width is at least one since band isn't empty. - unsigned width = band.size(); + unsigned width = input.size(); SmallVector tiledLoops(2 * width); @@ -209,14 +208,13 @@ } // Move the loop body of the original nest to the new one. - moveLoopBody(origLoops[origLoops.size() - 1], innermostPointLoop); + moveLoopBody(origLoops.back(), innermostPointLoop); SmallVector origLoopIVs; - extractForInductionVars(band, &origLoopIVs); - SmallVector, 6> ids(origLoopIVs.begin(), origLoopIVs.end()); - FlatAffineConstraints cst; - getIndexSet(band, &cst); + extractForInductionVars(input, &origLoopIVs); + FlatAffineConstraints cst; + getIndexSet(input, &cst); if (!cst.isHyperRectangular(0, width)) { llvm::dbgs() << "tiled code generation unimplemented for the " "non-hyperrectangular case, op:" @@ -258,15 +256,15 @@ getMaximalPerfectLoopNest(forOp); } -// Reduce each tile size to the largest divisor of the corresponding trip count -// (if the trip count is known). +/// Reduces each tile size to the largest divisor of the corresponding trip +/// count (if the trip count is known). static void adjustToDivisorsOfTripCounts(ArrayRef band, SmallVectorImpl *tileSizes) { assert(band.size() == tileSizes->size() && "invalid tile size count"); for (unsigned i = 0, e = band.size(); i < e; i++) { unsigned &tSizeAdjusted = (*tileSizes)[i]; auto mayConst = getConstantTripCount(band[i]); - if (!mayConst.hasValue()) + if (!mayConst) continue; // Adjust the tile size to largest factor of the trip count less than // tSize. @@ -289,8 +287,8 @@ if (band.empty()) return; - // Use tileSize for all loops if specified. - if (tileSize.hasValue()) { + // Use command-line tileSize for all loops if specified. + if (tileSize) { tileSizes->assign(band.size(), tileSize); return; } @@ -312,7 +310,7 @@ // footprint increases with the tile size linearly in that dimension (i.e., // assumes one-to-one access function). auto fp = getMemoryFootprintBytes(band[0], 0); - if (!fp.hasValue()) { + if (!fp) { // Fill with default tile sizes if footprint is unknown. std::fill(tileSizes->begin(), tileSizes->end(), LoopTiling::kDefaultTileSize); @@ -339,7 +337,7 @@ // one possible approach. Or compute a polynomial in tile sizes and solve for // it. - // For an n-d tileable band, compute n^th root of the excess. + // For an n-d tileable band, compute the n^th root of the excess. unsigned tSize = static_cast(floorl(std::pow(excessFactor, 1.0 / band.size()))); // We'll keep a running product to determine the last tile size better. @@ -375,7 +373,7 @@ diag << "]\n"; } SmallVector tiledNest; - if (failed(tileCodeGen(band, tileSizes, &tiledNest))) + if (failed(tilePerfectlyNested(band, tileSizes, &tiledNest))) return signalPassFailure(); // Separate full and partial tiles. diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -707,8 +707,8 @@ /// Returns true if `loops` is a perfectly nested loop nest, where loops appear /// in it from outermost to innermost. -static bool LLVM_ATTRIBUTE_UNUSED -isPerfectlyNested(ArrayRef loops) { +bool LLVM_ATTRIBUTE_UNUSED +mlir::isPerfectlyNested(ArrayRef loops) { assert(!loops.empty() && "no loops provided"); // We already know that the block can't be empty.