diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h @@ -271,7 +271,6 @@ //===----------------------------------------------------------------------===// class Ranks final { // Not using `VarKindArray` since `EnumeratedArray` doesn't support constexpr. - // TODO(wrengr): to what extent do we actually care about constexpr here? unsigned impl[3]; static constexpr unsigned to_index(VarKind vk) { @@ -303,6 +302,14 @@ static_assert(IsZeroCostAbstraction); //===----------------------------------------------------------------------===// +/// Efficient representation of a set of `Var`. +/// +/// NOTE: For the `contains`/`occursIn` methods: if variables occurring in +/// the method parameter are OOB for the `VarSet`, then these methods will +/// always return false. However, for the `add` methods: OOB parameters +/// cause undefined behavior. Currently the `add` methods will raise an +/// assertion error; though we may change that behavior in the future +/// (e.g., to resize the underlying bitvectors). class VarSet final { // If we're willing to give up the possibility of resizing the // individual bitvectors, then we could flatten this into a single @@ -314,14 +321,12 @@ public: explicit VarSet(Ranks const &ranks); - // TODO(wrengr): can we come up with a single name that works for all three of - // these? bool contains(Var var) const; bool occursIn(VarSet const &vars) const; bool occursIn(DimLvlExpr expr) const; void add(Var var); - // TODO(wrengr): void add(VarSet const& vars); + void add(VarSet const &vars); void add(DimLvlExpr expr); }; @@ -397,7 +402,6 @@ VarInfo::ID nextID() const { return static_cast(vars.size()); } public: - // NOTE TO Wren: initializer needed! VarEnv() : nextNum(0) {} /// Gets the underlying storage for the `VarInfo` identified by diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp @@ -56,22 +56,22 @@ VarKind::Dimension, VarKind::Symbol, VarKind::Level}; VarSet::VarSet(Ranks const &ranks) { - // FIXME(wrengr): will this DWIM, or do we need to worry about - // `reserve` causing resizing/dangling issues? for (const auto vk : everyVarKind) impl[vk].reserve(ranks.getRank(vk)); } bool VarSet::contains(Var var) const { - // FIXME(wrengr): this implementation will raise assertion failure on OOB; - // but perhaps we'd rather have this return false on OOB? That's - // necessary for consistency with the `anyCommon` implementation of - // `occursIn(VarSet)`. + // NOTE: We make sure to return false on OOB, for consistency with + // the `anyCommon` implementation of `VarSet::occursIn(VarSet)`. + // However beware that, as always with silencing OOB, this can hide + // bugs in client code. const llvm::SmallBitVector &bits = impl[var.getKind()]; - // NOTE TO Wren: did this to avoid OOB but perhaps it is result of bug - if (var.getNum() >= bits.size()) - return false; - return bits[var.getNum()]; + const auto num = var.getNum(); + // FIXME(wrengr): If we `assert(num < bits.size())` then + // "roundtrip_encoding.mlir" will fail. So we need to figure out + // where exactly the OOB `var` is coming from, to determine whether + // that's a logic bug or not. + return num < bits.size() && bits[num]; } bool VarSet::occursIn(VarSet const &other) const { @@ -105,13 +105,20 @@ } void VarSet::add(Var var) { - // FIXME(wrengr): this implementation will raise assertion failure on OOB; - // but perhaps we'd rather have this be a noop on OOB? or to grow - // the underlying bitvectors on OOB? + // NOTE: `SmallBitVactor::operator[]` will raise assertion errors for OOB. impl[var.getKind()][var.getNum()] = true; } -// TODO(wrengr): void VarSet::add(VarSet const& other); +void VarSet::add(VarSet const &other) { + // NOTE: `SmallBitVector::operator&=` will implicitly resize + // the bitvector (unlike `BitVector::operator&=`), so we add an + // assertion against OOB for consistency with the implementation + // of `VarSet::add(Var)`. + for (const auto vk : everyVarKind) { + assert(impl[vk].size() >= other.impl[vk].size()); + impl[vk] &= other.impl[vk]; + } +} void VarSet::add(DimLvlExpr expr) { if (!expr)