diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -117,7 +117,7 @@ /// intersection with no simplification of any sort attempted. void append(const IntegerRelation &other); - /// Return the intersection of the two sets. + /// Return the intersection of the two relations. /// If there are locals, they will be merged. IntegerRelation intersect(IntegerRelation other) const; @@ -767,6 +767,10 @@ /// first added variable. unsigned insertVar(VarKind kind, unsigned pos, unsigned num = 1) override; + /// Return the intersection of the two relations. + /// If there are locals, they will be merged. + IntegerPolyhedron intersect(const IntegerPolyhedron &other) const; + /// Compute the symbolic integer lexmin of the polyhedron. /// This finds, for every assignment to the symbols, the lexicographically /// minimum value attained by the dimensions. For example, the symbolic lexmin diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h --- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h +++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h @@ -55,6 +55,8 @@ return output.getNumColumns() == domainSet.getNumVars() + 1; } const IntegerPolyhedron &getDomain() const { return domainSet; } + const Matrix &getOutputs() const { return output; } + ArrayRef getOutput(unsigned i) const { return output.getRow(i); } const PresburgerSpace &getDomainSpace() const { return domainSet.getSpace(); } /// Insert `num` variables of the specified kind at position `pos`. @@ -138,6 +140,7 @@ void addPiece(const MultiAffineFunction &piece); void addPiece(const IntegerPolyhedron &domain, const Matrix &output); + void addPiece(const PresburgerSet &domain, const Matrix &output); const MultiAffineFunction &getPiece(unsigned i) const { return pieces[i]; } unsigned getNumPieces() const { return pieces.size(); } @@ -163,10 +166,40 @@ /// TODO: refactor so that this can be accomplished through removeVarRange. void truncateOutput(unsigned count); + // Given a piecewise multi affine function `func`, take lexicographic + // maximum/minimum of `func` and `this`. maximum/minimum at points where one + // of the functions is defined but the other is defined, is taken to be the + // output of the defined function. + // + // Currently does not support PWMAFunctions which have pieces containing local + // variables. + // TODO: Support this. + PWMAFunction unionLexMin(const PWMAFunction &func); + PWMAFunction unionLexMax(const PWMAFunction &func); + void print(raw_ostream &os) const; void dump() const; private: + // Given a piecewise multi affine function `func`, take union of `func` with + // this function. The union function includes: + // - `this` on domain where `func` is not defined. + // - `func` on domain where `this` is not defined. + // - `this` on domain where both `func` and `this` is defined, but the + // output of `this` is "better" than `func`. + // - `func` on domain where both `func` and `this` is defined, but the + // output of `func` is "better" than `func`. + // + // The notion of "better" is defined by `sharedAndBetter`. Given two + // MultiAffineFunctions, `maf1` and `maf2`, `sharedAndBetter` returns the + // domain where both functions are defined and `maf1` has output which is + // better than output of `maf2`. + PWMAFunction + unionFunction(const PWMAFunction &func, + llvm::function_ref + sharedAndBetter) const; + PresburgerSpace space; /// The list of pieces in this piece-wise MultiAffineFunction. diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -2273,3 +2273,7 @@ "Domain has to be zero in a set"); return IntegerRelation::insertVar(kind, pos, num); } +IntegerPolyhedron +IntegerPolyhedron::intersect(const IntegerPolyhedron &other) const { + return IntegerPolyhedron(IntegerRelation::intersect(other)); +} diff --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp --- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp +++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp @@ -211,6 +211,11 @@ addPiece(MultiAffineFunction(domain, output)); } +void PWMAFunction::addPiece(const PresburgerSet &domain, const Matrix &output) { + for (const IntegerRelation &newDom : domain.getAllDisjuncts()) + addPiece(IntegerPolyhedron(newDom), output); +} + void PWMAFunction::print(raw_ostream &os) const { os << pieces.size() << " pieces:\n"; for (const MultiAffineFunction &piece : pieces) @@ -218,3 +223,89 @@ } void PWMAFunction::dump() const { print(llvm::errs()); } + +PWMAFunction PWMAFunction::unionFunction( + const PWMAFunction &func, + llvm::function_ref + sharedAndBetter) const { + assert(getNumOutputs() == func.getNumOutputs() && + "Number of outputs of functions should be same."); + assert(getSpace().isCompatible(func.getSpace()) && + "Space is not compatible."); + + PWMAFunction res(getSpace(), getNumOutputs()); + // Add parts of func2 which are shared and better with func1. Also add the + // parts of func1 which are not shared with func2. + for (const MultiAffineFunction &func1 : pieces) { + PresburgerSet dom(func1.getDomain()); + for (const MultiAffineFunction &func2 : func.pieces) { + // Get domain where where output of func2 is shared and better than output + // of func1. + PresburgerSet better = sharedAndBetter(func2, func1); + // Add this new piece. + res.addPiece(better, func2.getOutputs()); + // Subtract this shared domain from set. + dom = dom.subtract(better); + } + res.addPiece(dom, func1.getOutputs()); + } + + // Add parts of func2 which are not shared with func1. + for (const MultiAffineFunction &func2 : func.pieces) { + PresburgerSet dom(func2.getDomain()); + for (const MultiAffineFunction &func1 : pieces) + dom = dom.subtract(PresburgerSet(func1.getDomain())); + res.addPiece(dom, func2.getOutputs()); + } + + return res; +} + +template +static PresburgerSet sharedAndBetterLex(const MultiAffineFunction &maf1, + const MultiAffineFunction &maf2) { + // TODO: Support local variables here. + assert(maf1.getDomain().getNumLocalVars() == 0 && + "Local variables are not supported yet."); + + const PresburgerSpace &compatibleSpace = + maf1.getDomain().getSpaceWithoutLocals(); + const PresburgerSpace &space = maf1.getDomain().getSpace(); + + PresburgerSet res = PresburgerSet::getEmpty(compatibleSpace); + for (unsigned level = 0; level < maf1.getNumOutputs(); ++level) { + IntegerPolyhedron levelSet(space); + + // Add equalities from [0, level). + for (unsigned eqLevel = 0; eqLevel < level; ++eqLevel) + levelSet.addEquality( + subtract(maf1.getOutput(eqLevel), maf2.getOutput(eqLevel))); + + // Add inequality for level. + SmallVector ineq = + subtract(maf1.getOutput(level), maf2.getOutput(level)); + ineq.back() -= 1; + + if (!lexMin) + ineq = getComplementIneq(ineq); + + levelSet.addInequality(ineq); + + // Union the set with the result. + res.unionInPlace(levelSet); + } + + res = res.intersect(PresburgerSet(maf1.getDomain())) + .intersect(PresburgerSet(maf2.getDomain())); + + return res; +} + +PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) { + return unionFunction(func, sharedAndBetterLex); +} + +PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) { + return unionFunction(func, sharedAndBetterLex); +} diff --git a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp --- a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp @@ -189,3 +189,302 @@ }); EXPECT_TRUE(pwmafA.isEqual(pwmafB)); } + +TEST(PWMAFunction, unionLexMaxSimple) { + // func2 is better than func1, but func2's domain is empty. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{0, 1}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (1 == 0)", {{0, 2}}}, + }); + + PWMAFunction func3 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{0, 1}}}, + }); + + EXPECT_TRUE(func1.unionLexMax(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMax(func1).isEqual(func3)); + } + + // func2 is better than func1 on a subset of func1. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{0, 1}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x >= 0, 10 - x >= 0)", {{0, 2}}}, + }); + + PWMAFunction func3 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (-1 - x >= 0)", {{0, 1}}}, + {"(x) : (x >= 0, 10 - x >= 0)", {{0, 2}}}, + {"(x) : (x - 11 >= 0)", {{0, 1}}}, + }); + + EXPECT_TRUE(func1.unionLexMax(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMax(func1).isEqual(func3)); + } + + // func1 and func2 are defined over the whole domain with different outputs. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{1, 0}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{-1, 0}}}, + }); + + PWMAFunction func3 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x >= 0)", {{1, 0}}}, + {"(x) : (-1 - x >= 0)", {{-1, 0}}}, + }); + + EXPECT_TRUE(func1.unionLexMax(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMax(func1).isEqual(func3)); + } + + // func1 and func2 have disjoint domains. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x >= 0, 10 - x >= 0)", {{0, 1}}}, + {"(x) : (x - 71 >= 0, 80 - x >= 0)", {{0, 1}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x - 20 >= 0, 41 - x >= 0)", {{0, 2}}}, + {"(x) : (x - 101 >= 0, 120 - x >= 0)", {{0, 2}}}, + }); + + PWMAFunction func3 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x >= 0, 10 - x >= 0)", {{0, 1}}}, + {"(x) : (x - 71 >= 0, 80 - x >= 0)", {{0, 1}}}, + {"(x) : (x - 20 >= 0, 41 - x >= 0)", {{0, 2}}}, + {"(x) : (x - 101 >= 0, 120 - x >= 0)", {{0, 2}}}, + }); + + EXPECT_TRUE(func1.unionLexMin(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMin(func1).isEqual(func3)); + } +} + +TEST(PWMAFunction, unionLexMinSimple) { + // func2 is better than func1, but func2's domain is empty. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{0, -1}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (1 == 0)", {{0, -2}}}, + }); + + PWMAFunction func3 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{0, -1}}}, + }); + + EXPECT_TRUE(func1.unionLexMin(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMin(func1).isEqual(func3)); + } + + // func2 is better than func1 on a subset of func1. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{0, -1}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x >= 0, 10 - x >= 0)", {{0, -2}}}, + }); + + PWMAFunction func3 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (-1 - x >= 0)", {{0, -1}}}, + {"(x) : (x >= 0, 10 - x >= 0)", {{0, -2}}}, + {"(x) : (x - 11 >= 0)", {{0, -1}}}, + }); + + EXPECT_TRUE(func1.unionLexMin(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMin(func1).isEqual(func3)); + } + + // func1 and func2 are defined over the whole domain with different outputs. + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{-1, 0}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : ()", {{1, 0}}}, + }); + + PWMAFunction func3 = parsePWMAF( + /*numInputs=*/1, /*numOutputs=*/1, + { + {"(x) : (x >= 0)", {{-1, 0}}}, + {"(x) : (-1 - x >= 0)", {{1, 0}}}, + }); + + EXPECT_TRUE(func1.unionLexMin(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMin(func1).isEqual(func3)); + } +} + +TEST(PWMAFunction, unionLexMaxComplex) { + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/1, + { + {"(x, y) : (x - 10 >= 0)", {{1, 1, 0}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/1, + { + {"(x, y) : (20 - x >= 0)", {{1, -1, 0}}}, + }); + + PWMAFunction func3 = + parsePWMAF(/*numInputs=*/2, /*numOutputs=*/1, + {{"(x, y) : (x - 10 >= 0, 20 - x >= 0, y - 1 >= 0)", + { + {1, 1, 0}, + }}, + {"(x, y) : (x - 21 >= 0)", + { + {1, 1, 0}, + }}, + {"(x, y) : (9 - x >= 0)", + { + {1, -1, 0}, + }}, + {"(x, y) : (x - 10 >= 0, 20 - x >= 0, -y >= 0)", + { + {1, -1, 0}, + }}}); + + EXPECT_TRUE(func1.unionLexMax(func2).isEqual(func3)); + } + + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (x >= 0, y >= 0)", {{1, 1, 0}, {-2, 0, 4}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (x >= 0, y >= 0)", {{1, 0, 0}, {2, 0, -2}}}, + }); + + PWMAFunction func3 = parsePWMAF(/*numInputs=*/2, /*numOutputs=*/2, + {{"(x, y) : (x >= 0, y - 1 >= 0)", + { + {1, 1, 0}, + {-2, 0, 4}, + }}, + {"(x, y) : (x >= 0, 1 - x >= 0, y == 0)", + { + {1, 1, 0}, + {-2, 0, 4}, + }}, + {"(x, y) : (x - 2 >= 0, y == 0)", + { + {1, 0, 0}, + {2, 0, -2}, + }}}); + + EXPECT_TRUE(func1.unionLexMax(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMax(func1).isEqual(func3)); + } + + { + PWMAFunction func1 = parsePWMAF( + /*numInputs=*/6, /*numOutputs=*/3, + { + {"(x, y, z, a, b, c) : (a >= 0, 1 - a >= 0, b >= 0, 1 - b >= 0, c " + ">= 0, 1 - c >= 0)", + {{1, 0, 0, 0, 0, 0, 0}, + {0, 1, 0, 0, 1, 0, 0}, + {0, 0, 1, 0, 0, 0, 0}}}, + }); + + PWMAFunction func2 = parsePWMAF( + /*numInputs=*/6, /*numOutputs=*/3, + { + {"(x, y, z, a, b, c) : (a >= 0, 1 - a >= 0, b >= 0, 1 - b >= 0, c " + ">= 0, 1 - c >= 0)", + {{1, 0, 0, 1, 0, 0, 0}, + {0, 1, 0, 0, 0, 0, 0}, + {0, 0, 1, 0, 0, 1, 0}}}, + }); + + PWMAFunction func3 = parsePWMAF( + /*numInputs=*/6, /*numOutputs=*/3, + { + {"(x, y, z, a, b, c) : (a - 1 == 0, b >= 0, 1 - b >= 0, c " + ">= 0, 1 - c >= 0)", + {{1, 0, 0, 1, 0, 0, 0}, + {0, 1, 0, 0, 0, 0, 0}, + {0, 0, 1, 0, 0, 1, 0}}}, + + {"(x, y, z, a, b, c) : (a == 0, b - 1 == 0, c >= 0, 1 - c >= 0)", + {{1, 0, 0, 0, 0, 0, 0}, + {0, 1, 0, 0, 1, 0, 0}, + {0, 0, 1, 0, 0, 0, 0}}}, + + {"(x, y, z, a, b, c) : (a == 0, b == 0, c >= 0, 1 - c >= 0)", + {{1, 0, 0, 1, 0, 0, 0}, + {0, 1, 0, 0, 0, 0, 0}, + {0, 0, 1, 0, 0, 1, 0}}}, + }); + + EXPECT_TRUE(func1.unionLexMax(func2).isEqual(func3)); + EXPECT_TRUE(func2.unionLexMax(func1).isEqual(func3)); + } +}