diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h @@ -292,6 +292,9 @@ : Ranks(ranks[VarKind::Symbol], ranks[VarKind::Dimension], ranks[VarKind::Level]) {} + bool operator==(Ranks const &other) const; + bool operator!=(Ranks const &other) const { return !(*this == other); } + constexpr unsigned getRank(VarKind vk) const { return impl[to_index(vk)]; } constexpr unsigned getSymRank() const { return getRank(VarKind::Symbol); } constexpr unsigned getDimRank() const { return getRank(VarKind::Dimension); } @@ -324,6 +327,14 @@ public: explicit VarSet(Ranks const &ranks); + unsigned getRank(VarKind vk) const { return impl[vk].size(); } + unsigned getSymRank() const { return getRank(VarKind::Symbol); } + unsigned getDimRank() const { return getRank(VarKind::Dimension); } + unsigned getLvlRank() const { return getRank(VarKind::Level); } + Ranks getRanks() const { + return Ranks(getSymRank(), getDimRank(), getLvlRank()); + } + bool contains(Var var) const; bool occursIn(VarSet const &vars) const; bool occursIn(DimLvlExpr expr) const; diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp @@ -13,6 +13,14 @@ using namespace mlir::sparse_tensor; using namespace mlir::sparse_tensor::ir_detail; +//===----------------------------------------------------------------------===// +// `VarKind` helpers. +//===----------------------------------------------------------------------===// + +/// For use in foreach loops. +static constexpr const VarKind everyVarKind[] = { + VarKind::Dimension, VarKind::Symbol, VarKind::Level}; + //===----------------------------------------------------------------------===// // `Var` implementation. //===----------------------------------------------------------------------===// @@ -32,6 +40,12 @@ // `Ranks` implementation. //===----------------------------------------------------------------------===// +bool Ranks::operator==(Ranks const &other) const { + for (const auto vk : everyVarKind) + if (getRank(vk) != other.getRank(vk)) + return false; + return true; +} bool Ranks::isValid(DimLvlExpr expr) const { assert(expr); // Compute the maximum identifiers for symbol-vars and dim/lvl-vars @@ -49,9 +63,6 @@ // `VarSet` implementation. //===----------------------------------------------------------------------===// -static constexpr const VarKind everyVarKind[] = { - VarKind::Dimension, VarKind::Symbol, VarKind::Level}; - VarSet::VarSet(Ranks const &ranks) { // NOTE: We must not use `reserve` here, since that doesn't change // the `size` of the bitvectors and therefore will result in unexpected @@ -59,6 +70,7 @@ // move-ctor since it should be (marginally) more efficient. for (const auto vk : everyVarKind) impl[vk] = llvm::SmallBitVector(ranks.getRank(vk)); + assert(getRanks() == ranks); } bool VarSet::contains(Var var) const {