diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -230,6 +230,10 @@ return tensor(b) == outTensor && index(b) == i; } + /// Gets tensor ID for special tensors + unsigned getOutTensorID() const { return outTensor; } + unsigned getSynTensorID() const { return syntheticTensor; } + /// Returns true if given tensor iterates *only* in the given tensor /// expression. For the output tensor, this defines a "simply dynamic" /// operation [Bik96]. For instance: a(i) *= 2.0 or a(i) += a(i) for diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp --- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -309,6 +309,8 @@ const unsigned l0 = 0; MergerTest3T1L() : MergerTestBase(3, 1) { + EXPECT_TRUE(merger.getOutTensorID() == t2); + // Tensor 0: sparse input vector. merger.addExp(Kind::kTensor, t0, -1u); merger.setDimLevelType(t0, l0, DimLevelType::Compressed); @@ -332,6 +334,8 @@ const unsigned l0 = 0; MergerTest4T1L() : MergerTestBase(4, 1) { + EXPECT_TRUE(merger.getOutTensorID() == t3); + // Tensor 0: sparse input vector. merger.addExp(Kind::kTensor, t0, -1u); merger.setDimLevelType(t0, l0, DimLevelType::Compressed); @@ -363,6 +367,8 @@ const unsigned l0 = 0; MergerTest3T1LD() : MergerTestBase(3, 1) { + EXPECT_TRUE(merger.getOutTensorID() == t2); + // Tensor 0: sparse input vector. merger.addExp(Kind::kTensor, t0, -1u); merger.setDimLevelType(t0, l0, DimLevelType::Compressed); @@ -383,13 +389,15 @@ class MergerTest4T1LU : public MergerTestBase { protected: - // Our three tensors (two inputs, one output). + // Our three tensors (three inputs, one output). const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3; // Our single loop. const unsigned l0 = 0; MergerTest4T1LU() : MergerTestBase(4, 1) { + EXPECT_TRUE(merger.getOutTensorID() == t3); + // Tensor 0: undef input vector. merger.addExp(Kind::kTensor, t0, -1u); merger.setDimLevelType(t0, l0, DimLevelType::Undef); @@ -421,6 +429,9 @@ const unsigned l0 = 0; MergerTest3T1L_SO() : MergerTestBase(3, 1) { + EXPECT_TRUE(merger.getOutTensorID() == t2); + EXPECT_TRUE(merger.getSynTensorID() == t3); + merger.setHasSparseOut(true); // Tensor 0: undef input vector.