diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -236,21 +236,33 @@ // such inequalities to b. llvm::SmallBitVector canIgnoreIneq(sI.getNumInequalities() + 2 * sI.getNumEqualities()); - for (MaybeLocalRepr &maybeInequality : repr) { + for (MaybeLocalRepr &maybeRepr : repr) { assert( - maybeInequality.kind == ReprKind::Inequality && + maybeRepr && "Subtraction is not supported when a representation of the local " "variables of the subtrahend cannot be found!"); - unsigned lb = maybeInequality.repr.inequalityPair.lowerBoundIdx; - unsigned ub = maybeInequality.repr.inequalityPair.upperBoundIdx; - b.addInequality(sI.getInequality(lb)); - b.addInequality(sI.getInequality(ub)); - - assert(lb != ub && - "Upper and lower bounds must be different inequalities!"); - canIgnoreIneq[lb] = true; - canIgnoreIneq[ub] = true; + if (maybeRepr.kind == ReprKind::Inequality) { + unsigned lb = maybeRepr.repr.inequalityPair.lowerBoundIdx; + unsigned ub = maybeRepr.repr.inequalityPair.upperBoundIdx; + + b.addInequality(sI.getInequality(lb)); + b.addInequality(sI.getInequality(ub)); + + assert(lb != ub && + "Upper and lower bounds must be different inequalities!"); + canIgnoreIneq[lb] = true; + canIgnoreIneq[ub] = true; + } else { + assert(maybeRepr.kind == ReprKind::Equality && + "ReprKind isn't inequality so should be equality"); + unsigned idx = maybeRepr.repr.equalityIdx; + b.addEquality(sI.getEquality(idx)); + // We can ignore both inequalities corresponding to this equality. + unsigned offset = sI.getNumInequalities(); + canIgnoreIneq[offset + 2 * idx] = true; + canIgnoreIneq[offset + 2 * idx + 1] = true; + } } unsigned offset = simplex.getNumConstraints(); diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp --- a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp @@ -455,6 +455,11 @@ PresburgerSet setA{parsePoly("(x) : (-x >= 0)")}; PresburgerSet setB{parsePoly("(x) : (x floordiv 2 - 4 >= 0)")}; EXPECT_TRUE(setA.subtract(setB).isEqual(setA)); + + IntegerPolyhedron evensDefByEquality(PresburgerSpace::getSetSpace( + /*numDims=*/1, /*numSymbols=*/0, /*numLocals=*/1)); + evensDefByEquality.addEquality({1, -2, 0}); + expectEqual(evens, PresburgerSet(evensDefByEquality)); } TEST(SetTest, subtractDuplicateDivsRegression) {