diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h --- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h +++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h @@ -22,6 +22,8 @@ namespace mlir { namespace presburger { +enum class Compare { LT, GT, GE, LE, EQ, NE }; + /// This class represents a multi-affine function with the domain as Z^d, where /// `d` is the number of domain variables of the function. For example: /// @@ -64,7 +66,10 @@ /// Get the `i^th` output expression. ArrayRef getOutputExpr(unsigned i) const { return output.getRow(i); } - // Remove the specified range of outputs. + /// Get the divisions used in this function. + const DivisionRepr &getDivs() const { return divs; } + + /// Remove the specified range of outputs. void removeOutputs(unsigned start, unsigned end); /// Given a MAF `other`, merges division variables such that both functions @@ -88,6 +93,12 @@ void subtract(const MultiAffineFunction &other); + /// Return the set of points where the output of `this` and `other` follow the + /// given comparison lexicographically. For example, if the given comparison + /// is `LT`, then the returned set contains all points where the first output + /// of `this` is lexicographically less than `other`. + PresburgerSet getLexSet(Compare comp, const MultiAffineFunction &other) const; + /// Get this function as a relation. IntegerRelation getAsRelation() const; @@ -180,6 +191,9 @@ return valueAt(getMPIntVec(point)); } + /// Return all the pieces of this piece-wise function. + ArrayRef getAllPieces() const { return pieces; } + /// Return whether `this` and `other` are equal as PWMAFunctions, i.e. whether /// they have the same dimensions, the same domain and they take the same /// value at every point in the domain. diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h @@ -90,11 +90,14 @@ numLocals); } - // Get the domain/range space of this space. The returned space is a set - // space. + /// Get the domain/range space of this space. The returned space is a set + /// space. PresburgerSpace getDomainSpace() const; PresburgerSpace getRangeSpace() const; + /// Get the space without local variables. + PresburgerSpace getSpaceWithoutLocals() const; + unsigned getNumDomainVars() const { return numDomain; } unsigned getNumRangeVars() const { return numRange; } unsigned getNumSetDimVars() const { return numRange; } diff --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp --- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp +++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp @@ -171,6 +171,93 @@ other.assertIsConsistent(); } +PresburgerSet +MultiAffineFunction::getLexSet(Compare comp, + const MultiAffineFunction &other) const { + assert(getSpace().isCompatible(other.getSpace()) && + "Output space of funcs should be compatible"); + + // Create copies of functions and merge their local space. + MultiAffineFunction funcA = *this; + MultiAffineFunction funcB = other; + funcA.mergeDivs(funcB); + + // We first create the set `result`, corresponding to the set where output + // of funcA is lexicographically larger/smaller than funcB. This is done by + // creating a PresburgerSet with the following constraints: + // + // (outA[0] > outB[0]) U + // (outA[0] = outB[0], outA[1] > outA[1]) U + // (outA[0] = outB[0], outA[1] = outA[1], outA[2] > outA[2]) U + // ... + // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] > outB[n-1]) + // + // where `n` is the number of outputs. + // If `lexMin` is set, the complement inequality is used: + // + // (outA[0] < outB[0]) U + // (outA[0] = outB[0], outA[1] < outA[1]) U + // (outA[0] = outB[0], outA[1] = outA[1], outA[2] < outA[2]) U + // ... + // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1]) + PresburgerSpace resultSpace = funcA.getDomainSpace(); + PresburgerSet result = + PresburgerSet::getEmpty(resultSpace.getSpaceWithoutLocals()); + IntegerPolyhedron levelSet( + /*numReservedInequalities=*/1 + 2 * resultSpace.getNumLocalVars(), + /*numReservedEqualities=*/funcA.getNumOutputs(), + /*numReservedCols=*/resultSpace.getNumVars() + 1, resultSpace); + + // Add division inequalities to `levelSet`. + for (unsigned i = 0, e = funcA.getNumDivs(); i < e; ++i) { + levelSet.addInequality(getDivUpperBound(funcA.divs.getDividend(i), + funcA.divs.getDenom(i), + funcA.divs.getDivOffset() + i)); + levelSet.addInequality(getDivLowerBound(funcA.divs.getDividend(i), + funcA.divs.getDenom(i), + funcA.divs.getDivOffset() + i)); + } + + for (unsigned level = 0; level < funcA.getNumOutputs(); ++level) { + // Create the expression `outA - outB` for this level. + SmallVector subExpr = + subtractExprs(funcA.getOutputExpr(level), funcB.getOutputExpr(level)); + + // TODO: Implement all comparison cases. + switch (comp) { + case Compare::LT: + // For less than, we add an upper bound of -1: + // outA - outB <= -1 + // outA <= outB - 1 + // outA < outB + levelSet.addBound(IntegerPolyhedron::BoundType::UB, subExpr, MPInt(-1)); + break; + case Compare::GT: + // For greater than, we add a lower bound of 1: + // outA - outB >= 1 + // outA > outB + 1 + // outA > outB + levelSet.addBound(IntegerPolyhedron::BoundType::LB, subExpr, MPInt(1)); + break; + case Compare::GE: + case Compare::LE: + case Compare::EQ: + case Compare::NE: + assert(false && "Not implemented case"); + } + + // Union the set with the result. + result.unionInPlace(levelSet); + // The last inequality in `levelSet` is the bound we inserted. We remove + // that for next iteration. + levelSet.removeInequality(levelSet.getNumInequalities() - 1); + // Add equality `outA - outB == 0` for this level for next iteration. + levelSet.addEquality(subExpr); + } + + return result; +} + /// Two PWMAFunctions are equal if they have the same dimensionalities, /// the same domain, and take the same value at every point in the domain. bool PWMAFunction::isEqual(const PWMAFunction &other) const { @@ -194,6 +281,8 @@ void PWMAFunction::addPiece(const Piece &piece) { assert(piece.isConsistent() && "Piece should be consistent"); + assert(piece.domain.intersect(getDomain()).isIntegerEmpty() && + "Piece should disjoint from the function"); pieces.push_back(piece); } @@ -262,85 +351,22 @@ } /// A tiebreak function which breaks ties by comparing the outputs -/// lexicographically. If `lexMin` is true, then the ties are broken by -/// taking the lexicographically smaller output and otherwise, by taking the -/// lexicographically larger output. -template +/// lexicographically based on the given comparison operator. +template static PresburgerSet tiebreakLex(const PWMAFunction::Piece &pieceA, const PWMAFunction::Piece &pieceB) { - // TODO: Support local variables here. - assert(pieceA.output.getSpace().isCompatible(pieceB.output.getSpace()) && - "Pieces should be compatible"); - assert(pieceA.domain.getSpace().getNumLocalVars() == 0 && - "Local variables are not supported yet."); - - PresburgerSpace compatibleSpace = pieceA.domain.getSpace(); - const PresburgerSpace &space = pieceA.domain.getSpace(); - - // We first create the set `result`, corresponding to the set where output - // of pieceA is lexicographically larger/smaller than pieceB. This is done by - // creating a PresburgerSet with the following constraints: - // - // (outA[0] > outB[0]) U - // (outA[0] = outB[0], outA[1] > outA[1]) U - // (outA[0] = outB[0], outA[1] = outA[1], outA[2] > outA[2]) U - // ... - // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] > outB[n-1]) - // - // where `n` is the number of outputs. - // If `lexMin` is set, the complement inequality is used: - // - // (outA[0] < outB[0]) U - // (outA[0] = outB[0], outA[1] < outA[1]) U - // (outA[0] = outB[0], outA[1] = outA[1], outA[2] < outA[2]) U - // ... - // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1]) - PresburgerSet result = PresburgerSet::getEmpty(compatibleSpace); - IntegerPolyhedron levelSet( - /*numReservedInequalities=*/1, - /*numReservedEqualities=*/pieceA.output.getNumOutputs(), - /*numReservedCols=*/space.getNumVars() + 1, space); - for (unsigned level = 0; level < pieceA.output.getNumOutputs(); ++level) { - - // Create the expression `outA - outB` for this level. - SmallVector subExpr = subtractExprs( - pieceA.output.getOutputExpr(level), pieceB.output.getOutputExpr(level)); - - if (lexMin) { - // For lexMin, we add an upper bound of -1: - // outA - outB <= -1 - // outA <= outB - 1 - // outA < outB - levelSet.addBound(IntegerPolyhedron::BoundType::UB, subExpr, MPInt(-1)); - } else { - // For lexMax, we add a lower bound of 1: - // outA - outB >= 1 - // outA > outB + 1 - // outA > outB - levelSet.addBound(IntegerPolyhedron::BoundType::LB, subExpr, MPInt(1)); - } - - // Union the set with the result. - result.unionInPlace(levelSet); - // There is only 1 inequality in `levelSet`, so the index is always 0. - levelSet.removeInequality(0); - // Add equality `outA - outB == 0` for this level for next iteration. - levelSet.addEquality(subExpr); - } - - // We then intersect `result` with the domain of pieceA and pieceB, to only - // tiebreak on the domain where both are defined. + PresburgerSet result = pieceA.output.getLexSet(comp, pieceB.output); result = result.intersect(pieceA.domain).intersect(pieceB.domain); return result; } PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) { - return unionFunction(func, tiebreakLex); + return unionFunction(func, tiebreakLex); } PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) { - return unionFunction(func, tiebreakLex); + return unionFunction(func, tiebreakLex); } void MultiAffineFunction::subtract(const MultiAffineFunction &other) { diff --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp @@ -22,6 +22,12 @@ return PresburgerSpace::getSetSpace(numRange, numSymbols, numLocals); } +PresburgerSpace PresburgerSpace::getSpaceWithoutLocals() const { + PresburgerSpace space = *this; + space.removeVarRange(VarKind::Local, 0, numLocals); + return space; +} + unsigned PresburgerSpace::getNumVarKind(VarKind kind) const { if (kind == VarKind::Domain) return getNumDomainVars(); diff --git a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp --- a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp @@ -395,3 +395,44 @@ EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result)); EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result)); } + +TEST(PWMAFunction, unionLexMinWithDivs) { + { + PWMAFunction func1 = parsePWMAF({ + {"(x, y) : (x mod 5 == 0)", "(x, y) -> (x, 1)"}, + }); + + PWMAFunction func2 = parsePWMAF({ + {"(x, y) : (x mod 7 == 0)", "(x, y) -> (x + y, 2)"}, + }); + + PWMAFunction result = parsePWMAF({ + {"(x, y) : (x mod 5 == 0, x mod 7 >= 1)", "(x, y) -> (x, 1)"}, + {"(x, y) : (x mod 7 == 0, x mod 5 >= 1)", "(x, y) -> (x + y, 2)"}, + {"(x, y) : (x mod 5 == 0, x mod 7 == 0, y >= 0)", "(x, y) -> (x, 1)"}, + {"(x, y) : (x mod 7 == 0, x mod 5 == 0, y <= -1)", + "(x, y) -> (x + y, 2)"}, + }); + + EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result)); + } + + { + PWMAFunction func1 = parsePWMAF({ + {"(x) : (x >= 0, x <= 1000)", "(x) -> (x floordiv 16)"}, + }); + + PWMAFunction func2 = parsePWMAF({ + {"(x) : (x >= 0, x <= 1000)", "(x) -> ((x + 10) floordiv 17)"}, + }); + + PWMAFunction result = parsePWMAF({ + {"(x) : (x >= 0, x <= 1000, x floordiv 16 <= (x + 10) floordiv 17)", + "(x) -> (x floordiv 16)"}, + {"(x) : (x >= 0, x <= 1000, x floordiv 16 >= (x + 10) floordiv 17 + 1)", + "(x) -> ((x + 10) floordiv 17)"}, + }); + + EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result)); + } +}