diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -118,7 +118,7 @@ /// intersection with no simplification of any sort attempted. void append(const IntegerRelation &other); - /// Return the intersection of the two sets. + /// Return the intersection of the two relations. /// If there are locals, they will be merged. IntegerRelation intersect(IntegerRelation other) const; @@ -608,6 +608,10 @@ /// `PresburgerSet`, `unboundedDomain`. SymbolicLexMin findSymbolicIntegerLexMin() const; + /// Return the set difference of this set and the given set, i.e., + /// return `this \ set`. + PresburgerRelation subtract(const PresburgerRelation &set) const; + void print(raw_ostream &os) const; void dump() const; @@ -790,6 +794,14 @@ /// column position (i.e., not relative to the kind of variable) of the /// first added variable. unsigned insertVar(VarKind kind, unsigned pos, unsigned num = 1) override; + + /// Return the intersection of the two relations. + /// If there are locals, they will be merged. + IntegerPolyhedron intersect(const IntegerPolyhedron &other) const; + + /// Return the set difference of this set and the given set, i.e., + /// return `this \ set`. + PresburgerSet subtract(const PresburgerSet &other) const; }; } // namespace presburger diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h --- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h +++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h @@ -54,8 +54,15 @@ bool isConsistent() const { return output.getNumColumns() == domainSet.getNumVars() + 1; } - const IntegerPolyhedron &getDomain() const { return domainSet; } + + /// Get the space of the input domain of this function. const PresburgerSpace &getDomainSpace() const { return domainSet.getSpace(); } + /// Get the input domain of this function. + const IntegerPolyhedron &getDomain() const { return domainSet; } + /// Get a matrix with each row representing row^th output expression. + const Matrix &getOutputMatrix() const { return output; } + /// Get the `i^th` output expression. + ArrayRef getOutputExpr(unsigned i) const { return output.getRow(i); } /// Insert `num` variables of the specified kind at position `pos`. /// Positions are relative to the kind of variable. The coefficient columns @@ -138,6 +145,7 @@ void addPiece(const MultiAffineFunction &piece); void addPiece(const IntegerPolyhedron &domain, const Matrix &output); + void addPiece(const PresburgerSet &domain, const Matrix &output); const MultiAffineFunction &getPiece(unsigned i) const { return pieces[i]; } unsigned getNumPieces() const { return pieces.size(); } @@ -163,10 +171,41 @@ /// TODO: refactor so that this can be accomplished through removeVarRange. void truncateOutput(unsigned count); + /// Return a function defined on the union of the domains of this and func, + /// such that when only one of the functions is defined, it outputs the same + /// as that function, and if both are defined, it outputs the lexmax/lexmin of + /// the two outputs. On points where neither function is defined, the returned + /// function is not defined either. + /// + /// Currently this does not support PWMAFunctions which have pieces containing + /// local variables. + /// TODO: Support local variables in peices. + PWMAFunction unionLexMin(const PWMAFunction &func); + PWMAFunction unionLexMax(const PWMAFunction &func); + void print(raw_ostream &os) const; void dump() const; private: + /// Return a function defined on the union of the domains of `this` and + /// `func`, such that when only one of the functions is defined, it outputs + /// the same as that function, and if neither is defined, the returned + /// function is not defined either. + /// + /// The provided `tiebreak` function determines which of the two functions' + /// output should be used on inputs where both the functions are defined. More + /// precisely, given two `MultiAffineFunction`s `mafA` and `mafB`, `tiebreak` + /// returns the subset of the intersection of the two functions' domains where + /// the output of `mafA` should be used. + /// + /// The PresburgerSet returned by `tiebreak` should be disjoint. + /// TODO: Remove this constraint of returning disjoint set. + PWMAFunction + unionFunction(const PWMAFunction &func, + llvm::function_ref + tiebreak) const; + PresburgerSpace space; /// The list of pieces in this piece-wise MultiAffineFunction. diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -252,6 +252,11 @@ return result; } +PresburgerRelation +IntegerRelation::subtract(const PresburgerRelation &set) const { + return PresburgerRelation(*this).subtract(set); +} + unsigned IntegerRelation::insertVar(VarKind kind, unsigned pos, unsigned num) { assert(pos <= getNumVarKind(kind)); @@ -2284,3 +2289,11 @@ "Domain has to be zero in a set"); return IntegerRelation::insertVar(kind, pos, num); } +IntegerPolyhedron +IntegerPolyhedron::intersect(const IntegerPolyhedron &other) const { + return IntegerPolyhedron(IntegerRelation::intersect(other)); +} + +PresburgerSet IntegerPolyhedron::subtract(const PresburgerSet &other) const { + return PresburgerSet(IntegerRelation::subtract(other)); +} diff --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp --- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp +++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp @@ -211,6 +211,11 @@ addPiece(MultiAffineFunction(domain, output)); } +void PWMAFunction::addPiece(const PresburgerSet &domain, const Matrix &output) { + for (const IntegerRelation &newDom : domain.getAllDisjuncts()) + addPiece(IntegerPolyhedron(newDom), output); +} + void PWMAFunction::print(raw_ostream &os) const { os << pieces.size() << " pieces:\n"; for (const MultiAffineFunction &piece : pieces) @@ -218,3 +223,138 @@ } void PWMAFunction::dump() const { print(llvm::errs()); } + +PWMAFunction PWMAFunction::unionFunction( + const PWMAFunction &func, + llvm::function_ref + tiebreak) const { + assert(getNumOutputs() == func.getNumOutputs() && + "Number of outputs of functions should be same."); + assert(getSpace().isCompatible(func.getSpace()) && + "Space is not compatible."); + + // The algorithm used here is as follows: + // - Add the output of funcB for the part of the domain where both funcA and + // funcB are defined, and `tiebreak` chooses the output of funcB. + // - Add the output of funcA, where funcB is not defined or `tiebreak` chooses + // funcA over funcB. + // - Add the output of funcB, where funcA is not defined. + + // Add parts of the common domain where funcB's output is used. Also + // add all the parts where funcA's output is used, both common and non-common. + PWMAFunction result(getSpace(), getNumOutputs()); + for (const MultiAffineFunction &funcA : pieces) { + PresburgerSet dom(funcA.getDomain()); + for (const MultiAffineFunction &funcB : func.pieces) { + PresburgerSet better = tiebreak(funcB, funcA); + // Add the output of funcB, where it is better than output of funcA. + // The disjuncts in "better" will be disjoint as tiebreak should gurantee + // that. + result.addPiece(better, funcB.getOutputMatrix()); + dom = dom.subtract(better); + } + // Add output of funcA, where it is better than funcB, or funcB is not + // defined. + // + // `dom` here is guranteed to be disjoint from already added pieces + // because because the pieces added before are either: + // - Subsets of the domain of other MAFs in `this`, which are guranteed + // to be disjoint from `dom`, or + // - They are one of the pieces added for `funcB`, and we have been + // subtracting all such pieces from `dom`, so `dom` is disjoint from those + // pieces as well. + result.addPiece(dom, funcA.getOutputMatrix()); + } + + // Add parts of funcB which are not shared with funcA. + PresburgerSet dom = getDomain(); + for (const MultiAffineFunction &funcB : func.pieces) + result.addPiece(funcB.getDomain().subtract(dom), funcB.getOutputMatrix()); + + return result; +} + +/// A tiebreak function which breaks ties by comparing the outputs +/// lexicographically. If `lexMin` is true, then the ties are broken by +/// taking the lexicographically smaller output and otherwise, by taking the +/// lexicographically larger output. +template +static PresburgerSet tiebreakLex(const MultiAffineFunction &mafA, + const MultiAffineFunction &mafB) { + // TODO: Support local variables here. + assert(mafA.getDomainSpace().isCompatible(mafB.getDomainSpace()) && + "Domain spaces should be compatible."); + assert(mafA.getNumOutputs() == mafB.getNumOutputs() && + "Number of outputs of both functions should be same."); + assert(mafA.getDomain().getNumLocalVars() == 0 && + "Local variables are not supported yet."); + + PresburgerSpace compatibleSpace = mafA.getDomain().getSpaceWithoutLocals(); + const PresburgerSpace &space = mafA.getDomain().getSpace(); + + // We first create the set `result`, corresponding to the set where output + // of mafA is lexicographically larger/smaller than mafB. This is done by + // creating a PresburgerSet with the following constraints: + // + // (outA[0] > outB[0]) U + // (outA[0] = outB[0], outA[1] > outA[1]) U + // (outA[0] = outB[0], outA[1] = outA[1], outA[2] > outA[2]) U + // ... + // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] > outB[n-1]) + // + // where `n` is the number of outputs. + // If `lexMin` is set, the complement inequality is used: + // + // (outA[0] < outB[0]) U + // (outA[0] = outB[0], outA[1] < outA[1]) U + // (outA[0] = outB[0], outA[1] = outA[1], outA[2] < outA[2]) U + // ... + // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1]) + PresburgerSet result = PresburgerSet::getEmpty(compatibleSpace); + IntegerPolyhedron levelSet(/*numReservedInequalities=*/1, + /*numReservedEqualities=*/mafA.getNumOutputs(), + /*numReservedCols=*/space.getNumVars() + 1, space); + for (unsigned level = 0; level < mafA.getNumOutputs(); ++level) { + + // Create the expression `outA - outB` for this level. + SmallVector subExpr = + subtract(mafA.getOutputExpr(level), mafB.getOutputExpr(level)); + + if (lexMin) { + // For lexMin, we add an upper bound of -1: + // outA - outB <= -1 + // outA <= outB - 1 + // outA < outB + levelSet.addBound(IntegerPolyhedron::BoundType::UB, subExpr, -1); + } else { + // For lexMax, we add a lower bound of 1: + // outA - outB >= 1 + // outA > outB + 1 + // outA > outB + levelSet.addBound(IntegerPolyhedron::BoundType::LB, subExpr, 1); + } + + // Union the set with the result. + result.unionInPlace(levelSet); + // There is only 1 inequality in `levelSet`, so the index is always 0. + levelSet.removeInequality(0); + // Add equality `outA - outB == 0` for this level for next iteration. + levelSet.addEquality(subExpr); + } + + // We then intersect `result` with the domain of mafA and mafB, to only + // tiebreak on the domain where both are defined. + result = result.intersect(PresburgerSet(mafA.getDomain())) + .intersect(PresburgerSet(mafB.getDomain())); + + return result; +} + +PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) { + return unionFunction(func, tiebreakLex); +} + +PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) { + return unionFunction(func, tiebreakLex); +} diff --git a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp --- a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp @@ -189,3 +189,331 @@ }); EXPECT_TRUE(pwmafA.isEqual(pwmafB)); } + +TEST(PWMAFunction, unionLexMaxSimple) { + // func2 is better than func1, but func2's domain is empty. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{0, 1}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (1 == 0)", {{0, 2}}}, + }); + + EXPECT_TRUE(func1.unionLexMax(func2).isEqual(func1)); + EXPECT_TRUE(func2.unionLexMax(func1).isEqual(func1)); + } + + // func2 is better than func1 on a subset of func1. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{0, 1}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x >= 0, 10 - x >= 0)", {{0, 2}}}, + }); + + PWMAFunction result = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (-1 - x >= 0)", {{0, 1}}}, + {"(x) : (x >= 0, 10 - x >= 0)", {{0, 2}}}, + {"(x) : (x - 11 >= 0)", {{0, 1}}}, + }); + + EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result)); + EXPECT_TRUE(func2.unionLexMax(func1).isEqual(result)); + } + + // func1 and func2 are defined over the whole domain with different outputs. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{1, 0}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{-1, 0}}}, + }); + + PWMAFunction result = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x >= 0)", {{1, 0}}}, + {"(x) : (-1 - x >= 0)", {{-1, 0}}}, + }); + + EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result)); + EXPECT_TRUE(func2.unionLexMax(func1).isEqual(result)); + } + + // func1 and func2 have disjoint domains. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x >= 0, 10 - x >= 0)", {{0, 1}}}, + {"(x) : (x - 71 >= 0, 80 - x >= 0)", {{0, 1}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x - 20 >= 0, 41 - x >= 0)", {{0, 2}}}, + {"(x) : (x - 101 >= 0, 120 - x >= 0)", {{0, 2}}}, + }); + + PWMAFunction result = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x >= 0, 10 - x >= 0)", {{0, 1}}}, + {"(x) : (x - 71 >= 0, 80 - x >= 0)", {{0, 1}}}, + {"(x) : (x - 20 >= 0, 41 - x >= 0)", {{0, 2}}}, + {"(x) : (x - 101 >= 0, 120 - x >= 0)", {{0, 2}}}, + }); + + EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result)); + EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result)); + } +} + +TEST(PWMAFunction, unionLexMinSimple) { + // func2 is better than func1, but func2's domain is empty. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{0, -1}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (1 == 0)", {{0, -2}}}, + }); + + EXPECT_TRUE(func1.unionLexMin(func2).isEqual(func1)); + EXPECT_TRUE(func2.unionLexMin(func1).isEqual(func1)); + } + + // func2 is better than func1 on a subset of func1. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{0, -1}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x >= 0, 10 - x >= 0)", {{0, -2}}}, + }); + + PWMAFunction result = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (-1 - x >= 0)", {{0, -1}}}, + {"(x) : (x >= 0, 10 - x >= 0)", {{0, -2}}}, + {"(x) : (x - 11 >= 0)", {{0, -1}}}, + }); + + EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result)); + EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result)); + } + + // func1 and func2 are defined over the whole domain with different outputs. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{-1, 0}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{1, 0}}}, + }); + + PWMAFunction result = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x >= 0)", {{-1, 0}}}, + {"(x) : (-1 - x >= 0)", {{1, 0}}}, + }); + + EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result)); + EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result)); + } +} + +TEST(PWMAFunction, unionLexMaxComplex) { + // Union of function containing 4 different pieces of output. + // + // x >= 21 --> func1 (func2 not defined) + // x <= 0 --> func2 (func1 not defined) + // 10 <= x <= 20, y > 0 --> func1 (x + y > x - y for y > 0) + // 10 <= x <= 20, y <= 0 --> func2 (x + y <= x - y for y <= 0) + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/1, + { + {"(x, y) : (x >= 10)", {{1, 1, 0}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/1, + { + {"(x, y) : (x <= 20)", {{1, -1, 0}}}, + }); + + PWMAFunction result = parsePWMAF(/*numInputs=*/2, /*numOutputs=*/1, + {{"(x, y) : (x >= 10, x <= 20, y >= 1)", + { + {1, 1, 0}, + }}, + {"(x, y) : (x >= 21)", + { + {1, 1, 0}, + }}, + {"(x, y) : (x <= 9)", + { + {1, -1, 0}, + }}, + {"(x, y) : (x >= 10, x <= 20, y <= 0)", + { + {1, -1, 0}, + }}}); + + EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result)); + } + + // Functions with more than one output, with contribution from both functions. + // + // If y >= 1, func1 is better because in the first output, + // x + y (func1) > x (func2), when y >= 1 + // + // If y == 0, the first output is same for both functions, so we look at the + // second output. -2x + 4 (func1) > 2x - 2 (func2) when 0 <= x <= 1, so we + // take func1 for this domain and func2 for the remaining. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (x >= 0, y >= 0)", {{1, 1, 0}, {-2, 0, 4}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (x >= 0, y >= 0)", {{1, 0, 0}, {2, 0, -2}}}, + }); + + PWMAFunction result = parsePWMAF(/*numInputs=*/2, /*numOutputs=*/2, + {{"(x, y) : (x >= 0, y >= 1)", + { + {1, 1, 0}, + {-2, 0, 4}, + }}, + {"(x, y) : (x >= 0, x <= 1, y == 0)", + { + {1, 1, 0}, + {-2, 0, 4}, + }}, + {"(x, y) : (x >= 2, y == 0)", + { + {1, 0, 0}, + {2, 0, -2}, + }}}); + + EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result)); + EXPECT_TRUE(func2.unionLexMax(func1).isEqual(result)); + } + + // Function with three boolean variables `a, b, c` used to control which + // output will be taken lexicographically. + // + // a == 1 --> Take func2 + // a == 0, b == 1 --> Take func1 + // a == 0, b == 0, c == 1 --> Take func2 + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/3, /*numOutputs=*/3, + { + {"(a, b, c) : (a >= 0, 1 - a >= 0, b >= 0, 1 - b >= 0, c " + ">= 0, 1 - c >= 0)", + {{0, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 0, 0}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/3, /*numOutputs=*/3, + { + {"(a, b, c) : (a >= 0, 1 - a >= 0, b >= 0, 1 - b >= 0, c >= 0, 1 - " + "c >= 0)", + {{1, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 1, 0}}}, + }); + + PWMAFunction result = parsePWMAF( + /*numInputs=*/3, /*numOutputs=*/3, + { + {"(a, b, c) : (a - 1 == 0, b >= 0, 1 - b >= 0, c >= 0, 1 - c >= 0)", + {{1, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 1, 0}}}, + {"(a, b, c) : (a == 0, b - 1 == 0, c >= 0, 1 - c >= 0)", + {{0, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 0, 0}}}, + {"(a, b, c) : (a == 0, b == 0, c >= 0, 1 - c >= 0)", + {{1, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 1, 0}}}, + }); + + EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result)); + EXPECT_TRUE(func2.unionLexMax(func1).isEqual(result)); + } +} + +TEST(PWMAFunction, unionLexMinComplex) { + // Regression test checking if lexicographic tiebreak produces disjoint + // domains. + // + // If x == 1, func1 is better since in the first output, + // -x (func1) is < 0 (func2) when x == 1. + // + // If x == 0, func1 and func2 both have the same first output. So we take a + // look at the second output. func2 is better since in the second output, + // x - 1 (func2) is < x (func1) when x == 0. + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (x >= 0, x <= 1, y >= 0, y <= 1)", + {{-1, 0, 0}, {0, 1, 0}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (x >= 0, x <= 1, y >= 0, y <= 1)", + {{0, 0, 0}, {0, 1, -1}}}, + }); + + PWMAFunction result = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (x == 1, y >= 0, y <= 1)", {{-1, 0, 0}, {0, 1, 0}}}, + {"(x, y) : (x == 0, y >= 0, y <= 1)", {{0, 0, 0}, {0, 1, -1}}}, + }); + + EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result)); + EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result)); +}