diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -399,11 +399,15 @@ /// to sparse level-type. bool hasAnySparse(const BitVector &bits) const; - /// Gets the level-type of the `t`th tensor on `i`th loop. + /// Returns true if any set bit corresponds to sparse dimension level type. + bool hasAnySlice(const BitVector &bits) const; + + /// Gets the dimension level type of the `t`th tensor on `i`th loop. DimLevelType getDimLevelType(TensorId t, LoopId i) const { assert(t < numTensors && i < numLoops); return lvlTypes[t][i]; } + DimLevelType getDimLevelType(TensorLoopId b) const { return getDimLevelType(tensor(b), loop(b)); } @@ -433,18 +437,6 @@ lvlToLoop[t][lvl] = i; } - /// Iterates over a set of `TensorLoopId`s, invoking the callback - /// for each `TensorLoopId` and passing it the corresponding tensor - /// identifier, level, and level-type. - void - foreachTensorLoopId(const BitVector &bits, - function_ref, DimLevelType)> - callback) const { - for (const TensorLoopId b : bits.set_bits()) - callback(b, tensor(b), getLvl(b), getDimLevelType(b)); - } - /// Sets whether the output tensor is sparse or not. void setHasSparseOut(bool s) { hasSparseOut = s; } @@ -469,6 +461,39 @@ return sliceToRelatedldx[t][d]; } + // Return the defining [tid, dim] for the loop. + std::pair getLoopDefiningDim(unsigned l) const { + return loopBounds[l]; + } + + /// Gets the dimension level type of `b`. + bool isSliceDim(unsigned b) const { + return ldxToDependentSlice[tensor(b)][loop(b)].has_value(); + } + + /// Iterates over a set of `TensorLoopId`s, invoking the callback + /// for each `TensorLoopId` and passing it the corresponding tensor + /// identifier, level, and level-type. + void foreachTensorLoopId( + const BitVector &bits, + function_ref, + DimLevelType, bool)> + callback) { + for (unsigned b : bits.set_bits()) { + unsigned t = tensor(b), l = loop(b); + bool isSlice = false; + std::optional lvl = getLvl(b); + if (isSliceDim(b)) { + // This must be an undefined dim. + assert(!lvl.has_value()); + isSlice = true; + // Slice the tid along the dependent dim to iterate current loop. + lvl = ldxToDependentSlice[t][l]; + } + callback(b, t, lvl, getDimLevelType(b), isSlice); + } + } + /// Convenience getters to immediately access the stored nodes. /// Typically it is inadvisible to keep the reference around, as in /// `TensorExpr &te = merger.exp(e)`, since insertions into the merger @@ -536,6 +561,8 @@ std::vector>> ldxToDependentSlice; // Map from the dependent slice [tid, dim] pair to a list of loop idx. std::vector>> sliceToRelatedldx; + // Map from loop index to the [tid, dim] pair that defines the loop boundary. + std::vector> loopBounds; llvm::SmallVector tensorExps; llvm::SmallVector latPoints; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp @@ -61,11 +61,22 @@ SmallVector tensors; for (OpOperand &t : linalgOp->getOpOperands()) tensors.push_back(t.get()); - loopEmitter.initialize(tensors, - StringAttr::get(linalgOp.getContext(), - linalg::GenericOp::getOperationName()), - /*hasOutput=*/true, - /*isSparseOut=*/sparseOut != nullptr, topSort); + loopEmitter.initialize( + tensors, + StringAttr::get(linalgOp.getContext(), + linalg::GenericOp::getOperationName()), + /*hasOutput=*/true, + /*isSparseOut=*/sparseOut != nullptr, topSort, + [this](unsigned t, + unsigned d) -> std::vector> { + // Translates from a list of loop index to a list of [tid, dim] pair. + std::vector rLoops = this->merger().getRelatedLoops(t, d); + std::vector> ret; + ret.reserve(rLoops.size()); + for (unsigned l : rLoops) + ret.push_back(this->merger().getLoopDefiningDim(l)); + return ret; + }); } std::optional CodegenEnv::genLoopBoundary( diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h @@ -76,6 +76,15 @@ /// initializing the loop emitter (e.g., to fill a dense output with zeros). using OutputUpdater = function_ref; + // Map from [tid, dim] to a list of dependent [tid, dim] for affine expression + // index on sparse tensors. + // E.g., for affine index (d0 + d1), it depends on two [tid, dim] that defines + // d0 and d1 (for affine expression reduction). + // If the list is empty, it means that there is no affine expression on the + // input [tid, dim]. + using DependentDimGetter = + function_ref>(unsigned, + unsigned)>; LoopEmitter() = default; @@ -89,11 +98,13 @@ /// to `LoopId`. void initialize(ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false, bool isSparseOut = false, - ArrayRef topSort = {}); + ArrayRef topSort = {}, + DependentDimGetter getter = nullptr); explicit LoopEmitter(ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false, bool isSparseOut = false, - ArrayRef topSort = {}); + ArrayRef topSort = {}, + DependentDimGetter getter = nullptr); /// Starts a loop emitting session by generating all the buffers needed /// for iterating over the tensors. @@ -295,8 +306,8 @@ MutableArrayRef reduc); /// Exits a while loop, returns the reduction results. - void exitCoIterationLoop(OpBuilder &builder, Location loc, - MutableArrayRef reduc); + void exitWhileLoop(OpBuilder &builder, Location loc, + MutableArrayRef reduc); // // View-based-reshape methods. @@ -380,6 +391,15 @@ std::vector> sliceOffsets; std::vector> sliceStrides; + // Map from [tid, dim] to a list of dependent [tid, dim]. + // See comments for `DependentDimGetter`. + std::vector>>> + dependentDimMap; + + // + // View based reshape related-fields and methods + // + /// Collapse Reassociations related to a specific tensor // TODO: support expand. std::vector collapseReassoc; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -208,12 +208,14 @@ } LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput, - bool isSparseOut, ArrayRef topSort) { - initialize(tensors, loopTag, hasOutput, isSparseOut, topSort); + bool isSparseOut, ArrayRef topSort, + DependentDimGetter dimGetter) { + initialize(tensors, loopTag, hasOutput, isSparseOut, topSort, dimGetter); } void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput, - bool isSparseOut, ArrayRef topSort) { + bool isSparseOut, ArrayRef topSort, + DependentDimGetter dimGetter) { // First initialize the top-level type of the fields. this->loopTag = loopTag; this->hasOutput = hasOutput; @@ -242,6 +244,10 @@ this->loopStack.reserve(numLoops); this->loopSeqStack.reserve(numLoops); + // TODO: dim or level? + this->dependentDimMap.assign( + numTensors, std::vector>>()); + // Initialize nested types of `TensorId`-indexed fields. for (TensorId tid = 0; tid < numTensors; tid++) { const Value t = tensors[tid]; @@ -283,6 +289,12 @@ coordinatesBuffers[tid].assign(lvlRank, Value()); sliceOffsets[tid].assign(lvlRank, Value()); sliceStrides[tid].assign(lvlRank, Value()); + + dependentDimMap[tid].assign(lvlRank, + std::vector>()); + if (dimGetter) + for (Level i = 0; i < lvlRank; i++) + dependentDimMap[tid][i] = dimGetter(tid, i); } // Construct the inverse of the `topSort` from the sparsifier. @@ -997,8 +1009,8 @@ } } -void LoopEmitter::exitCoIterationLoop(OpBuilder &builder, Location loc, - MutableArrayRef reduc) { +void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc, + MutableArrayRef reduc) { const LoopInfo &loopInfo = loopStack.back(); auto whileOp = llvm::cast(loopInfo.loop); builder.setInsertionPointToEnd(loopInfo.userCodeBlock); @@ -1082,7 +1094,7 @@ assert(loopInfo.tids.size() == loopInfo.lvls.size()); SmallVector red; if (llvm::isa(loopInfo.loop)) { - exitCoIterationLoop(rewriter, loc, reduc); + exitWhileLoop(rewriter, loc, reduc); } else { exitForLoop(rewriter, loc, reduc); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1471,10 +1471,11 @@ SmallVector tids; SmallVector lvls; env.merger().foreachTensorLoopId( - env.lat(l0).bits, [&](TensorLoopId b, TensorId tid, - std::optional lvl, DimLevelType dlt) { + env.lat(l0).bits, + [&](TensorLoopId b, TensorId tid, std::optional lvl, + DimLevelType dlt, bool isSlice) { assert(env.merger().loop(b) == idx); - if (isDenseDLT(dlt) || isUndefDLT(dlt)) { + if (!isSlice && (isDenseDLT(dlt) || isUndefDLT(dlt))) { needsUniv = true; } else { // sparse/singleton levels. @@ -1491,7 +1492,8 @@ unsigned lsize = env.set(lts).size(); for (unsigned i = 1; i < lsize; i++) { const LatPointId li = env.set(lts)[i]; - if (!env.merger().hasAnySparse(env.lat(li).simple)) + if (!env.merger().hasAnySparse(env.lat(li).simple) && + !env.merger().hasAnySlice(env.lat(li).simple)) return true; } } @@ -1546,30 +1548,38 @@ unsigned numloopCond = 0; bool hasNonUnique = false; + // Converts bits to array + dim pair env.merger().foreachTensorLoopId(all, [&, ldx](TensorLoopId b, TensorId tid, - std::optional lvl, - DimLevelType dlt) { + std::optional dim, + DimLevelType dlt, + bool isSlice) { + if (isSlice) { + // We need to coiterate using slice unconditinally. + tids.push_back(tid); + lvls.push_back(*dim); + numloopCond++; + return; + } if (simple.test(b)) { if (isUndefDLT(dlt)) { - // An undefined dlt in the lattices, we probably mean to - // iterate based on the level of output tensor. E.g., this - // could be a synthetic tensor (for invariants and sparse - // output tensor). - // out[i][j] = invariant; or a broadcast + // An undefined dlt in the lattices, we probably mean to iterate + // based on the dim of output tensor. + // E.g., this could be a synthetic tensor (for invariants and + // sparse output tensor). out[i][j] = invariant; or a broadcast // out[i][j] = in[i] (j is undef for input) - tid = outTid; - lvl = outLvl; - // Skips invalid lvl (e.g., when this is a zero ranked tensor). - if (!lvl) + tid = env.merger().getOutTensorID(); + dim = env.merger().getLvl(tid, ldx); + // Skips invalid dim (e.g., when this is a zero ranked tensor). + if (!dim) return; } hasNonUnique = !isUniqueDLT(dlt) || hasNonUnique; tids.push_back(tid); - lvls.push_back(*lvl); + lvls.push_back(*dim); numloopCond++; } else if (isDenseDLT(dlt)) { tids.push_back(tid); - lvls.push_back(*lvl); + lvls.push_back(*dim); } else { assert(isUndefDLT(dlt)); linalg::GenericOp op = env.op(); @@ -1582,29 +1592,28 @@ if (!stt.hasEncoding()) return; + auto enc = stt.getEncoding(); ArrayRef affines = op.getMatchingIndexingMap(operand).getResults(); - const Level lvlRank = stt.getLvlRank(); - assert(affines.size() == static_cast(lvlRank)); - for (Level l = 0; l < lvlRank; l++) { - // FIXME: `toOrigDim` is deprecated. - AffineExpr exp = affines[toOrigDim(stt.getEncoding(), l)]; - // Skip simple affine expression and non-dense levels (which - // have their own filter loop). - if (exp.isa() || !stt.isDenseLvl(l)) + assert(affines.size() == stt.getLvlRank()); + for (unsigned l = 0, e = stt.getLvlRank(); l < e; l++) { + AffineExpr exp = affines[toOrigDim(enc, l)]; + // Skip simple affine expression and non dense dimensions (which + // has it own filter loop). + if (exp.isa() || !isDenseDLT(enc.getDimLevelType()[l])) continue; // Constant affine expression are handled in genLoop if (!exp.isa()) { - bool isAtLoop = false; - if (isInvariantAffine(env, exp, ldx, isAtLoop) && isAtLoop) { + bool atLevel = false; + if (isInvariantAffine(env, exp, ldx, atLevel) && atLevel) { // If the compound affine is invariant and we are right at the // level. We need to generate the address according to the // affine expression. This is also the best place we can do it // to avoid putting it inside inner loops. // NOTE: It assumes that the levels of the input tensor are - // initialized in order (and it is also currently guaranteed by - // computeIterationGraph), another more admissible approach + // initialized in order (and it is also currently guaranteed + // by computeIterationGraph), another more admissible approach // might be accepting out-of-order access between consecutive // dense levels. affineTids.push_back(tid); @@ -1631,23 +1640,25 @@ } /// Starts a single loop in current sequence. -static Operation *startLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at, - LatPointId li, bool needsUniv) { - // The set of tensors + lvls to generate loops on +static std::pair startLoop(CodegenEnv &env, + OpBuilder &builder, unsigned at, + unsigned li, bool needsUniv) { + // The set of tensors + dims to generate loops on SmallVector tids, affineTids; SmallVector lvls, affineLvls; // The set of dense tensors with non-trivial affine expression that just // becomes invariant and the address shall now be generated at the current // level. SmallVector affines; - bool isFor = translateBitsToTidLvlPairs( + bool isSingleCond = translateBitsToTidLvlPairs( env, li, env.topSortAt(at), tids, lvls, affineTids, affineLvls, affines); // Emit the for/while-loop control. - Operation *loop = genLoop(env, builder, at, needsUniv, tids, lvls, isFor); - Location loc = env.op().getLoc(); - for (auto [tid, lvl, exp] : llvm::zip(affineTids, affineLvls, affines)) { - env.emitter().genDenseAffineAddress(builder, loc, tid, lvl, exp); + Operation *loop = + genLoop(env, builder, at, needsUniv, tids, lvls, isSingleCond); + for (auto [tid, dim, exp] : llvm::zip(affineTids, affineLvls, affines)) { + env.emitter().genDenseAffineAddress(builder, env.op().getLoc(), tid, dim, + exp); } // Until now, we have entered every pair in {cond, extra, @@ -1660,7 +1671,7 @@ genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1); } - return loop; + return std::make_pair(loop, isSingleCond); } /// Ends a single loop in current sequence. Returns new values for needsUniv. @@ -1723,20 +1734,19 @@ for (unsigned i = 0; i < lsize; i++) { // Start a loop. const LatPointId li = env.set(lts)[i]; - Operation *loop = startLoop(env, rewriter, at, li, needsUniv); + auto [loop, isSingleCond] = startLoop(env, rewriter, at, li, needsUniv); // Visit all lattices points with Li >= Lj to generate the // loop-body, possibly with if statements for coiteration. Value redInput = env.getReduc(); Value cntInput = env.getExpandCount(); Value insInput = env.getInsertionChain(); - bool isWhile = dyn_cast(loop) != nullptr; for (unsigned j = 0; j < lsize; j++) { const LatPointId lj = env.set(lts)[j]; const ExprId ej = env.lat(lj).exp; if (li == lj || env.merger().latGT(li, lj)) { // Recurse into body of each branch. - if (isWhile) { + if (!isSingleCond) { scf::IfOp ifOp = genIf(env, rewriter, idx, env.lat(lj).simple); genStmt(env, rewriter, ej, at + 1); endIf(env, rewriter, ifOp, loop, redInput, cntInput, insInput); @@ -1855,6 +1865,12 @@ if (!isAdmissible) return failure(); // inadmissible expression, reject + // Sort the related loop array such that they are in the same order as they + // appears on the topoOrder. + // TODO: since we only handle affine addition for slice based codegen, and + // addition is assoicative, the order how we evaluate the expression does + // not matter. However, to support multiplication, the order of the loop + // index should match the evaluation order to the affine expression AST. for (OpOperand &t : env.op()->getOpOperands()) { unsigned rank = env.op().getMatchingIndexingMap(&t).getNumResults(); for (unsigned i = 0; i < rank; i++) { @@ -1867,6 +1883,9 @@ // Recursively generates code if admissible. env.startEmit(); genBuffers(env, rewriter); + // TODO: Constant affine expression should be handled differently when using + // slice-based codegen, it does not matter now becasue we already reject the + // constant expression at a earlier stage. genInitConstantDenseAddress(env, rewriter); genStmt(env, rewriter, env.getExprId(), 0); genResult(env, rewriter); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -224,7 +224,8 @@ ldxToDependentSlice(numTensors, std::vector>( numLoops, std::nullopt)), sliceToRelatedldx(numTensors, std::vector>( - numLoops, std::vector())) {} + numLoops, std::vector())), + loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {} //===----------------------------------------------------------------------===// // Lattice methods. @@ -361,7 +362,7 @@ } BitVector simple(latPoints[p0].bits); - bool reset = isSingleton && hasAnySparse(simple); + bool reset = isSingleton && (hasAnySparse(simple) || hasAnySlice(simple)); const TensorLoopId be = simple.size(); TensorLoopId offset = 0; // relative to the end if (!reset) @@ -378,7 +379,7 @@ // keep the rightmost bit (which could possibly be a synthetic tensor). for (TensorLoopId b = be - 1 - offset, i = 0; i < be; b = b == 0 ? be - 1 : b - 1, i++) { - if (simple[b]) { + if (simple[b] && !isSliceDim(b)) { const auto dlt = getDimLevelType(b); if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt)) { if (reset) @@ -406,7 +407,7 @@ bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const { BitVector tmp(latPoints[j].bits); tmp ^= latPoints[i].bits; - return !hasAnySparse(tmp); + return !hasAnySparse(tmp) && !hasAnySlice(tmp); } bool Merger::expContainsTensor(ExprId e, TensorId t) const { @@ -554,6 +555,13 @@ return false; } +bool Merger::hasAnySlice(const BitVector &bits) const { + for (unsigned b = 0, be = bits.size(); b < be; b++) + if (bits[b] && isSliceDim(b)) + return true; + return false; +} + #ifndef NDEBUG //===----------------------------------------------------------------------===// @@ -766,9 +774,10 @@ const TensorId t = tensor(b); const LoopId i = loop(b); const auto dlt = lvlTypes[t][i]; - llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt); - if (ldxToDependentSlice[t][i]) - llvm::dbgs() << "_D_" << *ldxToDependentSlice[t][i]; + if (isSliceDim(b)) + llvm::dbgs() << " DEP_" << t << "_" << i; + else + llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt); } } }