diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -133,15 +133,13 @@ /// that some constraints are redundant. These redundant constraints are /// ignored. /// -/// b and simplex are callee saved, i.e., their values on return are -/// semantically equivalent to their values when the function is called. -/// /// b should not have duplicate divs because this might lead to existing /// divs disappearing in the call to mergeLocalIds below, which cannot be /// handled. static void subtractRecursively(IntegerRelation &b, Simplex &simplex, const PresburgerRelation &s, unsigned i, PresburgerRelation &result) { + if (i == s.getNumDisjuncts()) { result.unionInPlace(b); return; @@ -156,17 +154,9 @@ // rollback b to its initial state before returning, which we will do by // removing all constraints beyond the original number of inequalities // and equalities, so we store these counts first. - const IntegerRelation::CountsSnapshot bCounts = b.getCounts(); + IntegerRelation::CountsSnapshot initBCounts = b.getCounts(); // Similarly, we also want to rollback simplex to its original state. - const unsigned initialSnapshot = simplex.getSnapshot(); - - auto restoreState = [&]() { - b.truncate(bCounts); - simplex.rollback(initialSnapshot); - }; - - // Automatically restore the original state when we return. - auto stateRestorer = llvm::make_scope_exit(restoreState); + unsigned initialSnapshot = simplex.getSnapshot(); // Find out which inequalities of sI correspond to division inequalities for // the local variables of sI. @@ -176,31 +166,41 @@ // Add sI's locals to b, after b's locals. Also add b's locals to sI, before // sI's locals. b.mergeLocalIds(sI); + unsigned numLocalsAdded = + b.getNumLocalIds() - initBCounts.getSpace().getNumLocalIds(); + // Update simplex to also include the new locals in `b` from merging. + simplex.appendVariable(numLocalsAdded); + + // Equalities are processed by considering them as a pair of inequalities. + // The first sI.getNumInequalities() elements are for sI's inequalities; + // then a pair of inequalities occurs for each of sI's equalities. + // If the equality is expr == 0, the first element in the pair + // corresponds to expr >= 0, and the second to expr <= 0. + llvm::SmallBitVector canIgnoreIneq(sI.getNumInequalities() + + 2 * sI.getNumEqualities()); - // Mark which inequalities of sI are division inequalities and add all such - // inequalities to b. - llvm::SmallBitVector isDivInequality(sI.getNumInequalities()); + // Add all division inequalities to `b`. for (MaybeLocalRepr &maybeInequality : repr) { assert(maybeInequality.kind == ReprKind::Inequality && "Subtraction is not supported when a representation of the local " "variables of the subtrahend cannot be found!"); - auto lb = maybeInequality.repr.inequalityPair.lowerBoundIdx; - auto ub = maybeInequality.repr.inequalityPair.upperBoundIdx; + unsigned lb = maybeInequality.repr.inequalityPair.lowerBoundIdx; + unsigned ub = maybeInequality.repr.inequalityPair.upperBoundIdx; b.addInequality(sI.getInequality(lb)); b.addInequality(sI.getInequality(ub)); assert(lb != ub && "Upper and lower bounds must be different inequalities!"); - isDivInequality[lb] = true; - isDivInequality[ub] = true; + + // We just added these inequalities to `b`, so there is no point considering + // the parts where these inequalities occur complemented -- such parts are + // empty. Therefore, we mark that these can be ignored. + canIgnoreIneq[lb] = true; + canIgnoreIneq[ub] = true; } unsigned offset = simplex.getNumConstraints(); - unsigned numLocalsAdded = - b.getNumLocalIds() - bCounts.getSpace().getNumLocalIds(); - simplex.appendVariable(numLocalsAdded); - unsigned snapshotBeforeIntersect = simplex.getSnapshot(); simplex.intersectIntegerRelation(sI); @@ -208,72 +208,64 @@ // b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1. // We are ignoring level i completely, so we restore the state // *before* going to level i + 1. - restoreState(); + b.truncate(initBCounts); + simplex.rollback(initialSnapshot); subtractRecursively(b, simplex, s, i + 1, result); - - // We already restored the state above and the recursive call should have - // restored to the same state before returning, so we don't need to restore - // the state again. - stateRestorer.release(); return; } simplex.detectRedundant(); - // Equalities are added to simplex as a pair of inequalities. unsigned totalNewSimplexInequalities = 2 * sI.getNumEqualities() + sI.getNumInequalities(); - llvm::SmallBitVector isMarkedRedundant(totalNewSimplexInequalities); + // Redundant inequalities can be safely ignored. This is not required for + // correctness but improves performance and results in a more compact + // representation of the set difference. for (unsigned j = 0; j < totalNewSimplexInequalities; j++) - isMarkedRedundant[j] = simplex.isMarkedRedundant(offset + j); - + canIgnoreIneq[j] = simplex.isMarkedRedundant(offset + j); simplex.rollback(snapshotBeforeIntersect); + SmallVector ineqsToProcess(totalNewSimplexInequalities); + for (unsigned i = 0; i < totalNewSimplexInequalities; ++i) + if (!canIgnoreIneq[i]) + ineqsToProcess.push_back(i); + // Recurse with the part b ^ ~ineq. Note that b is modified throughout // subtractRecursively. At the time this function is called, the current b is // actually equal to b ^ s_i1 ^ s_i2 ^ ... ^ s_ij, and ineq is the next // inequality, s_{i,j+1}. This function recurses into the next level i + 1 // with the part b ^ s_i1 ^ s_i2 ^ ... ^ s_ij ^ ~s_{i,j+1}. auto recurseWithInequality = [&, i](ArrayRef ineq) { - SimplexRollbackScopeExit scopeExit(simplex); b.addInequality(ineq); simplex.addInequality(ineq); subtractRecursively(b, simplex, s, i + 1, result); - b.removeInequality(b.getNumInequalities() - 1); }; // For each inequality ineq, we first recurse with the part where ineq // is not satisfied, and then add the ineq to b and simplex because // ineq must be satisfied by all later parts. auto processInequality = [&](ArrayRef ineq) { + unsigned snapshot = simplex.getSnapshot(); + IntegerRelation::CountsSnapshot bCounts = b.getCounts(); recurseWithInequality(getComplementIneq(ineq)); + simplex.rollback(snapshot); + b.truncate(bCounts); + b.addInequality(ineq); simplex.addInequality(ineq); }; - // Process all the inequalities, ignoring redundant inequalities and division - // inequalities. The result is correct whether or not we ignore these, but - // ignoring them makes the result simpler. - for (unsigned j = 0, e = sI.getNumInequalities(); j < e; j++) { - if (isMarkedRedundant[j]) - continue; - if (isDivInequality[j]) - continue; - processInequality(sI.getInequality(j)); - } - - offset = sI.getNumInequalities(); - for (unsigned j = 0, e = sI.getNumEqualities(); j < e; ++j) { - ArrayRef coeffs = sI.getEquality(j); - // For each equality, process the positive and negative inequalities that - // make up this equality. If Simplex found an inequality to be redundant, we - // skip it as above to make the result simpler. Divisions are always - // represented in terms of inequalities and not equalities, so we do not - // check for division inequalities here. - if (!isMarkedRedundant[offset + 2 * j]) - processInequality(coeffs); - if (!isMarkedRedundant[offset + 2 * j + 1]) - processInequality(getNegatedCoeffs(coeffs)); + for (unsigned idx : ineqsToProcess) { + if (idx < sI.getNumInequalities()) { + processInequality(sI.getInequality(idx)); + } else { + idx -= sI.getNumInequalities(); + ArrayRef eqCoeffs = sI.getEquality(idx / 2); + if (idx % 2 == 0) + processInequality(eqCoeffs); + else + processInequality(getNegatedCoeffs(eqCoeffs)); + } } }