diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp @@ -252,6 +252,8 @@ ArrayRef lvlSpecs) : symRank(symRank), dimSpecs(dimSpecs), lvlSpecs(lvlSpecs) { // First, check integrity of the variable-binding structure. + // NOTE: This establishes the invariant that calls to `VarSet::add` + // below cannot cause OOB errors. assert(isWF()); // TODO: Second, we need to infer/validate the `lvlToDim` mapping. @@ -260,14 +262,19 @@ // needs to happen before the code for setting every `LvlSpec::elideVar`, // since if the LvlVar is only used in elided DimExpr, then the // LvlVar should also be elided. + // NOTE: Whenever we set a new DimExpr, we must make sure to validate it + // against our ranks, to restore the invariant established by `isWF` above. + // TODO(wrengr): We might should adjust the `DimLvlExpr` ctor to take a + // `Ranks` argument and perform the validation then. // Third, we set every `LvlSpec::elideVar` according to whether that // LvlVar occurs in a non-elided DimExpr (TODO: or CountingExpr). + // NOTE: The invariant established by `isWF` ensures that the following + // calls to `VarSet::add` cannot raise OOB errors. VarSet usedVars(getRanks()); - // NOTE TO Wren: bypassed for now - // for (const auto &dimSpec : dimSpecs) - // if (!dimSpec.canElideExpr()) - // usedVars.add(dimSpec.getExpr()); + for (const auto &dimSpec : dimSpecs) + if (!dimSpec.canElideExpr()) + usedVars.add(dimSpec.getExpr()); for (auto &lvlSpec : this->lvlSpecs) lvlSpec.setElideVar(!usedVars.contains(lvlSpec.getBoundVar())); } diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp @@ -56,8 +56,12 @@ VarKind::Dimension, VarKind::Symbol, VarKind::Level}; VarSet::VarSet(Ranks const &ranks) { + // NOTE: We must not use `reserve` here, since that doesn't change + // the `size` of the bitvectors and therefore will result in unexpected + // OOB errors. Either `resize` or copy/move-ctor work; we opt for the + // move-ctor since it should be (marginally) more efficient. for (const auto vk : everyVarKind) - impl[vk].reserve(ranks.getRank(vk)); + impl[vk] = llvm::SmallBitVector(ranks.getRank(vk)); } bool VarSet::contains(Var var) const {