diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h @@ -184,24 +184,32 @@ void exitCurrentLoop(RewriterBase &rewriter, Location loc, MutableArrayRef reduc = {}); + /// Get the range of values for all induction variables. + auto getLoopIVsRange() const { + return llvm::map_range(loopStack, [](const LoopInfo &li) { return li.iv; }); + } + /// Fills the out-parameter with the loop induction variables for all /// loops in the current loop-stack. The variables are given in the /// same order as the loop-stack, hence `ivs` should be indexed into /// by `LoopOrd` (not `LoopId`). - void getLoopIVs(SmallVectorImpl &ivs) const { - ivs.clear(); - ivs.reserve(getCurrentDepth()); - for (auto &l : loopStack) - ivs.push_back(l.iv); + SmallVector getLoopIVs() const { + return llvm::to_vector(getLoopIVsRange()); } /// Gets the current depth of the loop-stack. The result is given /// the type `LoopOrd` for the same reason as one-past-the-end iterators. - LoopOrd getCurrentDepth() const { return loopStack.size(); } + LoopOrd getCurrentDepth() const { + return llvm::range_size(getLoopIVsRange()); + } /// Gets loop induction variable for the given `LoopOrd`. Value getLoopIV(LoopOrd n) const { - return n < getCurrentDepth() ? loopStack[n].iv : Value(); + if (n >= getCurrentDepth()) + return Value(); + auto it = getLoopIVsRange().begin(); + std::advance(it, n); + return *it; } /// Gets the total number of manifest tensors (excluding the synthetic diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -1318,10 +1318,7 @@ reduc); } - SmallVector lcvs; - lcvs.reserve(lvlRank); - loopEmitter.getLoopIVs(lcvs); - + SmallVector lcvs = loopEmitter.getLoopIVs(); if (op.getOrder()) { // FIXME: There is some dim/lvl confusion here since `dimRank != lvlRank` SmallVector dcvs = lcvs; // keep a copy diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -977,17 +977,10 @@ // Direct insertion in lexicographic coordinate order. if (!env.isExpand()) { const LoopOrd numLoops = op.getRank(t); - // TODO: rewrite this to use `env.emitter().getLoopIVs(ivs)` - // instead. We just need to either assert that `numLoops == - // env.emitter().getCurrentDepth()`, or else update the `getLoopIVs` - // method to take an optional parameter to restrict to a smaller depth. - SmallVector ivs; - ivs.reserve(numLoops); - for (LoopOrd n = 0; n < numLoops; n++) { - const auto iv = env.emitter().getLoopIV(n); - assert(iv); - ivs.push_back(iv); - } + // Retrieves the first `numLoop` induction variables. + SmallVector ivs = llvm::to_vector( + llvm::drop_end(env.emitter().getLoopIVsRange(), + env.emitter().getCurrentDepth() - numLoops)); Value chain = env.getInsertionChain(); if (!env.getValidLexInsert()) { env.updateInsertionChain(builder.create(loc, rhs, chain, ivs)); @@ -1438,7 +1431,7 @@ /// Generates the induction structure for a while-loop. static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx, - bool needsUniv, scf::WhileOp whileOp) { + bool needsUniv) { Location loc = env.op().getLoc(); // Finalize each else branch of all if statements. if (env.isReduc() || env.isExpand() || env.getInsertionChain()) { @@ -1472,7 +1465,8 @@ builder.setInsertionPointAfter(ifOp); } } - builder.setInsertionPointToEnd(&whileOp.getAfter().front()); + // No need to set the insertion point here as LoopEmitter keeps track of the + // basic block where scf::Yield should be inserted. } /// Generates a single if-statement within a while-loop. @@ -1525,8 +1519,8 @@ /// Generates end of true branch of if-statement within a while-loop. static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp, - Operation *loop, Value redInput, Value cntInput, - Value insInput, Value validIns) { + Value redInput, Value cntInput, Value insInput, + Value validIns) { SmallVector operands; if (env.isReduc()) { operands.push_back(env.getReduc()); @@ -1800,7 +1794,7 @@ env.setValidLexInsert(constantI1(rewriter, env.op().getLoc(), true)); } else if (auto whileOp = dyn_cast(loop)) { // End a while-loop. - finalizeWhileOp(env, rewriter, idx, needsUniv, whileOp); + finalizeWhileOp(env, rewriter, idx, needsUniv); } else { needsUniv = false; } @@ -1875,8 +1869,7 @@ if (!isSingleCond) { scf::IfOp ifOp = genIf(env, rewriter, idx, lj); genStmt(env, rewriter, ej, at + 1); - endIf(env, rewriter, ifOp, loop, redInput, cntInput, insInput, - validIns); + endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns); } else { genStmt(env, rewriter, ej, at + 1); }