diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -118,7 +118,7 @@ /// intersection with no simplification of any sort attempted. void append(const IntegerRelation &other); - /// Return the intersection of the two sets. + /// Return the intersection of the two relations. /// If there are locals, they will be merged. IntegerRelation intersect(IntegerRelation other) const; @@ -608,6 +608,10 @@ /// `PresburgerSet`, `unboundedDomain`. SymbolicLexMin findSymbolicIntegerLexMin() const; + /// Return the set difference of this set and the given set, i.e., + /// return `this \ set`. + PresburgerRelation subtract(const PresburgerRelation &set) const; + void print(raw_ostream &os) const; void dump() const; @@ -790,6 +794,12 @@ /// column position (i.e., not relative to the kind of variable) of the /// first added variable. unsigned insertVar(VarKind kind, unsigned pos, unsigned num = 1) override; + + /// Return the intersection of the two relations. + /// If there are locals, they will be merged. + IntegerPolyhedron intersect(const IntegerPolyhedron &other) const; + + PresburgerSet subtract(const PresburgerSet &other) const; }; } // namespace presburger diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h --- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h +++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h @@ -55,6 +55,8 @@ return output.getNumColumns() == domainSet.getNumVars() + 1; } const IntegerPolyhedron &getDomain() const { return domainSet; } + const Matrix &getOutputMatrix() const { return output; } + ArrayRef getOutputExpr(unsigned i) const { return output.getRow(i); } const PresburgerSpace &getDomainSpace() const { return domainSet.getSpace(); } /// Insert `num` variables of the specified kind at position `pos`. @@ -138,6 +140,7 @@ void addPiece(const MultiAffineFunction &piece); void addPiece(const IntegerPolyhedron &domain, const Matrix &output); + void addPiece(const PresburgerSet &domain, const Matrix &output); const MultiAffineFunction &getPiece(unsigned i) const { return pieces[i]; } unsigned getNumPieces() const { return pieces.size(); } @@ -163,10 +166,41 @@ /// TODO: refactor so that this can be accomplished through removeVarRange. void truncateOutput(unsigned count); + // Return a function defined on the union of the domains of this and func, + // such that when only one of the functions is defined, it outputs the same as + // that function, and if both are defined, it outputs the lexmax/lexmin of the + // two outputs. On points where neither function is defined, the returned + // function is not defined either. + // + // Currently this does not support PWMAFunctions which have pieces containing + // local variables. + // TODO: Support local variables in peices. + PWMAFunction unionLexMin(const PWMAFunction &func); + PWMAFunction unionLexMax(const PWMAFunction &func); + void print(raw_ostream &os) const; void dump() const; private: + // Return a function defined on the union of the domains of `this` and `func`, + // such that when only one of the functions is defined, it outputs the same as + // that function, and if neither is defined, the returned function is not + // defined either. + // + // The provided `tiebreak` function determines which of the two functions' + // output should be used on inputs where both the functions are defined. More + // precisely, given two `MultiAffineFunction`s `mafA` and `mafB`, `tiebreak` + // returns the subset of the intersection of the two functions' domains where + // the output of `mafA` should be used. + // + // The PresburgerSet returned by `tiebreak` should be disjoint. + // TODO: Remove this constraint of returning disjoint set. + PWMAFunction + unionFunction(const PWMAFunction &func, + llvm::function_ref + tiebreak) const; + PresburgerSpace space; /// The list of pieces in this piece-wise MultiAffineFunction. diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -252,6 +252,11 @@ return result; } +PresburgerRelation +IntegerRelation::subtract(const PresburgerRelation &set) const { + return PresburgerRelation(*this).subtract(set); +} + unsigned IntegerRelation::insertVar(VarKind kind, unsigned pos, unsigned num) { assert(pos <= getNumVarKind(kind)); @@ -2284,3 +2289,11 @@ "Domain has to be zero in a set"); return IntegerRelation::insertVar(kind, pos, num); } +IntegerPolyhedron +IntegerPolyhedron::intersect(const IntegerPolyhedron &other) const { + return IntegerPolyhedron(IntegerRelation::intersect(other)); +} + +PresburgerSet IntegerPolyhedron::subtract(const PresburgerSet &other) const { + return PresburgerSet(IntegerRelation::subtract(other)); +} diff --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp --- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp +++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp @@ -211,6 +211,11 @@ addPiece(MultiAffineFunction(domain, output)); } +void PWMAFunction::addPiece(const PresburgerSet &domain, const Matrix &output) { + for (const IntegerRelation &newDom : domain.getAllDisjuncts()) + addPiece(IntegerPolyhedron(newDom), output); +} + void PWMAFunction::print(raw_ostream &os) const { os << pieces.size() << " pieces:\n"; for (const MultiAffineFunction &piece : pieces) @@ -218,3 +223,138 @@ } void PWMAFunction::dump() const { print(llvm::errs()); } + +PWMAFunction PWMAFunction::unionFunction( + const PWMAFunction &func, + llvm::function_ref + tiebreak) const { + assert(getNumOutputs() == func.getNumOutputs() && + "Number of outputs of functions should be same."); + assert(getSpace().isCompatible(func.getSpace()) && + "Space is not compatible."); + + // The algorithm used here is as follows: + // - Add the output of funcB for the part of the domain where both funcA and + // funcB are defined, and `tiebreak` chooses the output of funcB. + // - Add the output of funcA, where funcB is not defined or `tiebreak` chooses + // funcA over funcB. + // - Add the output of funcB, where funcA is not defined. + + // Add parts of the common domain where funcB's output is used. Also + // add all the parts where funcA's output is used, both common and non-common. + PWMAFunction result(getSpace(), getNumOutputs()); + for (const MultiAffineFunction &funcA : pieces) { + PresburgerSet dom(funcA.getDomain()); + for (const MultiAffineFunction &funcB : func.pieces) { + PresburgerSet better = tiebreak(funcB, funcA); + // Add the output of funcB, where it is better than output of funcA. + // The disjuncts in "better" will be disjoint as tiebreak should gurantee + // that. + result.addPiece(better, funcB.getOutputMatrix()); + dom = dom.subtract(better); + } + // Add output of funcA, where it is better than funcB, or funcB is not + // defined. + // + // `dom` here is guranteed to be disjoint from already added pieces + // because because the pieces added before are either: + // - Subsets of the domain of other MAFs in `this`, which are guranteed + // to be disjoint from `dom`, or + // - They are one of the pieces added for `funcB`, and we have been + // subtracting all such pieces from `dom`, so `dom` is disjoint from those + // pieces as well. + result.addPiece(dom, funcA.getOutputMatrix()); + } + + // Add parts of funcB which are not shared with funcA. + PresburgerSet dom = getDomain(); + for (const MultiAffineFunction &funcB : func.pieces) + result.addPiece(funcB.getDomain().subtract(dom), funcB.getOutputMatrix()); + + return result; +} + +/// A tiebreak function which breaks ties by comparing the outputs +/// lexicographically. If `lexMin` is true, then the ties are broken by +/// taking the lexicographically smaller output and otherwise, by taking the +/// lexicographically larger output. +template +static PresburgerSet tiebreakLex(const MultiAffineFunction &mafA, + const MultiAffineFunction &mafB) { + // TODO: Support local variables here. + assert(mafA.getDomainSpace().isCompatible(mafB.getDomainSpace()) && + "Domain spaces should be compatible."); + assert(mafA.getNumOutputs() == mafB.getNumOutputs() && + "Number of outputs of both functions should be same."); + assert(mafA.getDomain().getNumLocalVars() == 0 && + "Local variables are not supported yet."); + + PresburgerSpace compatibleSpace = mafA.getDomain().getSpaceWithoutLocals(); + const PresburgerSpace &space = mafA.getDomain().getSpace(); + + // We first create the set `result`, corresponding to the set where output + // of mafA is lexicographically larger/smaller than mafB. This is done by + // creating a PresburgerSet with the following constraints: + // + // (outA[0] > outB[0]) U + // (outA[0] = outB[0], outA[1] > outA[1]) U + // (outA[0] = outB[0], outA[1] = outA[1], outA[2] > outA[2]) U + // ... + // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] > outB[n-1]) + // + // where `n` is the number of outputs. + // If `lexMin` is set, the complement inequality is used: + // + // (outA[0] < outB[0]) U + // (outA[0] = outB[0], outA[1] < outA[1]) U + // (outA[0] = outB[0], outA[1] = outA[1], outA[2] < outA[2]) U + // ... + // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1]) + PresburgerSet result = PresburgerSet::getEmpty(compatibleSpace); + IntegerPolyhedron levelSet(/*numReservedInequalities=*/1, + /*numReservedEqualities=*/mafA.getNumOutputs(), + /*numReservedCols=*/space.getNumVars() + 1, space); + for (unsigned level = 0; level < mafA.getNumOutputs(); ++level) { + + // Create the expression `outA - outB` for this level. + SmallVector subExpr = + subtract(mafA.getOutputExpr(level), mafB.getOutputExpr(level)); + + if (lexMin) { + // For lexMin, we add an upper bound of -1: + // outA - outB <= -1 + // outA <= outB - 1 + // outA < outB + levelSet.addBound(IntegerPolyhedron::BoundType::UB, subExpr, -1); + } else { + // For lexMax, we add a lower bound of 1: + // outA - outB >= 1 + // outA > outB + 1 + // outA > outB + levelSet.addBound(IntegerPolyhedron::BoundType::LB, subExpr, 1); + } + + // Union the set with the result. + result.unionInPlace(levelSet); + // There is only 1 inequality in `levelSet`, so the index is always 0. + levelSet.removeInequality(0); + // Add equality `outA - outB == 0` for this level for next iteration. + levelSet.addEquality(subExpr); + } + + // We then intersect `result` with the domain of mafA and mafB, to only + // tiebreak on the domain where both are defined. + result = result.intersect(PresburgerSet(mafA.getDomain())) + .intersect(PresburgerSet(mafB.getDomain())); + + return result; +} + +PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) { + return unionFunction(func, tiebreakLex); +} + +PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) { + return unionFunction(func, tiebreakLex); +} diff --git a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp --- a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp @@ -189,3 +189,328 @@ }); EXPECT_TRUE(pwmafA.isEqual(pwmafB)); } + +TEST(PWMAFunction, unionLexMaxSimple) { + // func2 is better than func1, but func2's domain is empty. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{0, 1}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (1 == 0)", {{0, 2}}}, + }); + + PWMAFunction func3 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{0, 1}}}, + }); + + EXPECT_TRUE(func1.unionLexMax(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMax(func1).isEqual(func3)); + } + + // func2 is better than func1 on a subset of func1. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{0, 1}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x >= 0, 10 - x >= 0)", {{0, 2}}}, + }); + + PWMAFunction func3 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (-1 - x >= 0)", {{0, 1}}}, + {"(x) : (x >= 0, 10 - x >= 0)", {{0, 2}}}, + {"(x) : (x - 11 >= 0)", {{0, 1}}}, + }); + + EXPECT_TRUE(func1.unionLexMax(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMax(func1).isEqual(func3)); + } + + // func1 and func2 are defined over the whole domain with different outputs. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{1, 0}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{-1, 0}}}, + }); + + PWMAFunction func3 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x >= 0)", {{1, 0}}}, + {"(x) : (-1 - x >= 0)", {{-1, 0}}}, + }); + + EXPECT_TRUE(func1.unionLexMax(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMax(func1).isEqual(func3)); + } + + // func1 and func2 have disjoint domains. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x >= 0, 10 - x >= 0)", {{0, 1}}}, + {"(x) : (x - 71 >= 0, 80 - x >= 0)", {{0, 1}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x - 20 >= 0, 41 - x >= 0)", {{0, 2}}}, + {"(x) : (x - 101 >= 0, 120 - x >= 0)", {{0, 2}}}, + }); + + PWMAFunction func3 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x >= 0, 10 - x >= 0)", {{0, 1}}}, + {"(x) : (x - 71 >= 0, 80 - x >= 0)", {{0, 1}}}, + {"(x) : (x - 20 >= 0, 41 - x >= 0)", {{0, 2}}}, + {"(x) : (x - 101 >= 0, 120 - x >= 0)", {{0, 2}}}, + }); + + EXPECT_TRUE(func1.unionLexMin(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMin(func1).isEqual(func3)); + } +} + +TEST(PWMAFunction, unionLexMinSimple) { + // func2 is better than func1, but func2's domain is empty. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{0, -1}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (1 == 0)", {{0, -2}}}, + }); + + PWMAFunction func3 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{0, -1}}}, + }); + + EXPECT_TRUE(func1.unionLexMin(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMin(func1).isEqual(func3)); + } + + // func2 is better than func1 on a subset of func1. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{0, -1}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x >= 0, 10 - x >= 0)", {{0, -2}}}, + }); + + PWMAFunction func3 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (-1 - x >= 0)", {{0, -1}}}, + {"(x) : (x >= 0, 10 - x >= 0)", {{0, -2}}}, + {"(x) : (x - 11 >= 0)", {{0, -1}}}, + }); + + EXPECT_TRUE(func1.unionLexMin(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMin(func1).isEqual(func3)); + } + + // func1 and func2 are defined over the whole domain with different outputs. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{-1, 0}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{1, 0}}}, + }); + + PWMAFunction func3 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x >= 0)", {{-1, 0}}}, + {"(x) : (-1 - x >= 0)", {{1, 0}}}, + }); + + EXPECT_TRUE(func1.unionLexMin(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMin(func1).isEqual(func3)); + } +} + +TEST(PWMAFunction, unionLexMaxComplex) { + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/1, + { + {"(x, y) : (x - 10 >= 0)", {{1, 1, 0}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/1, + { + {"(x, y) : (20 - x >= 0)", {{1, -1, 0}}}, + }); + + PWMAFunction func3 = + parsePWMAF(/*numInputs=*/2, /*numOutputs=*/1, + {{"(x, y) : (x - 10 >= 0, 20 - x >= 0, y - 1 >= 0)", + { + {1, 1, 0}, + }}, + {"(x, y) : (x - 21 >= 0)", + { + {1, 1, 0}, + }}, + {"(x, y) : (9 - x >= 0)", + { + {1, -1, 0}, + }}, + {"(x, y) : (x - 10 >= 0, 20 - x >= 0, -y >= 0)", + { + {1, -1, 0}, + }}}); + + EXPECT_TRUE(func1.unionLexMax(func2).isEqual(func3)); + } + + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (x >= 0, y >= 0)", {{1, 1, 0}, {-2, 0, 4}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (x >= 0, y >= 0)", {{1, 0, 0}, {2, 0, -2}}}, + }); + + PWMAFunction func3 = parsePWMAF(/*numInputs=*/2, /*numOutputs=*/2, + {{"(x, y) : (x >= 0, y - 1 >= 0)", + { + {1, 1, 0}, + {-2, 0, 4}, + }}, + {"(x, y) : (x >= 0, 1 - x >= 0, y == 0)", + { + {1, 1, 0}, + {-2, 0, 4}, + }}, + {"(x, y) : (x - 2 >= 0, y == 0)", + { + {1, 0, 0}, + {2, 0, -2}, + }}}); + + EXPECT_TRUE(func1.unionLexMax(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMax(func1).isEqual(func3)); + } + + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/6, /*numOutputs=*/3, + { + {"(x, y, z, a, b, c) : (a >= 0, 1 - a >= 0, b >= 0, 1 - b >= 0, c " + ">= 0, 1 - c >= 0)", + {{1, 0, 0, 0, 0, 0, 0}, + {0, 1, 0, 0, 1, 0, 0}, + {0, 0, 1, 0, 0, 0, 0}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/6, /*numOutputs=*/3, + { + {"(x, y, z, a, b, c) : (a >= 0, 1 - a >= 0, b >= 0, 1 - b >= 0, c " + ">= 0, 1 - c >= 0)", + {{1, 0, 0, 1, 0, 0, 0}, + {0, 1, 0, 0, 0, 0, 0}, + {0, 0, 1, 0, 0, 1, 0}}}, + }); + + PWMAFunction func3 = parsePWMAF( + /*numInputs=*/6, /*numOutputs=*/3, + { + {"(x, y, z, a, b, c) : (a - 1 == 0, b >= 0, 1 - b >= 0, c " + ">= 0, 1 - c >= 0)", + {{1, 0, 0, 1, 0, 0, 0}, + {0, 1, 0, 0, 0, 0, 0}, + {0, 0, 1, 0, 0, 1, 0}}}, + + {"(x, y, z, a, b, c) : (a == 0, b - 1 == 0, c >= 0, 1 - c >= 0)", + {{1, 0, 0, 0, 0, 0, 0}, + {0, 1, 0, 0, 1, 0, 0}, + {0, 0, 1, 0, 0, 0, 0}}}, + + {"(x, y, z, a, b, c) : (a == 0, b == 0, c >= 0, 1 - c >= 0)", + {{1, 0, 0, 1, 0, 0, 0}, + {0, 1, 0, 0, 0, 0, 0}, + {0, 0, 1, 0, 0, 1, 0}}}, + }); + + EXPECT_TRUE(func1.unionLexMax(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMax(func1).isEqual(func3)); + } +} + +TEST(PWMAFunction, unionLexMinComplex) { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (x >= 0, 1 >= x, y >= 0, 1 >= y)", + {{-1, 0, 0}, {0, 1, 0}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (x >= 0, 1 >= x, y >= 0, 1 >= y)", + {{0, 0, 0}, {0, 1, -1}}}, + }); + + PWMAFunction func3 = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (x == 1, y >= 0, 1 >= y)", {{-1, 0, 0}, {0, 1, 0}}}, + {"(x, y) : (x == 0, y >= 0, 1 >= y)", {{0, 0, 0}, {0, 1, -1}}}, + }); + + EXPECT_TRUE(func1.unionLexMin(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMin(func1).isEqual(func3)); +}