diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -261,8 +261,8 @@ return getDimLevelFormat(t, i).levelType == tp; } - /// Returns true if any set bit corresponds to given dimension level type. - bool hasAnyDimLevelTypeOf(const BitVector &bits, DimLvlType tp) const; + /// Returns true if any set bit corresponds to sparse dimension level type. + bool hasAnySparse(const BitVector &bits) const; /// Dimension level format getter. DimLevelFormat getDimLevelFormat(unsigned t, unsigned i) const { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1641,8 +1641,7 @@ unsigned lsize = merger.set(lts).size(); for (unsigned i = 1; i < lsize; i++) { unsigned li = merger.set(lts)[i]; - if (!merger.hasAnyDimLevelTypeOf(merger.lat(li).simple, DimLvlType::kCompressed) && - !merger.hasAnyDimLevelTypeOf(merger.lat(li).simple, DimLvlType::kSingleton)) + if (!merger.hasAnySparse(merger.lat(li).simple)) return true; } } diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -262,13 +262,8 @@ } } // Now apply the two basic rules. - // - // TODO: improve for singleton and properties - // BitVector simple = latPoints[p0].bits; - bool reset = isSingleton && - (hasAnyDimLevelTypeOf(simple, DimLvlType::kCompressed) || - hasAnyDimLevelTypeOf(simple, DimLvlType::kSingleton)); + bool reset = isSingleton && hasAnySparse(simple); for (unsigned b = 0, be = simple.size(); b < be; b++) { if (simple[b] && (!isDimLevelType(b, DimLvlType::kCompressed) && @@ -297,8 +292,7 @@ bool Merger::onlyDenseDiff(unsigned i, unsigned j) { BitVector tmp = latPoints[j].bits; tmp ^= latPoints[i].bits; - return !hasAnyDimLevelTypeOf(tmp, DimLvlType::kCompressed) && - !hasAnyDimLevelTypeOf(tmp, DimLvlType::kSingleton); + return !hasAnySparse(tmp); } bool Merger::isSingleCondition(unsigned t, unsigned e) const { @@ -384,9 +378,10 @@ llvm_unreachable("unexpected kind"); } -bool Merger::hasAnyDimLevelTypeOf(const BitVector &bits, DimLvlType tp) const { +bool Merger::hasAnySparse(const BitVector &bits) const { for (unsigned b = 0, be = bits.size(); b < be; b++) - if (bits[b] && isDimLevelType(b, tp)) + if (bits[b] && (isDimLevelType(b, DimLvlType::kCompressed) || + isDimLevelType(b, DimLvlType::kSingleton))) return true; return false; }