diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -399,11 +399,17 @@ /// to sparse level-type. bool hasAnySparse(const BitVector &bits) const; + /// Returns true if bits contains a dependent index reduction condition on + /// sparse levels. + bool hasSparseIdxReduction(const BitVector &bits) const; + /// Gets the level-type of the `t`th tensor on `i`th loop. DimLevelType getDimLevelType(TensorId t, LoopId i) const { assert(t < numTensors && i < numLoops); return lvlTypes[t][i]; } + + /// Gets the level-type of the TensorLoopId. DimLevelType getDimLevelType(TensorLoopId b) const { return getDimLevelType(tensor(b), loop(b)); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h @@ -99,7 +99,6 @@ topSort.reserve(capacity); } - ArrayRef getTopSort() const { return topSort; }; ArrayRef getTopSortSlice(LoopOrd n, LoopOrd m) const; ArrayRef getLoopStackUpTo(LoopOrd n) const; ArrayRef getCurrentLoopStack() const; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp @@ -28,6 +28,23 @@ val.getDefiningOp(); } +/// Makes target array's elements sorted according to the `order` array. +static void sortArrayBasedOnOrder(std::vector &target, + ArrayRef order) { + std::sort(target.begin(), target.end(), [&order](LoopId l, LoopId r) { + assert(l != r); + int idxL = -1, idxR = -1; + for (int i = 0, e = order.size(); i < e; i++) { + if (order[i] == l) + idxL = i; + if (order[i] == r) + idxR = i; + } + assert(idxL >= 0 && idxR >= 0); + return idxL < idxR; + }); +} + //===----------------------------------------------------------------------===// // Code generation environment constructor and general methods //===----------------------------------------------------------------------===// @@ -57,15 +74,42 @@ insChain = sparseOut->get(); latticeMerger.setHasSparseOut(true); } + + // Sort the related loop array such that they are in the same order as they + // appears on the topoOrder. + // TODO: since we only handle affine addition for slice based codegen, and + // addition is assoicative, the order how we evaluate the expression does + // not matter. However, to support multiplication, the order of the loop + // index should match the evaluation order to the affine expression AST. + // Initialize loop emitter. - SmallVector tensors; - for (OpOperand &t : linalgOp->getOpOperands()) + SmallVector tensors; // input tensors passed to loop emitter + for (OpOperand &t : linalgOp->getOpOperands()) { tensors.push_back(t.get()); - loopEmitter.initialize(tensors, - StringAttr::get(linalgOp.getContext(), - linalg::GenericOp::getOperationName()), - /*hasOutput=*/true, - /*isSparseOut=*/sparseOut != nullptr, topSort); + Level rank = linalgOp.getMatchingIndexingMap(&t).getNumResults(); + for (Level lvl = 0; lvl < rank; lvl++) { + sortArrayBasedOnOrder( + latticeMerger.getDependentLoops(t.getOperandNumber(), lvl), topSort); + } + } + + loopEmitter.initialize( + tensors, + StringAttr::get(linalgOp.getContext(), + linalg::GenericOp::getOperationName()), + /*hasOutput=*/true, + /*isSparseOut=*/sparseOut != nullptr, topSort, + // TODO: compute the map and pass it to loop emitter directly instead of + // passing in a callback. + [this](TensorId t, Level lvl) -> std::vector> { + // Translates from a list of loop index to a list of [tid, dim] pair. + std::vector rLoops = this->merger().getDependentLoops(t, lvl); + std::vector> ret; + ret.reserve(rLoops.size()); + for (LoopId l : rLoops) + ret.emplace_back(this->merger().getLoopDefiningLvl(l)); + return ret; + }); } std::optional CodegenEnv::genLoopBoundary( diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h @@ -76,6 +76,14 @@ /// initializing the loop emitter (e.g., to fill a dense output with zeros). using OutputUpdater = function_ref; + // Map from [tid, dim] to a list of dependent [tid, dim] for affine expression + // index on sparse tensors. + // E.g., for affine index (d0 + d1), it depends on two [tid, dim] that defines + // d0 and d1 (for affine expression reduction). + // If the list is empty, it means that there is no affine expression on the + // input [tid, dim]. + using DependentLvlGetter = + function_ref>(TensorId, Level)>; LoopEmitter() = default; @@ -89,11 +97,13 @@ /// to `LoopId`. void initialize(ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false, bool isSparseOut = false, - ArrayRef topSort = {}); + ArrayRef topSort = {}, + DependentLvlGetter getter = nullptr); explicit LoopEmitter(ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false, bool isSparseOut = false, - ArrayRef topSort = {}); + ArrayRef topSort = {}, + DependentLvlGetter getter = nullptr); /// Starts a loop emitting session by generating all the buffers needed /// for iterating over the tensors. @@ -295,8 +305,8 @@ MutableArrayRef reduc); /// Exits a while loop, returns the reduction results. - void exitCoIterationLoop(OpBuilder &builder, Location loc, - MutableArrayRef reduc); + void exitWhileLoop(OpBuilder &builder, Location loc, + MutableArrayRef reduc); // // View-based-reshape methods. @@ -380,6 +390,15 @@ std::vector> sliceOffsets; std::vector> sliceStrides; + // Map from [tid, level] to a list of dependent [tid, level]. + // See comments for `DependentDimGetter`. + std::vector>>> + dependentLvlMap; + + // + // View based reshape related-fields and methods + // + /// Collapse Reassociations related to a specific tensor // TODO: support expand. std::vector collapseReassoc; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -208,12 +208,14 @@ } LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput, - bool isSparseOut, ArrayRef topSort) { - initialize(tensors, loopTag, hasOutput, isSparseOut, topSort); + bool isSparseOut, ArrayRef topSort, + DependentLvlGetter dimGetter) { + initialize(tensors, loopTag, hasOutput, isSparseOut, topSort, dimGetter); } void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput, - bool isSparseOut, ArrayRef topSort) { + bool isSparseOut, ArrayRef topSort, + DependentLvlGetter dimGetter) { // First initialize the top-level type of the fields. this->loopTag = loopTag; this->hasOutput = hasOutput; @@ -242,6 +244,9 @@ this->loopStack.reserve(numLoops); this->loopSeqStack.reserve(numLoops); + this->dependentLvlMap.assign( + numTensors, std::vector>>()); + // Initialize nested types of `TensorId`-indexed fields. for (TensorId tid = 0; tid < numTensors; tid++) { const Value t = tensors[tid]; @@ -283,6 +288,12 @@ coordinatesBuffers[tid].assign(lvlRank, Value()); sliceOffsets[tid].assign(lvlRank, Value()); sliceStrides[tid].assign(lvlRank, Value()); + + dependentLvlMap[tid].assign(lvlRank, + std::vector>()); + if (dimGetter) + for (Level l = 0; l < lvlRank; l++) + dependentLvlMap[tid][l] = dimGetter(tid, l); } // Construct the inverse of the `topSort` from the sparsifier. @@ -997,8 +1008,8 @@ } } -void LoopEmitter::exitCoIterationLoop(OpBuilder &builder, Location loc, - MutableArrayRef reduc) { +void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc, + MutableArrayRef reduc) { const LoopInfo &loopInfo = loopStack.back(); auto whileOp = llvm::cast(loopInfo.loop); builder.setInsertionPointToEnd(loopInfo.userCodeBlock); @@ -1082,7 +1093,7 @@ assert(loopInfo.tids.size() == loopInfo.lvls.size()); SmallVector red; if (llvm::isa(loopInfo.loop)) { - exitCoIterationLoop(rewriter, loc, reduc); + exitWhileLoop(rewriter, loc, reduc); } else { exitForLoop(rewriter, loc, reduc); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -593,23 +593,6 @@ } } -/// Makes target array's elements appear in the same order as the `order` array. -static void sortArrayBasedOnOrder(std::vector &target, - ArrayRef order) { - std::sort(target.begin(), target.end(), [&order](LoopId l, LoopId r) { - assert(l != r); - int idxL = -1, idxR = -1; - for (int i = 0, e = order.size(); i < e; i++) { - if (order[i] == l) - idxL = i; - if (order[i] == r) - idxR = i; - } - assert(idxL >= 0 && idxR >= 0); - return idxL < idxR; - }); -} - static void addFilterLoopBasedConstraints(CodegenEnv &env, OpOperand &t, OpOperand *skip, SortMask mask, std::vector> &adjM, @@ -1484,9 +1467,10 @@ SmallVector lvls; env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid, std::optional lvl, - DimLevelType dlt, bool) { + DimLevelType dlt, bool isIdxReduc) { assert(env.merger().loop(b) == idx); - if (isDenseDLT(dlt) || isUndefDLT(dlt)) { + // FIXME: Dense index reduction can reuse the universal index as well. + if (!isIdxReduc && (isDenseDLT(dlt) || isUndefDLT(dlt))) { needsUniv = true; } else { // sparse/singleton levels. @@ -1503,7 +1487,8 @@ unsigned lsize = env.set(lts).size(); for (unsigned i = 1; i < lsize; i++) { const LatPointId li = env.set(lts)[i]; - if (!env.merger().hasAnySparse(env.lat(li).simple)) + if (!env.merger().hasAnySparse(env.lat(li).simple) && + !env.merger().hasSparseIdxReduction(env.lat(li).simple)) return true; } } @@ -1557,75 +1542,82 @@ unsigned numloopCond = 0; bool hasNonUnique = false; - env.merger().foreachTensorLoopId(li, [&, ldx](TensorLoopId b, TensorId tid, - std::optional lvl, - DimLevelType dlt, bool) { - if (simple.test(b)) { - if (isUndefDLT(dlt)) { - // An undefined dlt in the lattices, we probably mean to - // iterate based on the level of output tensor. E.g., this - // could be a synthetic tensor (for invariants and sparse - // output tensor). - // out[i][j] = invariant; or a broadcast - // out[i][j] = in[i] (j is undef for input) - tid = outTid; - lvl = outLvl; - // Skips invalid lvl (e.g., when this is a zero ranked tensor). - if (!lvl) - return; - } - hasNonUnique = !isUniqueDLT(dlt) || hasNonUnique; - tids.push_back(tid); - lvls.push_back(*lvl); - numloopCond++; - } else if (isDenseDLT(dlt)) { - tids.push_back(tid); - lvls.push_back(*lvl); - } else { - assert(isUndefDLT(dlt)); - linalg::GenericOp op = env.op(); - if (tid >= op.getNumDpsInputs()) - // We only handle affine expression on input tensors (for now). - return; - OpOperand *operand = &op->getOpOperand(tid); - const auto stt = getSparseTensorType(operand->get()); - // Non-annotated dense tensors requires no special handling. - if (!stt.hasEncoding()) - return; - - ArrayRef affines = - op.getMatchingIndexingMap(operand).getResults(); - const Level lvlRank = stt.getLvlRank(); - assert(affines.size() == static_cast(lvlRank)); - for (Level l = 0; l < lvlRank; l++) { - // FIXME: `toOrigDim` is deprecated. - AffineExpr exp = affines[toOrigDim(stt.getEncoding(), l)]; - // Skip simple affine expression and non-dense levels (which - // have their own filter loop). - if (exp.isa() || !stt.isDenseLvl(l)) - continue; - // Constant affine expression are handled in genLoop - if (!exp.isa()) { - bool isAtLoop = false; - if (isInvariantAffine(env, exp, ldx, isAtLoop) && isAtLoop) { - // If the compound affine is invariant and we are right at the - // level. We need to generate the address according to the - // affine expression. This is also the best place we can do it - // to avoid putting it inside inner loops. - // NOTE: It assumes that the levels of the input tensor are - // initialized in order (and it is also currently guaranteed by - // computeIterationGraph), another more admissible approach - // might be accepting out-of-order access between consecutive - // dense levels. - affineTids.push_back(tid); - affineLvls.push_back(l); - exps.push_back(exp); + env.merger().foreachTensorLoopId( + li, [&, ldx](TensorLoopId b, TensorId tid, std::optional lvl, + DimLevelType dlt, bool isIdxReduc) { + if (simple.test(b)) { + if (isIdxReduc) { + tids.push_back(tid); + lvls.push_back(*lvl); + numloopCond++; + return; + } + if (isUndefDLT(dlt)) { + // An undefined dlt in the lattices, we probably mean to + // iterate based on the level of output tensor. E.g., this + // could be a synthetic tensor (for invariants and sparse + // output tensor). + // out[i][j] = invariant; or a broadcast + // out[i][j] = in[i] (j is undef for input) + tid = outTid; + lvl = outLvl; + // Skips invalid lvl (e.g., when this is a zero ranked tensor). + if (!lvl) + return; + } + hasNonUnique = !isUniqueDLT(dlt) || hasNonUnique; + tids.push_back(tid); + lvls.push_back(*lvl); + numloopCond++; + } else if (isDenseDLT(dlt)) { + tids.push_back(tid); + lvls.push_back(*lvl); + } else { + assert(isUndefDLT(dlt)); + linalg::GenericOp op = env.op(); + if (tid >= op.getNumDpsInputs()) + // We only handle affine expression on input tensors (for now). + return; + OpOperand *operand = &op->getOpOperand(tid); + const auto stt = getSparseTensorType(operand->get()); + // Non-annotated dense tensors requires no special handling. + if (!stt.hasEncoding()) + return; + + ArrayRef affines = + op.getMatchingIndexingMap(operand).getResults(); + const Level lvlRank = stt.getLvlRank(); + assert(affines.size() == static_cast(lvlRank)); + for (Level l = 0; l < lvlRank; l++) { + // FIXME: `toOrigDim` is deprecated. + AffineExpr exp = affines[toOrigDim(stt.getEncoding(), l)]; + // Skip simple affine expression and non-dense levels (which + // have their own filter loop). + if (exp.isa() || !stt.isDenseLvl(l)) + continue; + + // Constant affine expression are handled in genLoop + if (!exp.isa()) { + bool isAtLoop = false; + if (isInvariantAffine(env, exp, ldx, isAtLoop) && isAtLoop) { + // If the compound affine is invariant and we are right at the + // level. We need to generate the address according to the + // affine expression. This is also the best place we can do it + // to avoid putting it inside inner loops. + // NOTE: It assumes that the levels of the input tensor are + // initialized in order (and it is also currently guaranteed by + // computeIterationGraph), another more admissible approach + // might be accepting out-of-order access between consecutive + // dense levels. + affineTids.push_back(tid); + affineLvls.push_back(l); + exps.push_back(exp); + } + } } } - } - } - }); + }); if (isDenseDLT(env.dlt(outTid, ldx))) { // Note that we generate dense indices of the output tensor @@ -1642,8 +1634,9 @@ } /// Starts a single loop in current sequence. -static Operation *startLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at, - LatPointId li, bool needsUniv) { +static std::pair startLoop(CodegenEnv &env, + OpBuilder &builder, unsigned at, + unsigned li, bool needsUniv) { // The set of tensors + lvls to generate loops on SmallVector tids, affineTids; SmallVector lvls, affineLvls; @@ -1651,11 +1644,12 @@ // becomes invariant and the address shall now be generated at the current // level. SmallVector affines; - bool isFor = translateBitsToTidLvlPairs( + bool isSingleCond = translateBitsToTidLvlPairs( env, li, env.topSortAt(at), tids, lvls, affineTids, affineLvls, affines); // Emit the for/while-loop control. - Operation *loop = genLoop(env, builder, at, needsUniv, tids, lvls, isFor); + Operation *loop = + genLoop(env, builder, at, needsUniv, tids, lvls, isSingleCond); Location loc = env.op().getLoc(); for (auto [tid, lvl, exp] : llvm::zip(affineTids, affineLvls, affines)) { env.emitter().genDenseAffineAddress(builder, loc, tid, lvl, exp); @@ -1671,7 +1665,7 @@ genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1); } - return loop; + return std::make_pair(loop, isSingleCond); } /// Ends a single loop in current sequence. Returns new values for needsUniv. @@ -1734,20 +1728,19 @@ for (unsigned i = 0; i < lsize; i++) { // Start a loop. const LatPointId li = env.set(lts)[i]; - Operation *loop = startLoop(env, rewriter, at, li, needsUniv); + auto [loop, isSingleCond] = startLoop(env, rewriter, at, li, needsUniv); // Visit all lattices points with Li >= Lj to generate the // loop-body, possibly with if statements for coiteration. Value redInput = env.getReduc(); Value cntInput = env.getExpandCount(); Value insInput = env.getInsertionChain(); - bool isWhile = dyn_cast(loop) != nullptr; for (unsigned j = 0; j < lsize; j++) { const LatPointId lj = env.set(lts)[j]; const ExprId ej = env.lat(lj).exp; if (li == lj || env.merger().latGT(li, lj)) { // Recurse into body of each branch. - if (isWhile) { + if (!isSingleCond) { scf::IfOp ifOp = genIf(env, rewriter, idx, env.lat(lj).simple); genStmt(env, rewriter, ej, at + 1); endIf(env, rewriter, ifOp, loop, redInput, cntInput, insInput); @@ -1866,18 +1859,12 @@ if (!isAdmissible) return failure(); // inadmissible expression, reject - for (OpOperand &t : env.op()->getOpOperands()) { - Level rank = env.op().getMatchingIndexingMap(&t).getNumResults(); - for (Level lvl = 0; lvl < rank; lvl++) { - sortArrayBasedOnOrder( - env.merger().getDependentLoops(t.getOperandNumber(), lvl), - env.getTopSort()); - } - } - // Recursively generates code if admissible. env.startEmit(); genBuffers(env, rewriter); + // TODO: Constant affine expression should be handled differently when using + // slice-based codegen, it does not matter now becasue we already reject the + // constant expression at a earlier stage. genInitConstantDenseAddress(env, rewriter); genStmt(env, rewriter, env.getExprId(), 0); genResult(env, rewriter); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -362,7 +362,8 @@ } BitVector simple(latPoints[p0].bits); - bool reset = isSingleton && hasAnySparse(simple); + bool reset = + isSingleton && (hasAnySparse(simple) || hasSparseIdxReduction(simple)); const TensorLoopId be = simple.size(); TensorLoopId offset = 0; // relative to the end if (!reset) @@ -379,7 +380,9 @@ // keep the rightmost bit (which could possibly be a synthetic tensor). for (TensorLoopId b = be - 1 - offset, i = 0; i < be; b = b == 0 ? be - 1 : b - 1, i++) { - if (simple[b]) { + // FIXME: better name? also slice on dense level has locate property as + // well. Handle it correctly! + if (simple[b] && !isLvlWithNonTrivialIdxExp(b)) { const auto dlt = getDimLevelType(b); if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt)) { if (reset) @@ -407,7 +410,7 @@ bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const { BitVector tmp(latPoints[j].bits); tmp ^= latPoints[i].bits; - return !hasAnySparse(tmp); + return !hasAnySparse(tmp) && !hasSparseIdxReduction(tmp); } bool Merger::expContainsTensor(ExprId e, TensorId t) const { @@ -555,6 +558,14 @@ return false; } +bool Merger::hasSparseIdxReduction(const BitVector &bits) const { + // TODO: return false on dense levels. + for (unsigned b = 0, be = bits.size(); b < be; b++) + if (bits[b] && isLvlWithNonTrivialIdxExp(b)) + return true; + return false; +} + #ifndef NDEBUG //===----------------------------------------------------------------------===//