diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -645,69 +645,31 @@ /// Check if the expected division representation of local variables matches the /// computed representation. The expected division representation is given as -/// a vector of expressions set in `divisions` and the corressponding -/// denominator in `denoms`. If expected denominator for a variable is -/// non-positive, the local variable is expected to not have a computed -/// representation. +/// a vector of expressions set in `expectedDividends` and the corressponding +/// denominator in `expectedDenominators`. The `denominators` and `dividends` +/// obtained through `getLocalRepr` function is verified against the +/// `expectedDenominators` and `expectedDividends` respectively. static void checkDivisionRepresentation( FlatAffineConstraints &fac, - const std::vector> &divisions, - const SmallVector &denoms) { - - assert(divisions.size() == fac.getNumLocalIds() && - "Size of expected divisions does not match number of local variables"); - assert( - denoms.size() == fac.getNumLocalIds() && - "Size of expected denominators does not match number of local variables"); - - std::vector>> res( - fac.getNumLocalIds(), llvm::None); - fac.getLocalReprs(res); - - // Check if all expected divisions are computed. - for (unsigned i = 0, e = fac.getNumLocalIds(); i < e; ++i) - if (denoms[i] > 0) - EXPECT_TRUE(res[i].hasValue()); - else - EXPECT_FALSE(res[i].hasValue()); - - unsigned divOffset = fac.getNumDimAndSymbolIds(); - for (unsigned i = 0, e = fac.getNumLocalIds(); i < e; ++i) { - if (!res[i]) - continue; - - // Check if the bounds are of the form: - // 0 <= expr - divisor * id <= divisor - 1 - // Rearranging, we have: - // divisor * id - expr + (divisor - 1) >= 0 <-- Lower bound for 'id' - // -divisor * id + expr >= 0 <-- Upper bound for 'id' - // where `id = expr floordiv divisor`. - unsigned ubPos = res[i]->first, lbPos = res[i]->second; - const SmallVector &expr = divisions[i]; - - // Check if lower bound is of the correct form. - int64_t computedDivisorLb = fac.atIneq(lbPos, i + divOffset); - EXPECT_EQ(computedDivisorLb, denoms[i]); - for (unsigned c = 0, f = fac.getNumLocalIds(); c < f; ++c) { - if (c == i + divOffset) - continue; - EXPECT_EQ(fac.atIneq(lbPos, c), -expr[c]); - } - // Check if constant term of lower bound matches expected constant term. - EXPECT_EQ(fac.atIneq(lbPos, fac.getNumCols() - 1), - -expr.back() + (denoms[i] - 1)); - - // Check if upper bound is of the correct form. - int64_t computedDivisorUb = fac.atIneq(ubPos, i + divOffset); - EXPECT_EQ(computedDivisorUb, -denoms[i]); - for (unsigned c = 0, f = fac.getNumLocalIds(); c < f; ++c) { - if (c == i + divOffset) - continue; - EXPECT_EQ(fac.atIneq(ubPos, c), expr[c]); - } - // Check if constant term of upper bound matches expected constant term. - EXPECT_EQ(fac.atIneq(ubPos, fac.getNumCols() - 1), expr.back()); - } + const std::vector> &expectedDividends, + const SmallVectorImpl &expectedDenominators) { + + std::vector> dividends; + SmallVector denominators; + + fac.getLocalReprs(dividends, denominators); + + assert(expectedDividends.size() == dividends.size() && + "Size of expected dividends does not match the calculated dividends"); + assert(expectedDenominators.size() == denominators.size() && + "Size of expected denominators does not match the calculated " + "denominators"); + + // Check that the `dividends` and `true dividends` match. + EXPECT_TRUE(expectedDividends == dividends); + + // Check that the `denominators` and `true_denominators` match. + EXPECT_TRUE(expectedDenominators == denominators); } TEST(FlatAffineConstraintsTest, computeLocalReprSimple) { @@ -718,7 +680,7 @@ std::vector> divisions = {{1, 0, 0, 4}, {1, 0, 0, 100}}; - SmallVector denoms = {10, 10}; + SmallVector denoms = {10, 10}; // Check if floordivs can be computed when no other inequalities exist // and floor divs do not depend on each other. @@ -737,7 +699,7 @@ std::vector> divisions = {{0, 0, 0, 0, 0, 0, 10}, {0, 0, 0, 0, 0, 0, 99}}; - SmallVector denoms = {30, 101}; + SmallVector denoms = {30, 101}; // Check if floordivs with constant numerator can be computed. checkDivisionRepresentation(fac, divisions, denoms); @@ -756,10 +718,12 @@ fac.addInequality({1, 2, -2, 1, -5, 0, 6, 100}); fac.addInequality({1, 2, -8, 1, 3, 7, 0, -9}); - std::vector> divisions = {{0, -2, 7, 2, 0, 0, 0, 10}, - {3, 0, 9, 2, 2, 0, 0, 10}, - {0, 1, -123, 2, 0, -4, 10}}; - SmallVector denoms = {3, 5, 3}; + std::vector> divisions = { + {0, -2, 7, 2, 0, 0, 0, 10}, + {3, 0, 9, 2, 2, 0, 0, 10}, + {0, 1, -123, 2, 0, -4, 0, 10}}; + + SmallVector denoms = {3, 5, 3}; // Check if floordivs which may depend on other floordivs can be computed. checkDivisionRepresentation(fac, divisions, denoms);