diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h @@ -85,6 +85,21 @@ return latticeMerger.getDimLevelType(b); } + // + // LoopEmitter delegates. + // + + constexpr TensorLevel makeTensorLevel(unsigned t, unsigned l) const { + // Make sure LoopEmitter, GenericOp, and Merger agree on the number of + // tensors. Merger has one more synthetic tensor for loop invariants. + assert(loopEmitter.getNumTensors() == linalgOp->getNumOperands() && + loopEmitter.getNumTensors() == latticeMerger.getNumTensors() - 1); + return loopEmitter.makeTensorLevel(t, l); + } + std::pair toTidLvlPair(TensorLevel tl) const { + return loopEmitter.toTidLvlPair(tl); + } + // // Code generation environment verify functions. // diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h @@ -42,6 +42,8 @@ // typecheck this to avoid mixups in the code. using LoopOrd = unsigned; +// A compressed pair. +using TensorLevel = unsigned; //===----------------------------------------------------------------------===// // SparseTensorLoopEmiter class, manages sparse tensors and helps to // generate loop structure to (co)-iterate sparse tensors. @@ -134,7 +136,7 @@ /// // loop sequence end. /// } void enterNewLoopSeq(OpBuilder &builder, Location loc, - ArrayRef tids, ArrayRef lvls); + ArrayRef tidLvls); /// Exits the current loop sequence, this will reset universal index to 0. void exitCurrentLoopSeq(OpBuilder &builder, Location loc); @@ -149,8 +151,7 @@ /// The function will also perform in-place update on the `reduc` vector to /// return the reduction variable used inside the generated loop. Operation *enterLoopOverTensorAtLvl(OpBuilder &builder, Location loc, - ArrayRef tids, - ArrayRef lvls, + ArrayRef tidLvls, MutableArrayRef reduc = {}, bool isParallel = false); @@ -159,13 +160,13 @@ AffineExpr affine, MutableArrayRef reduc = {}); - void genDenseAffineAddress(OpBuilder &builder, Location loc, TensorId tid, - Level lvl, AffineExpr lvlExpr); + void genDenseAffineAddress(OpBuilder &builder, Location loc, + TensorLevel tidLvl, AffineExpr lvlExpr); /// Emits a co-iteration loop over a set of tensors. Operation *enterCoIterationOverTensorsAtLvls( - OpBuilder &builder, Location loc, ArrayRef tids, - ArrayRef lvls, bool needsUniv, MutableArrayRef reduc = {}); + OpBuilder &builder, Location loc, ArrayRef tidLvls, + bool needsUniv, MutableArrayRef reduc = {}); void exitCurrentLoop(RewriterBase &rewriter, Location loc, MutableArrayRef reduc = {}); @@ -190,6 +191,31 @@ return n < getCurrentDepth() ? loopStack[n].iv : Value(); } + /// Gets the total number of tensors that loopEmitter is operating on. + unsigned getNumTensors() const { return tensors.size(); } + + /// Compresses a TensorId and Level into a TensorLevel. + constexpr TensorLevel makeTensorLevel(TensorId t, Level l) const { + return l * getNumTensors() + t; + } + + /// De-compresses a TensorLevel back to a pair of TensorId and Level. + std::pair toTidLvlPair(TensorLevel tidLvl) const { + unsigned nt = getNumTensors(); + return std::make_pair(tidLvl % nt, tidLvl / nt); + } + + // Maps a Container into a range>. + template ::value_type, + TensorLevel>, + bool> = true> + auto toTidLvlPairRange(ContainerTy &&c) const { + return llvm::map_range( + c, [this](TensorLevel tl) { return this->toTidLvlPair(tl); }); + } + /// /// Getters. /// @@ -209,32 +235,30 @@ } private: + // A tuple that stored the slice-driven loop information. + struct SliceLoopInfo { + TensorId tid; + Level lvl; + bool reduced; + SliceLoopInfo(TensorId tid, Level lvl, bool reduced) + : tid(tid), lvl(lvl), reduced(reduced) {} + }; // LoopInfo stores information of a loop generated by LoopEmitter. E.g., // the set of tensors levels that the loop is iterating over. struct LoopInfo final { - LoopInfo(ArrayRef tids, ArrayRef lvls, - ArrayRef slicedTids, ArrayRef slicedLvls, - ArrayRef sliceReduced, Operation *loop, Block *userBlock, - Value iv, StringAttr loopTag) - : tids(tids), lvls(lvls), slicedTids(slicedTids), - slicedLvls(slicedLvls), sliceReduced(sliceReduced), loop(loop), + LoopInfo(ArrayRef tidLvls, + ArrayRef sliceDrivenInfo, Operation *loop, + Block *userBlock, Value iv, StringAttr loopTag) + : tidLvls(tidLvls), sliceDrivenInfo(sliceDrivenInfo), loop(loop), userCodeBlock(userBlock), iv(iv) { // Attached a special tag to loop emitter generated loop. if (loopTag) loop->setAttr(LoopEmitter::getLoopEmitterLoopAttrName(), loopTag); } - // TODO: maybe use a vector for tid and lvl? - // (Or compress them together with a `TensorLoopId`.) - // The set of tensors that the loop is operating on - const llvm::SmallVector tids; - // The corresponding levels for the tensors - const llvm::SmallVector lvls; - // The set of tensors for slice-driven loop conditions. - const llvm::SmallVector slicedTids; - // The corresponding level for slice-driven tensors. - const llvm::SmallVector slicedLvls; - // Whether the tensor is fully reduced (e.g., i + j => j). - const llvm::SmallVector sliceReduced; + // The set of that the loop is operating on + const llvm::SmallVector tidLvls; + // Slice-driven loop conditions. + const llvm::SmallVector sliceDrivenInfo; const Operation *loop; // the loop operation Block *const userCodeBlock; // the block holding users' generated code. const Value iv; // the induction variable for the loop @@ -295,8 +319,6 @@ Location loc, Value crd, TensorId tid, Level lvl); - unsigned getNumTensors() const { return tensors.size(); } - bool isOutputTensor(TensorId tid) const { return hasOutput && tid == getNumTensors() - 1; } @@ -318,8 +340,7 @@ /// point used to generate the loops, but are still required to generate /// expressions. void emitExtraLocalsForTensorsAtDenseLvls(OpBuilder &builder, Location loc, - ArrayRef tids, - ArrayRef lvls); + ArrayRef tidLvls); /// Emits a for loop to iterate over a tensor level with the provided lower /// bound `lo` and upper bound `hi`. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -475,13 +475,12 @@ } void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc, - ArrayRef tids, - ArrayRef lvls) { + ArrayRef tidLvls) { // TODO: sort assert(loopSeqStack.size() == loopStack.size()); // Prepares for all the tensors used in the current loop sequence. std::vector> slicedTids; - for (auto [tid, lvl] : llvm::zip(tids, lvls)) { + for (auto [tid, lvl] : toTidLvlPairRange(tidLvls)) { if (!dependentLvlMap[tid][lvl].empty()) { bool fullyRed = genSliceBegin(builder, loc, tid, lvl); slicedTids.emplace_back(tid, lvl, fullyRed); @@ -679,17 +678,19 @@ return loop; } -Operation *LoopEmitter::enterLoopOverTensorAtLvl( - OpBuilder &builder, Location loc, ArrayRef tids, - ArrayRef lvls, MutableArrayRef reduc, bool isParallel) { +Operation *LoopEmitter::enterLoopOverTensorAtLvl(OpBuilder &builder, + Location loc, + ArrayRef tidLvls, + MutableArrayRef reduc, + bool isParallel) { // TODO: support multiple return on parallel for? assert(!isParallel || reduc.size() <= 1); bool isSparseCond = false, isSparseSliceCond = false; - size_t tid = tids.front(), lvl = lvls.front(); + auto [tid, lvl] = toTidLvlPair(tidLvls.front()); // Finds out the tensor level that we should use to generate loops. Amongs all // the tensor levels, there is at most one sparse tensor level. - for (auto [t, l] : llvm::zip(tids, lvls)) { + for (auto [t, l] : toTidLvlPairRange(tidLvls)) { assert(lvlTypes[t].size() > l); // Must be a valid tid, dim pair assert(!coords[t][l] || // We cannot re-enter the same level !dependentLvlMap[t][l].empty()); // unless it is a slice-driver loop @@ -730,12 +731,9 @@ Operation *l = nullptr; // At most one tensor used as condition in for loop; - SmallVector condTid; - SmallVector condLvl; - // There Might be multiple dense slice driven tensor. - SmallVector sliceTids; - SmallVector sliceLvls; - SmallVector sliceReduc; + SmallVector condTidLvl; + // There might be multiple dense slice driven tensor. + SmallVector sliceDrivenInfo; // Generates loops differently depending on whether we need a slice-driven // loop or a simple level traversal loop. @@ -752,9 +750,7 @@ lvl, reduc); } levelReducedDep[tid][lvl]++; - sliceTids.push_back(tid); - sliceLvls.push_back(lvl); - sliceReduc.push_back(fullyReduced); + sliceDrivenInfo.emplace_back(tid, lvl, fullyReduced); } else { Value lo = isSparseCond ? posits[tid][lvl] // current offset : loopSeqStack.back().first; // universal index @@ -765,21 +761,19 @@ // Adjust for loop hi for dense slice-driven loop. if (fullyReduced) { hi = sliceSz; - condTid.push_back(tid); - condLvl.push_back(lvl); + condTidLvl.emplace_back(makeTensorLevel(tid, lvl)); } else { hi = SUBI(lvlSizes[tid][lvl], sliceSz); hi = ADDI(hi, C_IDX(1)); } } else { - condTid.push_back(tid); - condLvl.push_back(lvl); + condTidLvl.emplace_back(makeTensorLevel(tid, lvl)); } l = emitForLoopOverTensorAtLvl(builder, loc, tid, lvl, lo, hi, reduc, isParallel); } Value iv = coords[tid][lvl]; - for (auto [t, l] : llvm::zip(tids, lvls)) { + for (auto [t, l] : toTidLvlPairRange(tidLvls)) { // We only need to handle slice-driven loops on dense level here. // If it is a slice-driven loop on sparse level, it needs a while loop to // insert break statements, and it must have been handled correctly in L692. @@ -792,9 +786,7 @@ } else { // Puts sliced dense loop into LoopInfo so that LoopEmitter knows how to // exit it. - sliceTids.push_back(t); - sliceLvls.push_back(l); - sliceReduc.push_back(fullyReduc); + sliceDrivenInfo.emplace_back(t, l, fullyReduc); // Update the slice information as we enter the new loop. assert(*info.slicedOnLvl == l); info.minCrd = info.offset = iv; @@ -805,10 +797,10 @@ } // NOTE: we can also prepare for next dim here in advance // Pushes the loop into stack. - loopStack.emplace_back(condTid, condLvl, sliceTids, sliceLvls, sliceReduc, l, + loopStack.emplace_back(condTidLvl, sliceDrivenInfo, l, builder.getInsertionBlock(), iv, loopTag); // Emit extra locals. - emitExtraLocalsForTensorsAtDenseLvls(builder, loc, tids, lvls); + emitExtraLocalsForTensorsAtDenseLvls(builder, loc, tidLvls); return l; } @@ -872,16 +864,17 @@ // NOTE: we can also prepare for next lvl here in advance // Push the loop into stack - loopStack.emplace_back(ArrayRef(tid), ArrayRef(lvl), - ArrayRef(), ArrayRef(), - ArrayRef(), forOp, builder.getInsertionBlock(), - coords[tid][lvl], nullptr); + loopStack.emplace_back(ArrayRef(makeTensorLevel(tid, lvl)), + ArrayRef(), forOp, + builder.getInsertionBlock(), coords[tid][lvl], + nullptr); return forOp; } void LoopEmitter::genDenseAffineAddress(OpBuilder &builder, Location loc, - TensorId tid, Level lvl, + TensorLevel tidLvl, AffineExpr lvlExpr) { + auto [tid, lvl] = toTidLvlPair(tidLvl); assert(isDenseDLT(lvlTypes[tid][lvl])); // For dense levels, the level-coordinate also serves as the position. Value lvlCrd = genAffine(builder, loc, lvlExpr); @@ -889,16 +882,15 @@ } Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls( - OpBuilder &builder, Location loc, ArrayRef tids, - ArrayRef lvls, bool needsUniv, MutableArrayRef reduc) { + OpBuilder &builder, Location loc, ArrayRef tidLvls, + bool needsUniv, MutableArrayRef reduc) { // NOTE: the slice driven tensor-related reduction variable must // appear before normal tensors. - assert(tids.size() == lvls.size()); SmallVector types; SmallVector operands; // Construct the while-loop with a parameter for each coordinate. const Type indexType = builder.getIndexType(); - for (auto [tid, lvl] : llvm::zip(tids, lvls)) { + for (auto [tid, lvl] : toTidLvlPairRange(tidLvls)) { // TODO: support coiteration with slice driven tensors. const auto lvlTp = lvlTypes[tid][lvl]; assert(dependentLvlMap[tid][lvl].empty() && "TODO: not yet implemented"); @@ -940,7 +932,7 @@ builder.setInsertionPointToStart(&whileOp.getBefore().front()); Value cond; unsigned o = 0; - for (auto [t, lvl] : llvm::zip(tids, lvls)) { + for (auto [t, lvl] : toTidLvlPairRange(tidLvls)) { const TensorId tid = t; // Why `t` can not be captured by lambda? const auto lvlTp = lvlTypes[tid][lvl]; if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) || @@ -974,7 +966,7 @@ SmallVector> slicesPreds; unsigned i = 0; - for (auto [tid, lvl] : llvm::zip(tids, lvls)) { + for (auto [tid, lvl] : toTidLvlPairRange(tidLvls)) { // Prepares for next level. const auto lvlTp = lvlTypes[tid][lvl]; if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) || @@ -1025,7 +1017,7 @@ Value min; // Finds the minimum coordinate if (!needsUniv) { - for (auto [tid, lvl] : llvm::zip(tids, lvls)) { + for (auto [tid, lvl] : toTidLvlPairRange(tidLvls)) { const auto lvlTp = lvlTypes[tid][lvl]; if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) || isCompressedWithHiDLT(lvlTp)) { @@ -1045,12 +1037,11 @@ } // Sets up the loop stack. - loopStack.emplace_back(tids, lvls, ArrayRef(), ArrayRef(), - ArrayRef(), whileOp, builder.getInsertionBlock(), - min, loopTag); + loopStack.emplace_back(tidLvls, ArrayRef(), whileOp, + builder.getInsertionBlock(), min, loopTag); assert(loopStack.size() == loopSeqStack.size()); - for (auto [tid, dstLvl] : llvm::zip(tids, lvls)) { + for (auto [tid, dstLvl] : toTidLvlPairRange(tidLvls)) { const auto reassoc = getCollapseReassociation(tid, dstLvl); assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType())); // TODO: Refactors this into smaller functions. @@ -1097,7 +1088,7 @@ } // Emits extra locals - emitExtraLocalsForTensorsAtDenseLvls(builder, loc, tids, lvls); + emitExtraLocalsForTensorsAtDenseLvls(builder, loc, tidLvls); // Updates reduction variables assert(after->getNumArguments() == o + reduc.size() + (needsUniv ? 1 : 0)); @@ -1158,15 +1149,12 @@ llvm_unreachable("Unrecognized level-type!"); } -void LoopEmitter::emitExtraLocalsForTensorsAtDenseLvls(OpBuilder &builder, - Location loc, - ArrayRef tids, - ArrayRef lvls) { +void LoopEmitter::emitExtraLocalsForTensorsAtDenseLvls( + OpBuilder &builder, Location loc, ArrayRef tidLvls) { // Initialize dense positions. Note that we generate dense coordinates of the // output tensor unconditionally, since they may not appear in the lattice, // but may be needed for linearized codegen. - assert(tids.size() == lvls.size()); - for (auto [tid, lvl] : llvm::zip(tids, lvls)) { + for (auto [tid, lvl] : toTidLvlPairRange(tidLvls)) { if (isDenseDLT(lvlTypes[tid][lvl])) { // Slice-driven dense level should have be handled already. if (!dependentLvlMap[tid][lvl].empty()) @@ -1193,8 +1181,7 @@ MutableArrayRef reduc) { const LoopInfo &loopInfo = loopStack.back(); rewriter.setInsertionPointToEnd(loopInfo.userCodeBlock); - for (auto [tid, lvl, reduced] : llvm::zip( - loopInfo.slicedTids, loopInfo.slicedLvls, loopInfo.sliceReduced)) { + for (auto [tid, lvl, reduced] : loopInfo.sliceDrivenInfo) { SliceInfo &info = sliceStack[tid].back(); assert(isDenseDLT(lvlTypes[tid][lvl])); assert(*info.slicedOnLvl == lvl && !reduced); @@ -1271,7 +1258,7 @@ // Finished iterating a tensor, clean up // We only do the clean up on for loop as while loops do not necessarily // finish the iteration on a sparse tensor - for (auto [tid, lvl] : llvm::zip(loopInfo.tids, loopInfo.lvls)) { + for (auto [tid, lvl] : toTidLvlPairRange(loopInfo.tidLvls)) { // Reset to null. coords[tid][lvl] = Value(); posits[tid][lvl] = Value(); @@ -1296,8 +1283,7 @@ unsigned o = 0; SmallVector operands; unsigned delta = 0; - for (auto [tid, lvl, resolved] : llvm::zip( - loopInfo.slicedTids, loopInfo.slicedLvls, loopInfo.sliceReduced)) { + for (auto [tid, lvl, resolved] : loopInfo.sliceDrivenInfo) { // TODO: handle dense. assert(isCompressedDLT(lvlTypes[tid][lvl])); levelReducedDep[tid][lvl]--; @@ -1309,7 +1295,7 @@ // fully reduced while op for iterating one slices. // FIXME: since we didn't implement coiteration, this must be iteration // just on fully resolved slice. - assert(loopInfo.slicedTids.size() == 1 && loopInfo.tids.empty()); + assert(loopInfo.sliceDrivenInfo.size() == 1 && loopInfo.tidLvls.empty()); // The if guard to filter out out-range coordinates. assert(llvm::isa(builder.getInsertionBlock()->getParentOp())); posits[tid][lvl] = whileOp->getResult(o++); @@ -1326,7 +1312,7 @@ }; Value one = C_IDX(1); - for (auto [tid, dstLvl] : llvm::zip(loopInfo.tids, loopInfo.lvls)) { + for (auto [tid, dstLvl] : toTidLvlPairRange(loopInfo.tidLvls)) { const auto lvlTp = lvlTypes[tid][dstLvl]; if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) || isCompressedWithHiDLT(lvlTp)) { @@ -1394,7 +1380,6 @@ // Clean up the values, it would help use to discover potential bug at a // earlier stage (instead of silently using a wrong value). const LoopInfo &loopInfo = loopStack.back(); - assert(loopInfo.tids.size() == loopInfo.lvls.size()); SmallVector red; if (llvm::isa(loopInfo.loop)) { exitWhileLoop(rewriter, loc, reduc); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -951,14 +951,12 @@ for (Level l = 0; l < lvlRank; l++) { // TODO: provide utility function for loop sequences that only contains // one for loop? - // FIXME(wrengr): what is this "ld" supposed to be really? - const Level ld = op.getOrder() ? op.getOrder()->getDimPosition(l) : l; - const SmallVector tids{0}; - loopEmitter.enterNewLoopSeq(rewriter, loc, tids, ld); + const SmallVector tidLvls{ + loopEmitter.makeTensorLevel(0, l)}; + loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls); // Note that reduc will be taken care of by loop emitter and get updated // in place. - - loopEmitter.enterLoopOverTensorAtLvl(rewriter, loc, tids, l, reduc); + loopEmitter.enterLoopOverTensorAtLvl(rewriter, loc, tidLvls, reduc); } SmallVector lcvs; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1296,13 +1296,16 @@ /// Generates a for-loop on a single index. static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter, - bool isInner, LoopId ldx, ArrayRef tids, - ArrayRef lvls) { + bool isInner, LoopId ldx, + ArrayRef tidLvls) { linalg::GenericOp op = env.op(); Location loc = op.getLoc(); auto iteratorTypes = op.getIteratorTypesArray(); - bool isSparse = llvm::any_of(tids, [ldx, &env](TensorId tid) { - const auto dlt = env.dlt(tid, ldx); + bool isSparse = llvm::any_of(tidLvls, [ldx, &env](TensorLevel tidLvl) { + // Queries the DLT based on the tensor id and loop idx, as requested by + // `CodegenEnv::dlt(TensorId, LoopIdx)`. The returned DLT from CodegenEnv + // should be consistent with the DLT indexed by . + const auto dlt = env.dlt(env.toTidLvlPair(tidLvl).first, ldx); return isCompressedDLT(dlt) || isSingletonDLT(dlt); }); @@ -1310,11 +1313,10 @@ Operation *loop = *env.genLoopBoundary([&](MutableArrayRef reduc) { if (env.merger().isFilterLoop(ldx)) { - const TensorId tid = tids.front(); - const Level lvl = lvls.front(); + const auto [tid, lvl] = env.toTidLvlPair(tidLvls.front()); // tids/lvls must only have one value because filter loops only // corresponding to the one and only sparse tensor level. - assert(isSparse && tids.size() == 1 && lvls.size() == 1); + assert(isSparse && tidLvls.size() == 1); OpOperand *t = &op->getOpOperand(tid); auto enc = getSparseTensorEncoding(t->get().getType()); // Retrieves the affine expression for the filter loop. @@ -1324,8 +1326,8 @@ return env.emitter().enterFilterLoopOverTensorAtLvl(builder, loc, tid, lvl, a, reduc); } - return env.emitter().enterLoopOverTensorAtLvl(builder, loc, tids, lvls, - reduc, isParallel); + return env.emitter().enterLoopOverTensorAtLvl(builder, loc, tidLvls, reduc, + isParallel); }); assert(loop); return loop; @@ -1333,13 +1335,12 @@ /// Emit a while-loop for co-iteration over multiple indices. static Operation *genWhile(CodegenEnv &env, OpBuilder &builder, LoopId idx, - bool needsUniv, ArrayRef tids, - ArrayRef lvls) { + bool needsUniv, ArrayRef tidLvls) { Operation *loop = *env.genLoopBoundary([&](MutableArrayRef reduc) { // Construct the while-loop with a parameter for each // index. return env.emitter().enterCoIterationOverTensorsAtLvls( - builder, env.op().getLoc(), tids, lvls, needsUniv, reduc); + builder, env.op().getLoc(), tidLvls, needsUniv, reduc); }); assert(loop); return loop; @@ -1348,16 +1349,15 @@ /// Generates a for-loop or a while-loop, depending on whether it implements /// singleton iteration or co-iteration over the given conjunction. static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at, - bool needsUniv, ArrayRef tids, - ArrayRef lvls, bool isFor) { - assert(tids.size() == lvls.size()); + bool needsUniv, ArrayRef tidLvls, + bool isFor) { const LoopId idx = env.topSortAt(at); if (isFor) { bool isOuter = at == 0; bool isInner = at == env.topSortSize() - 1; - return genFor(env, builder, isOuter, isInner, idx, tids, lvls); + return genFor(env, builder, isOuter, isInner, idx, tidLvls); } - return genWhile(env, builder, idx, needsUniv, tids, lvls); + return genWhile(env, builder, idx, needsUniv, tidLvls); } /// Generates the induction structure for a while-loop. @@ -1480,8 +1480,7 @@ const LatPointId l0 = env.set(lts)[0]; bool needsUniv = false; - SmallVector tids; - SmallVector lvls; + SmallVector tidLvls; env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid, std::optional lvl, DimLevelType dlt, bool isIdxReduc) { @@ -1493,12 +1492,11 @@ // Only when this is a index reduction loop, can the dlt be undefined. assert(!isUndefDLT(dlt) || isIdxReduc); // sparse/singleton levels, or a dense/sparse index reduction loop. - tids.push_back(tid); - lvls.push_back(*lvl); + tidLvls.emplace_back(env.makeTensorLevel(tid, *lvl)); } }); - env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tids, lvls); + env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls); // Maintain the universal index only if it is actually // consumed by a subsequent lattice point. @@ -1529,7 +1527,8 @@ // FIXME: `toOrigDim` is deprecated. AffineExpr lvlExpr = lvlExprs[toOrigDim(enc, l)]; if (enc.isDenseLvl(l) && lvlExpr.isa()) - env.emitter().genDenseAffineAddress(builder, loc, tid, l, lvlExpr); + env.emitter().genDenseAffineAddress( + builder, loc, env.makeTensorLevel(tid, l), lvlExpr); else return; // break on first non-dense non-constant level } @@ -1547,24 +1546,23 @@ } /// Return true if the lattices bit can be iterated by a for loop. -static bool translateBitsToTidLvlPairs( - CodegenEnv &env, LatPointId li, LoopId ldx, SmallVectorImpl &tids, - SmallVectorImpl &lvls, SmallVectorImpl &affineTids, - SmallVectorImpl &affineLvls, SmallVectorImpl &exps) { +static bool +translateBitsToTidLvlPairs(CodegenEnv &env, LatPointId li, LoopId ldx, + SmallVectorImpl &tidLvls, + SmallVectorImpl &affineTidLvls, + SmallVectorImpl &exps) { const BitVector &simple = env.lat(li).simple; const TensorId outTid = env.merger().getOutTensorID(); const std::optional outLvl = env.merger().getLvl(outTid, ldx); unsigned numloopCond = 0; bool hasNonUnique = false; - env.merger().foreachTensorLoopId( li, [&, ldx](TensorLoopId b, TensorId tid, std::optional lvl, DimLevelType dlt, bool isIdxReduc) { if (simple[b]) { if (isIdxReduc) { - tids.push_back(tid); - lvls.push_back(*lvl); + tidLvls.emplace_back(env.makeTensorLevel(tid, *lvl)); numloopCond++; return; } @@ -1582,12 +1580,10 @@ return; } hasNonUnique = !isUniqueDLT(dlt) || hasNonUnique; - tids.push_back(tid); - lvls.push_back(*lvl); + tidLvls.emplace_back(env.makeTensorLevel(tid, *lvl)); numloopCond++; } else if (isDenseDLT(dlt) || isIdxReduc) { - tids.push_back(tid); - lvls.push_back(*lvl); + tidLvls.emplace_back(env.makeTensorLevel(tid, *lvl)); } else { assert(isUndefDLT(dlt)); linalg::GenericOp op = env.op(); @@ -1625,8 +1621,7 @@ // computeIterationGraph), another more admissible approach // might be accepting out-of-order access between consecutive // dense levels. - affineTids.push_back(tid); - affineLvls.push_back(l); + affineTidLvls.emplace_back(env.makeTensorLevel(tid, l)); exps.push_back(exp); } } @@ -1638,8 +1633,7 @@ // Note that we generate dense indices of the output tensor // unconditionally, since they may not appear in the lattice, but may be // needed for linearized env. - tids.push_back(outTid); - lvls.push_back(*outLvl); + tidLvls.emplace_back(env.makeTensorLevel(outTid, *outLvl)); } assert(numloopCond > 0); @@ -1653,29 +1647,28 @@ OpBuilder &builder, LoopOrd at, LatPointId li, bool needsUniv) { // The set of tensors + lvls to generate loops on - SmallVector tids, affineTids; - SmallVector lvls, affineLvls; + SmallVector tidLvls; // The set of dense tensors with non-trivial affine expression that just // becomes invariant and the address shall now be generated at the current // level. + SmallVector affineTidLvls; SmallVector affines; bool isSingleCond = translateBitsToTidLvlPairs( - env, li, env.topSortAt(at), tids, lvls, affineTids, affineLvls, affines); + env, li, env.topSortAt(at), tidLvls, affineTidLvls, affines); // Emit the for/while-loop control. - Operation *loop = - genLoop(env, builder, at, needsUniv, tids, lvls, isSingleCond); + Operation *loop = genLoop(env, builder, at, needsUniv, tidLvls, isSingleCond); Location loc = env.op().getLoc(); - for (auto [tid, lvl, exp] : llvm::zip(affineTids, affineLvls, affines)) { - env.emitter().genDenseAffineAddress(builder, loc, tid, lvl, exp); + for (auto [tidLvl, exp] : llvm::zip(affineTidLvls, affines)) { + env.emitter().genDenseAffineAddress(builder, loc, tidLvl, exp); } // Until now, we have entered every pair in {cond, extra, // affine}Tids/Lvls. The addresses of the upcoming levels which are dependent // on constant affines expression may now be determined. - auto allTids = llvm::concat(tids, affineTids); - auto allLvls = llvm::concat(lvls, affineLvls); - for (auto [tid, lvl] : llvm::zip(allTids, allLvls)) { + auto allTidLvls = llvm::concat(tidLvls, affineTidLvls); + for (TensorLevel tidLvl : allTidLvls) { + auto [tid, lvl] = env.toTidLvlPair(tidLvl); if (tid != env.merger().getOutTensorID()) genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1); }