diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -265,11 +265,19 @@ // Now apply the two basic rules. BitVector simple = latPoints[p0].bits; bool reset = isSingleton && hasAnySparse(simple); + // Starts resetting from a dense dimension, so that the first bit (if kept) + // are not undefined dimension type. + unsigned offset = 0; + for (unsigned b = 0, be = simple.size(); b < be; b++) + if (simple[b] && isDimLevelType(b, DimLvlType::kDense)) + offset = b; + for (unsigned b = 0, be = simple.size(); b < be; b++) { - if (simple[b] && (!isDimLevelType(b, DimLvlType::kCompressed) && - !isDimLevelType(b, DimLvlType::kSingleton))) { + unsigned i = (offset + b) % simple.size(); + if (simple[i] && (!isDimLevelType(i, DimLvlType::kCompressed) && + !isDimLevelType(i, DimLvlType::kSingleton))) { if (reset) - simple.reset(b); + simple.reset(i); reset = true; } } diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp --- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -377,8 +377,64 @@ } }; +/// +/// Tests with both undef and dense input. +/// +class MergerTest3T1LU : public MergerTestBase { +protected: + // Our three tensors (two inputs, one output). + const unsigned t0 = 0, t1 = 1, t2 = 2; + + // Our single loop. + const unsigned l0 = 0; + + MergerTest3T1LU() : MergerTestBase(3, 1) { + // Tensor 0: undef input vector. + merger.addExp(Kind::kTensor, t0, -1u); + merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kUndef)); + + // Tensor 1: dense input vector. + merger.addExp(Kind::kTensor, t1, -1u); + merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kDense)); + + // Tensor 2: dense output vector. + merger.addExp(Kind::kTensor, t2, -1u); + merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kDense)); + } +}; } // namespace +/// Vector multiplication (conjunction) of 2 vectors, i.e.; +/// a(i) = b(i) * c(i) +/// which should form the single lattice point +/// { +/// lat( i_00_U i_01_D / (tensor_0 * tensor_1) ) +/// } +/// after optimization, the dense dimesion should be kept, despite it appears +/// after the undef dimension +/// { +/// lat( i_01_D / (tensor_0 * tensor_1) ) +/// } +#define IMPL_MERGER_TEST_CONJ(OP) \ + TEST_F(MergerTest3T1LU, vector_##OP) { \ + auto e = OP##Expr(t0, t1); \ + auto p0 = tensorPattern(t0); \ + auto p1 = tensorPattern(t1); \ + auto s = merger.buildLattices(e, l0); \ + \ + expectNumLatPoints(s, 1); \ + expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ + loopsToBits({{l0, t0}, {l0, t1}})); \ + \ + s = merger.optimizeSet(s); \ + expectNumLatPoints(s, 1); \ + expectLatPoint(s, lat(0), OP##Pattern(p0, p1), loopsToBits({{l0, t1}}), \ + true); \ + } +FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ) + +#undef IMPL_MERGER_TEST_CONJ + /// Vector addition (disjunction) of 2 vectors. i.e.; /// a(i) = b(i) + c(i) /// which should form the 3 lattice points