diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -437,11 +437,11 @@ /// for each `TensorLoopId` and passing it the corresponding tensor /// identifier, level, and level-type. void - foreachTensorLoopId(const BitVector &bits, + foreachTensorLoopId(LatPointId p, function_ref, DimLevelType)> callback) const { - for (const TensorLoopId b : bits.set_bits()) + for (const TensorLoopId b : latPoints[p].bits.set_bits()) callback(b, tensor(b), getLvl(b), getDimLevelType(b)); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1274,8 +1274,7 @@ SmallVector tids; SmallVector lvls; env.merger().foreachTensorLoopId( - env.lat(l0).bits, [&](TensorLoopId b, TensorId tid, - std::optional lvl, DimLevelType dlt) { + l0, [&](TensorLoopId b, TensorId tid, std::optional lvl, DimLevelType dlt) { assert(env.merger().loop(b) == idx); if (isDenseDLT(dlt) || isUndefDLT(dlt)) { needsUniv = true; @@ -1342,7 +1341,6 @@ CodegenEnv &env, LatPointId li, LoopId ldx, SmallVectorImpl &tids, SmallVectorImpl &lvls, SmallVectorImpl &affineTids, SmallVectorImpl &affineLvls, SmallVectorImpl &exps) { - const BitVector &all = env.lat(li).bits; const BitVector &simple = env.lat(li).simple; const TensorId outTid = env.merger().getOutTensorID(); const std::optional outLvl = env.merger().getLvl(outTid, ldx); @@ -1350,8 +1348,8 @@ unsigned numloopCond = 0; bool hasNonUnique = false; env.merger().foreachTensorLoopId( - all, [&, ldx](TensorLoopId b, TensorId tid, std::optional lvl, - DimLevelType dlt) { + li, [&, ldx](TensorLoopId b, TensorId tid, std::optional lvl, + DimLevelType dlt) { if (simple.test(b)) { if (isUndefDLT(dlt)) { // An undefined dlt in the lattices, we probably mean to