diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -117,7 +117,7 @@ /// intersection with no simplification of any sort attempted. void append(const IntegerRelation &other); - /// Return the intersection of the two sets. + /// Return the intersection of the two relations. /// If there are locals, they will be merged. IntegerRelation intersect(IntegerRelation other) const; @@ -582,6 +582,12 @@ /// union of convex disjuncts. PresburgerRelation computeReprWithOnlyDivLocals() const; + /// Return the set difference of this set and the given set, i.e., + /// return `this \ set`. All local variables in `set` must correspond + /// to floor divisions, but local variables in `this` need not correspond to + /// divisions. + PresburgerRelation subtract(const PresburgerRelation &set) const; + void print(raw_ostream &os) const; void dump() const; @@ -767,6 +773,12 @@ /// first added variable. unsigned insertVar(VarKind kind, unsigned pos, unsigned num = 1) override; + /// Return the intersection of the two relations. + /// If there are locals, they will be merged. + IntegerPolyhedron intersect(const IntegerPolyhedron &other) const; + + PresburgerSet subtract(const PresburgerSet &other) const; + /// Compute the symbolic integer lexmin of the polyhedron. /// This finds, for every assignment to the symbols, the lexicographically /// minimum value attained by the dimensions. For example, the symbolic lexmin diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h --- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h +++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h @@ -55,6 +55,8 @@ return output.getNumColumns() == domainSet.getNumVars() + 1; } const IntegerPolyhedron &getDomain() const { return domainSet; } + const Matrix &getOutputMatrix() const { return output; } + ArrayRef getOutputExpr(unsigned i) const { return output.getRow(i); } const PresburgerSpace &getDomainSpace() const { return domainSet.getSpace(); } /// Insert `num` variables of the specified kind at position `pos`. @@ -138,6 +140,7 @@ void addPiece(const MultiAffineFunction &piece); void addPiece(const IntegerPolyhedron &domain, const Matrix &output); + void addPiece(const PresburgerSet &domain, const Matrix &output); const MultiAffineFunction &getPiece(unsigned i) const { return pieces[i]; } unsigned getNumPieces() const { return pieces.size(); } @@ -163,10 +166,41 @@ /// TODO: refactor so that this can be accomplished through removeVarRange. void truncateOutput(unsigned count); + // Return a function defined on the union of the domains of this and func, + // such that when only one of the functions is defined, it outputs the same as + // that function, and if both are defined, it outputs the lexmax/lexmin of the + // two outputs. On points where neither function is defined, the returned + // function is not defined either. + // + // Currently this does not support PWMAFunctions which have pieces containing + // local variables. + // TODO: Support local variables in peices. + PWMAFunction unionLexMin(const PWMAFunction &func); + PWMAFunction unionLexMax(const PWMAFunction &func); + void print(raw_ostream &os) const; void dump() const; private: + // Return a function defined on the union of the domains of `this` and `func`, + // such that when only one of the functions is defined, it outputs the same as + // that function, and if neither is defined, the returned function is not + // defined either. + // + // The provided `tiebreak` function determines which of the two functions' + // output should be used on inputs where both the functions are defined. More + // precisely, given two `MultiAffineFunction`s `mafA` and `mafB`, `tiebreak` + // returns the subset of the intersection of the two functions' domains where + // the output of `mafA` should be used. + // + // The PresburgerSet returned by `tiebreak` should be disjoint. + // TODO: Remove this constraint of returning disjoint set. + PWMAFunction + unionFunction(const PWMAFunction &func, + llvm::function_ref + tiebreak) const; + PresburgerSpace space; /// The list of pieces in this piece-wise MultiAffineFunction. diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -241,6 +241,11 @@ return result; } +PresburgerRelation +IntegerRelation::subtract(const PresburgerRelation &set) const { + return PresburgerRelation(*this).subtract(set); +} + unsigned IntegerRelation::insertVar(VarKind kind, unsigned pos, unsigned num) { assert(pos <= getNumVarKind(kind)); @@ -2273,3 +2278,11 @@ "Domain has to be zero in a set"); return IntegerRelation::insertVar(kind, pos, num); } +IntegerPolyhedron +IntegerPolyhedron::intersect(const IntegerPolyhedron &other) const { + return IntegerPolyhedron(IntegerRelation::intersect(other)); +} + +PresburgerSet IntegerPolyhedron::subtract(const PresburgerSet &other) const { + return PresburgerSet(IntegerRelation::subtract(other)); +} diff --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp --- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp +++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp @@ -211,6 +211,11 @@ addPiece(MultiAffineFunction(domain, output)); } +void PWMAFunction::addPiece(const PresburgerSet &domain, const Matrix &output) { + for (const IntegerRelation &newDom : domain.getAllDisjuncts()) + addPiece(IntegerPolyhedron(newDom), output); +} + void PWMAFunction::print(raw_ostream &os) const { os << pieces.size() << " pieces:\n"; for (const MultiAffineFunction &piece : pieces) @@ -218,3 +223,121 @@ } void PWMAFunction::dump() const { print(llvm::errs()); } + +PWMAFunction PWMAFunction::unionFunction( + const PWMAFunction &func, + llvm::function_ref + tiebreak) const { + assert(getNumOutputs() == func.getNumOutputs() && + "Number of outputs of functions should be same."); + assert(getSpace().isCompatible(func.getSpace()) && + "Space is not compatible."); + + // The algorithm used here is as follows: + // - Add the output of funcB on domain where funcA and funcB, both are + // defined, and output of funcB is better than that of funcA. + // - Add the output of funcA, where funcB is not defined or funcA is + // better than funcB. + // - Add the output of funcB, where funcA is not defined. + + // Add parts of the common domain where funcB's output is used. Also + // add all the parts where funcA's output is used, both common and non-common. + PWMAFunction result(getSpace(), getNumOutputs()); + for (const MultiAffineFunction &funcA : pieces) { + PresburgerSet dom(funcA.getDomain()); + for (const MultiAffineFunction &funcB : func.pieces) { + PresburgerSet better = tiebreak(funcB, funcA); + // Add the output of funcB, where it is better than output of funcA. + // The set "better" will be disjoint as tiebreak should gurantee that. + result.addPiece(better, funcB.getOutputMatrix()); + dom = dom.subtract(better); + } + // Add output of funcA, where it is better than funcB, or funcB is not + // defined. + // + // `dom` here will be disjoint, since we started from a single convex set + // i.e. disjoint. Subtract anything from a disjoint set is always disjoint. + result.addPiece(dom, funcA.getOutputMatrix()); + } + + // Add parts of funcB which are not shared with funcA. + PresburgerSet dom = getDomain(); + for (const MultiAffineFunction &funcB : func.pieces) + result.addPiece(funcB.getDomain().subtract(dom), funcB.getOutputMatrix()); + + return result; +} + +/// A tiebreak function which breaks ties by comparing the outputs +/// lexicographically. If the `lexMin` is true, then the ties are broken by +/// taken the lexicographically smaller output otherwise, taking the +/// lexicographically larger output. +template +static PresburgerSet tiebreakLex(const MultiAffineFunction &mafA, + const MultiAffineFunction &mafB) { + // TODO: Support local variables here. + assert(mafA.getDomain().getNumLocalVars() == 0 && + "Local variables are not supported yet."); + + PresburgerSpace compatibleSpace = mafA.getDomain().getSpaceWithoutLocals(); + const PresburgerSpace &space = mafA.getDomain().getSpace(); + + // We first create the set `result`, corresponding to the set where output + // of mafA is lexicographically larger/smaller than mafB. This is done by + // creating a PresburgerSet with the following constraints: + // + // (outA[0] > outB[0]) U + // (outA[0] = outB[0], outA[1] > outA[1]) U + // (outA[0] = outB[0], outA[1] = outA[1], outA[2] > outA[2]) U + // ... + // (outA[0] = outB[0], ..., outA[n - 1] = outB[n - 1], outA[n] > outB[n]) + // + // where `n` is the number of outputs. + // If `lexMin` is set, the complement inequality is used: + // + // (outA[0] < outB[0]) U + // (outA[0] = outB[0], outA[1] < outA[1]) U + // (outA[0] = outB[0], outA[1] = outA[1], outA[2] < outA[2]) U + // ... + // (outA[0] = outB[0], ..., outA[n - 1] = outB[n - 1], outA[n] < outB[n]) + PresburgerSet result = PresburgerSet::getEmpty(compatibleSpace); + IntegerPolyhedron levelSet(/*numReservedInequalities=*/1, + /*numReservedEqualities=*/mafA.getNumOutputs(), + /*numReservedCols=*/space.getNumVars() + 1, space); + for (unsigned level = 0; level < mafA.getNumOutputs(); ++level) { + // Add inequality for level. + SmallVector ineq = + subtract(mafA.getOutputExpr(level), mafB.getOutputExpr(level)); + ineq.back() -= 1; + if (lexMin) + levelSet.addInequality(getComplementIneq(ineq)); + else + levelSet.addInequality(ineq); + // Union the set with the result. + result.unionInPlace(levelSet); + // There is only 1 inequality in `levelSet`, so the index is always 0. + levelSet.removeInequality(0); + + // Add equality for this level for next iteration. This equality is created + // by adding 1 to `ineq` and using it, since `ineq + 1 = outA - outB`, and + // we want to add `outA - outB == 0`. + ineq.back() += 1; + levelSet.addEquality(ineq); + } + + // We then intersect `result` with the domain of mafA and mafB, to only + // tiebreak on the domain where both are defined. + result = result.intersect(PresburgerSet(mafA.getDomain())) + .intersect(PresburgerSet(mafB.getDomain())); + + return result; +} + +PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) { + return unionFunction(func, tiebreakLex); +} + +PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) { + return unionFunction(func, tiebreakLex); +} diff --git a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp --- a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp @@ -189,3 +189,302 @@ }); EXPECT_TRUE(pwmafA.isEqual(pwmafB)); } + +TEST(PWMAFunction, unionLexMaxSimple) { + // func2 is better than func1, but func2's domain is empty. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{0, 1}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (1 == 0)", {{0, 2}}}, + }); + + PWMAFunction func3 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{0, 1}}}, + }); + + EXPECT_TRUE(func1.unionLexMax(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMax(func1).isEqual(func3)); + } + + // func2 is better than func1 on a subset of func1. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{0, 1}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x >= 0, 10 - x >= 0)", {{0, 2}}}, + }); + + PWMAFunction func3 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (-1 - x >= 0)", {{0, 1}}}, + {"(x) : (x >= 0, 10 - x >= 0)", {{0, 2}}}, + {"(x) : (x - 11 >= 0)", {{0, 1}}}, + }); + + EXPECT_TRUE(func1.unionLexMax(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMax(func1).isEqual(func3)); + } + + // func1 and func2 are defined over the whole domain with different outputs. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{1, 0}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{-1, 0}}}, + }); + + PWMAFunction func3 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x >= 0)", {{1, 0}}}, + {"(x) : (-1 - x >= 0)", {{-1, 0}}}, + }); + + EXPECT_TRUE(func1.unionLexMax(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMax(func1).isEqual(func3)); + } + + // func1 and func2 have disjoint domains. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x >= 0, 10 - x >= 0)", {{0, 1}}}, + {"(x) : (x - 71 >= 0, 80 - x >= 0)", {{0, 1}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x - 20 >= 0, 41 - x >= 0)", {{0, 2}}}, + {"(x) : (x - 101 >= 0, 120 - x >= 0)", {{0, 2}}}, + }); + + PWMAFunction func3 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x >= 0, 10 - x >= 0)", {{0, 1}}}, + {"(x) : (x - 71 >= 0, 80 - x >= 0)", {{0, 1}}}, + {"(x) : (x - 20 >= 0, 41 - x >= 0)", {{0, 2}}}, + {"(x) : (x - 101 >= 0, 120 - x >= 0)", {{0, 2}}}, + }); + + EXPECT_TRUE(func1.unionLexMin(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMin(func1).isEqual(func3)); + } +} + +TEST(PWMAFunction, unionLexMinSimple) { + // func2 is better than func1, but func2's domain is empty. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{0, -1}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (1 == 0)", {{0, -2}}}, + }); + + PWMAFunction func3 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{0, -1}}}, + }); + + EXPECT_TRUE(func1.unionLexMin(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMin(func1).isEqual(func3)); + } + + // func2 is better than func1 on a subset of func1. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{0, -1}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x >= 0, 10 - x >= 0)", {{0, -2}}}, + }); + + PWMAFunction func3 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (-1 - x >= 0)", {{0, -1}}}, + {"(x) : (x >= 0, 10 - x >= 0)", {{0, -2}}}, + {"(x) : (x - 11 >= 0)", {{0, -1}}}, + }); + + EXPECT_TRUE(func1.unionLexMin(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMin(func1).isEqual(func3)); + } + + // func1 and func2 are defined over the whole domain with different outputs. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{-1, 0}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{1, 0}}}, + }); + + PWMAFunction func3 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x >= 0)", {{-1, 0}}}, + {"(x) : (-1 - x >= 0)", {{1, 0}}}, + }); + + EXPECT_TRUE(func1.unionLexMin(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMin(func1).isEqual(func3)); + } +} + +TEST(PWMAFunction, unionLexMaxComplex) { + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/1, + { + {"(x, y) : (x - 10 >= 0)", {{1, 1, 0}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/1, + { + {"(x, y) : (20 - x >= 0)", {{1, -1, 0}}}, + }); + + PWMAFunction func3 = + parsePWMAF(/*numInputs=*/2, /*numOutputs=*/1, + {{"(x, y) : (x - 10 >= 0, 20 - x >= 0, y - 1 >= 0)", + { + {1, 1, 0}, + }}, + {"(x, y) : (x - 21 >= 0)", + { + {1, 1, 0}, + }}, + {"(x, y) : (9 - x >= 0)", + { + {1, -1, 0}, + }}, + {"(x, y) : (x - 10 >= 0, 20 - x >= 0, -y >= 0)", + { + {1, -1, 0}, + }}}); + + EXPECT_TRUE(func1.unionLexMax(func2).isEqual(func3)); + } + + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (x >= 0, y >= 0)", {{1, 1, 0}, {-2, 0, 4}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (x >= 0, y >= 0)", {{1, 0, 0}, {2, 0, -2}}}, + }); + + PWMAFunction func3 = parsePWMAF(/*numInputs=*/2, /*numOutputs=*/2, + {{"(x, y) : (x >= 0, y - 1 >= 0)", + { + {1, 1, 0}, + {-2, 0, 4}, + }}, + {"(x, y) : (x >= 0, 1 - x >= 0, y == 0)", + { + {1, 1, 0}, + {-2, 0, 4}, + }}, + {"(x, y) : (x - 2 >= 0, y == 0)", + { + {1, 0, 0}, + {2, 0, -2}, + }}}); + + EXPECT_TRUE(func1.unionLexMax(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMax(func1).isEqual(func3)); + } + + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/6, /*numOutputs=*/3, + { + {"(x, y, z, a, b, c) : (a >= 0, 1 - a >= 0, b >= 0, 1 - b >= 0, c " + ">= 0, 1 - c >= 0)", + {{1, 0, 0, 0, 0, 0, 0}, + {0, 1, 0, 0, 1, 0, 0}, + {0, 0, 1, 0, 0, 0, 0}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/6, /*numOutputs=*/3, + { + {"(x, y, z, a, b, c) : (a >= 0, 1 - a >= 0, b >= 0, 1 - b >= 0, c " + ">= 0, 1 - c >= 0)", + {{1, 0, 0, 1, 0, 0, 0}, + {0, 1, 0, 0, 0, 0, 0}, + {0, 0, 1, 0, 0, 1, 0}}}, + }); + + PWMAFunction func3 = parsePWMAF( + /*numInputs=*/6, /*numOutputs=*/3, + { + {"(x, y, z, a, b, c) : (a - 1 == 0, b >= 0, 1 - b >= 0, c " + ">= 0, 1 - c >= 0)", + {{1, 0, 0, 1, 0, 0, 0}, + {0, 1, 0, 0, 0, 0, 0}, + {0, 0, 1, 0, 0, 1, 0}}}, + + {"(x, y, z, a, b, c) : (a == 0, b - 1 == 0, c >= 0, 1 - c >= 0)", + {{1, 0, 0, 0, 0, 0, 0}, + {0, 1, 0, 0, 1, 0, 0}, + {0, 0, 1, 0, 0, 0, 0}}}, + + {"(x, y, z, a, b, c) : (a == 0, b == 0, c >= 0, 1 - c >= 0)", + {{1, 0, 0, 1, 0, 0, 0}, + {0, 1, 0, 0, 0, 0, 0}, + {0, 0, 1, 0, 0, 1, 0}}}, + }); + + EXPECT_TRUE(func1.unionLexMax(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMax(func1).isEqual(func3)); + } +}