diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -285,6 +285,9 @@ /// Returns true if any set bit corresponds to sparse dimension level type. bool hasAnySparse(const BitVector &bits) const; + /// Returns true if any set bit corresponds to sparse dimension level type. + bool hasAnySlice(const BitVector &bits) const; + /// Gets the dimension level type of the `t`th tensor on `i`th loop. DimLevelType getDimLevelType(unsigned t, unsigned i) const { assert(t < numTensors && i < numLoops); @@ -321,6 +324,9 @@ loopIdxToDim[t][i] = dim; assert(dim < numLoops); dimToLoopIdx[t][dim] = i; + // Maybe we should favor constant dimensions when there are multiple + // choices. + loopBounds[i] = std::make_pair(t, dim); } void setLoopDependentSliceDim(unsigned l, unsigned t, unsigned dim) { @@ -343,15 +349,36 @@ return sliceToRelatedldx[t][d]; } + // Return the defining [tid, dim] for the loop. + std::pair getLoopDefiningDim(unsigned l) const { + return loopBounds[l]; + } + + /// Gets the dimension level type of `b`. + bool isSliceDim(unsigned b) const { + return ldxToDependentSlice[tensor(b)][index(b)].has_value(); + } + // Iterates the bits of a lattice, for each set bit, converts it into the // corresponding tensor dimension and invokes the callback. void foreachTidDimPairInBits( const BitVector &bits, function_ref dim, - DimLevelType dlt)> + DimLevelType dlt, bool isSlice)> cb) { - for (unsigned b : bits.set_bits()) - cb(b, tensor(b), getDimNum(b), getDimLevelType(b)); + for (unsigned b : bits.set_bits()) { + unsigned t = tensor(b), l = index(b); + bool isSlice = false; + std::optional dim = getDimNum(b); + if (isSliceDim(b)) { + // This must be an undefined dim. + assert(!dim.has_value()); + isSlice = true; + // Slice the tid along the dependent dim to iterate current loop. + dim = ldxToDependentSlice[t][l]; + } + cb(b, t, dim, getDimLevelType(b), isSlice); + } } // Has sparse output tensor setter. @@ -418,6 +445,8 @@ std::vector>> ldxToDependentSlice; // Map from the dependent slice [tid, dim] pair to a list of loop idx. std::vector>> sliceToRelatedldx; + // Map from loop index to the [tid, dim] pair that defines the loop boundary. + std::vector> loopBounds; llvm::SmallVector tensorExps; llvm::SmallVector latPoints; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp @@ -61,11 +61,22 @@ SmallVector tensors; for (OpOperand &t : linalgOp->getOpOperands()) tensors.push_back(t.get()); - loopEmitter.initialize(tensors, - StringAttr::get(linalgOp.getContext(), - linalg::GenericOp::getOperationName()), - /*hasOutput=*/true, - /*isSparseOut=*/sparseOut != nullptr, topSort); + loopEmitter.initialize( + tensors, + StringAttr::get(linalgOp.getContext(), + linalg::GenericOp::getOperationName()), + /*hasOutput=*/true, + /*isSparseOut=*/sparseOut != nullptr, topSort, + [this](unsigned t, + unsigned d) -> std::vector> { + // Translates from a list of loop index to a list of [tid, dim] pair. + std::vector rLoops = this->merger().getRelatedLoops(t, d); + std::vector> ret; + ret.reserve(rLoops.size()); + for (unsigned l : rLoops) + ret.push_back(this->merger().getLoopDefiningDim(l)); + return ret; + }); } std::optional CodegenEnv::genLoopBoundary( diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h @@ -51,6 +51,15 @@ /// initializing the loop emitter (e.g., to fill a dense output with zeros). using OutputUpdater = function_ref; + // Map from [tid, dim] to a list of dependent [tid, dim] for affine expression + // index on sparse tensors. + // E.g., for affine index (d0 + d1), it depends on two [tid, dim] that defines + // d0 and d1 (for affine expression reduction). + // If the list is empty, it means that there is no affine expression on the + // input [tid, dim]. + using DependentDimGetter = + function_ref>(unsigned, + unsigned)>; LoopEmitter() = default; @@ -65,11 +74,13 @@ /// loop emitter. void initialize(ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false, bool isSparseOut = false, - ArrayRef topSort = {}); + ArrayRef topSort = {}, + DependentDimGetter getter = nullptr); explicit LoopEmitter(ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false, bool isSparseOut = false, - ArrayRef topSort = {}); + ArrayRef topSort = {}, + DependentDimGetter getter = nullptr); /// Starts a loop emitting session by generating all the buffers needed to /// iterate tensors. @@ -242,8 +253,8 @@ MutableArrayRef reduc); /// Exits a while loop, returns the reduction results. - void exitCoIterationLoop(OpBuilder &builder, Location loc, - MutableArrayRef reduc); + void exitWhileLoop(OpBuilder &builder, Location loc, + MutableArrayRef reduc); /// A optional string attribute that should be attached to the loop /// generated by loop emitter, it might help following passes to identify @@ -273,8 +284,13 @@ std::vector> sliceOffsets; std::vector> sliceStrides; - /// Loop Stack, stores the information of all the nested loops that are - /// alive. + // Map from [tid, dim] to a list of dependent [tid, dim]. + // See comments for `DependentDimGetter`. + std::vector>>> + dependentDimMap; + + // Loop Stack, stores the information of all the nested loops that are + // alive. std::vector loopStack; /// Loop Sequence Stack, stores the unversial index for the current loop diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -139,13 +139,15 @@ } LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput, - bool isSparseOut, ArrayRef topSort) { - initialize(tensors, loopTag, hasOutput, isSparseOut, topSort); + bool isSparseOut, ArrayRef topSort, + DependentDimGetter getter) { + initialize(tensors, loopTag, hasOutput, isSparseOut, topSort, getter); } void LoopEmitter::initialize(ValueRange tensors, StringAttr loopTag, bool hasOutput, bool isSparseOut, - ArrayRef topSort) { + ArrayRef topSort, + DependentDimGetter dimGetter) { // First initializes fields. this->loopTag = loopTag; this->hasOutput = hasOutput; @@ -164,8 +166,12 @@ this->valBuffer.assign(tensors.size(), nullptr); this->loopStack.reserve(topSort.size()); this->sparsiferLoopLvlMap.assign(topSort.size(), 0); + this->dependentDimMap.assign( + tensors.size(), + std::vector>>()); for (size_t tid = 0, e = tensors.size(); tid < e; tid++) { + auto t = tensors[tid]; // a scalar or 0-dimension tensors if (isZeroRankedTensorOrScalar(t.getType())) @@ -191,6 +197,12 @@ highs[tid].assign(rank, Value()); ptrBuffer[tid].assign(rank, Value()); idxBuffer[tid].assign(rank, Value()); + + dependentDimMap[tid].assign(rank, + std::vector>()); + if (dimGetter) + for (unsigned i = 0; i < rank; i++) + dependentDimMap[tid][i] = dimGetter(tid, i); } // FIXME: This map should be maintained outside loop emitter. @@ -797,8 +809,8 @@ } } -void LoopEmitter::exitCoIterationLoop(OpBuilder &builder, Location loc, - MutableArrayRef reduc) { +void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc, + MutableArrayRef reduc) { const LoopLevelInfo &loopInfo = loopStack.back(); auto whileOp = llvm::cast(loopInfo.loop); builder.setInsertionPointToEnd(loopInfo.userCodeBlock); @@ -860,7 +872,7 @@ assert(loopInfo.tids.size() == loopInfo.dims.size()); SmallVector red; if (llvm::isa(loopInfo.loop)) { - exitCoIterationLoop(rewriter, loc, reduc); + exitWhileLoop(rewriter, loc, reduc); } else { exitForLoop(rewriter, loc, reduc); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1375,10 +1375,11 @@ SmallVector tids; SmallVector dims; env.merger().foreachTidDimPairInBits( - env.lat(l0).bits, [&](unsigned b, unsigned tid, - std::optional dim, DimLevelType dlt) { + env.lat(l0).bits, + [&](unsigned b, unsigned tid, std::optional dim, + DimLevelType dlt, bool isSlice) { assert(env.merger().index(b) == idx); - if (isDenseDLT(dlt) || isUndefDLT(dlt)) { + if (!isSlice && (isDenseDLT(dlt) || isUndefDLT(dlt))) { needsUniv = true; } else { // sparse/singleton dim levels. @@ -1395,7 +1396,8 @@ unsigned lsize = env.set(lts).size(); for (unsigned i = 1; i < lsize; i++) { unsigned li = env.set(lts)[i]; - if (!env.merger().hasAnySparse(env.lat(li).simple)) + if (!env.merger().hasAnySparse(env.lat(li).simple) && + !env.merger().hasAnySlice(env.lat(li).simple)) return true; } } @@ -1448,14 +1450,20 @@ // Converts bits to array + dim pair env.merger().foreachTidDimPairInBits( all, [&, idx](unsigned b, unsigned tid, std::optional dim, - DimLevelType dlt) { + DimLevelType dlt, bool isSlice) { + if (isSlice) { + // We need to coiterate using slice unconditinally. + tids.push_back(tid); + dims.push_back(*dim); + numloopCond++; + return; + } if (simple.test(b)) { if (isUndefDLT(dlt)) { // An undefined dlt in the lattices, we probably mean to iterate // based on the dim of output tensor. - // E.g., this could be a synthetic tensor (for invariants and sparse - // output tensor). - // out[i][j] = invariant; or a broadcast + // E.g., this could be a synthetic tensor (for invariants and + // sparse output tensor). out[i][j] = invariant; or a broadcast // out[i][j] = in[i] (j is undef for input) tid = env.merger().getOutTensorID(); dim = env.merger().getDimNum(tid, idx); @@ -1481,16 +1489,16 @@ if (!stt.hasEncoding()) return; + auto enc = stt.getEncoding(); ArrayRef affines = op.getMatchingIndexingMap(operand).getResults(); - const Level lvlRank = stt.getLvlRank(); - assert(affines.size() == static_cast(lvlRank)); - for (Level l = 0; l < lvlRank; l++) { - // FIXME: `toOrigDim` is deprecated. - AffineExpr exp = affines[toOrigDim(stt.getEncoding(), l)]; - // Skip simple affine expression and non dense dimensions (which has - // it own filter loop). - if (exp.isa() || !stt.isDenseLvl(l)) + assert(affines.size() == stt.getLvlRank()); + for (unsigned l = 0, e = stt.getLvlRank(); l < e; l++) { + AffineExpr exp = affines[toOrigDim(enc, l)]; + // Skip simple affine expression and non dense dimensions (which + // has it own filter loop). + if (exp.isa() || + !isDenseDLT(enc.getDimLevelType()[l])) continue; // Constant affine expression are handled in genLoop @@ -1502,8 +1510,8 @@ // affine expression. This is also the best place we can do it // to avoid putting it inside inner loops. // NOTE: It assumes that the levels of the input tensor are - // initialized in order (and it is also currently guaranteed by - // computeIterationGraph), another more admissible approach + // initialized in order (and it is also currently guaranteed + // by computeIterationGraph), another more admissible approach // might be accepting out-of-order access between consecutive // dense levels. affineTids.push_back(tid); @@ -1531,8 +1539,9 @@ } /// Starts a single loop in current sequence. -static Operation *startLoop(CodegenEnv &env, OpBuilder &builder, unsigned at, - unsigned li, bool needsUniv) { +static std::pair startLoop(CodegenEnv &env, + OpBuilder &builder, unsigned at, + unsigned li, bool needsUniv) { // The set of tensors + dims to generate loops on SmallVector tids, dims; // The set of dense tensors with non-trivial affine expression that just @@ -1540,11 +1549,12 @@ // level. SmallVector affineTids, affineDims; SmallVector affines; - bool isFor = translateBitsToTidDimPairs( + bool isSingleCond = translateBitsToTidDimPairs( env, li, env.topSortAt(at), tids, dims, affineTids, affineDims, affines); // Emit the for/while-loop control. - Operation *loop = genLoop(env, builder, at, needsUniv, tids, dims, isFor); + Operation *loop = + genLoop(env, builder, at, needsUniv, tids, dims, isSingleCond); for (auto [tid, dim, exp] : llvm::zip(affineTids, affineDims, affines)) { env.emitter().genDenseAffineAddressAtCurLevel(builder, env.op().getLoc(), tid, dim, exp); @@ -1560,7 +1570,7 @@ genConstantDenseAddressFromLevel(env, builder, tid, dim + 1); } - return loop; + return std::make_pair(loop, isSingleCond); } /// Ends a single loop in current sequence. Returns new values for needsUniv. @@ -1625,20 +1635,19 @@ for (unsigned i = 0; i < lsize; i++) { // Start a loop. unsigned li = env.set(lts)[i]; - Operation *loop = startLoop(env, rewriter, at, li, needsUniv); + auto [loop, isSingleCond] = startLoop(env, rewriter, at, li, needsUniv); // Visit all lattices points with Li >= Lj to generate the // loop-body, possibly with if statements for coiteration. Value redInput = env.getReduc(); Value cntInput = env.getExpandCount(); Value insInput = env.getInsertionChain(); - bool isWhile = dyn_cast(loop) != nullptr; for (unsigned j = 0; j < lsize; j++) { unsigned lj = env.set(lts)[j]; unsigned ej = env.lat(lj).exp; if (li == lj || env.merger().latGT(li, lj)) { // Recurse into body of each branch. - if (isWhile) { + if (!isSingleCond) { scf::IfOp ifOp = genIf(env, rewriter, idx, env.lat(lj).simple); genStmt(env, rewriter, ej, at + 1); endIf(env, rewriter, ifOp, loop, redInput, cntInput, insInput); @@ -1753,6 +1762,12 @@ if (!isAdmissible) return failure(); // inadmissible expression, reject + // Sort the related loop array such that they are in the same order as they + // appears on the topoOrder. + // TODO: since we only handle affine addition for slice based codegen, and + // addition is assoicative, the order how we evaluate the expression does + // not matter. However, to support multiplication, the order of the loop + // index should match the evaluation order to the affine expression AST. for (OpOperand &t : env.op()->getOpOperands()) { unsigned rank = env.op().getMatchingIndexingMap(&t).getNumResults(); for (unsigned i = 0; i < rank; i++) { @@ -1765,6 +1780,9 @@ // Recursively generates code if admissible. env.startEmit(); genBuffers(env, rewriter); + // TODO: Constant affine expression should be handled differently when using + // slice-based codegen, it does not matter now becasue we already reject the + // constant expression at a earlier stage. genInitConstantDenseAddress(env, rewriter); genStmt(env, rewriter, env.getTensorExp(), 0); genResult(env, rewriter); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -218,7 +218,8 @@ ldxToDependentSlice(numTensors, std::vector>( numLoops, std::nullopt)), sliceToRelatedldx(numTensors, std::vector>( - numLoops, std::vector())) {} + numLoops, std::vector())), + loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {} //===----------------------------------------------------------------------===// // Lattice methods. @@ -354,7 +355,7 @@ } BitVector simple = latPoints[p0].bits; - bool reset = isSingleton && hasAnySparse(simple); + bool reset = isSingleton && (hasAnySparse(simple) || hasAnySlice(simple)); unsigned be = simple.size(); unsigned offset = 0; // relative to the end if (!reset) @@ -371,7 +372,8 @@ // keep the rightmost bit (which could possibly be a synthetic tensor). for (unsigned b = be - 1 - offset, i = 0; i < be; b = b == 0 ? be - 1 : b - 1, i++) { - if (simple[b] && (!isCompressedDLT(getDimLevelType(b)) && + if (simple[b] && (!isSliceDim(b) && // keep the bit for slice-based loop + !isCompressedDLT(getDimLevelType(b)) && !isSingletonDLT(getDimLevelType(b)))) { if (reset) simple.reset(b); @@ -397,7 +399,7 @@ bool Merger::onlyDenseDiff(unsigned i, unsigned j) { BitVector tmp = latPoints[j].bits; tmp ^= latPoints[i].bits; - return !hasAnySparse(tmp); + return !hasAnySparse(tmp) && !hasAnySlice(tmp); } bool Merger::expContainsTensor(unsigned e, unsigned t) const { @@ -543,6 +545,13 @@ return false; } +bool Merger::hasAnySlice(const BitVector &bits) const { + for (unsigned b = 0, be = bits.size(); b < be; b++) + if (bits[b] && isSliceDim(b)) + return true; + return false; +} + #ifndef NDEBUG //===----------------------------------------------------------------------===// @@ -755,9 +764,10 @@ unsigned t = tensor(b); unsigned i = index(b); DimLevelType dlt = dimTypes[t][i]; - llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt); - if (ldxToDependentSlice[t][i]) - llvm::dbgs() << "_D_" << *ldxToDependentSlice[t][i]; + if (isSliceDim(b)) + llvm::dbgs() << " DEP_" << t << "_" << i; + else + llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt); } } }