diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h --- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h +++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h @@ -190,6 +190,9 @@ /// union of the domains of all the pieces. PresburgerSet getDomain() const; + /// Return the piece-wise MultiAffineFunction as a PresburgerRelation. + PresburgerRelation getAsRelation() const; + /// Return the output of the function at the given point. std::optional> valueAt(ArrayRef point) const; std::optional> valueAt(ArrayRef point) const { diff --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp --- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp +++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp @@ -48,6 +48,16 @@ return domain; } +PresburgerRelation PWMAFunction::getAsRelation() const { + PresburgerRelation result = PresburgerRelation::getEmpty(space); + for (const Piece &piece : pieces) { + PresburgerRelation rel(piece.output.getAsRelation()); + rel = rel.intersectDomain(piece.domain); + result.unionInPlace(rel); + } + return result; +} + void MultiAffineFunction::print(raw_ostream &os) const { space.print(os); os << "Division Representation:\n"; diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -159,7 +159,7 @@ "Domain of `this` must be compatible with range of `set`"); PresburgerRelation other = set; - other.insertVarInPlace(VarKind::Domain, 0, getNumDomainVars()); + other.insertVarInPlace(VarKind::Domain, 0, getNumRangeVars()); other.inverse(); return intersect(other); } diff --git a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp --- a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp @@ -436,3 +436,25 @@ EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result)); } } + +TEST(PWAFunctionTest, getAsRelation) { + PWMAFunction func = parsePWMAF({ + {"(x, y)[N] : (N >= 0, x - N - 1 >= 0, y - N - 1 >= 0)", + "(x, y)[N] -> (x + y)"}, + {"(x, y)[N] : (N >= 0, x - N >= 0, -y - N >= 0)", "(x, y)[N] -> (x - y)"}, + {"(x, y)[N] : (N >= 0, -x - N - 1 >= 0, y - N >= 0)", + "(x, y)[N] -> (y - x)"}, + {"(x, y)[N] : (N >= 0, -x - N - 1 >= 0, -y - N - 1 >= 0)", + "(x, y)[N] -> (-x - y)"}, + }); + PresburgerRelation rel = parsePresburgerRelationFromPresburgerSet( + {"(x, y, a)[N] : (N >= 0, x - N - 1 >= 0, y - N - 1 >= 0, a - x - y == " + "0)", + "(x, y, a)[N] : (N >= 0, x - N >= 0, -y - N >= 0, a - x + y == 0)", + "(x, y, a)[N] : (N >= 0, -x - N - 1 >= 0, y - N >= 0, a + x - y == 0)", + "(x, y, a)[N] : (N >= 0, -x - N - 1 >= 0, -y - N - 1 >= 0, a + x + y == " + "0)"}, + 2); + + EXPECT_TRUE(rel.isEqual(func.getAsRelation())); +} diff --git a/mlir/unittests/Analysis/Presburger/Parser.h b/mlir/unittests/Analysis/Presburger/Parser.h --- a/mlir/unittests/Analysis/Presburger/Parser.h +++ b/mlir/unittests/Analysis/Presburger/Parser.h @@ -47,6 +47,22 @@ return result; } +inline PresburgerRelation +parsePresburgerRelationFromPresburgerSet(ArrayRef strs, + unsigned numDomain) { + assert(!strs.empty() && "strs should not be empty"); + + IntegerRelation rel = parseIntegerPolyhedron(strs[0]); + rel.convertVarKind(VarKind::SetDim, 0, numDomain, VarKind::Domain); + PresburgerRelation result(rel); + for (unsigned i = 1, e = strs.size(); i < e; ++i) { + rel = parseIntegerPolyhedron(strs[i]); + rel.convertVarKind(VarKind::SetDim, 0, numDomain, VarKind::Domain); + result.unionInPlace(rel); + } + return result; +} + inline MultiAffineFunction parseMultiAffineFunction(StringRef str) { MLIRContext context(MLIRContext::Threading::DISABLED); diff --git a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp --- a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp @@ -16,22 +16,6 @@ using namespace mlir; using namespace presburger; -static PresburgerRelation -parsePresburgerRelationFromPresburgerSet(ArrayRef strs, - unsigned numDomain) { - assert(!strs.empty() && "strs should not be empty"); - - IntegerRelation rel = parseIntegerPolyhedron(strs[0]); - rel.convertVarKind(VarKind::SetDim, 0, numDomain, VarKind::Domain); - PresburgerRelation result(rel); - for (unsigned i = 1, e = strs.size(); i < e; ++i) { - rel = parseIntegerPolyhedron(strs[i]); - rel.convertVarKind(VarKind::SetDim, 0, numDomain, VarKind::Domain); - result.unionInPlace(rel); - } - return result; -} - TEST(PresburgerRelationTest, intersectDomainAndRange) { PresburgerRelation rel = parsePresburgerRelationFromPresburgerSet( {// (x, y) -> (x + N, y - N)