diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -210,6 +210,15 @@ ValueRange operands, bool eq, bool lower = true); + /// Returns the bound for the identifier at `pos` from the inequality at + /// `ineqPos` as a 1-d affine value map (affine map + operands). The returned + /// affine value map can either be a lower bound or an upper bound depending + /// on the sign of atIneq(ineqPos, pos). Asserts if the row at `ineqPos` does + /// not involve the `pos`th identifier. + void getIneqAsAffineValueMap(unsigned pos, unsigned ineqPos, + AffineValueMap &vmap, + MLIRContext *context) const; + /// Returns the constraint system as an integer set. Returns a null integer /// set if the system has no constraints, or if an integer set couldn't be /// constructed as a result of a local variable's explicit representation not @@ -452,15 +461,17 @@ /// affine expressions involving only the symbolic identifiers. `lb` and /// `ub` (along with the `boundFloorDivisor`) are set to represent the lower /// and upper bound associated with the constant difference: `lb`, `ub` have - /// the coefficients, and boundFloorDivisor, their divisor. + /// the coefficients, and boundFloorDivisor, their divisor. `minLbPos` and + /// `minUbPos` if non-null are set to the position of the constant lower bound + /// and upper bound respectively (to the same if they are from an equality). /// Ex: if the lower bound is [(s0 + s2 - 1) floordiv 32] for a system with - /// three symbolic identifiers, *lb = [1, 0, 1], boundDivisor = 32. See - /// comments at function definition for examples. - Optional - getConstantBoundOnDimSize(unsigned pos, - SmallVectorImpl *lb = nullptr, - int64_t *boundFloorDivisor = nullptr, - SmallVectorImpl *ub = nullptr) const; + /// three symbolic identifiers, *lb = [1, 0, 1], lbDivisor = 32. See comments + /// at function definition for examples. + Optional getConstantBoundOnDimSize( + unsigned pos, SmallVectorImpl *lb = nullptr, + int64_t *boundFloorDivisor = nullptr, + SmallVectorImpl *ub = nullptr, unsigned *minLbPos = nullptr, + unsigned *minUbPos = nullptr) const; /// Returns the constant lower bound for the pos^th identifier if there is /// one; None otherwise. @@ -482,6 +493,20 @@ unsigned symStartPos, ArrayRef localExprs, MLIRContext *context) const; + /// Gather positions of all lower and upper bounds of the identifier at `pos`, + /// and optionally any equalities on it. In addition, the bounds are to be + /// independent of identifiers in position range [`offset`, `offset` + `num`). + void + getLowerAndUpperBoundIndices(unsigned pos, + SmallVectorImpl *lbIndices, + SmallVectorImpl *ubIndices, + SmallVectorImpl *eqIndices = nullptr, + unsigned offset = 0, unsigned num = 0) const; + + /// Removes constraints that are independent of (i.e., do not have a + /// coefficient for) for identifiers in the range [pos, pos + num). + void removeIndependentConstraints(unsigned pos, unsigned num); + /// Returns true if the set can be trivially detected as being /// hyper-rectangular on the specified contiguous set of identifiers. bool isHyperRectangular(unsigned pos, unsigned num) const; diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -275,6 +275,16 @@ /// list of AffineIf is not resizable. void setConditional(IntegerSet set, ValueRange operands); + Block *getThenBlock() { + assert(!thenRegion().empty() && "Unexpected empty 'then' region."); + return &thenRegion().front(); + } + + Block *getElseBlock() { + assert(!elseRegion().empty() && "Empty 'else' region."); + return &elseRegion().front(); + } + OpBuilder getThenBodyBuilder() { assert(!thenRegion().empty() && "Unexpected empty 'then' region."); Block &body = thenRegion().front(); @@ -401,7 +411,7 @@ /// Get ranges as constants, may fail in dynamic case. Optional> getConstantRanges(); - + Block *getBody(); OpBuilder getBodyBuilder(); void setSteps(ArrayRef newSteps); diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -24,6 +24,7 @@ class FuncOp; class OpBuilder; class Value; +class ValueRange; struct MemRefRegion; namespace loop { @@ -90,10 +91,12 @@ bool unrollPrologueEpilogue = false); /// Tiles the specified band of perfectly nested loops creating tile-space loops -/// and intra-tile loops. A band is a contiguous set of loops. +/// and intra-tile loops. A band is a contiguous set of loops. `tiledNest` when +/// non-null is set to the loops of the tiled nest from outermost to innermost. LLVM_NODISCARD LogicalResult tileCodeGen(MutableArrayRef band, - ArrayRef tileSizes); + ArrayRef tileSizes, + SmallVectorImpl *tiledNest = nullptr); /// Performs loop interchange on 'forOpA' and 'forOpB'. Requires that 'forOpA' /// and 'forOpB' are part of a perfectly nested sequence of loops. @@ -271,6 +274,29 @@ void gatherLoops(FuncOp func, std::vector> &depthToLoops); +/// Creates an AffineForOp while ensuring that the lower and upper bounds are +/// canonicalized, i.e., unused and duplicate operands are removed, and any +/// constant operands propagated/folded in. +AffineForOp createCanonicalizedAffineForOp(OpBuilder b, Location loc, + ValueRange lbOperands, + AffineMap lbMap, + ValueRange ubOperands, + AffineMap ubMap, int64_t step = 1); + +/// Separates full tiles from partial tiles for a perfect nest `nest` by +/// generating a conditional guard that selects between the full tile version +/// and the partial tile version using an AffineIfOp. The original loop nest +/// is replaced by this guarded two version form. +/// +/// affine.if (cond) +/// // full_tile +/// else +/// // partial tile +/// +LogicalResult +separateFullTiles(MutableArrayRef nest, + SmallVectorImpl *fullTileNest = nullptr); + } // end namespace mlir #endif // MLIR_TRANSFORMS_LOOP_UTILS_H diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1200,38 +1200,58 @@ return false; } -/// Gather all lower and upper bounds of the identifier at `pos`. The bounds are -/// to be independent of [offset, offset + num) identifiers. -static void getLowerAndUpperBoundIndices(const FlatAffineConstraints &cst, - unsigned pos, - SmallVectorImpl *lbIndices, - SmallVectorImpl *ubIndices, - unsigned offset = 0, - unsigned num = 0) { - assert(pos < cst.getNumIds() && "invalid position"); +/// Gather all lower and upper bounds of the identifier at `pos`, and +/// optionally any equalities on it. In addition, the bounds are to be +/// independent of identifiers in position range [`offset`, `offset` + `num`). +void FlatAffineConstraints::getLowerAndUpperBoundIndices( + unsigned pos, SmallVectorImpl *lbIndices, + SmallVectorImpl *ubIndices, SmallVectorImpl *eqIndices, + unsigned offset, unsigned num) const { + assert(pos < getNumIds() && "invalid position"); + assert(offset + num < getNumCols() && "invalid range"); - // Gather all lower bounds and upper bounds of the variable. Since the - // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower - // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. - for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) { - // The bounds are to be independent of [offset, offset + num) columns. + // Checks for a constraint that has a non-zero coeff for the identifiers in + // the position range [offset, offset + num) while ignoring `pos`. + auto containsConstraintDependentOnRange = [&](unsigned r, bool isEq) { unsigned c, f; + auto cst = isEq ? getEquality(r) : getInequality(r); for (c = offset, f = offset + num; c < f; ++c) { if (c == pos) continue; - if (cst.atIneq(r, c) != 0) + if (cst[c] != 0) break; } - if (c < f) + return c < f; + }; + + // Gather all lower bounds and upper bounds of the variable. Since the + // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower + // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + // The bounds are to be independent of [offset, offset + num) columns. + if (containsConstraintDependentOnRange(r, /*isEq=*/false)) continue; - if (cst.atIneq(r, pos) >= 1) { + if (atIneq(r, pos) >= 1) { // Lower bound. lbIndices->push_back(r); - } else if (cst.atIneq(r, pos) <= -1) { + } else if (atIneq(r, pos) <= -1) { // Upper bound. ubIndices->push_back(r); } } + + // An equality is both a lower and upper bound. Record any equalities + // involving the pos^th identifier. + if (!eqIndices) + return; + + for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { + if (atEq(r, pos) == 0) + continue; + if (containsConstraintDependentOnRange(r, /*isEq=*/true)) + continue; + eqIndices->push_back(r); + } } /// Check if the pos^th identifier can be expressed as a floordiv of an affine @@ -1247,7 +1267,7 @@ assert(pos < cst.getNumIds() && "invalid position"); SmallVector lbIndices, ubIndices; - getLowerAndUpperBoundIndices(cst, pos, &lbIndices, &ubIndices); + cst.getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices); // Check if any lower bound, upper bound pair is of the form: // divisor * id >= expr - (divisor - 1) <-- Lower bound for 'id' @@ -1376,7 +1396,7 @@ "incorrect local exprs count"); SmallVector lbIndices, ubIndices; - getLowerAndUpperBoundIndices(*this, pos + offset, &lbIndices, &ubIndices); + getLowerAndUpperBoundIndices(pos + offset, &lbIndices, &ubIndices); /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos). auto addCoeffs = [&](ArrayRef a, SmallVectorImpl &b) { @@ -1872,9 +1892,23 @@ std::copy(equalities.begin() + inputIndex, equalities.begin() + inputIndex + numElemsToCopy, equalities.begin() + outputIndex); + assert(equalities.size() >= numReservedCols); equalities.resize(equalities.size() - numReservedCols); } +void FlatAffineConstraints::removeInequality(unsigned pos) { + unsigned numInequalities = getNumInequalities(); + assert(pos < numInequalities && "invalid position"); + unsigned outputIndex = pos * numReservedCols; + unsigned inputIndex = (pos + 1) * numReservedCols; + unsigned numElemsToCopy = (numInequalities - pos - 1) * numReservedCols; + std::copy(inequalities.begin() + inputIndex, + inequalities.begin() + inputIndex + numElemsToCopy, + inequalities.begin() + outputIndex); + assert(inequalities.size() >= numReservedCols); + inequalities.resize(inequalities.size() - numReservedCols); +} + /// Finds an equality that equates the specified identifier to a constant. /// Returns the position of the equality row. If 'symbolic' is set to true, /// symbols are also treated like a constant, i.e., an affine function of the @@ -1951,14 +1985,22 @@ // ceil(s0 - 7 / 8) = floor(s0 / 8)). Optional FlatAffineConstraints::getConstantBoundOnDimSize( unsigned pos, SmallVectorImpl *lb, int64_t *boundFloorDivisor, - SmallVectorImpl *ub) const { + SmallVectorImpl *ub, unsigned *minLbPos, + unsigned *minUbPos) const { assert(pos < getNumDimIds() && "Invalid identifier position"); - assert(getNumLocalIds() == 0); // Find an equality for 'pos'^th identifier that equates it to some function // of the symbolic identifiers (+ constant). int eqPos = findEqualityToConstant(*this, pos, /*symbolic=*/true); if (eqPos != -1) { + auto eq = getEquality(eqPos); + // If the equality involves a local var, punt for now. + // TODO: this can be handled in the future by using the explicit + // representation of the local vars. + if (!std::all_of(eq.begin() + getNumDimAndSymbolIds(), eq.end() - 1, + [](int64_t coeff) { return coeff == 0; })) + return None; + // This identifier can only take a single value. if (lb) { // Set lb to that symbolic value. @@ -1979,6 +2021,10 @@ "both lb and divisor or none should be provided"); *boundFloorDivisor = 1; } + if (minLbPos) + *minLbPos = eqPos; + if (minUbPos) + *minUbPos = eqPos; return 1; } @@ -1999,8 +2045,8 @@ // the bounds can only involve symbolic (and local) identifiers. Since the // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. - getLowerAndUpperBoundIndices(*this, pos, &lbIndices, &ubIndices, - /*offset=*/0, + getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices, + /*eqIndices=*/nullptr, /*offset=*/0, /*num=*/getNumDimIds()); Optional minDiff = None; @@ -2054,6 +2100,10 @@ // the constant term for the lower bound. (*lb)[getNumSymbolIds()] += atIneq(minLbPosition, pos) - 1; } + if (minLbPos) + *minLbPos = minLbPosition; + if (minUbPos) + *minUbPos = minUbPosition; return minDiff; } @@ -2726,6 +2776,51 @@ llvm::all_of(localExprs, [](AffineExpr expr) { return expr; })); } +void FlatAffineConstraints::getIneqAsAffineValueMap( + unsigned pos, unsigned ineqPos, AffineValueMap &vmap, + MLIRContext *context) const { + unsigned numDims = getNumDimIds(); + unsigned numSyms = getNumSymbolIds(); + + assert(pos < numDims && "invalid position"); + assert(ineqPos < getNumInequalities() && "invalid inequality position"); + + // Get expressions for local vars. + SmallVector memo(getNumIds(), AffineExpr()); + if (failed(computeLocalVars(*this, memo, context))) + assert(false && + "one or more local exprs do not have an explicit representation"); + auto localExprs = ArrayRef(memo).take_back(getNumLocalIds()); + + // Compute the AffineExpr lower/upper bound for this inequality. + ArrayRef inequality = getInequality(ineqPos); + SmallVector bound; + bound.reserve(getNumCols() - 1); + // Everything other than the coefficient at `pos`. + bound.append(inequality.begin(), inequality.begin() + pos); + bound.append(inequality.begin() + pos + 1, inequality.end()); + + if (inequality[pos] > 0) + // Lower bound. + std::transform(bound.begin(), bound.end(), bound.begin(), + std::negate()); + else + // Upper bound (which is exclusive). + bound.back() += 1; + + // Convert to AffineExpr (tree) form. + auto boundExpr = getAffineExprFromFlatForm(bound, numDims - 1, numSyms, + localExprs, context); + + // Get the values to bind to this affine expr (all dims and symbols). + SmallVector operands; + getIdValues(0, pos, &operands); + SmallVector trailingOperands; + getIdValues(pos + 1, getNumDimAndSymbolIds(), &trailingOperands); + operands.append(trailingOperands.begin(), trailingOperands.end()); + vmap.reset(AffineMap::get(numDims - 1, numSyms, boundExpr), operands); +} + /// Returns true if the pos^th column is all zero for both inequalities and /// equalities.. static bool isColZero(const FlatAffineConstraints &cst, unsigned pos) { @@ -2739,7 +2834,7 @@ // Return universal set (always true): 0 == 0. return IntegerSet::get(getNumDimIds(), getNumSymbolIds(), getAffineConstantExpr(/*constant=*/0, context), - true); + /*eqFlags=*/true); // Construct local references. SmallVector memo(getNumIds(), AffineExpr()); @@ -2778,3 +2873,52 @@ numSyms, localExprs, context)); return IntegerSet::get(numDims, numSyms, exprs, eqFlags); } + +/// Find positions of inequalities and equalities that do not have a coefficient +/// for [pos, pos + num) identifiers. +static void getIndependentConstraints(const FlatAffineConstraints &cst, + unsigned pos, unsigned num, + SmallVectorImpl &nbIneqIndices, + SmallVectorImpl &nbEqIndices) { + assert(pos < cst.getNumIds() && "invalid start position"); + assert(pos + num <= cst.getNumIds() && "invalid limit"); + + for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) { + // The bounds are to be independent of [offset, offset + num) columns. + unsigned c; + for (c = pos; c < pos + num; ++c) { + if (cst.atIneq(r, c) != 0) + break; + } + if (c == pos + num) + nbIneqIndices.push_back(r); + } + + for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) { + // The bounds are to be independent of [offset, offset + num) columns. + unsigned c; + for (c = pos; c < pos + num; ++c) { + if (cst.atEq(r, c) != 0) + break; + } + if (c == pos + num) + nbEqIndices.push_back(r); + } +} + +void FlatAffineConstraints::removeIndependentConstraints(unsigned pos, + unsigned num) { + assert(pos + num <= getNumIds() && "invalid range"); + + // Remove constraints that are independent of these identifiers. + SmallVector nbIneqIndices, nbEqIndices; + getIndependentConstraints(*this, /*pos=*/0, num, nbIneqIndices, nbEqIndices); + + // Iterate in reverse so that indices don't have to be updated. + // TODO: This method can be made more efficient (because removal of each + // inequality leads to much shifting/copying in the underlying buffer). + for (auto nbIndex : llvm::reverse(nbIneqIndices)) + removeInequality(nbIndex); + for (auto nbIndex : llvm::reverse(nbEqIndices)) + removeEquality(nbIndex); +} diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -1011,14 +1011,12 @@ IntegerSet mlir::simplifyIntegerSet(IntegerSet set) { FlatAffineConstraints fac(set); - MLIRContext *context = set.getContext(); if (fac.isEmpty()) return IntegerSet::getEmptySet(set.getNumDims(), set.getNumSymbols(), - context); + set.getContext()); fac.removeTrivialRedundancy(); - auto simplifiedSet = fac.getAsIntegerSet(context); + auto simplifiedSet = fac.getAsIntegerSet(set.getContext()); assert(simplifiedSet && "guaranteed to succeed while roundtripping"); - return simplifiedSet; } diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp --- a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp @@ -15,7 +15,9 @@ #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/Affine/Passes.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/LoopUtils.h" @@ -33,6 +35,12 @@ llvm::cl::desc("Set size of cache to tile for in KiB"), llvm::cl::cat(clOptionsCategory)); +// Separate full and partial tiles. +static llvm::cl::opt + clSeparate("affine-tile-separate", + llvm::cl::desc("Separate full and partial tiles"), + llvm::cl::cat(clOptionsCategory)); + // Tile size to use for all loops (overrides -tile-sizes if provided). static llvm::cl::opt clTileSize("affine-tile-size", @@ -176,11 +184,12 @@ /// and intra-tile loops. A band is a contiguous set of loops. // TODO(bondhugula): handle non hyper-rectangular spaces. LogicalResult mlir::tileCodeGen(MutableArrayRef band, - ArrayRef tileSizes) { + ArrayRef tileSizes, + SmallVectorImpl *tiledNest) { + // Check if the supplied for op's are all successively nested. assert(!band.empty() && "no loops in band"); assert(band.size() == tileSizes.size() && "Too few/many tile sizes"); - // Check if the supplied for op's are all successively nested. for (unsigned i = 1, e = band.size(); i < e; i++) assert(band[i].getParentOp() == band[i - 1] && "not a perfect nest / band"); @@ -248,6 +257,9 @@ // Erase the old loop nest. rootAffineForOp.erase(); + if (tiledNest) + *tiledNest = std::move(tiledLoops); + return success(); } @@ -393,8 +405,16 @@ diag << tSize << ' '; diag << "]\n"; } - if (failed(tileCodeGen(band, tileSizes))) + SmallVector tiledNest; + if (failed(tileCodeGen(band, tileSizes, &tiledNest))) return signalPassFailure(); + + // Separate full and partial tiles. + if (clSeparate) { + auto intraTileLoops = + MutableArrayRef(tiledNest).drop_front(band.size()); + separateFullTiles(intraTileLoops); + } } } diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -17,10 +17,12 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Function.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/Transforms/RegionUtils.h" #include "mlir/Transforms/Utils.h" #include "llvm/ADT/DenseMap.h" @@ -2068,3 +2070,213 @@ depthToLoops.pop_back(); } } + +// TODO: if necessary, this can be extended to also compose in any +// affine.applys, fold to constant if all result dimensions of the map are +// constant (canonicalizeMapAndOperands below already does this for single +// result bound maps), and use simplifyMap to perform algebraic simplication. +AffineForOp mlir::createCanonicalizedAffineForOp( + OpBuilder b, Location loc, ValueRange lbOperands, AffineMap lbMap, + ValueRange ubOperands, AffineMap ubMap, int64_t step) { + SmallVector lowerOperands(lbOperands); + SmallVector upperOperands(ubOperands); + + fullyComposeAffineMapAndOperands(&lbMap, &lowerOperands); + canonicalizeMapAndOperands(&lbMap, &lowerOperands); + fullyComposeAffineMapAndOperands(&ubMap, &upperOperands); + canonicalizeMapAndOperands(&ubMap, &upperOperands); + + return b.create(loc, lowerOperands, lbMap, upperOperands, ubMap, + step); +} + +/// Creates an AffineIfOp that encodes the conditional to choose between +/// the constant trip count version and an unknown trip count version of this +/// nest of loops. This is used to separate partial and full tiles if `loops` +/// has the intra-tile loops. The affine.if op is inserted at the builder +/// insertion point of `b`. +static AffineIfOp createSeparationCondition(MutableArrayRef loops, + OpBuilder b) { + if (loops.empty()) + return nullptr; + + auto *context = loops[0].getContext(); + + FlatAffineConstraints cst; + getIndexSet(loops, &cst); + + // Remove constraints that are independent of these loop IVs. + cst.removeIndependentConstraints(/*pos=*/0, /*num=*/loops.size()); + + // Construct the constraint set representing the guard for full tiles. The + // lower bound (and upper bound) corresponding to the full tile should be + // larger (and resp. smaller) than any other lower (or upper bound). + SmallVector fullTileLb, fullTileUb; + for (auto loop : loops) { + // TODO: Non-unit stride is not an issue to generalize to. + assert(loop.getStep() == 1 && "point loop step expected to be one"); + // Mark everything symbols for the purpose of finding a constant diff pair. + cst.setDimSymbolSeparation(/*newSymbolCount=*/cst.getNumDimAndSymbolIds() - + 1); + unsigned fullTileLbPos, fullTileUbPos; + if (!cst.getConstantBoundOnDimSize(0, /*lb=*/nullptr, + /*lbFloorDivisor=*/nullptr, + /*ub=*/nullptr, &fullTileLbPos, + &fullTileUbPos)) { + LLVM_DEBUG(llvm::dbgs() << "Can't get constant diff pair for a loop\n"); + return nullptr; + } + + SmallVector lbIndices, ubIndices; + cst.getLowerAndUpperBoundIndices(/*pos=*/0, &lbIndices, &ubIndices); + + auto fLb = cst.getInequality(fullTileLbPos); + auto fUb = cst.getInequality(fullTileUbPos); + fullTileLb.assign(fLb.begin(), fLb.end()); + fullTileUb.assign(fUb.begin(), fUb.end()); + + // Full tile lower bound should be >= than any other lower bound. + for (auto lbIndex : lbIndices) + for (unsigned i = 0, e = cst.getNumCols(); i < e; ++i) + cst.atIneq(lbIndex, i) = fullTileLb[i] - cst.atIneq(lbIndex, i); + + // Full tile upper bound should be <= any other upper bound. + for (auto ubIndex : ubIndices) + for (unsigned i = 0, e = cst.getNumCols(); i < e; ++i) + cst.atIneq(ubIndex, i) -= fullTileUb[i]; + + cst.removeId(0); + } + + // The previous step leads to all zeros for the full tile lb and ub position + // itself; remove those and any other duplicates / trivial redundancies. + cst.removeTrivialRedundancy(); + + // Turn everything into dims conservatively since we earlier turned all + // trailing ids past point loop IV into symbols. Some of these could be outer + // loop IVs; we'll canonicalize anyway. + cst.setDimSymbolSeparation(0); + + IntegerSet ifCondSet = cst.getAsIntegerSet(context); + // ifCondSet can be null if cst was empty -- this can happen if all loops + // in the nest have constant trip counts. + if (!ifCondSet) + return nullptr; + + SmallVector setOperands; + cst.getIdValues(0, cst.getNumDimAndSymbolIds(), &setOperands); + canonicalizeSetAndOperands(&ifCondSet, &setOperands); + return b.create(loops[0].getLoc(), ifCondSet, setOperands, + /*withElseRegion=*/true); +} + +/// Create the full tile loop nest (along with its body). +static LogicalResult +createFullTiles(MutableArrayRef inputNest, + SmallVectorImpl &fullTileLoops, OpBuilder b) { + fullTileLoops.reserve(inputNest.size()); + + // For each loop in the original nest identify a lower/upper bound pair such + // that their difference is a constant. + FlatAffineConstraints cst; + for (auto loop : inputNest) { + // TODO: straightforward to generalize to a non-unit stride. + if (loop.getStep() != 1) { + LLVM_DEBUG(llvm::dbgs() + << "[tile separation] non-unit stride not implemented\n"); + return failure(); + } + getIndexSet({loop}, &cst); + // We will mark everything other than this loop IV as symbol for getting a + // pair of with a constant difference. + cst.setDimSymbolSeparation(cst.getNumDimAndSymbolIds() - 1); + unsigned lbPos, ubPos; + if (!cst.getConstantBoundOnDimSize(/*pos=*/0, /*lb=*/nullptr, + /*lbDivisor=*/nullptr, /*ub=*/nullptr, + &lbPos, &ubPos) || + lbPos == ubPos) { + LLVM_DEBUG(llvm::dbgs() << "[tile separation] Can't get constant diff / " + "equalities not yet handled\n"); + return failure(); + } + + // Set all identifiers as dimensions uniformly since some of those marked as + // symbols above could be outer loop IVs (corresponding tile space IVs). + cst.setDimSymbolSeparation(/*newSymbolCount=*/0); + + AffineValueMap lbVmap, ubVmap; + cst.getIneqAsAffineValueMap(/*pos=*/0, lbPos, lbVmap, b.getContext()); + cst.getIneqAsAffineValueMap(/*pos=*/0, ubPos, ubVmap, b.getContext()); + AffineForOp fullTileLoop = createCanonicalizedAffineForOp( + b, loop.getLoc(), lbVmap.getOperands(), lbVmap.getAffineMap(), + ubVmap.getOperands(), ubVmap.getAffineMap()); + b = fullTileLoop.getBodyBuilder(); + fullTileLoops.push_back(fullTileLoop); + } + + // Add the body for the full tile loop nest. + BlockAndValueMapping operandMap; + for (auto loopEn : llvm::enumerate(inputNest)) + operandMap.map(loopEn.value().getInductionVar(), + fullTileLoops[loopEn.index()].getInductionVar()); + b = fullTileLoops.back().getBodyBuilder(); + for (auto &op : inputNest.back().getBody()->without_terminator()) + b.clone(op, operandMap); + return success(); +} + +LogicalResult +mlir::separateFullTiles(MutableArrayRef inputNest, + SmallVectorImpl *fullTileNest) { + if (inputNest.empty()) + return success(); + + auto firstLoop = inputNest[0]; + + // Each successive for op has to be nested in the other. + auto prevLoop = firstLoop; + for (auto loop : inputNest.drop_front(1)) { + assert(loop.getParentOp() == prevLoop && "input not contiguously nested"); + prevLoop = loop; + } + + // Create the full tile loop nest. + SmallVector fullTileLoops; + OpBuilder b(firstLoop); + if (failed(createFullTiles(inputNest, fullTileLoops, b))) { + if (!fullTileLoops.empty()) + fullTileLoops.front().erase(); + return failure(); + } + + // Create and insert the version select right before the root of the nest. + b = OpBuilder(firstLoop); + AffineIfOp ifOp = createSeparationCondition(inputNest, b); + if (!ifOp) { + fullTileLoops.front().erase(); + LLVM_DEBUG(llvm::dbgs() << "All tiles are full tiles, or failure creating " + "separation condition\n"); + return failure(); + } + + // Move the full tile into the then block. + Block *thenBlock = ifOp.getThenBlock(); + AffineForOp outermostFullTileLoop = fullTileLoops[0]; + thenBlock->getOperations().splice( + std::prev(thenBlock->end()), + outermostFullTileLoop.getOperation()->getBlock()->getOperations(), + Block::iterator(outermostFullTileLoop)); + + // Move the partial tile into the else block. The partial tile is the same as + // the original loop nest. + Block *elseBlock = ifOp.getElseBlock(); + elseBlock->getOperations().splice( + std::prev(elseBlock->end()), + firstLoop.getOperation()->getBlock()->getOperations(), + Block::iterator(firstLoop)); + + if (fullTileNest) + *fullTileNest = std::move(fullTileLoops); + + return success(); +} diff --git a/mlir/test/Dialect/Affine/loop-tiling.mlir b/mlir/test/Dialect/Affine/loop-tiling.mlir --- a/mlir/test/Dialect/Affine/loop-tiling.mlir +++ b/mlir/test/Dialect/Affine/loop-tiling.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt %s -split-input-file -affine-loop-tile -affine-tile-size=32 | FileCheck %s // RUN: mlir-opt %s -split-input-file -affine-loop-tile -affine-tile-cache-size=512 | FileCheck %s --check-prefix=MODEL +// RUN: mlir-opt %s -split-input-file -affine-loop-tile -affine-tile-size=32 -affine-tile-separate | FileCheck %s --check-prefix=SEPARATE // ----- @@ -169,6 +170,8 @@ // ----- +// CHECK-LABEL: func @trip_count_1 +// SEPARATE-LABEL: func @trip_count_1 func @trip_count_1(%arg0: memref<196608x1xf32>, %arg1: memref<196608x1xf32>) -> memref<196608x1xf32> { affine.for %i1 = 0 to 196608 { @@ -177,8 +180,65 @@ affine.store %4, %arg1[%i1, %i3] : memref<196608x1xf32> } } + // CHECK: affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<196608x1xf32> return %arg1 : memref<196608x1xf32> } +// SEPARATE: return -// CHECK: %{{.*}} = affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<196608x1xf32> +// ----- + +func @separate_full_tile_2d(%M : index, %N : index) { + affine.for %i0 = 0 to %M { + affine.for %i1 = 0 to %N { + "foo"() : () -> () + } + } + return +} + +// SEPARATE-DAG: #[[SEP_COND:.*]] = affine_set<(d0, d1)[s0, s1] : (-d0 + s0 - 32 >= 0, -d1 + s1 - 32 >= 0)> +// SEPARATE-DAG: #[[LB:.*]] = affine_map<(d0) -> (d0)> +// SEPARATE-DAG: #[[FULL_TILE_UB:.*]] = affine_map<(d0) -> (d0 + 32)> +// SEPARATE-DAG: #[[PART_TILE_UB:.*]] = affine_map<(d0)[s0] -> (d0 + 32, s0)> + +// SEPARATE: affine.for %arg2 +// SEPARATE-NEXT: affine.for %arg3 +// SEPARATE-NEXT: affine.if #[[SEP_COND]](%arg2, %arg3)[%arg0, %arg1] { +// SEPARATE-NEXT: affine.for %arg4 = #[[LB]](%arg2) to #[[FULL_TILE_UB]](%arg2) { +// SEPARATE-NEXT: affine.for %arg5 = #[[LB]](%arg3) to #[[FULL_TILE_UB]](%arg3) { +// SEPARATE-NEXT: "foo" +// SEPARATE-NEXT: } +// SEPARATE-NEXT: } +// SEPARATE-NEXT: } else { +// SEPARATE-NEXT: affine.for %arg4 = #[[LB]](%arg2) to min #[[PART_TILE_UB]](%arg2)[%arg0] { +// SEPARATE-NEXT: affine.for %arg5 = #[[LB]](%arg3) to min #[[PART_TILE_UB]](%arg3)[%arg1] { +// SEPARATE-NEXT: "foo" +// SEPARATE-NEXT: } +// SEPARATE-NEXT: } +// SEPARATE-NEXT: } +// SEPARATE-NEXT: } +// SEPARATE-NEXT: } +// SEPARATE-NEXT: return + +// ----- + +func @separate_full_tile_1d_max_min(%M : index, %N : index, %P : index, %Q : index) { + affine.for %i0 = max affine_map<(d0, d1) -> (d0, d1)> (%M, %N) to min affine_map< (d0, d1) -> (d0, d1)> (%P, %Q) { + } + return +} +// SEPARATE-DAG: #[[SEP_COND:.*]] = affine_set<(d0)[s0, s1] : (-d0 + s0 - 32 >= 0, -d0 + s1 - 32 >= 0)> +// SEPARATE-DAG: #[[TILE_LB:.*]] = affine_map<(d0) -> (d0)> +// SEPARATE-DAG: #[[FULL_TILE_UB:.*]] = affine_map<(d0) -> (d0 + 32)> +// SEPARATE-DAG: #[[PARTIAL_TILE_UB:.*]] = affine_map<(d0, d1, d2) -> (d2 + 32, d0, d1)> + +// SEPARATE: affine.for %arg4 +// SEPARATE-NEXT: affine.if #[[SEP_COND]](%arg4)[%arg2, %arg3] { +// SEPARATE-NEXT: affine.for %arg5 = #[[TILE_LB]](%arg4) to #[[FULL_TILE_UB]](%arg4) { +// SEPARATE-NEXT: } +// SEPARATE-NEXT: } else { +// SEPARATE-NEXT: affine.for %arg5 = #[[TILE_LB]](%arg4) to min #[[PARTIAL_TILE_UB]](%arg2, %arg3, %arg4) { +// SEPARATE-NEXT: } +// SEPARATE-NEXT: } +// SEPARATE-NEXT: }