diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp --- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -80,27 +80,39 @@ /// Helper classes/functions for testing Merger. /// -/// Simple recursive data structure used to match expressions in Mergers. +/// Simple recursive data structure used to match expressions in `Merger`. +struct Pattern; +/// Since the patterns we need are rather small and short-lived, we use +/// `Pattern const&` for "pointers" to patterns, rather than using +/// something more elaborate like `std::shared_ptr const&`. +/// (But since we use a typedef rather than spelling it out everywhere, +/// that's easy enough to swap out if we need something more elaborate +/// in the future.) +using PatternRef = const Pattern &; struct Pattern { + struct Children { + Children(PatternRef e0, PatternRef e1) : e0(e0), e1(e1) {} + PatternRef e0; + PatternRef e1; + }; + TensorExp::Kind kind; - /// Expressions representing tensors simply have a tensor number. - unsigned tensorNum; + union { + /// Expressions representing tensors simply have a tensor number. + TensorId tid; - /// Tensor operations point to their children. - std::shared_ptr e0; - std::shared_ptr e1; + /// Tensor operations point to their children. + Children children; + }; /// Constructors. /// Rather than using these, please use the readable helper constructor /// functions below to make tests more readable. - Pattern(unsigned tensorNum) - : kind(TensorExp::Kind::kTensor), tensorNum(tensorNum) {} - Pattern(TensorExp::Kind kind, const std::shared_ptr &e0, - const std::shared_ptr &e1) - : kind(kind), e0(e0), e1(e1) { + Pattern(TensorId tid) : kind(TensorExp::Kind::kTensor), tid(tid) {} + Pattern(TensorExp::Kind kind, PatternRef e0, PatternRef e1) + : kind(kind), children(e0, e1) { assert(kind >= TensorExp::Kind::kMulF); - assert(e0 && e1); } }; @@ -109,15 +121,12 @@ /// These should be preferred over the actual constructors. /// -static std::shared_ptr tensorPattern(unsigned tensorNum) { - return std::make_shared(tensorNum); -} +static Pattern tensorPattern(TensorId tid) { return Pattern(tid); } #define IMPL_BINOP_PATTERN(OP, KIND) \ - LLVM_ATTRIBUTE_UNUSED static std::shared_ptr OP##Pattern( \ - const std::shared_ptr &e0, \ - const std::shared_ptr &e1) { \ - return std::make_shared(KIND, e0, e1); \ + LLVM_ATTRIBUTE_UNUSED static Pattern OP##Pattern(PatternRef e0, \ + PatternRef e1) { \ + return Pattern(KIND, e0, e1); \ } FOREVERY_BINOP(IMPL_BINOP_PATTERN) @@ -127,19 +136,31 @@ class MergerTestBase : public ::testing::Test { protected: MergerTestBase(unsigned numTensors, unsigned numLoops) - : numTensors(numTensors), numLoops(numLoops), - merger(numTensors, numLoops, /*numFilterLoops=*/0) {} + : merger(numTensors, numLoops, /*numFilterLoops=*/0) { + tensors.reserve(numTensors); + for (unsigned t = 0; t < numTensors; t++) + tensors.push_back(merger.addExp(TensorExp::Kind::kTensor, tid(t))); + } /// /// Expression construction helpers. /// - unsigned tensor(unsigned tensor) { - return merger.addExp(TensorExp::Kind::kTensor, tensor); + TensorId tid(unsigned t) const { + assert(t < merger.getNumTensors()); + return t; + } + LoopId lid(unsigned i) const { + assert(i < merger.getNumLoops()); + return i; + } + ExprId tensor(unsigned t) const { + assert(t < tensors.size()); + return tensors[t]; } #define IMPL_BINOP_EXPR(OP, KIND) \ - LLVM_ATTRIBUTE_UNUSED unsigned OP##Expr(unsigned e0, unsigned e1) { \ + LLVM_ATTRIBUTE_UNUSED ExprId OP##Expr(ExprId e0, ExprId e1) { \ return merger.addExp(KIND, e0, e1); \ } @@ -151,83 +172,77 @@ /// Comparison helpers. /// - /// For readability of tests. - unsigned lat(unsigned lat) { return lat; } - - /// Returns true if a lattice point with an expression matching the given - /// pattern and bits matching the given bits is present in lattice points - /// [p, p+n) of lattice set s. This is useful for testing partial ordering - /// constraints between lattice points. We generally know how contiguous - /// groups of lattice points should be ordered with respect to other groups, - /// but there is no required ordering within groups. - /// If simple is true, then compare the lat.simple field instead to test the - /// result after optimization - bool latPointWithinRange(unsigned s, unsigned p, unsigned n, - const std::shared_ptr &pattern, - const BitVector &bits, bool simple) { - for (unsigned i = p; i < p + n; ++i) { - if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) && - compareBits(s, i, bits, simple)) + /// Returns true if any lattice point with an expression matching + /// the given `pattern` and bits matching the given `bits` is present + /// in the `[lo, lo+n)` slice of the lattice set `s`. This is useful + /// for testing partial ordering constraints between lattice points. + /// We generally know how contiguous groups of lattice points should + /// be ordered with respect to other groups, but there is no required + /// ordering within groups. If `simple` is true, then compare the + /// `lat.simple` field instead to test the result after optimization. + bool latPointWithinRange(LatSetId s, unsigned lo, unsigned n, + PatternRef pattern, const BitVector &bits, + bool simple) { + for (unsigned k = lo, hi = lo + n; k < hi; ++k) { + if (compareExpression(merger.lat(merger.set(s)[k]).exp, pattern) && + compareBits(s, k, bits, simple)) return true; } return false; } /// Wrapper over latPointWithinRange for readability of tests. - void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n, - const std::shared_ptr &pattern, - const BitVector &bits, bool simple = false) { - EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits, simple)); + void expectLatPointWithinRange(LatSetId s, unsigned lo, unsigned n, + PatternRef pattern, const BitVector &bits, + bool simple = false) { + EXPECT_TRUE(latPointWithinRange(s, lo, n, pattern, bits, simple)); } /// Wrapper over expectLatPointWithinRange for a single lat point. - void expectLatPoint(unsigned s, unsigned p, - const std::shared_ptr &pattern, + void expectLatPoint(LatSetId s, unsigned lo, PatternRef pattern, const BitVector &bits, bool simple = false) { - EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits, simple)); + EXPECT_TRUE(latPointWithinRange(s, lo, 1, pattern, bits, simple)); } /// Converts a vector of (loop, tensor) pairs to a bitvector with the /// corresponding bits set. - BitVector - loopsToBits(const std::vector> &loops) { - BitVector testBits = BitVector(numTensors + 1, false); - for (auto l : loops) { - auto loop = std::get<0>(l); - auto tensor = std::get<1>(l); + BitVector loopsToBits(const std::vector> &loops) { + // NOTE: this `numTensors` includes both the output- and synthetic-tensors. + const auto numTensors = merger.getNumTensors(); + BitVector testBits = BitVector(numTensors, false); + for (auto [loop, tensor] : loops) testBits.set(numTensors * loop + tensor); - } return testBits; } - /// Returns true if the bits of lattice point p in set s match the given bits. - /// If simple is true, then compare the lat.simple field instead to test the - /// result after optimization - bool compareBits(unsigned s, unsigned p, const BitVector &bits, bool simple) { - if (simple) - return merger.lat(merger.set(s)[p]).simple == bits; - return merger.lat(merger.set(s)[p]).bits == bits; + /// Returns true if the bits of the `k`th point in set `s` matches + /// the given `bits`. If `simple` is true, then compares the `lat.simple` + /// field instead, to test the result after optimization + bool compareBits(LatSetId s, unsigned k, const BitVector &bits, bool simple) { + const auto &point = merger.lat(merger.set(s)[k]); + return (simple ? point.simple : point.bits) == bits; } /// Check that there are n lattice points in set s. - void expectNumLatPoints(unsigned s, unsigned n) { + void expectNumLatPoints(LatSetId s, unsigned n) { EXPECT_THAT(merger.set(s).size(), n); } /// Compares expressions for equality. Equality is defined recursively as: /// - Operations are equal if they have the same kind and children. /// - Leaf tensors are equal if they refer to the same tensor. - bool compareExpression(unsigned e, const std::shared_ptr &pattern) { - auto tensorExp = merger.exp(e); - if (tensorExp.kind != pattern->kind) + bool compareExpression(ExprId e, PatternRef pattern) { + const auto &tensorExp = merger.exp(e); + if (tensorExp.kind != pattern.kind) return false; switch (tensorExp.kind) { // Leaf. case TensorExp::Kind::kTensor: - return tensorExp.tensor == pattern->tensorNum; + return tensorExp.tensor == pattern.tid; case TensorExp::Kind::kInvariant: - case TensorExp::Kind::kLoopVar: llvm_unreachable("invariant not handled yet"); + case TensorExp::Kind::kLoopVar: + llvm_unreachable("loop-variables not handled yet"); // Unary operations. case TensorExp::Kind::kAbsF: case TensorExp::Kind::kAbsC: @@ -263,7 +278,7 @@ case TensorExp::Kind::kSelect: case TensorExp::Kind::kBinaryBranch: case TensorExp::Kind::kUnary: - return compareExpression(tensorExp.children.e0, pattern->e0); + return compareExpression(tensorExp.children.e0, pattern.children.e0); // Binary operations. case TensorExp::Kind::kMulF: case TensorExp::Kind::kMulC: @@ -286,72 +301,51 @@ case TensorExp::Kind::kShlI: case TensorExp::Kind::kBinary: case TensorExp::Kind::kReduce: - return compareExpression(tensorExp.children.e0, pattern->e0) && - compareExpression(tensorExp.children.e1, pattern->e1); + return compareExpression(tensorExp.children.e0, pattern.children.e0) && + compareExpression(tensorExp.children.e1, pattern.children.e1); } llvm_unreachable("unexpected kind"); } - unsigned numTensors; - unsigned numLoops; + // This field is public for convenience. Merger merger; + +private: + // This field is private to prevent mutation after the ctor. + SmallVector tensors; }; /// /// Tests with all sparse inputs. /// +/// Three tensors (two inputs, one output); and a single loop. class MergerTest3T1L : public MergerTestBase { protected: - // Our three tensors (two inputs, one output). - const unsigned t0 = 0, t1 = 1, t2 = 2; - - // Our single loop. - const unsigned l0 = 0; - MergerTest3T1L() : MergerTestBase(3, 1) { - EXPECT_TRUE(merger.getOutTensorID() == t2); - + EXPECT_TRUE(merger.getOutTensorID() == tid(2)); // Tensor 0: sparse input vector. - merger.addExp(TensorExp::Kind::kTensor, t0, -1u); - merger.setLevelAndType(t0, l0, 0, DimLevelType::Compressed); - + merger.setLevelAndType(tid(0), lid(0), 0, DimLevelType::Compressed); // Tensor 1: sparse input vector. - merger.addExp(TensorExp::Kind::kTensor, t1, -1u); - merger.setLevelAndType(t1, l0, 0, DimLevelType::Compressed); - + merger.setLevelAndType(tid(1), lid(0), 0, DimLevelType::Compressed); // Tensor 2: dense output vector. - merger.addExp(TensorExp::Kind::kTensor, t2, -1u); - merger.setLevelAndType(t2, l0, 0, DimLevelType::Dense); + merger.setLevelAndType(tid(2), lid(0), 0, DimLevelType::Dense); } }; +/// Four tensors (three inputs, one output); and a single loop. class MergerTest4T1L : public MergerTestBase { protected: - // Our four tensors (three inputs, one output). - const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3; - - // Our single loop. - const unsigned l0 = 0; - MergerTest4T1L() : MergerTestBase(4, 1) { - EXPECT_TRUE(merger.getOutTensorID() == t3); - + EXPECT_TRUE(merger.getOutTensorID() == tid(3)); // Tensor 0: sparse input vector. - merger.addExp(TensorExp::Kind::kTensor, t0, -1u); - merger.setLevelAndType(t0, l0, 0, DimLevelType::Compressed); - + merger.setLevelAndType(tid(0), lid(0), 0, DimLevelType::Compressed); // Tensor 1: sparse input vector. - merger.addExp(TensorExp::Kind::kTensor, t1, -1u); - merger.setLevelAndType(t1, l0, 0, DimLevelType::Compressed); - + merger.setLevelAndType(tid(1), lid(0), 0, DimLevelType::Compressed); // Tensor 2: sparse input vector - merger.addExp(TensorExp::Kind::kTensor, t2, -1u); - merger.setLevelAndType(t2, l0, 0, DimLevelType::Compressed); - + merger.setLevelAndType(tid(2), lid(0), 0, DimLevelType::Compressed); // Tensor 3: dense output vector - merger.addExp(TensorExp::Kind::kTensor, t3, -1u); - merger.setLevelAndType(t3, l0, 0, DimLevelType::Dense); + merger.setLevelAndType(tid(3), lid(0), 0, DimLevelType::Dense); } }; @@ -359,28 +353,17 @@ /// Tests with both sparse and dense input. /// +/// Three tensors (two inputs, one output); and a single loop. class MergerTest3T1LD : public MergerTestBase { protected: - // Our three tensors (two inputs, one output). - const unsigned t0 = 0, t1 = 1, t2 = 2; - - // Our single loop. - const unsigned l0 = 0; - MergerTest3T1LD() : MergerTestBase(3, 1) { - EXPECT_TRUE(merger.getOutTensorID() == t2); - + EXPECT_TRUE(merger.getOutTensorID() == tid(2)); // Tensor 0: sparse input vector. - merger.addExp(TensorExp::Kind::kTensor, t0, -1u); - merger.setLevelAndType(t0, l0, 0, DimLevelType::Compressed); - + merger.setLevelAndType(tid(0), lid(0), 0, DimLevelType::Compressed); // Tensor 1: dense input vector. - merger.addExp(TensorExp::Kind::kTensor, t1, -1u); - merger.setLevelAndType(t1, l0, 0, DimLevelType::Dense); - + merger.setLevelAndType(tid(1), lid(0), 0, DimLevelType::Dense); // Tensor 2: dense output vector. - merger.addExp(TensorExp::Kind::kTensor, t2, -1u); - merger.setLevelAndType(t2, l0, 0, DimLevelType::Dense); + merger.setLevelAndType(tid(2), lid(0), 0, DimLevelType::Dense); } }; @@ -388,32 +371,19 @@ /// Tests with both undef and dense input. /// +/// Three tensors (three inputs, one output); and a single loop. class MergerTest4T1LU : public MergerTestBase { protected: - // Our three tensors (three inputs, one output). - const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3; - - // Our single loop. - const unsigned l0 = 0; - MergerTest4T1LU() : MergerTestBase(4, 1) { - EXPECT_TRUE(merger.getOutTensorID() == t3); - + EXPECT_TRUE(merger.getOutTensorID() == tid(3)); // Tensor 0: undef input vector. - merger.addExp(TensorExp::Kind::kTensor, t0, -1u); - merger.setLevelAndType(t0, l0, 0, DimLevelType::Undef); - + merger.setLevelAndType(tid(0), lid(0), 0, DimLevelType::Undef); // Tensor 1: dense input vector. - merger.addExp(TensorExp::Kind::kTensor, t1, -1u); - merger.setLevelAndType(t1, l0, 0, DimLevelType::Dense); - + merger.setLevelAndType(tid(1), lid(0), 0, DimLevelType::Dense); // Tensor 2: undef input vector. - merger.addExp(TensorExp::Kind::kTensor, t2, -1u); - merger.setLevelAndType(t2, l0, 0, DimLevelType::Undef); - + merger.setLevelAndType(tid(2), lid(0), 0, DimLevelType::Undef); // Tensor 3: dense output vector. - merger.addExp(TensorExp::Kind::kTensor, t3, -1u); - merger.setLevelAndType(t3, l0, 0, DimLevelType::Dense); + merger.setLevelAndType(tid(3), lid(0), 0, DimLevelType::Dense); } }; @@ -421,31 +391,19 @@ /// Tests with operation on sparse output. /// +/// Three tensors (two inputs, one output, one synthetic); and a single loop. class MergerTest3T1LSo : public MergerTestBase { protected: - // Our three tensors (two inputs, one output, one synthetic). - const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3; - - // Our single loop. - const unsigned l0 = 0; - MergerTest3T1LSo() : MergerTestBase(3, 1) { - EXPECT_TRUE(merger.getOutTensorID() == t2); - EXPECT_TRUE(merger.getSynTensorID() == t3); - + EXPECT_TRUE(merger.getOutTensorID() == tid(2)); + EXPECT_TRUE(merger.getSynTensorID() == tid(3)); merger.setHasSparseOut(true); - // Tensor 0: undef input vector. - merger.addExp(TensorExp::Kind::kTensor, t0, -1u); - merger.setLevelAndType(t0, l0, 0, DimLevelType::Undef); - + merger.setLevelAndType(tid(0), lid(0), 0, DimLevelType::Undef); // Tensor 1: undef input vector. - merger.addExp(TensorExp::Kind::kTensor, t1, -1u); - merger.setLevelAndType(t1, l0, 0, DimLevelType::Undef); - + merger.setLevelAndType(tid(1), lid(0), 0, DimLevelType::Undef); // Tensor 2: sparse output vector. - merger.addExp(TensorExp::Kind::kTensor, t2, -1u); - merger.setLevelAndType(t2, l0, 0, DimLevelType::Compressed); + merger.setLevelAndType(tid(2), lid(0), 0, DimLevelType::Compressed); } }; @@ -464,18 +422,22 @@ /// } #define IMPL_MERGER_TEST_CONJ_CONJ_UNDEF(CONJ1, CONJ2) \ TEST_F(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) { \ - auto em = CONJ1##Expr(t0, t1); \ - auto e = CONJ2##Expr(em, t2); \ - auto p0 = tensorPattern(t0); \ - auto p1 = tensorPattern(t1); \ - auto p2 = tensorPattern(t2); \ + const auto em = CONJ1##Expr(tensor(0), tensor(1)); \ + const auto e = CONJ2##Expr(em, tensor(2)); \ + const auto l0 = lid(0); \ + const auto t0 = tid(0); \ + const auto t1 = tid(1); \ + const auto t2 = tid(2); \ + const PatternRef p0 = tensorPattern(t0); \ + const PatternRef p1 = tensorPattern(t1); \ + const PatternRef p2 = tensorPattern(t2); \ auto s = merger.buildLattices(e, l0); \ expectNumLatPoints(s, 1); \ - expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \ + expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \ loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ s = merger.optimizeSet(s); \ expectNumLatPoints(s, 1); \ - expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \ + expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \ loopsToBits({{l0, t1}}), true); \ } @@ -496,18 +458,23 @@ /// } #define IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT(CONJ1, CONJ2) \ TEST_F(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) { \ - auto em = CONJ1##Expr(t0, t1); \ - auto e = CONJ2##Expr(em, t2); \ - auto p0 = tensorPattern(t0); \ - auto p1 = tensorPattern(t1); \ - auto p2 = tensorPattern(t2); \ + const auto em = CONJ1##Expr(tensor(0), tensor(1)); \ + const auto e = CONJ2##Expr(em, tensor(2)); \ + const auto l0 = lid(0); \ + const auto t0 = tid(0); \ + const auto t1 = tid(1); \ + const auto t2 = tid(2); \ + const auto t3 = tid(3); \ + const PatternRef p0 = tensorPattern(t0); \ + const PatternRef p1 = tensorPattern(t1); \ + const PatternRef p2 = tensorPattern(t2); \ auto s = merger.buildLattices(e, l0); \ expectNumLatPoints(s, 1); \ - expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \ + expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \ loopsToBits({{l0, t0}, {l0, t1}, {l0, t3}})); \ s = merger.optimizeSet(s); \ expectNumLatPoints(s, 1); \ - expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \ + expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \ loopsToBits({{l0, t3}}), true); \ } @@ -532,25 +499,26 @@ /// } #define IMPL_MERGER_TEST_DISJ(OP) \ TEST_F(MergerTest3T1L, vector_##OP) { \ - auto e = OP##Expr(tensor(t0), tensor(t1)); \ - auto p0 = tensorPattern(t0); \ - auto p1 = tensorPattern(t1); \ + const auto e = OP##Expr(tensor(0), tensor(1)); \ + const auto l0 = lid(0); \ + const auto t0 = tid(0); \ + const auto t1 = tid(1); \ + const PatternRef p0 = tensorPattern(t0); \ + const PatternRef p1 = tensorPattern(t1); \ auto s = merger.buildLattices(e, l0); \ \ expectNumLatPoints(s, 3); \ - expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ + expectLatPoint(s, 0, OP##Pattern(p0, p1), \ loopsToBits({{l0, t0}, {l0, t1}})); \ - expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}})); \ - expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}})); \ + expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}})); \ + expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}})); \ \ s = merger.optimizeSet(s); \ expectNumLatPoints(s, 3); \ - expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ + expectLatPoint(s, 0, OP##Pattern(p0, p1), \ loopsToBits({{l0, t0}, {l0, t1}}), true); \ - expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}}), \ - true); \ - expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}}), \ - true); \ + expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}}), true); \ + expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}}), true); \ } FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ) @@ -565,18 +533,21 @@ /// } #define IMPL_MERGER_TEST_CONJ(OP) \ TEST_F(MergerTest3T1L, vector_##OP) { \ - auto e = OP##Expr(t0, t1); \ - auto p0 = tensorPattern(t0); \ - auto p1 = tensorPattern(t1); \ + const auto e = OP##Expr(tensor(0), tensor(1)); \ + const auto l0 = lid(0); \ + const auto t0 = tid(0); \ + const auto t1 = tid(1); \ + const PatternRef p0 = tensorPattern(t0); \ + const PatternRef p1 = tensorPattern(t1); \ auto s = merger.buildLattices(e, l0); \ \ expectNumLatPoints(s, 1); \ - expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ + expectLatPoint(s, 0, OP##Pattern(p0, p1), \ loopsToBits({{l0, t0}, {l0, t1}})); \ \ s = merger.optimizeSet(s); \ expectNumLatPoints(s, 1); \ - expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ + expectLatPoint(s, 0, OP##Pattern(p0, p1), \ loopsToBits({{l0, t0}, {l0, t1}}), true); \ } @@ -594,27 +565,31 @@ /// } #define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ) \ TEST_F(MergerTest4T1L, vector_##CONJ##_##DISJ) { \ - auto em = CONJ##Expr(t0, t1); \ - auto e = DISJ##Expr(em, t2); \ - auto p0 = tensorPattern(t0); \ - auto p1 = tensorPattern(t1); \ - auto p2 = tensorPattern(t2); \ + const auto em = CONJ##Expr(tensor(0), tensor(1)); \ + const auto e = DISJ##Expr(em, tensor(2)); \ + const auto l0 = lid(0); \ + const auto t0 = tid(0); \ + const auto t1 = tid(1); \ + const auto t2 = tid(2); \ + const PatternRef p0 = tensorPattern(t0); \ + const PatternRef p1 = tensorPattern(t1); \ + const PatternRef p2 = tensorPattern(t2); \ auto s = merger.buildLattices(e, l0); \ \ expectNumLatPoints(s, 3); \ - expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2), \ + expectLatPoint(s, 0, DISJ##Pattern(CONJ##Pattern(p0, p1), p2), \ loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ - expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1), \ + expectLatPointWithinRange(s, 1, 2, CONJ##Pattern(p0, p1), \ loopsToBits({{l0, t0}, {l0, t1}})); \ - expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}})); \ + expectLatPointWithinRange(s, 1, 2, p2, loopsToBits({{l0, t2}})); \ \ s = merger.optimizeSet(s); \ expectNumLatPoints(s, 3); \ - expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2), \ + expectLatPoint(s, 0, DISJ##Pattern(CONJ##Pattern(p0, p1), p2), \ loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ - expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1), \ + expectLatPointWithinRange(s, 1, 2, CONJ##Pattern(p0, p1), \ loopsToBits({{l0, t0}, {l0, t1}})); \ - expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}})); \ + expectLatPointWithinRange(s, 1, 2, p2, loopsToBits({{l0, t2}})); \ } FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ) @@ -635,39 +610,43 @@ /// } #define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2) \ TEST_F(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \ - auto em = DISJ1##Expr(t0, t1); \ - auto e = DISJ2##Expr(em, t2); \ - auto p0 = tensorPattern(t0); \ - auto p1 = tensorPattern(t1); \ - auto p2 = tensorPattern(t2); \ + const auto em = DISJ1##Expr(tensor(0), tensor(1)); \ + const auto e = DISJ2##Expr(em, tensor(2)); \ + const auto l0 = lid(0); \ + const auto t0 = tid(0); \ + const auto t1 = tid(1); \ + const auto t2 = tid(2); \ + const PatternRef p0 = tensorPattern(t0); \ + const PatternRef p1 = tensorPattern(t1); \ + const PatternRef p2 = tensorPattern(t2); \ auto s = merger.buildLattices(e, l0); \ \ expectNumLatPoints(s, 7); \ - expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2), \ + expectLatPoint(s, 0, DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2), \ loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ - expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2), \ + expectLatPointWithinRange(s, 1, 6, DISJ2##Pattern(p1, p2), \ loopsToBits({{l0, t1}, {l0, t2}})); \ - expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2), \ + expectLatPointWithinRange(s, 1, 6, DISJ2##Pattern(p0, p2), \ loopsToBits({{l0, t0}, {l0, t2}})); \ - expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1), \ + expectLatPointWithinRange(s, 1, 6, DISJ1##Pattern(p0, p1), \ loopsToBits({{l0, t0}, {l0, t1}})); \ - expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}})); \ - expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}})); \ - expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}})); \ + expectLatPointWithinRange(s, 1, 6, p2, loopsToBits({{l0, t2}})); \ + expectLatPointWithinRange(s, 1, 6, p1, loopsToBits({{l0, t1}})); \ + expectLatPointWithinRange(s, 1, 6, p0, loopsToBits({{l0, t0}})); \ \ s = merger.optimizeSet(s); \ expectNumLatPoints(s, 7); \ - expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2), \ + expectLatPoint(s, 0, DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2), \ loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ - expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2), \ + expectLatPointWithinRange(s, 1, 6, DISJ2##Pattern(p1, p2), \ loopsToBits({{l0, t1}, {l0, t2}})); \ - expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2), \ + expectLatPointWithinRange(s, 1, 6, DISJ2##Pattern(p0, p2), \ loopsToBits({{l0, t0}, {l0, t2}})); \ - expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1), \ + expectLatPointWithinRange(s, 1, 6, DISJ1##Pattern(p0, p1), \ loopsToBits({{l0, t0}, {l0, t1}})); \ - expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}})); \ - expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}})); \ - expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}})); \ + expectLatPointWithinRange(s, 1, 6, p2, loopsToBits({{l0, t2}})); \ + expectLatPointWithinRange(s, 1, 6, p1, loopsToBits({{l0, t1}})); \ + expectLatPointWithinRange(s, 1, 6, p0, loopsToBits({{l0, t0}})); \ } FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ) @@ -682,18 +661,22 @@ /// } #define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2) \ TEST_F(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \ - auto em = CONJ1##Expr(t0, t1); \ - auto e = CONJ2##Expr(em, t2); \ - auto p0 = tensorPattern(t0); \ - auto p1 = tensorPattern(t1); \ - auto p2 = tensorPattern(t2); \ + const auto em = CONJ1##Expr(tensor(0), tensor(1)); \ + const auto e = CONJ2##Expr(em, tensor(2)); \ + const auto l0 = lid(0); \ + const auto t0 = tid(0); \ + const auto t1 = tid(1); \ + const auto t2 = tid(2); \ + const PatternRef p0 = tensorPattern(t0); \ + const PatternRef p1 = tensorPattern(t1); \ + const PatternRef p2 = tensorPattern(t2); \ auto s = merger.buildLattices(e, l0); \ expectNumLatPoints(s, 1); \ - expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \ + expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \ loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ s = merger.optimizeSet(s); \ expectNumLatPoints(s, 1); \ - expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \ + expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \ loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}), true); \ } @@ -719,22 +702,25 @@ /// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ). #define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP) \ TEST_F(MergerTest3T1LD, vector_opted_##OP) { \ - auto e = OP##Expr(tensor(t0), tensor(t1)); \ - auto p0 = tensorPattern(t0); \ - auto p1 = tensorPattern(t1); \ + const auto e = OP##Expr(tensor(0), tensor(1)); \ + const auto l0 = lid(0); \ + const auto t0 = tid(0); \ + const auto t1 = tid(1); \ + const PatternRef p0 = tensorPattern(t0); \ + const PatternRef p1 = tensorPattern(t1); \ auto s = merger.buildLattices(e, l0); \ \ expectNumLatPoints(s, 3); \ - expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ + expectLatPoint(s, 0, OP##Pattern(p0, p1), \ loopsToBits({{l0, t0}, {l0, t1}})); \ - expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}})); \ - expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}})); \ + expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}})); \ + expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}})); \ \ s = merger.optimizeSet(s); \ expectNumLatPoints(s, 2); \ - expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ + expectLatPoint(s, 0, OP##Pattern(p0, p1), \ loopsToBits({{l0, t0}, {l0, t1}}), true); \ - expectLatPoint(s, lat(1), p1, loopsToBits({{l0, t1}}), true); \ + expectLatPoint(s, 1, p1, loopsToBits({{l0, t1}}), true); \ } FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ) @@ -754,19 +740,21 @@ /// since i_01 is a dense dimension. #define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP) \ TEST_F(MergerTest3T1LD, vector_opted_##OP) { \ - auto e = OP##Expr(t0, t1); \ - auto p0 = tensorPattern(t0); \ - auto p1 = tensorPattern(t1); \ + const auto e = OP##Expr(tensor(0), tensor(1)); \ + const auto l0 = lid(0); \ + const auto t0 = tid(0); \ + const auto t1 = tid(1); \ + const PatternRef p0 = tensorPattern(t0); \ + const PatternRef p1 = tensorPattern(t1); \ auto s = merger.buildLattices(e, l0); \ \ expectNumLatPoints(s, 1); \ - expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ + expectLatPoint(s, 0, OP##Pattern(p0, p1), \ loopsToBits({{l0, t0}, {l0, t1}})); \ \ s = merger.optimizeSet(s); \ expectNumLatPoints(s, 1); \ - expectLatPoint(s, lat(0), OP##Pattern(p0, p1), loopsToBits({{l0, t0}}), \ - true); \ + expectLatPoint(s, 0, OP##Pattern(p0, p1), loopsToBits({{l0, t0}}), true); \ } FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)