diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -437,11 +437,11 @@ /// for each `TensorLoopId` and passing it the corresponding tensor /// identifier, level, and level-type. void - foreachTensorLoopId(const BitVector &bits, + foreachTensorLoopId(LatPointId p, function_ref, DimLevelType)> callback) const { - for (const TensorLoopId b : bits.set_bits()) + for (const TensorLoopId b : latPoints[p].bits.set_bits()) callback(b, tensor(b), getLvl(b), getDimLevelType(b)); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1273,18 +1273,18 @@ SmallVector tids; SmallVector lvls; - env.merger().foreachTensorLoopId( - env.lat(l0).bits, [&](TensorLoopId b, TensorId tid, - std::optional lvl, DimLevelType dlt) { - assert(env.merger().loop(b) == idx); - if (isDenseDLT(dlt) || isUndefDLT(dlt)) { - needsUniv = true; - } else { - // sparse/singleton levels. - tids.push_back(tid); - lvls.push_back(*lvl); - } - }); + env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid, + std::optional lvl, + DimLevelType dlt) { + assert(env.merger().loop(b) == idx); + if (isDenseDLT(dlt) || isUndefDLT(dlt)) { + needsUniv = true; + } else { + // sparse/singleton levels. + tids.push_back(tid); + lvls.push_back(*lvl); + } + }); env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tids, lvls); @@ -1342,7 +1342,6 @@ CodegenEnv &env, LatPointId li, LoopId ldx, SmallVectorImpl &tids, SmallVectorImpl &lvls, SmallVectorImpl &affineTids, SmallVectorImpl &affineLvls, SmallVectorImpl &exps) { - const BitVector &all = env.lat(li).bits; const BitVector &simple = env.lat(li).simple; const TensorId outTid = env.merger().getOutTensorID(); const std::optional outLvl = env.merger().getLvl(outTid, ldx); @@ -1350,8 +1349,8 @@ unsigned numloopCond = 0; bool hasNonUnique = false; env.merger().foreachTensorLoopId( - all, [&, ldx](TensorLoopId b, TensorId tid, std::optional lvl, - DimLevelType dlt) { + li, [&, ldx](TensorLoopId b, TensorId tid, std::optional lvl, + DimLevelType dlt) { if (simple.test(b)) { if (isUndefDLT(dlt)) { // An undefined dlt in the lattices, we probably mean to