diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt --- a/mlir/unittests/Dialect/CMakeLists.txt +++ b/mlir/unittests/Dialect/CMakeLists.txt @@ -7,4 +7,5 @@ MLIRDialect) add_subdirectory(Quant) +add_subdirectory(SparseTensor) add_subdirectory(SPIRV) diff --git a/mlir/unittests/Dialect/SparseTensor/CMakeLists.txt b/mlir/unittests/Dialect/SparseTensor/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/unittests/Dialect/SparseTensor/CMakeLists.txt @@ -0,0 +1,7 @@ +add_mlir_unittest(MLIRSparseTensorTests + MergerTest.cpp +) +target_link_libraries(MLIRSparseTensorTests + PRIVATE + MLIRSparseTensorUtils +) diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -0,0 +1,260 @@ +#include "mlir/Dialect/SparseTensor/Utils/Merger.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include + +using namespace mlir::sparse_tensor; + +namespace { + +struct Pattern; + +struct PatternChildren { + std::shared_ptr e0; + std::shared_ptr e1; +}; + +/// Simple recursive data structure used to match expressions in Mergers. +struct Pattern { + Kind kind; + + /// Expressions representing tensors simply have a tensor number. + unsigned tensorNum; + + /// Tensor operations point to their children. + PatternChildren children; + + /// Constructors. + /// Rather than using these, please use the readable helper constructor + /// functions below to make tests more readable. + Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {} + Pattern(Kind kind, std::shared_ptr e0, std::shared_ptr e1) + : kind(kind) { + assert(kind >= Kind::kMulF); + assert(e0 && e1); + children.e0 = e0; + children.e1 = e1; + } +}; + +/// +/// Readable Pattern builder functions. +/// These should be preferred over the actual constructors. +/// + +static std::shared_ptr tensorPattern(unsigned tensorNum) { + return std::make_shared(tensorNum); +} + +static std::shared_ptr addfPattern(std::shared_ptr e0, + std::shared_ptr e1) { + return std::make_shared(Kind::kAddF, e0, e1); +} + +static std::shared_ptr mulfPattern(std::shared_ptr e0, + std::shared_ptr e1) { + return std::make_shared(Kind::kMulF, e0, e1); +} + +class MergerTestBase : public ::testing::Test { +protected: + MergerTestBase(unsigned numTensors, unsigned numLoops) + : numTensors(numTensors), numLoops(numLoops), + merger(numTensors, numLoops) {} + + /// + /// Expression construction helpers. + /// + + unsigned tensor(unsigned tensor) { + return merger.addExp(Kind::kTensor, tensor); + } + + unsigned addf(unsigned e0, unsigned e1) { + return merger.addExp(Kind::kAddF, e0, e1); + } + + unsigned mulf(unsigned e0, unsigned e1) { + return merger.addExp(Kind::kMulF, e0, e1); + } + + /// + /// Comparison helpers. + /// + + /// For readability of tests. + unsigned lat(unsigned lat) { return lat; } + + /// Returns true if a lattice point with an expression matching the given + /// pattern and bits matching the given bits is present in lattice points + /// [p, p+n) of lattice set s. This is useful for testing partial ordering + /// constraints between lattice points. We generally know how contiguous + /// groups of lattice points should be ordered with respect to other groups, + /// but there is no required ordering within groups. + bool latPointWithinRange(unsigned s, unsigned p, unsigned n, + std::shared_ptr pattern, + llvm::BitVector bits) { + for (unsigned i = p; i < p + n; ++i) { + if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) && + compareBits(s, i, bits)) + return true; + } + return false; + } + + /// Wrapper over latPointWithinRange for readability of tests. + void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n, + std::shared_ptr pattern, + llvm::BitVector bits) { + EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits)); + } + + /// Wrapper over expectLatPointWithinRange for a single lat point. + void expectLatPoint(unsigned s, unsigned p, std::shared_ptr pattern, + llvm::BitVector bits) { + EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits)); + } + + /// Converts a vector of (loop, tensor) pairs to a bitvector with the + /// corresponding bits set. + llvm::BitVector + loopsToBits(std::vector> loops) { + llvm::BitVector testBits = llvm::BitVector(numTensors + 1, false); + for (auto l : loops) { + auto loop = std::get<0>(l); + auto tensor = std::get<1>(l); + testBits.set(numTensors * loop + tensor); + } + return testBits; + } + + /// Returns true if the bits of lattice point p in set s match the given bits. + bool compareBits(unsigned s, unsigned p, llvm::BitVector bits) { + return merger.lat(merger.set(s)[p]).bits == bits; + } + + /// Check that there are n lattice points in set s. + void expectNumLatPoints(unsigned s, unsigned n) { + EXPECT_THAT(merger.set(s).size(), n); + } + + /// Compares expressions for equality. Equality is defined recursively as: + /// - Two expressions can only be equal if they have the same Kind. + /// - Two binary expressions are equal if they have the same Kind and their + /// children are equal. + /// - Expressions with Kind invariant or tensor are equal if they have the + /// same expression id. + bool compareExpression(unsigned e, std::shared_ptr pattern) { + auto tensorExp = merger.exp(e); + if (tensorExp.kind != pattern->kind) + return false; + assert(tensorExp.kind != Kind::kInvariant && + "Invariant comparison not yet supported"); + switch (tensorExp.kind) { + case Kind::kTensor: + return tensorExp.tensor == pattern->tensorNum; + case Kind::kZero: + return true; + case Kind::kMulF: + case Kind::kMulI: + case Kind::kAddF: + case Kind::kAddI: + case Kind::kSubF: + case Kind::kSubI: + return compareExpression(tensorExp.children.e0, pattern->children.e0) && + compareExpression(tensorExp.children.e1, pattern->children.e1); + default: + llvm_unreachable("Unhandled Kind"); + } + } + + unsigned numTensors; + unsigned numLoops; + Merger merger; +}; + +class MergerTest3T1L : public MergerTestBase { +protected: + // Our three tensors (two inputs, one output). + const unsigned t0 = 0, t1 = 1, t2 = 2; + + // Our single loop. + const unsigned l0 = 0; + + MergerTest3T1L() : MergerTestBase(3, 1) { + // Tensor 0: sparse input vector. + merger.addExp(Kind::kTensor, t0, -1u); + merger.setDim(t0, l0, Dim::kSparse); + + // Tensor 1: sparse input vector. + merger.addExp(Kind::kTensor, t1, -1u); + merger.setDim(t1, l0, Dim::kSparse); + + // Tensor 2: dense output vector. + merger.addExp(Kind::kTensor, t2, -1u); + merger.setDim(t2, l0, Dim::kDense); + } +}; + +} // anonymous namespace + +/// Vector addition of 2 vectors, i.e.: +/// a(i) = b(i) + c(i) +/// which should form the 3 lattice points +/// { +/// lat( i_00 i_01 / (tensor_0 + tensor_1) ) +/// lat( i_00 / tensor_0 ) +/// lat( i_01 / tensor_1 ) +/// } +/// and after optimization, will reduce to the 2 lattice points +/// { +/// lat( i_00 i_01 / (tensor_0 + tensor_1) ) +/// lat( i_00 / tensor_0 ) +/// } +TEST_F(MergerTest3T1L, VectorAdd2) { + // Construct expression. + auto e = addf(tensor(t0), tensor(t1)); + + // Build lattices and check. + auto s = merger.buildLattices(e, l0); + expectNumLatPoints(s, 3); + expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)), + loopsToBits({{l0, t0}, {l0, t1}})); + expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0), + loopsToBits({{l0, t0}})); + expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1), + loopsToBits({{l0, t1}})); + + // Optimize lattices and check. + s = merger.optimizeSet(s); + expectNumLatPoints(s, 3); + expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)), + loopsToBits({{l0, t0}, {l0, t1}})); + expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0), + loopsToBits({{l0, t0}})); + expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1), + loopsToBits({{l0, t1}})); +} + +/// Vector multiplication of 2 vectors, i.e.: +/// a(i) = b(i) * c(i) +/// which should form the single lattice point +/// { +/// lat( i_00 i_01 / (tensor_0 * tensor_1) ) +/// } +TEST_F(MergerTest3T1L, VectorMul2) { + // Construct expression. + auto e = mulf(t0, t1); + + // Build lattices and check. + auto s = merger.buildLattices(e, l0); + expectNumLatPoints(s, 1); + expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)), + loopsToBits({{l0, t0}, {l0, t1}})); + + // Optimize lattices and check. + s = merger.optimizeSet(s); + expectNumLatPoints(s, 1); + expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)), + loopsToBits({{l0, t0}, {l0, t1}})); +}