diff --git a/mlir/include/mlir/Analysis/PresburgerSet.h b/mlir/include/mlir/Analysis/PresburgerSet.h --- a/mlir/include/mlir/Analysis/PresburgerSet.h +++ b/mlir/include/mlir/Analysis/PresburgerSet.h @@ -67,17 +67,18 @@ void print(raw_ostream &os) const; void dump() const; - /// Return the complement of this set. Computing the complement of a set - /// containing divisions is not yet supported. + /// Return the complement of this set. All local variables in the set must + /// correspond to floor divisions. PresburgerSet complement() const; /// Return the set difference of this set and the given set, i.e., - /// return `this \ set`. Subtracting when either set contains divisions is not - /// yet supported. + /// return `this \ set`. All local variables in `set` must correspond + /// to floor divisions, but local variables in `this` need not correspond to + /// divisions. PresburgerSet subtract(const PresburgerSet &set) const; /// Return true if this set is equal to the given set, and false otherwise. - /// Checking equality when either set contains divisions is not yet supported. + /// All local variables in both sets must correspond to floor divisions. bool isEqual(const PresburgerSet &set) const; /// Return a universe set of the specified type that contains all points. diff --git a/mlir/lib/Analysis/PresburgerSet.cpp b/mlir/lib/Analysis/PresburgerSet.cpp --- a/mlir/lib/Analysis/PresburgerSet.cpp +++ b/mlir/lib/Analysis/PresburgerSet.cpp @@ -106,16 +106,20 @@ // // We directly compute (S_1 or S_2 ...) and (T_1 or T_2 ...) // as (S_1 and T_1) or (S_1 and T_2) or ... +// +// If S_i or T_j have local variables, then S_i and T_j contains the local +// variables of both. PresburgerSet PresburgerSet::intersect(const PresburgerSet &set) const { assertDimensionsCompatible(set, *this); PresburgerSet result(nDim, nSym); for (const FlatAffineConstraints &csA : flatAffineConstraints) { for (const FlatAffineConstraints &csB : set.flatAffineConstraints) { - FlatAffineConstraints intersection(csA); - intersection.append(csB); - if (!intersection.isEmpty()) - result.unionFACInPlace(std::move(intersection)); + FlatAffineConstraints csACopy = csA, csBCopy = csB; + csACopy.mergeLocalIds(csBCopy); + csACopy.append(std::move(csBCopy)); + if (!csACopy.isEmpty()) + result.unionFACInPlace(std::move(csACopy)); } } return result; @@ -160,6 +164,17 @@ /// returning the union of the results. Each equality is handled as a /// conjunction of two inequalities. /// +/// Note that the same approach works even if an inequality involves a floor +/// division. For example, the complement of x <= 7*floor(x/7) is still +/// x > 7*floor(x/7). Since b \ s_i contains the inequalities of both b and s_i +/// (or the complements of those inequalities), b \ s_i may contain the +/// divisions present in both b and s_i. Therefore, we need to add the local +/// division variables of both b and s_i to each part in the result. This means +/// adding the local variables of both b and s_i, as well as the corresponding +/// division inequalities to each part. Since the division inequalities are +/// added to each part, we can skip the parts where the complement of any +/// division inequality is added, as these parts will become empty anyway. +/// /// As a heuristic, we try adding all the constraints and check if simplex /// says that the intersection is empty. If it is, then subtracting this FAC is /// a no-op and we just skip it. Also, in the process we find out that some @@ -174,27 +189,63 @@ result.unionFACInPlace(b); return; } - const FlatAffineConstraints &sI = s.getFlatAffineConstraints(i); - assert(sI.getNumLocalIds() == 0 && - "Subtracting sets with divisions is not yet supported!"); + FlatAffineConstraints sI = s.getFlatAffineConstraints(i); + unsigned bInitNumLocals = b.getNumLocalIds(); + + // Find out which inequalities of sI correspond to division inequalities for + // the local variables of sI. + std::vector>> repr( + sI.getNumLocalIds()); + sI.getLocalReprLbUbPairs(repr); + + // Add sI's locals to b, after b's locals. Also add b's locals to sI, before + // sI's locals. + b.mergeLocalIds(sI); + + // Mark which inequalities of sI are division inequalities and add all such + // inequalities to b. + llvm::SmallBitVector isDivInequality(sI.getNumInequalities()); + for (Optional> &maybePair : repr) { + assert(maybePair && + "Subtraction is not supported when a representation of the local " + "variables of the subtrahend cannot be found!"); + + b.addInequality(sI.getInequality(maybePair->first)); + b.addInequality(sI.getInequality(maybePair->second)); + + assert(maybePair->first != maybePair->second && + "Upper and lower bounds must be different inequalities!"); + isDivInequality[maybePair->first] = true; + isDivInequality[maybePair->second] = true; + } + unsigned initialSnapshot = simplex.getSnapshot(); unsigned offset = simplex.getNumConstraints(); + unsigned numLocalsAdded = b.getNumLocalIds() - bInitNumLocals; + simplex.appendVariable(numLocalsAdded); + + unsigned snapshotBeforeIntersect = simplex.getSnapshot(); simplex.intersectFlatAffineConstraints(sI); if (simplex.isEmpty()) { /// b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1. simplex.rollback(initialSnapshot); + b.removeIdRange(FlatAffineConstraints::IdKind::Local, bInitNumLocals, + b.getNumLocalIds()); subtractRecursively(b, simplex, s, i + 1, result); return; } simplex.detectRedundant(); - llvm::SmallBitVector isMarkedRedundant; - for (unsigned j = 0; j < 2 * sI.getNumEqualities() + sI.getNumInequalities(); - j++) - isMarkedRedundant.push_back(simplex.isMarkedRedundant(offset + j)); - simplex.rollback(initialSnapshot); + // Equalities are added to simplex as a pair of inequalities. + unsigned totalNewSimplexInequalities = + 2 * sI.getNumEqualities() + sI.getNumInequalities(); + llvm::SmallBitVector isMarkedRedundant(totalNewSimplexInequalities); + for (unsigned j = 0; j < totalNewSimplexInequalities; j++) + isMarkedRedundant[j] = simplex.isMarkedRedundant(offset + j); + + simplex.rollback(snapshotBeforeIntersect); // Recurse with the part b ^ ~ineq. Note that b is modified throughout // subtractRecursively. At the time this function is called, the current b is @@ -223,20 +274,28 @@ // rollback b to its initial state before returning, which we will do by // removing all constraints beyond the original number of inequalities // and equalities, so we store these counts first. - unsigned originalNumIneqs = b.getNumInequalities(); - unsigned originalNumEqs = b.getNumEqualities(); + unsigned bInitNumIneqs = b.getNumInequalities(); + unsigned bInitNumEqs = b.getNumEqualities(); + // Process all the inequalities, ignoring redundant inequalities and division + // inequalities. The result is correct whether or not we ignore these, but + // ignoring them makes the result simpler. for (unsigned j = 0, e = sI.getNumInequalities(); j < e; j++) { if (isMarkedRedundant[j]) continue; + if (isDivInequality[j]) + continue; processInequality(sI.getInequality(j)); } offset = sI.getNumInequalities(); for (unsigned j = 0, e = sI.getNumEqualities(); j < e; ++j) { - const ArrayRef &coeffs = sI.getEquality(j); - // Same as the above loop for inequalities, done once each for the positive - // and negative inequalities that make up this equality. + ArrayRef coeffs = sI.getEquality(j); + // For each equality, process the positive and negative inequalities that + // make up this equality. If Simplex found an inequality to be redundant, we + // skip it as above to make the result simpler. Divisions are always + // represented in terms of inequalities and not equalities, so we do not + // check for division inequalities here. if (!isMarkedRedundant[offset + 2 * j]) processInequality(coeffs); if (!isMarkedRedundant[offset + 2 * j + 1]) @@ -244,11 +303,10 @@ } // Rollback b and simplex to their initial states. - for (unsigned i = b.getNumInequalities(); i > originalNumIneqs; --i) - b.removeInequality(i - 1); - - for (unsigned i = b.getNumEqualities(); i > originalNumEqs; --i) - b.removeEquality(i - 1); + b.removeIdRange(FlatAffineConstraints::IdKind::Local, bInitNumLocals, + b.getNumLocalIds()); + b.removeInequalityRange(bInitNumIneqs, b.getNumInequalities()); + b.removeEqualityRange(bInitNumEqs, b.getNumEqualities()); simplex.rollback(initialSnapshot); } @@ -261,8 +319,6 @@ PresburgerSet PresburgerSet::getSetDifference(FlatAffineConstraints fac, const PresburgerSet &set) { assertDimensionsCompatible(fac, set); - assert(fac.getNumLocalIds() == 0 && - "Subtracting sets with divisions is not yet supported!"); if (fac.isEmptyByGCDTest()) return PresburgerSet::getEmptySet(fac.getNumDimIds(), fac.getNumSymbolIds()); diff --git a/mlir/unittests/Analysis/PresburgerSetTest.cpp b/mlir/unittests/Analysis/PresburgerSetTest.cpp --- a/mlir/unittests/Analysis/PresburgerSetTest.cpp +++ b/mlir/unittests/Analysis/PresburgerSetTest.cpp @@ -80,12 +80,17 @@ } /// Construct a FlatAffineConstraints from a set of inequality and -/// equality constraints. +/// equality constraints. `numIds` is the total number of ids, of which +/// `numLocals` is the number of local ids. static FlatAffineConstraints -makeFACFromConstraints(unsigned dims, ArrayRef> ineqs, - ArrayRef> eqs) { - FlatAffineConstraints fac(ineqs.size(), eqs.size(), dims + 1, dims, - /*numSymbols=*/0, /*numLocals=*/0); +makeFACFromConstraints(unsigned numIds, ArrayRef> ineqs, + ArrayRef> eqs, + unsigned numLocals = 0) { + FlatAffineConstraints fac(/*numReservedInequalities=*/ineqs.size(), + /*numReservedEqualities=*/eqs.size(), + /*numReservedCols=*/numIds + 1, + /*numDims=*/numIds - numLocals, + /*numSymbols=*/0, numLocals); for (const SmallVector &eq : eqs) fac.addEquality(eq); for (const SmallVector &ineq : ineqs) @@ -93,14 +98,22 @@ return fac; } +/// Construct a FlatAffineConstraints having `numDims` dimensions from the given +/// set of inequality constraints. This is a convenience function to be used +/// when the FAC to be constructed does not have any local ids and does not have +/// equalties. static FlatAffineConstraints -makeFACFromIneqs(unsigned dims, ArrayRef> ineqs) { - return makeFACFromConstraints(dims, ineqs, {}); +makeFACFromIneqs(unsigned numDims, ArrayRef> ineqs) { + return makeFACFromConstraints(numDims, ineqs, /*eqs=*/{}); } -static PresburgerSet makeSetFromFACs(unsigned dims, +/// Construct a PresburgerSet having `numDims` dimensions and no symbols from +/// the given list of FlatAffineConstraints. Each FAC in `facs` should also have +/// `numDims` dimensions and no symbols, although it can have any number of +/// local ids. +static PresburgerSet makeSetFromFACs(unsigned numDims, ArrayRef facs) { - PresburgerSet set = PresburgerSet::getEmptySet(dims); + PresburgerSet set = PresburgerSet::getEmptySet(numDims); for (const FlatAffineConstraints &fac : facs) set.unionFACInPlace(fac); return set; @@ -592,4 +605,37 @@ EXPECT_FALSE(rect.complement().isEqual(square.complement())); } +void expectEqual(PresburgerSet s, PresburgerSet t) { + EXPECT_TRUE(s.isEqual(t)); +} + +void expectEmpty(PresburgerSet s) { EXPECT_TRUE(s.isIntegerEmpty()); } + +TEST(SetTest, divisions) { + // Note: we currently need to add the equalities as inequalities to the FAC + // since detecting divisions based on equalities is not yet supported. + + // evens = {x : exists q, x = 2q}. + PresburgerSet evens{ + makeFACFromConstraints(2, {{1, -2, 0}, {-1, 2, 1}}, {{1, -2, 0}}, 1)}; + // odds = {x : exists q, x = 2q + 1}. + PresburgerSet odds{ + makeFACFromConstraints(2, {{1, -2, 0}, {-1, 2, 1}}, {{1, -2, -1}}, 1)}; + // multiples6 = {x : exists q, x = 6q}. + PresburgerSet multiples3{ + makeFACFromConstraints(2, {{1, -3, 0}, {-1, 3, 2}}, {{1, -3, 0}}, 1)}; + // multiples6 = {x : exists q, x = 6q}. + PresburgerSet multiples6{ + makeFACFromConstraints(2, {{1, -6, 0}, {-1, 6, 5}}, {{1, -6, 0}}, 1)}; + + // evens /\ odds = empty. + expectEmpty(PresburgerSet(evens).intersect(PresburgerSet(odds))); + // evens U odds = universe. + expectEqual(evens.unionSet(odds), PresburgerSet::getUniverse(1)); + expectEqual(evens.complement(), odds); + expectEqual(odds.complement(), evens); + // even multiples of 3 = multiples of 6. + expectEqual(multiples3.intersect(evens), multiples6); +} + } // namespace mlir