diff --git a/mlir/unittests/Dialect/SparseTensor/CMakeLists.txt b/mlir/unittests/Dialect/SparseTensor/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/unittests/Dialect/SparseTensor/CMakeLists.txt @@ -0,0 +1,7 @@ +add_mlir_unittest(MLIRSparseTensorTests + MergerTest.cpp +) +target_link_libraries(MLIRSparseTensorTests + PRIVATE + MLIRSparseTensorUtils +) diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -0,0 +1,166 @@ +#include "mlir/Dialect/SparseTensor/Utils/Merger.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace mlir::sparse_tensor; + +namespace { + +const unsigned NUM_TENSORS = 2, NUM_LOOPS = 1; + +class MergerTest : public ::testing::Test { +protected: + MergerTest() : merger(NUM_TENSORS, NUM_LOOPS) { + + // Tensor 0: sparse vector. + merger.addExp(Kind::kTensor, 0, -1u); + merger.setDim(0, 0, Dim::kSparse); + + // Tensor 1: sparse vector. + merger.addExp(Kind::kTensor, 1, -1u); + merger.setDim(1, 0, Dim::kSparse); + } + + /// + /// Expression construction helpers. + /// + + unsigned tensor(unsigned tensor) { return tensor; } + + unsigned addf(unsigned e0, unsigned e1) { + return merger.addExp(Kind::kAddF, e0, e1); + } + + unsigned mulf(unsigned e0, unsigned e1) { + return merger.addExp(Kind::kMulF, e0, e1); + } + + /// + /// Comparison helpers. + /// + + /// Check that the expression of lat point p in set s is equivalent to the + /// expression pointed to by e. + void expectLatPointExpressionEquals(unsigned s, unsigned p, unsigned e) { + EXPECT_TRUE(compareExpression(merger.lat(merger.set(s)[p]).exp, e)); + } + + /// Checks that the true bits of the lat point p in set s correspond to the + /// bits indicated by the (loop, tensor) pairs in loops. + void + expectLatPointBitsEqual(unsigned s, unsigned p, + std::vector> loops) { + llvm::BitVector testBits = llvm::BitVector(NUM_TENSORS + 1, false); + for (auto l : loops) { + auto loop = std::get<0>(l); + auto tensor = std::get<1>(l); + testBits.set(NUM_TENSORS * loop + tensor); + } + EXPECT_THAT(merger.lat(merger.set(s)[p]).bits, testBits); + } + + /// Check that there are n lattice points in set s. + void expectNumLatPoints(unsigned s, unsigned n) { + EXPECT_THAT(merger.set(s).size(), n); + } + + Merger merger; + +private: + /// Compares e0 and e1 for equality. Equality is defined recursively as: + /// - Two expressions can only be equal if they have the same Kind. + /// - Two binary expressions are equal if they have the same Kind and their + /// children are equal. + /// - Expressions with Kind invariant or tensor are equal if they have the + /// same expression id. + bool compareExpression(unsigned e0, unsigned e1) { + auto tensorExp0 = merger.exp(e0); + auto tensorExp1 = merger.exp(e1); + if (tensorExp0.kind != tensorExp1.kind) { + return false; + } + switch (tensorExp0.kind) { + case Kind::kAddF: + case Kind::kMulF: + case Kind::kAddI: + case Kind::kMulI: + return compareExpression(tensorExp0.e0, tensorExp1.e0) && + compareExpression(tensorExp0.e1, tensorExp1.e1); + case Kind::kInvariant: + return tensorExp0.val == tensorExp1.val; + case Kind::kTensor: + return tensorExp0.e0 == tensorExp1.e0; + } + } +}; + +} // anonymous namespace + +/// Vector addition of 2 vectors, i.e.: +/// a(i) = b(i) + c(i) +/// which should form the 3 lattice points +/// { +/// lat( i_00 i_01 / (tensor_0 + tensor_1) ) +/// lat( i_00 / tensor_0 ) +/// lat( i_01 / tensor_1 ) +/// } +/// and after optimization, will reduce to the 2 lattice points +/// { +/// lat( i_00 i_01 / (tensor_0 + tensor_1) ) +/// lat( i_00 / tensor_0 ) +/// } +TEST_F(MergerTest, VectorAdd2) { + // Construct expression. + auto e = addf(tensor(0), tensor(1)); + + // Build lattices and check. + auto s = merger.buildLattices(e, 0); + expectNumLatPoints(s, 3); + expectLatPointExpressionEquals(s, /*lat point*/ 0, + addf(tensor(0), tensor(1))); + expectLatPointBitsEqual( + s, /*lat point*/ 0, + {{/*loop*/ 0, /*tensor*/ 0}, {/*loop*/ 0, /*tensor*/ 1}}); + expectLatPointExpressionEquals(s, /*lat point*/ 1, tensor(0)); + expectLatPointBitsEqual(s, /*lat point*/ 1, {{/*loop*/ 0, /*tensor*/ 0}}); + expectLatPointExpressionEquals(s, /*lat point*/ 2, tensor(1)); + expectLatPointBitsEqual(s, /*lat point*/ 2, {{/*loop*/ 0, /*tensor*/ 1}}); + + // Optimize lattices and check. + s = merger.optimizeSet(s); + expectNumLatPoints(s, 2); + expectLatPointExpressionEquals(s, 0, addf(tensor(0), tensor(1))); + expectLatPointBitsEqual( + s, /*lat point*/ 0, + {{/*loop*/ 0, /*tensor*/ 0}, {/*loop*/ 0, /*tensor*/ 1}}); + expectLatPointExpressionEquals(s, 1, tensor(0)); + expectLatPointBitsEqual(s, /*lat point*/ 1, {{/*loop*/ 0, /*tensor*/ 0}}); +} + +/// Vector multiplication of 2 vectors, i.e.: +/// a(i) = b(i) * c(i) +/// which should form the single lattice point +/// { +/// lat( i_00 i_01 / (tensor_0 * tensor_1) ) +/// } +TEST_F(MergerTest, VectorMul2) { + // Construct expression. + auto e = mulf(tensor(0), tensor(1)); + + // Build lattices and check. + auto s = merger.buildLattices(e, 0); + expectNumLatPoints(s, 1); + expectLatPointExpressionEquals(s, 0, mulf(tensor(0), tensor(1))); + expectLatPointBitsEqual( + s, /*lat point*/ 0, + {{/*loop*/ 0, /*tensor*/ 0}, {/*loop*/ 0, /*tensor*/ 1}}); + + // Optimize lattices and check. + s = merger.optimizeSet(s); + expectNumLatPoints(s, 1); + expectLatPointExpressionEquals(s, 0, mulf(tensor(0), tensor(1))); + expectLatPointBitsEqual( + s, /*lat point*/ 0, + {{/*loop*/ 0, /*tensor*/ 0}, {/*loop*/ 0, /*tensor*/ 1}}); +}