diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp --- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -40,17 +40,30 @@ return std::make_shared(tensorNum); } -static std::shared_ptr -addfPattern(const std::shared_ptr &e0, - const std::shared_ptr &e1) { - return std::make_shared(Kind::kAddF, e0, e1); -} +#define FOREVERY_BINOP(DO) \ + DO(mulf, Kind::kMulF) \ + DO(mulc, Kind::kMulC) \ + DO(muli, Kind::kMulI) \ + DO(andi, Kind::kAndI) \ + DO(addf, Kind::kAddF) \ + DO(addc, Kind::kAddC) \ + DO(addi, Kind::kAddI) \ + DO(subf, Kind::kSubF) \ + DO(subc, Kind::kSubC) \ + DO(subi, Kind::kSubI) \ + DO(ori, Kind::kOrI) \ + DO(xori, Kind::kXorI) + +#define DECLARE_PATTERN(OP, KIND) \ + static std::shared_ptr OP##Pattern( \ + const std::shared_ptr &e0, \ + const std::shared_ptr &e1) { \ + return std::make_shared(KIND, e0, e1); \ + } -static std::shared_ptr -mulfPattern(const std::shared_ptr &e0, - const std::shared_ptr &e1) { - return std::make_shared(Kind::kMulF, e0, e1); -} +FOREVERY_BINOP(DECLARE_PATTERN) + +#undef DECLARE_PATTERN class MergerTestBase : public ::testing::Test { protected: @@ -66,13 +79,12 @@ return merger.addExp(Kind::kTensor, tensor); } - unsigned addf(unsigned e0, unsigned e1) { - return merger.addExp(Kind::kAddF, e0, e1); - } +#define DECLARE_EXP(OP, KIND) \ + unsigned OP(unsigned e0, unsigned e1) { return merger.addExp(KIND, e0, e1); } - unsigned mulf(unsigned e0, unsigned e1) { - return merger.addExp(Kind::kMulF, e0, e1); - } + FOREVERY_BINOP(DECLARE_EXP) + +#undef DECLARE_EXP /// /// Comparison helpers. @@ -87,12 +99,14 @@ /// constraints between lattice points. We generally know how contiguous /// groups of lattice points should be ordered with respect to other groups, /// but there is no required ordering within groups. + /// if simple is true, then compare the lat.simple field instead to test the + /// result after optimization bool latPointWithinRange(unsigned s, unsigned p, unsigned n, const std::shared_ptr &pattern, - const BitVector &bits) { + const BitVector &bits, bool simple) { for (unsigned i = p; i < p + n; ++i) { if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) && - compareBits(s, i, bits)) + compareBits(s, i, bits, simple)) return true; } return false; @@ -101,15 +115,15 @@ /// Wrapper over latPointWithinRange for readability of tests. void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n, const std::shared_ptr &pattern, - const BitVector &bits) { - EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits)); + const BitVector &bits, bool simple = false) { + EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits, simple)); } /// Wrapper over expectLatPointWithinRange for a single lat point. void expectLatPoint(unsigned s, unsigned p, const std::shared_ptr &pattern, - const BitVector &bits) { - EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits)); + const BitVector &bits, bool simple = false) { + EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits, simple)); } /// Converts a vector of (loop, tensor) pairs to a bitvector with the @@ -126,7 +140,9 @@ } /// Returns true if the bits of lattice point p in set s match the given bits. - bool compareBits(unsigned s, unsigned p, const BitVector &bits) { + bool compareBits(unsigned s, unsigned p, const BitVector &bits, bool simple) { + if (simple) + return merger.lat(merger.set(s)[p]).simple == bits; return merger.lat(merger.set(s)[p]).bits == bits; } @@ -215,6 +231,10 @@ Merger merger; }; +/// +/// Tests with all sparse inputs. +/// + class MergerTest3T1L : public MergerTestBase { protected: // Our three tensors (two inputs, one output). @@ -238,9 +258,63 @@ } }; +class MergerTest4T1L : public MergerTestBase { +protected: + // Our four tensors (three inputs, one output). + const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3; + + // Our single loop. + const unsigned l0 = 0; + + MergerTest4T1L() : MergerTestBase(4, 1) { + // Tensor 0: sparse input vector. + merger.addExp(Kind::kTensor, t0, -1u); + merger.setDim(t0, l0, Dim::kSparse); + + // Tensor 1: sparse input vector. + merger.addExp(Kind::kTensor, t1, -1u); + merger.setDim(t1, l0, Dim::kSparse); + + // Tensor 2: sparse input vector + merger.addExp(Kind::kTensor, t2, -1u); + merger.setDim(t2, l0, Dim::kSparse); + + // Tensor 3: dense output vector + merger.addExp(Kind::kTensor, t3, -1u); + merger.setDim(t3, l0, Dim::kDense); + } +}; + +/// +/// Tests with both sparse and dense input. +/// + +class MergerTest3T1LD : public MergerTestBase { +protected: + // Our three tensors (two inputs, one output). + const unsigned t0 = 0, t1 = 1, t2 = 2; + + // Our single loop. + const unsigned l0 = 0; + + MergerTest3T1LD() : MergerTestBase(3, 1) { + // Tensor 0: sparse input vector. + merger.addExp(Kind::kTensor, t0, -1u); + merger.setDim(t0, l0, Dim::kSparse); + + // Tensor 1: dense input vector. + merger.addExp(Kind::kTensor, t1, -1u); + merger.setDim(t1, l0, Dim::kDense); + + // Tensor 2: dense output vector. + merger.addExp(Kind::kTensor, t2, -1u); + merger.setDim(t2, l0, Dim::kDense); + } +}; + } // namespace -/// Vector addition of 2 vectors, i.e.: +/// Vector addition (disjunction) of 2 vectors. i.e.; /// a(i) = b(i) + c(i) /// which should form the 3 lattice points /// { @@ -248,55 +322,316 @@ /// lat( i_00 / tensor_0 ) /// lat( i_01 / tensor_1 ) /// } -/// and after optimization, will reduce to the 2 lattice points +/// and after optimization, the lattice points do not change (as there is no +/// duplicated point and all input vectors are sparse vector). /// { /// lat( i_00 i_01 / (tensor_0 + tensor_1) ) /// lat( i_00 / tensor_0 ) +/// lat( i_01 / tensor_1 ) /// } -TEST_F(MergerTest3T1L, VectorAdd2) { - // Construct expression. - auto e = addf(tensor(t0), tensor(t1)); - - // Build lattices and check. - auto s = merger.buildLattices(e, l0); - expectNumLatPoints(s, 3); - expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)), - loopsToBits({{l0, t0}, {l0, t1}})); - expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0), - loopsToBits({{l0, t0}})); - expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1), - loopsToBits({{l0, t1}})); - - // Optimize lattices and check. - s = merger.optimizeSet(s); - expectNumLatPoints(s, 3); - expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)), - loopsToBits({{l0, t0}, {l0, t1}})); - expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0), - loopsToBits({{l0, t0}})); - expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1), - loopsToBits({{l0, t1}})); -} +#define MERGER_TEST_DISJ(op) \ + TEST_F(MergerTest3T1L, vector_##op) { \ + auto e = op(tensor(t0), tensor(t1)); \ + \ + auto s = merger.buildLattices(e, l0); \ + expectNumLatPoints(s, 3); \ + expectLatPoint(s, lat(0), \ + op##Pattern(tensorPattern(t0), tensorPattern(t1)), \ + loopsToBits({{l0, t0}, {l0, t1}})); \ + expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0), \ + loopsToBits({{l0, t0}})); \ + expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1), \ + loopsToBits({{l0, t1}})); \ + \ + s = merger.optimizeSet(s); \ + expectNumLatPoints(s, 3); \ + expectLatPoint(s, lat(0), \ + op##Pattern(tensorPattern(t0), tensorPattern(t1)), \ + loopsToBits({{l0, t0}, {l0, t1}}), true); \ + expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0), \ + loopsToBits({{l0, t0}}), true); \ + expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1), \ + loopsToBits({{l0, t1}}), true); \ + } + +MERGER_TEST_DISJ(addf) +MERGER_TEST_DISJ(addc) +MERGER_TEST_DISJ(addi) +MERGER_TEST_DISJ(ori) +MERGER_TEST_DISJ(xori) + +#undef MERGER_TEST_DISJ +// TODO :: pattern for substraction is different as it is mapped to negate +// operation. +// MERGER_TEST_DISJ(subf) +// MERGER_TEST_DISJ(subc) +// MERGER_TEST_DISJ(subi) -/// Vector multiplication of 2 vectors, i.e.: +/// Vector multiplication (conjunction) of 2 vectors, i.e.; /// a(i) = b(i) * c(i) /// which should form the single lattice point /// { /// lat( i_00 i_01 / (tensor_0 * tensor_1) ) /// } -TEST_F(MergerTest3T1L, VectorMul2) { - // Construct expression. - auto e = mulf(t0, t1); - - // Build lattices and check. - auto s = merger.buildLattices(e, l0); - expectNumLatPoints(s, 1); - expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)), - loopsToBits({{l0, t0}, {l0, t1}})); - - // Optimize lattices and check. - s = merger.optimizeSet(s); - expectNumLatPoints(s, 1); - expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)), - loopsToBits({{l0, t0}, {l0, t1}})); -} +#define MERGER_TEST_CONJ(op) \ + TEST_F(MergerTest3T1L, vector_##op) { \ + auto e = op(t0, t1); \ + \ + auto s = merger.buildLattices(e, l0); \ + expectNumLatPoints(s, 1); \ + expectLatPoint(s, lat(0), \ + op##Pattern(tensorPattern(t0), tensorPattern(t1)), \ + loopsToBits({{l0, t0}, {l0, t1}})); \ + \ + s = merger.optimizeSet(s); \ + expectNumLatPoints(s, 1); \ + expectLatPoint(s, lat(0), \ + op##Pattern(tensorPattern(t0), tensorPattern(t1)), \ + loopsToBits({{l0, t0}, {l0, t1}}), true); \ + } + +// TODO: div are not tested for now as it need a constant non-zero dividend. +MERGER_TEST_CONJ(mulf) +MERGER_TEST_CONJ(mulc) +MERGER_TEST_CONJ(muli) +MERGER_TEST_CONJ(andi) + +#undef MERGER_TEST_CONJ + +/// Vector multiplication (conjunction) then addition (disjunction), i.e.; +/// a(i) = b(i) * c(i) + d(i); +/// which should form +/// { +/// lat( i_00 i_01 i_02 / (tensor_0 * tensor_1) + tensor_2 ) +/// lat( i_00 i_01 / tensor_0 * tensor_1 +/// lat( i_02 / tensor_2 ) +/// } + +#define MERGER_TEST_CONJ_DISJ(conj, disj) \ + TEST_F(MergerTest4T1L, vector_##conj##_##disj) { \ + auto em = conj(t0, t1); \ + auto e = disj(em, t2); \ + \ + auto s = merger.buildLattices(e, l0); \ + expectNumLatPoints(s, 3); \ + expectLatPoint( \ + s, lat(0), \ + disj##Pattern(conj##Pattern(tensorPattern(t0), tensorPattern(t1)), \ + tensorPattern(t2)), \ + loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ + expectLatPointWithinRange( \ + s, lat(1), 2, conj##Pattern(tensorPattern(t0), tensorPattern(t1)), \ + loopsToBits({{l0, t0}, {l0, t1}})); \ + expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t2), \ + loopsToBits({{l0, t2}})); \ + \ + s = merger.optimizeSet(s); \ + expectNumLatPoints(s, 3); \ + expectLatPoint( \ + s, lat(0), \ + disj##Pattern(conj##Pattern(tensorPattern(t0), tensorPattern(t1)), \ + tensorPattern(t2)), \ + loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ + expectLatPointWithinRange( \ + s, lat(1), 2, conj##Pattern(tensorPattern(t0), tensorPattern(t1)), \ + loopsToBits({{l0, t0}, {l0, t1}})); \ + expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t2), \ + loopsToBits({{l0, t2}})); \ + } + +MERGER_TEST_CONJ_DISJ(mulf, addf) +MERGER_TEST_CONJ_DISJ(andi, addf) +MERGER_TEST_CONJ_DISJ(mulf, ori) +MERGER_TEST_CONJ_DISJ(andi, ori) +MERGER_TEST_CONJ_DISJ(mulf, xori) +MERGER_TEST_CONJ_DISJ(andi, xori) + +#undef MERGER_TEST_CONJ_DISJ + +/// Vector addition (disjunction) then addition (disjunction), i.e.; +/// a(i) = b(i) + c(i) + d(i) +/// which should form +/// { +/// lat( i_00 i_01 i_02 / (tensor_0 + tensor_1) + tensor_2 ) +/// lat( i_02 i_01 / tensor_2 + tensor_1 ) +/// lat( i_02 i_00 / tensor_2 + tensor_0 ) +/// lat( i_01 i_00 / tensor_1 + tensor_0 ) +/// lat( i_02 / tensor_2 ) +/// lat( i_01 / tensor_1 ) +/// lat( i_00 / tensor_0 ) +/// } +#define MERGER_TEST_DISJ_DISJ(op1, op2) \ + TEST_F(MergerTest4T1L, Vector_##op1##_##op2) { \ + auto em = op1(t0, t1); \ + auto e = op2(em, t2); \ + \ + auto s = merger.buildLattices(e, l0); \ + expectNumLatPoints(s, 7); \ + expectLatPoint( \ + s, lat(0), \ + op2##Pattern(op1##Pattern(tensorPattern(t0), tensorPattern(t1)), \ + tensorPattern(t2)), \ + loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ + expectLatPointWithinRange( \ + s, lat(1), 6, op2##Pattern(tensorPattern(t1), tensorPattern(t2)), \ + loopsToBits({{l0, t1}, {l0, t2}})); \ + expectLatPointWithinRange( \ + s, lat(1), 6, op2##Pattern(tensorPattern(t0), tensorPattern(t2)), \ + loopsToBits({{l0, t0}, {l0, t2}})); \ + expectLatPointWithinRange( \ + s, lat(1), 6, op1##Pattern(tensorPattern(t0), tensorPattern(t1)), \ + loopsToBits({{l0, t0}, {l0, t1}})); \ + expectLatPointWithinRange(s, lat(1), 6, tensorPattern(t2), \ + loopsToBits({{l0, t2}})); \ + expectLatPointWithinRange(s, lat(1), 6, tensorPattern(t1), \ + loopsToBits({{l0, t1}})); \ + expectLatPointWithinRange(s, lat(1), 6, tensorPattern(t0), \ + loopsToBits({{l0, t0}})); \ + \ + s = merger.optimizeSet(s); \ + expectNumLatPoints(s, 7); \ + expectLatPoint( \ + s, lat(0), \ + op2##Pattern(op1##Pattern(tensorPattern(t0), tensorPattern(t1)), \ + tensorPattern(t2)), \ + loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ + expectLatPointWithinRange( \ + s, lat(1), 6, op2##Pattern(tensorPattern(t1), tensorPattern(t2)), \ + loopsToBits({{l0, t1}, {l0, t2}})); \ + expectLatPointWithinRange( \ + s, lat(1), 6, op2##Pattern(tensorPattern(t0), tensorPattern(t2)), \ + loopsToBits({{l0, t0}, {l0, t2}})); \ + expectLatPointWithinRange( \ + s, lat(1), 6, op1##Pattern(tensorPattern(t0), tensorPattern(t1)), \ + loopsToBits({{l0, t0}, {l0, t1}})); \ + expectLatPointWithinRange(s, lat(1), 6, tensorPattern(t2), \ + loopsToBits({{l0, t2}})); \ + expectLatPointWithinRange(s, lat(1), 6, tensorPattern(t1), \ + loopsToBits({{l0, t1}})); \ + expectLatPointWithinRange(s, lat(1), 6, tensorPattern(t0), \ + loopsToBits({{l0, t0}})); \ + } + +MERGER_TEST_DISJ_DISJ(addf, addf) +MERGER_TEST_DISJ_DISJ(ori, ori) +MERGER_TEST_DISJ_DISJ(xori, xori) +MERGER_TEST_DISJ_DISJ(ori, xori) + +#undef MERGER_TEST_DISJ_DISJ + +/// Vector multiplication (conjunction) then multiplication (conjunction), i.e.; +/// a(i) = b(i) * c(i) * d(i); +/// which should form +/// { +/// lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 ) +/// } +#define MERGER_TEST_CONJ_CONJ(op1, op2) \ + TEST_F(MergerTest4T1L, vector_##op1##_##op2) { \ + auto em = op1(t0, t1); \ + auto e = op2(em, t2); \ + \ + auto s = merger.buildLattices(e, l0); \ + expectNumLatPoints(s, 1); \ + expectLatPoint( \ + s, lat(0), \ + op2##Pattern(op1##Pattern(tensorPattern(t0), tensorPattern(t1)), \ + tensorPattern(t2)), \ + loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ + s = merger.optimizeSet(s); \ + expectNumLatPoints(s, 1); \ + expectLatPoint( \ + s, lat(0), \ + op2##Pattern(op1##Pattern(tensorPattern(t0), tensorPattern(t1)), \ + tensorPattern(t2)), \ + loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}), true); \ + } + +MERGER_TEST_CONJ_CONJ(mulf, mulf) +MERGER_TEST_CONJ_CONJ(mulf, andi) +MERGER_TEST_CONJ_CONJ(andi, andi) + +#undef MERGER_TEST_CONJ_CONJ + +/// Vector addition (disjunction) of 2 vectors, i.e.; +/// a(i) = b(i) + c(i) +/// which should form the 3 lattice points +/// { +/// lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ) +/// lat( i_00 / sparse_tensor_0 ) +/// lat( i_01 / dense_tensor_1 ) +/// } +/// which should be optimized to +/// { +/// lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ) (not singleton) +/// lat( i_01 / dense_tensor_0 ) (no sparse dimension) +/// } +/// +/// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff +/// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ). +#define MERGER_TEST_OPTED_DISJ(op) \ + TEST_F(MergerTest3T1LD, vector_opted_##op) { \ + auto e = op(tensor(t0), tensor(t1)); \ + \ + auto s = merger.buildLattices(e, l0); \ + expectNumLatPoints(s, 3); \ + expectLatPoint(s, lat(0), \ + op##Pattern(tensorPattern(t0), tensorPattern(t1)), \ + loopsToBits({{l0, t0}, {l0, t1}})); \ + expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0), \ + loopsToBits({{l0, t0}})); \ + expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1), \ + loopsToBits({{l0, t1}})); \ + \ + s = merger.optimizeSet(s); \ + expectNumLatPoints(s, 2); \ + expectLatPoint(s, lat(0), \ + op##Pattern(tensorPattern(t0), tensorPattern(t1)), \ + loopsToBits({{l0, t0}, {l0, t1}}), true); \ + expectLatPoint(s, lat(1), tensorPattern(t1), loopsToBits({{l0, t1}}), \ + true); \ + } + +MERGER_TEST_OPTED_DISJ(addf) +MERGER_TEST_OPTED_DISJ(addc) +MERGER_TEST_OPTED_DISJ(addi) +MERGER_TEST_OPTED_DISJ(ori) +MERGER_TEST_OPTED_DISJ(xori) + +#undef MERGER_TEST_OPTED_DISJ + +/// Vector multiplication (conjunction) of 2 vectors, i.e.: +/// a(i) = b(i) * c(i) +/// which should form the single lattice point +/// { +/// lat( i_00 i_01 / (sparse_tensor_0 * dense_tensor_1) ) +/// } +/// it should be optimized to +/// { +/// lat( i_00 / (sparse_tensor_0 * dense_tensor_1) ) +/// } +/// since i_01 is a dense dimension. +#define MERGER_TEST_OPTED_CONJ(op) \ + TEST_F(MergerTest3T1LD, vector_opted_##op) { \ + auto e = op(t0, t1); \ + \ + auto s = merger.buildLattices(e, l0); \ + expectNumLatPoints(s, 1); \ + expectLatPoint(s, lat(0), \ + op##Pattern(tensorPattern(t0), tensorPattern(t1)), \ + loopsToBits({{l0, t0}, {l0, t1}})); \ + \ + s = merger.optimizeSet(s); \ + expectNumLatPoints(s, 1); \ + expectLatPoint(s, lat(0), \ + op##Pattern(tensorPattern(t0), tensorPattern(t1)), \ + loopsToBits({{l0, t0}}), true); \ + } + +MERGER_TEST_OPTED_CONJ(mulf) +MERGER_TEST_OPTED_CONJ(mulc) +MERGER_TEST_OPTED_CONJ(muli) +MERGER_TEST_OPTED_CONJ(andi) + +#undef MERGER_TEST_OPTED_CONJ + +// TODO: mult-dim tests