diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -24,6 +24,9 @@ namespace mlir { namespace presburger { +class IntegerRelation; +class IntegerPolyhedron; + /// An IntegerRelation represents the set of points from a PresburgerSpace that /// satisfy a list of affine constraints. Affine constraints can be inequalities /// or equalities in the form: @@ -496,6 +499,12 @@ space.setDimSymbolSeparation(newSymbolCount); } + /// Return a set corresponding to all points in the domain of the relation. + IntegerPolyhedron getDomainSet() const; + + /// Return a set corresponding to all points in the range of the relation. + IntegerPolyhedron getRangeSet() const; + void print(raw_ostream &os) const; void dump() const; @@ -643,6 +652,21 @@ /*numReservedEqualities=*/0, /*numReservedCols=*/space.getNumIds() + 1, space) {} + /// Construct a set from an IntegerRelation. The relation should have + /// no domain ids. + explicit IntegerPolyhedron(const IntegerRelation &rel) + : IntegerRelation(rel) { + assert(space.getNumDomainIds() == 0 && + "Number of domain id's should be zero in Set kind space."); + } + + /// Construct a set from an IntegerRelation, but instead of creating a copy, + /// use move constructor. The relation should have no domain ids. + explicit IntegerPolyhedron(IntegerRelation &&rel) : IntegerRelation(rel) { + assert(space.getNumDomainIds() == 0 && + "Number of domain id's should be zero in Set kind space."); + } + /// Return a system with no constraints, i.e., one which is satisfied by all /// points. static IntegerPolyhedron getUniverse(const PresburgerSpace &space) { diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -2075,6 +2075,33 @@ removeEquality(nbIndex); } +IntegerPolyhedron IntegerRelation::getDomainSet() const { + IntegerRelation copyRel = *this; + + // Convert Range variables to Local variables. + copyRel.convertIdKind(IdKind::Range, 0, getNumIdKind(IdKind::Range), + IdKind::Local); + + // Convert Domain variables to SetDim(Range) variables. + copyRel.convertIdKind(IdKind::Domain, 0, getNumIdKind(IdKind::Domain), + IdKind::SetDim); + + return IntegerPolyhedron(std::move(copyRel)); +} + +IntegerPolyhedron IntegerRelation::getRangeSet() const { + IntegerRelation copyRel = *this; + + // Convert Domain variables to Local variables. + copyRel.convertIdKind(IdKind::Domain, 0, getNumIdKind(IdKind::Domain), + IdKind::Local); + + // We do not need to do anything to Range variables since they are already in + // SetDim position. + + return IntegerPolyhedron(std::move(copyRel)); +} + void IntegerRelation::printSpace(raw_ostream &os) const { space.print(os); os << getNumConstraints() << " constraints\n"; diff --git a/mlir/unittests/Analysis/Presburger/CMakeLists.txt b/mlir/unittests/Analysis/Presburger/CMakeLists.txt --- a/mlir/unittests/Analysis/Presburger/CMakeLists.txt +++ b/mlir/unittests/Analysis/Presburger/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_unittest(MLIRPresburgerTests IntegerPolyhedronTest.cpp + IntegerRelationTest.cpp LinearTransformTest.cpp MatrixTest.cpp PresburgerSetTest.cpp diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp @@ -0,0 +1,43 @@ +//===- IntegerRelationTest.cpp - Tests for IntegerRelation class ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/IntegerRelation.h" +#include "./Utils.h" + +#include +#include + +using namespace mlir; +using namespace presburger; + +static IntegerRelation parseRelationFromSet(StringRef set, unsigned numDomain) { + IntegerRelation rel = parsePoly(set); + + rel.convertIdKind(IdKind::SetDim, 0, numDomain, IdKind::Domain); + + return rel; +} + +TEST(IntegerRelationTest, getDomainAndRangeSet) { + IntegerRelation rel = parseRelationFromSet( + "(x, xr)[N] : (xr - x - 10 == 0, xr >= 0, N - xr >= 0)", 1); + + IntegerPolyhedron domainSet = rel.getDomainSet(); + + IntegerPolyhedron expectedDomainSet = + parsePoly("(x)[N] : (x + 10 >= 0, N - x - 10 >= 0)"); + + EXPECT_TRUE(domainSet.isEqual(expectedDomainSet)); + + IntegerPolyhedron rangeSet = rel.getRangeSet(); + + IntegerPolyhedron expectedRangeSet = + parsePoly("(x)[N] : (x >= 0, N - x >= 0)"); + + EXPECT_TRUE(rangeSet.isEqual(expectedRangeSet)); +}