diff --git a/mlir/lib/Analysis/Presburger/PresburgerSet.cpp b/mlir/lib/Analysis/Presburger/PresburgerSet.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerSet.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerSet.cpp @@ -409,15 +409,33 @@ assert(i != j && "The indices must refer to different polyhedra"); unsigned n = polyhedrons.size(); - polyhedrons[i] = polyhedrons[n - 1]; - polyhedrons[j] = polyhedrons[n - 2]; - polyhedrons.pop_back(); - polyhedrons[n - 2] = poly; - - simplices[i] = simplices[n - 1]; - simplices[j] = simplices[n - 2]; - simplices.pop_back(); - simplices[n - 2] = Simplex(poly); + if (j == n - 1) { + // This case needs special handling since position `n` - 1 is removed from + // the vector, hence the `IntegerPolyhedron` at position `n` - 2 is lost + // otherwise. + polyhedrons[i] = polyhedrons[n - 2]; + polyhedrons.pop_back(); + polyhedrons[n - 2] = poly; + + simplices[i] = simplices[n - 2]; + simplices.pop_back(); + simplices[n - 2] = Simplex(poly); + + } else { + // Other possible edge cases are correct since for `j` or `i` == `n` - 2, + // the `IntegerPolyhedron` at position `n` - 2 should be lost. The case + // `i` == `n` - 1 makes the first following statement a noop. Hence, in this + // case the same thing is done as above, but with `j` rather than `i`. + polyhedrons[i] = polyhedrons[n - 1]; + polyhedrons[j] = polyhedrons[n - 2]; + polyhedrons.pop_back(); + polyhedrons[n - 2] = poly; + + simplices[i] = simplices[n - 1]; + simplices[j] = simplices[n - 2]; + simplices.pop_back(); + simplices[n - 2] = Simplex(poly); + } } /// Given two polyhedra `a` and `b` at positions `i` and `j` in `polyhedrons` diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp --- a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp @@ -645,6 +645,17 @@ expectCoalesce(3, set); } +TEST(SetTest, coalesceLastCoalesced) { + PresburgerSet set = parsePresburgerSetFromPolyStrings( + 1, { + "(x) : (x == 0)", + "(x) : (x - 1 >= 0, -x + 3 >= 0)", + "(x) : (x + 2 == 0)", + "(x) : (x - 2 >= 0, -x + 4 >= 0)", + }); + expectCoalesce(3, set); +} + TEST(SetTest, coalesceDiv) { PresburgerSet set = parsePresburgerSetFromPolyStrings(1, {