diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt --- a/mlir/unittests/Dialect/CMakeLists.txt +++ b/mlir/unittests/Dialect/CMakeLists.txt @@ -7,4 +7,5 @@ MLIRDialect) add_subdirectory(Quant) +add_subdirectory(SparseTensor) add_subdirectory(SPIRV) diff --git a/mlir/unittests/Dialect/SparseTensor/CMakeLists.txt b/mlir/unittests/Dialect/SparseTensor/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/unittests/Dialect/SparseTensor/CMakeLists.txt @@ -0,0 +1,7 @@ +add_mlir_unittest(MLIRSparseTensorTests + MergerTest.cpp +) +target_link_libraries(MLIRSparseTensorTests + PRIVATE + MLIRSparseTensorUtils +) diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -0,0 +1,235 @@ +#include "mlir/Dialect/SparseTensor/Utils/Merger.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include + +using namespace mlir::sparse_tensor; + +namespace { + +struct Pattern; + +struct PatternChildren { + std::unique_ptr e0; + std::unique_ptr e1; +}; + +/// Simple recursive data structure used to match expressions in Mergers. +struct Pattern { + Kind kind; + + /// Expressions representing tensors simply have a tensor number. + unsigned tensorNum; + + /// Tensor operations point to their children. + PatternChildren children; + + /// + /// Constructors. + /// Rather than using these, please use the readable helper constructor + /// functions below to make tests more readable. + /// + + Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {} + Pattern(Kind kind, std::unique_ptr e0, std::unique_ptr e1) + : kind(kind) { + assert(kind >= Kind::kMulF); + assert(e0 && e1); + children.e0 = move(e0); + children.e1 = move(e1); + } +}; + +/// +/// Readable Pattern builder functions. +/// These should be preferred over the actual constructors. +/// + +static std::unique_ptr tensorPattern(unsigned tensorNum) { + return std::make_unique(tensorNum); +} + +static std::unique_ptr addfPattern(std::unique_ptr e0, + std::unique_ptr e1) { + return std::make_unique(Kind::kAddF, std::move(e0), std::move(e1)); +} + +static std::unique_ptr mulfPattern(std::unique_ptr e0, + std::unique_ptr e1) { + return std::make_unique(Kind::kMulF, std::move(e0), std::move(e1)); +} + +class MergerTestBase : public ::testing::Test { +protected: + MergerTestBase(unsigned numTensors, unsigned numLoops) + : numTensors(numTensors), numLoops(numLoops), + merger(numTensors, numLoops) {} + /// + /// Expression construction helpers. + /// + + /// This function helps with the readability of tests. + unsigned tensor(unsigned tensor) { return tensor; } + + unsigned addf(unsigned e0, unsigned e1) { + return merger.addExp(Kind::kAddF, e0, e1); + } + + unsigned mulf(unsigned e0, unsigned e1) { + return merger.addExp(Kind::kMulF, e0, e1); + } + + /// + /// Comparison helpers. + /// + + /// For readability of tests. + unsigned lat(unsigned lat) { return lat; } + unsigned loop(unsigned loop) { return loop; } + + /// Check that the expression of lat point p in set s is equivalent to the + /// expression pointed to by e. + void expectLatPointExpressionEquals(unsigned s, unsigned p, + std::unique_ptr pattern) { + EXPECT_TRUE(compareExpression(merger.lat(merger.set(s)[p]).exp, + std::move(pattern))); + } + + /// Checks that the true bits of the lat point p in set s correspond to the + /// bits indicated by the (loop, tensor) pairs in loops. + void + expectLatPointBitsEqual(unsigned s, unsigned p, + std::vector> loops) { + llvm::BitVector testBits = llvm::BitVector(numTensors + 1, false); + for (auto l : loops) { + auto loop = std::get<0>(l); + auto tensor = std::get<1>(l); + testBits.set(numTensors * loop + tensor); + } + EXPECT_THAT(merger.lat(merger.set(s)[p]).bits, testBits); + } + + /// Check that there are n lattice points in set s. + void expectNumLatPoints(unsigned s, unsigned n) { + EXPECT_THAT(merger.set(s).size(), n); + } + + /// Compares expressions for equality. Equality is defined recursively as: + /// - Two expressions can only be equal if they have the same Kind. + /// - Two binary expressions are equal if they have the same Kind and their + /// children are equal. + /// - Expressions with Kind invariant or tensor are equal if they have the + /// same expression id. + bool compareExpression(unsigned e, std::unique_ptr pattern) { + auto tensorExp = merger.exp(e); + if (tensorExp.kind != pattern->kind) { + return false; + } + assert(tensorExp.kind != Kind::kInvariant && + "Invariant comparison not yet supported"); + switch (tensorExp.kind) { + case Kind::kTensor: + return tensorExp.tensor == pattern->tensorNum; + case Kind::kZero: + return true; + case Kind::kMulF: + case Kind::kMulI: + case Kind::kAddF: + case Kind::kAddI: + case Kind::kSubF: + case Kind::kSubI: + return compareExpression(tensorExp.children.e0, + std::move(pattern->children.e0)) && + compareExpression(tensorExp.children.e1, + std::move(pattern->children.e1)); + default: + llvm_unreachable("Unhandled Kind"); + } + } + + unsigned numTensors; + unsigned numLoops; + Merger merger; +}; + +class MergerTest2T1L : public MergerTestBase { +protected: + MergerTest2T1L() : MergerTestBase(2, 1) { + // Tensor 0: sparse vector. + merger.addExp(Kind::kTensor, 0, -1u); + merger.setDim(0, 0, Dim::kSparse); + + // Tensor 1: sparse vector. + merger.addExp(Kind::kTensor, 1, -1u); + merger.setDim(1, 0, Dim::kSparse); + } +}; + +} // anonymous namespace + +/// Vector addition of 2 vectors, i.e.: +/// a(i) = b(i) + c(i) +/// which should form the 3 lattice points +/// { +/// lat( i_00 i_01 / (tensor_0 + tensor_1) ) +/// lat( i_00 / tensor_0 ) +/// lat( i_01 / tensor_1 ) +/// } +/// and after optimization, will reduce to the 2 lattice points +/// { +/// lat( i_00 i_01 / (tensor_0 + tensor_1) ) +/// lat( i_00 / tensor_0 ) +/// } +TEST_F(MergerTest2T1L, VectorAdd2) { + // Construct expression. + auto e = addf(tensor(0), tensor(1)); + + // Build lattices and check. + auto s = merger.buildLattices(e, loop(0)); + expectNumLatPoints(s, 3); + expectLatPointExpressionEquals( + s, lat(0), addfPattern(tensorPattern(0), tensorPattern(1))); + expectLatPointBitsEqual(s, lat(0), + {{loop(0), tensor(0)}, {loop(0), tensor(1)}}); + expectLatPointExpressionEquals(s, lat(1), tensorPattern(0)); + expectLatPointBitsEqual(s, lat(1), {{loop(0), tensor(0)}}); + expectLatPointExpressionEquals(s, lat(2), tensorPattern(1)); + expectLatPointBitsEqual(s, lat(2), {{loop(0), tensor(1)}}); + + // Optimize lattices and check. + s = merger.optimizeSet(s); + expectNumLatPoints(s, 2); + expectLatPointExpressionEquals( + s, lat(0), addfPattern(tensorPattern(0), tensorPattern(1))); + expectLatPointBitsEqual(s, lat(0), + {{loop(0), tensor(0)}, {loop(0), tensor(1)}}); + expectLatPointExpressionEquals(s, lat(1), tensorPattern(0)); + expectLatPointBitsEqual(s, lat(1), {{loop(0), tensor(0)}}); +} + +/// Vector multiplication of 2 vectors, i.e.: +/// a(i) = b(i) * c(i) +/// which should form the single lattice point +/// { +/// lat( i_00 i_01 / (tensor_0 * tensor_1) ) +/// } +TEST_F(MergerTest2T1L, VectorMul2) { + // Construct expression. + auto e = mulf(tensor(0), tensor(1)); + + // Build lattices and check. + auto s = merger.buildLattices(e, loop(0)); + expectNumLatPoints(s, 1); + expectLatPointExpressionEquals( + s, lat(0), mulfPattern(tensorPattern(0), tensorPattern(1))); + expectLatPointBitsEqual(s, lat(0), + {{loop(0), tensor(0)}, {loop(0), tensor(1)}}); + + // Optimize lattices and check. + s = merger.optimizeSet(s); + expectNumLatPoints(s, 1); + expectLatPointExpressionEquals( + s, lat(0), mulfPattern(tensorPattern(0), tensorPattern(1))); + expectLatPointBitsEqual(s, lat(0), + {{loop(0), tensor(0)}, {loop(0), tensor(1)}}); +}