diff --git a/mlir/unittests/Dialect/SparseTensor/CMakeLists.txt b/mlir/unittests/Dialect/SparseTensor/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/unittests/Dialect/SparseTensor/CMakeLists.txt @@ -0,0 +1,7 @@ +add_mlir_unittest(MLIRSparseTensorTests + MergerTest.cpp +) +target_link_libraries(MLIRSparseTensorTests + PRIVATE + MLIRSparseTensorUtils +) diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -0,0 +1,312 @@ +#include "mlir/Dialect/SparseTensor/Utils/Merger.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace mlir::sparse_tensor; + +namespace { + +/// Compares e0 and e1 for equality. Equality is defined recursively as: +/// - Two expressions can only be equal if they have the same Kind. +/// - Two binary expressions are equal if they have the same Kind and their +/// children are equal. +/// - Expressions with Kind invariant or tensor are equal if they have the same +/// expression id. +bool compareExpression(Merger &merger, unsigned e0, unsigned e1) { + auto tensorExp0 = merger.exp(e0); + auto tensorExp1 = merger.exp(e1); + if (tensorExp0.kind != tensorExp1.kind) { + return false; + } + switch (tensorExp0.kind) { + case Kind::kAddF: + case Kind::kMulF: + case Kind::kAddI: + case Kind::kMulI: + return compareExpression(merger, tensorExp0.e0, tensorExp1.e0) && + compareExpression(merger, tensorExp0.e1, tensorExp1.e1); + case Kind::kInvariant: + return tensorExp0.val == tensorExp1.val; + case Kind::kTensor: + return e0 == e1; + } +} + +bool similarLatPointIsInSet(Merger &merger, LatPoint latPoint, unsigned s) { + for (auto p : merger.set(s)) { + if (compareExpression(merger, merger.lat(p).exp, latPoint.exp) && + merger.lat(p).bits == latPoint.bits) { + return true; + } + } + return false; +} + +} // anonymous namespace + +/// Vector addition of 4 vectors, i.e.: +/// a(i) = (b(i) + c(i)) + (d(i) + e(i)) +/// which should form the 15 lattice points +/// { +/// lat( i_00 i_01 i_02 i_03 / ((tensor_0 + tensor_1) +/// + (tensor_2 + tensor_3))) +/// lat( i_00 i_01 i_02 / ((tensor_0 + tensor_1) + tensor_2)) +/// lat( i_00 i_01 i_03 / ((tensor_0 + tensor_1) + tensor_3)) +/// lat( i_00 i_02 i_03 / (tensor_0 + (tensor_2 + tensor_3))) +/// lat( i_01 i_02 i_03 / (tensor_1 + (tensor_2 + tensor_3))) +/// lat( i_00 i_02 / (tensor_0 + tensor_2)) +/// lat( i_00 i_03 / (tensor_0 + tensor_3)) +/// lat( i_01 i_02 / (tensor_1 + tensor_2)) +/// lat( i_01 i_03 / (tensor_1 + tensor_3)) +/// lat( i_00 i_01 / (tensor_0 + tensor_1)) +/// lat( i_02 i_03 / (tensor_2 + tensor_3)) +/// lat( i_00 / tensor_0) +/// lat( i_01 / tensor_1) +/// lat( i_02 / tensor_2) +/// lat( i_03 / tensor_3) +/// } +TEST(MergerTest, VectorAdd4) { + const unsigned NUM_TENSORS = 5; + const unsigned NUM_LOOPS = 1; + + Merger merger = Merger(NUM_TENSORS, NUM_LOOPS); + auto b = merger.addExp(Kind::kTensor, 0); + auto c = merger.addExp(Kind::kTensor, 1); + auto d = merger.addExp(Kind::kTensor, 2); + auto e = merger.addExp(Kind::kTensor, 3); + merger.setDim(b, 0, Dim::kSparse); + merger.setDim(c, 0, Dim::kSparse); + merger.setDim(d, 0, Dim::kSparse); + merger.setDim(e, 0, Dim::kSparse); + auto add0 = merger.addExp(Kind::kAddF, b, c); + auto add1 = merger.addExp(Kind::kAddF, d, e); + auto add2 = merger.addExp(Kind::kAddF, add0, add1); + + auto s = merger.optimizeSet(merger.buildLattices(add2, 0)); + + EXPECT_THAT(merger.set(s).size(), 15); + + { + // Test for + // lat( i_00 i_01 i_02 i_03 / ((tensor_0 + tensor_1) + // + (tensor_2 + tensor_3))) + // Construct expression + auto testExp0 = merger.addExp(Kind::kAddF, b, c); + auto testExp1 = merger.addExp(Kind::kAddF, d, e); + auto testExp = merger.addExp(Kind::kAddF, testExp0, testExp1); + // Construct bits + llvm::BitVector testBits = llvm::BitVector(NUM_TENSORS + 1, false); + testBits.set(NUM_TENSORS * 0 + 0); + testBits.set(NUM_TENSORS * 0 + 1); + testBits.set(NUM_TENSORS * 0 + 2); + testBits.set(NUM_TENSORS * 0 + 3); + auto testLatPoint = LatPoint(testBits, testExp); + EXPECT_TRUE(similarLatPointIsInSet(merger, testLatPoint, s)); + } + + { + // Test for + // lat( i_00 i_01 i_02 / ((tensor_0 + tensor_1) + tensor_2)) + // Construct expression + auto testExp0 = merger.addExp(Kind::kAddF, b, c); + auto testExp = merger.addExp(Kind::kAddF, testExp0, d); + // Construct bits + llvm::BitVector testBits = llvm::BitVector(NUM_TENSORS + 1, false); + testBits.set(NUM_TENSORS * 0 + 0); + testBits.set(NUM_TENSORS * 0 + 1); + testBits.set(NUM_TENSORS * 0 + 2); + auto testLatPoint = LatPoint(testBits, testExp); + EXPECT_TRUE(similarLatPointIsInSet(merger, testLatPoint, s)); + } + + { + // Test for + // lat( i_00 i_01 i_03 / ((tensor_0 + tensor_1) + tensor_3)) + // Construct expression + auto testExp0 = merger.addExp(Kind::kAddF, b, c); + auto testExp = merger.addExp(Kind::kAddF, testExp0, e); + // Construct bits + llvm::BitVector testBits = llvm::BitVector(NUM_TENSORS + 1, false); + testBits.set(NUM_TENSORS * 0 + 0); + testBits.set(NUM_TENSORS * 0 + 1); + testBits.set(NUM_TENSORS * 0 + 3); + auto testLatPoint = LatPoint(testBits, testExp); + EXPECT_TRUE(similarLatPointIsInSet(merger, testLatPoint, s)); + } + + { + // Test for + // lat( i_00 i_02 i_03 / (tensor_0 + (tensor_2 + tensor_3))) + // Construct expression + auto testExp0 = merger.addExp(Kind::kAddF, d, e); + auto testExp = merger.addExp(Kind::kAddF, b, testExp0); + // Construct bits + llvm::BitVector testBits = llvm::BitVector(NUM_TENSORS + 1, false); + testBits.set(NUM_TENSORS * 0 + 0); + testBits.set(NUM_TENSORS * 0 + 2); + testBits.set(NUM_TENSORS * 0 + 3); + auto testLatPoint = LatPoint(testBits, testExp); + EXPECT_TRUE(similarLatPointIsInSet(merger, testLatPoint, s)); + } + + { + // Test for + // lat( i_01 i_02 i_03 / (tensor_1 + (tensor_2 + tensor_3))) + // Construct expression + auto testExp0 = merger.addExp(Kind::kAddF, d, e); + auto testExp = merger.addExp(Kind::kAddF, c, testExp0); + // Construct bits + llvm::BitVector testBits = llvm::BitVector(NUM_TENSORS + 1, false); + testBits.set(NUM_TENSORS * 0 + 1); + testBits.set(NUM_TENSORS * 0 + 2); + testBits.set(NUM_TENSORS * 0 + 3); + auto testLatPoint = LatPoint(testBits, testExp); + EXPECT_TRUE(similarLatPointIsInSet(merger, testLatPoint, s)); + } + + { + // Test for + // lat( i_00 i_02 / (tensor_0 + tensor_2)) + // Construct expression + auto testExp = merger.addExp(Kind::kAddF, b, d); + // Construct bits + llvm::BitVector testBits = llvm::BitVector(NUM_TENSORS + 1, false); + testBits.set(NUM_TENSORS * 0 + 0); + testBits.set(NUM_TENSORS * 0 + 2); + auto testLatPoint = LatPoint(testBits, testExp); + EXPECT_TRUE(similarLatPointIsInSet(merger, testLatPoint, s)); + } + + { + // Test for + // lat( i_00 i_03 / (tensor_0 + tensor_3)) + // Construct expression + auto testExp = merger.addExp(Kind::kAddF, b, e); + // Construct bits + llvm::BitVector testBits = llvm::BitVector(NUM_TENSORS + 1, false); + testBits.set(NUM_TENSORS * 0 + 0); + testBits.set(NUM_TENSORS * 0 + 3); + auto testLatPoint = LatPoint(testBits, testExp); + EXPECT_TRUE(similarLatPointIsInSet(merger, testLatPoint, s)); + } + + { + // Test for + // lat( i_01 i_02 / (tensor_1 + tensor_2)) + // Construct expression + auto testExp = merger.addExp(Kind::kAddF, c, d); + // Construct bits + llvm::BitVector testBits = llvm::BitVector(NUM_TENSORS + 1, false); + testBits.set(NUM_TENSORS * 0 + 1); + testBits.set(NUM_TENSORS * 0 + 2); + auto testLatPoint = LatPoint(testBits, testExp); + EXPECT_TRUE(similarLatPointIsInSet(merger, testLatPoint, s)); + } + + { + // Test for + // lat( i_01 i_03 / (tensor_1 + tensor_3)) + // Construct expression + auto testExp = merger.addExp(Kind::kAddF, c, e); + // Construct bits + llvm::BitVector testBits = llvm::BitVector(NUM_TENSORS + 1, false); + testBits.set(NUM_TENSORS * 0 + 1); + testBits.set(NUM_TENSORS * 0 + 3); + auto testLatPoint = LatPoint(testBits, testExp); + EXPECT_TRUE(similarLatPointIsInSet(merger, testLatPoint, s)); + } + + { + // Test for + // lat( i_00 / tensor_0 ) + // Construct expression + auto testExp = b; + // Construct bits + llvm::BitVector testBits = llvm::BitVector(NUM_TENSORS + 1, false); + testBits.set(NUM_TENSORS * 0 + 0); + auto testLatPoint = LatPoint(testBits, testExp); + EXPECT_TRUE(similarLatPointIsInSet(merger, testLatPoint, s)); + } + + { + // Test for + // lat( i_01 / tensor_1 ) + // Construct expression + auto testExp = c; + // Construct bits + llvm::BitVector testBits = llvm::BitVector(NUM_TENSORS + 1, false); + testBits.set(NUM_TENSORS * 0 + 1); + auto testLatPoint = LatPoint(testBits, testExp); + EXPECT_TRUE(similarLatPointIsInSet(merger, testLatPoint, s)); + } + + { + // Test for + // lat( i_02 / tensor_2 ) + // Construct expression + auto testExp = d; + // Construct bits + llvm::BitVector testBits = llvm::BitVector(NUM_TENSORS + 1, false); + testBits.set(NUM_TENSORS * 0 + 2); + auto testLatPoint = LatPoint(testBits, testExp); + EXPECT_TRUE(similarLatPointIsInSet(merger, testLatPoint, s)); + } + + { + // Test for + // lat( i_03 / tensor_3 ) + // Construct expression + auto testExp = e; + // Construct bits + llvm::BitVector testBits = llvm::BitVector(NUM_TENSORS + 1, false); + testBits.set(NUM_TENSORS * 0 + 3); + auto testLatPoint = LatPoint(testBits, testExp); + EXPECT_TRUE(similarLatPointIsInSet(merger, testLatPoint, s)); + } +} + +// Vector multplication of 4 vectors, i.e.: +// a(i) = (b(i) * c(i)) * (d(i) * e(i)) +// which should form the single lattice point +// { +// lat( i_00 i_01 i_02 i_03 / ((tensor_0 * tensor_1) * (tensor_2 * tensor_3))) +// } +TEST(MergerTest, VectorMul4) { + const unsigned NUM_TENSORS = 5; + const unsigned NUM_LOOPS = 1; + + Merger merger = Merger(NUM_TENSORS, NUM_LOOPS); + auto b = merger.addExp(Kind::kTensor, 0); + auto c = merger.addExp(Kind::kTensor, 1); + auto d = merger.addExp(Kind::kTensor, 2); + auto e = merger.addExp(Kind::kTensor, 3); + merger.setDim(b, 0, Dim::kSparse); + merger.setDim(c, 0, Dim::kSparse); + merger.setDim(d, 0, Dim::kSparse); + merger.setDim(e, 0, Dim::kSparse); + auto mul0 = merger.addExp(Kind::kMulF, b, c); + auto mul1 = merger.addExp(Kind::kMulF, d, e); + auto mul2 = merger.addExp(Kind::kMulF, mul0, mul1); + auto s = merger.optimizeSet(merger.buildLattices(mul2, 0)); + + EXPECT_THAT(merger.set(s).size(), 1); + + { + // Test for + // lat( i_00 i_01 i_02 i_03 / ((tensor_0 * tensor_1) + // * (tensor_2 * tensor_3))) + // Construct expression + auto testExp0 = merger.addExp(Kind::kMulF, b, c); + auto testExp1 = merger.addExp(Kind::kMulF, d, e); + auto testExp = merger.addExp(Kind::kMulF, testExp0, testExp1); + // Construct bits + llvm::BitVector testBits = llvm::BitVector(NUM_TENSORS + 1, false); + testBits.set(NUM_TENSORS * 0 + 0); + testBits.set(NUM_TENSORS * 0 + 1); + testBits.set(NUM_TENSORS * 0 + 2); + testBits.set(NUM_TENSORS * 0 + 3); + auto testLatPoint = LatPoint(testBits, testExp); + EXPECT_TRUE(similarLatPointIsInSet(merger, testLatPoint, s)); + } +}