diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -285,11 +285,12 @@ Optional computeVolume() const; /// Returns true if the given point satisfies the constraints, or false - /// otherwise. - /// - /// Note: currently, if the relation contains local ids, the values of - /// the local ids must also be provided. + /// otherwise. Takes the values of all ids including locals. bool containsPoint(ArrayRef point) const; + /// Given the values of non-local ids, return a satisfying assignment to the + /// local if one exists, or an empty optional otherwise. + Optional> + containsPointNoLocal(ArrayRef point) const; /// Find equality and pairs of inequality contraints identified by their /// position indices, using which an explicit representation for each local diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h --- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h +++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h @@ -41,8 +41,7 @@ /// each id, and an extra column at the end for the constant term. /// /// Checking equality of two such functions is supported, as well as finding the -/// value of the function at a specified point. Note that local ids in the -/// domain are not yet supported for finding the value at a point. +/// value of the function at a specified point. class MultiAffineFunction : protected IntegerPolyhedron { public: /// We use protected inheritance to avoid inheriting the whole public @@ -114,8 +113,6 @@ /// Get the value of the function at the specified point. If the point lies /// outside the domain, an empty optional is returned. - /// - /// Note: domains with local ids are not yet supported, and will assert-fail. Optional> valueAt(ArrayRef point) const; void print(raw_ostream &os) const; @@ -145,8 +142,7 @@ /// symbolic ids. /// /// Support is provided to compare equality of two such functions as well as -/// finding the value of the function at a point. Note that local ids in the -/// piece are not supported for the latter. +/// finding the value of the function at a point. class PWMAFunction : public PresburgerSpace { public: PWMAFunction(unsigned numDims, unsigned numSymbols, unsigned numOutputs) @@ -170,8 +166,6 @@ /// Return the value at the specified point and an empty optional if the /// point does not lie in the domain. - /// - /// Note: domains with local ids are not yet supported, and will assert-fail. Optional> valueAt(ArrayRef point) const; /// Return whether `this` and `other` are equal as PWMAFunctions, i.e. whether diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -784,6 +784,25 @@ return true; } +/// Just substitute the values given and check if an integer sample exists for +/// the local ids. +/// +/// TODO: this could be made more efficient by handling divisions separately. +/// Instead of finding an integer sample over all the locals, we can first +/// compute the values of the locals that have division representations and +/// only use the integer emptiness check for the locals that don't have this. +/// Handling this correctly requires ordering the divs, though. +Optional> +IntegerRelation::containsPointNoLocal(ArrayRef point) const { + assert(point.size() == getNumIds() - getNumLocalIds() && + "Point should contain all ids except locals!"); + assert(getIdKindOffset(IdKind::Local) == getNumIds() - getNumLocalIds() && + "This function depends on locals being stored last!"); + IntegerRelation copy = *this; + copy.setAndEliminate(0, point); + return copy.findIntegerSample(); +} + void IntegerRelation::getLocalReprs(std::vector &repr) const { std::vector> dividends(getNumLocalIds()); SmallVector denominators(getNumLocalIds()); diff --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp --- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp +++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp @@ -36,19 +36,26 @@ Optional> MultiAffineFunction::valueAt(ArrayRef point) const { - assert(getNumLocalIds() == 0 && "Local ids are not yet supported!"); - assert(point.size() == getNumIds() && "Point has incorrect dimensionality!"); + assert(point.size() == getNumDimAndSymbolIds() && + "Point has incorrect dimensionality!"); - if (!getDomain().containsPoint(point)) + Optional> maybeLocalValues = + getDomain().containsPointNoLocal(point); + if (!maybeLocalValues) return {}; // The point lies in the domain, so we need to compute the output value. + SmallVector pointHomogenous{llvm::to_vector(point)}; + // The given point didn't include the values of locals which the output is a + // function of; we have computed one possible set of values and use them + // here. The function is not allowed to have local ids that take more than + // one possible value. + pointHomogenous.append(*maybeLocalValues); // The matrix `output` has an affine expression in the ith row, corresponding // to the expression for the ith value in the output vector. The last column // of the matrix contains the constant term. Let v be the input point with // a 1 appended at the end. We can see that output * v gives the desired // output vector. - SmallVector pointHomogenous{llvm::to_vector(point)}; pointHomogenous.push_back(1); SmallVector result = output.postMultiplyWithColumn(pointHomogenous); diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp --- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp @@ -1187,3 +1187,18 @@ parsePoly("(x, y) : (2*x - y >= 0, y - 3*x >= 0)"), /*trueVolume=*/{}, /*resultBound=*/{}); } + +TEST(IntegerPolyhedronTest, containsPointNoLocal) { + IntegerPolyhedron poly1 = parsePoly("(x) : ((x floordiv 2) - x == 0)"); + EXPECT_TRUE(poly1.containsPointNoLocal({0})); + EXPECT_FALSE(poly1.containsPointNoLocal({1})); + + IntegerPolyhedron poly2 = parsePoly( + "(x) : (x - 2*(x floordiv 2) == 0, x - 4*(x floordiv 4) - 2 == 0)"); + EXPECT_TRUE(poly2.containsPointNoLocal({6})); + EXPECT_FALSE(poly2.containsPointNoLocal({4})); + + IntegerPolyhedron poly3 = parsePoly("(x, y) : (2*x - y >= 0, y - 3*x >= 0)"); + EXPECT_TRUE(poly3.containsPointNoLocal({0, 0})); + EXPECT_FALSE(poly3.containsPointNoLocal({1, 0})); +} diff --git a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp --- a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp @@ -129,16 +129,31 @@ } TEST(PWMAFunction, valueAt) { - PWMAFunction nonNegPWAF = parsePWMAF( + PWMAFunction nonNegPWMAF = parsePWMAF( /*numInputs=*/2, /*numOutputs=*/2, { {"(x, y) : (x >= 0)", {{1, 2, 3}, {3, 4, 5}}}, // (x, y). {"(x, y) : (y >= 0, -x - 1 >= 0)", {{-1, 2, 3}, {-3, 4, 5}}} // (x, y) }); - EXPECT_THAT(*nonNegPWAF.valueAt({2, 3}), ElementsAre(11, 23)); - EXPECT_THAT(*nonNegPWAF.valueAt({-2, 3}), ElementsAre(11, 23)); - EXPECT_THAT(*nonNegPWAF.valueAt({2, -3}), ElementsAre(-1, -1)); - EXPECT_FALSE(nonNegPWAF.valueAt({-2, -3}).hasValue()); + EXPECT_THAT(*nonNegPWMAF.valueAt({2, 3}), ElementsAre(11, 23)); + EXPECT_THAT(*nonNegPWMAF.valueAt({-2, 3}), ElementsAre(11, 23)); + EXPECT_THAT(*nonNegPWMAF.valueAt({2, -3}), ElementsAre(-1, -1)); + EXPECT_FALSE(nonNegPWMAF.valueAt({-2, -3}).hasValue()); + + PWMAFunction divPWMAF = parsePWMAF( + /*numInputs=*/2, /*numOutputs=*/2, + { + {"(x, y) : (x >= 0, x - 2*(x floordiv 2) == 0)", + {{0, 2, 1, 3}, {0, 4, 3, 5}}}, // (x, y). + {"(x, y) : (y >= 0, -x - 1 >= 0)", {{-1, 2, 3}, {-3, 4, 5}}} // (x, y) + }); + EXPECT_THAT(*divPWMAF.valueAt({4, 3}), ElementsAre(11, 23)); + EXPECT_THAT(*divPWMAF.valueAt({4, -3}), ElementsAre(-1, -1)); + EXPECT_FALSE(divPWMAF.valueAt({3, 3}).hasValue()); + EXPECT_FALSE(divPWMAF.valueAt({3, -3}).hasValue()); + + EXPECT_THAT(*divPWMAF.valueAt({-2, 3}), ElementsAre(11, 23)); + EXPECT_FALSE(divPWMAF.valueAt({-2, -3}).hasValue()); } TEST(PWMAFunction, removeIdRangeRegressionTest) {