diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h @@ -453,6 +453,9 @@ return tid < lvlTypes.size() && lvl < lvlTypes[tid].size(); } + void forwardsReducedSliceLevelTreeIt(OpBuilder &builder, Location loc, + TensorId tid, Level lvl, Value fcnt); + /// Prepares loop for iterating over `tensor[lvl]`, under the assumption /// that `tensor[0...lvl-1]` loops have already been set up. void prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" using namespace mlir; using namespace mlir::sparse_tensor; @@ -41,6 +42,13 @@ // File local helper functions. //===----------------------------------------------------------------------===// +// For index reduction loops, since the tensor are sliced into uncontinuous +// fragments, we need a tuple [pLo, pHi, pPtr], in which the pair (pLo, pHi) +// specifies the range of the fragment, and pPtr specifies the index of the +// corresponding fragment in the child level (i.e., a pointer to the sliced +// position array). +static constexpr unsigned kSliceIterWidth = 3; + static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor, Level lvl) { auto enc = getSparseTensorEncoding(tensor.getType()); @@ -123,6 +131,27 @@ return ifOp.getResult(0); } +static void dumpIndexMemRef(OpBuilder &builder, Location loc, Value memref) { + memref = builder.create( + loc, UnrankedMemRefType::get(builder.getIndexType(), 0), memref); + createFuncCall(builder, loc, "printMemrefInd", TypeRange{}, + ValueRange{memref}, EmitCInterface::On); +} + +static Value loadSlicePosPtr(OpBuilder &builder, Location loc, Value sPosBuf) { + return genIndexLoad(builder, loc, sPosBuf, C_IDX(1)); +} + +static void updateSlicePosPtr(OpBuilder &builder, Location loc, Value sPosBuf, + Value pPtr) { + builder.create(loc, pPtr, sPosBuf, C_IDX(1)); +} + +static Value loadSliceNextPosPtrStart(OpBuilder &builder, Location loc, + Value sPosBuf, Value pPtr) { + return genIndexLoad(builder, loc, sPosBuf, ADDI(pPtr, C_IDX(4))); +} + std::pair LoopEmitter::genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd, TensorId tid, Level lvl) { @@ -571,18 +600,6 @@ // If this is a unresolved-slice-driven loop, pops out the slice. assert(sliceStack[tid].back().slicedOnLvl == lvl); sliceStack[tid].pop_back(); - } else { - if (!isDenseDLT(lvlTypes[tid][lvl])) { - // Else this is a resolved-slice, and advance posit similar to TACO. - Value c1 = C_IDX(1), c2 = C_IDX(2); - // pIdx += 2, we finished the current lvl, advance the pointer index of - // the previous level by two to skip the [pLo, pHi] for current level. - Value sPtrBuf = slicePosBuffer[tid][lvl].back(); - Value curP = genIndexLoad(builder, loc, sPtrBuf, c1); - // TODO: we could probably use an SSA value for it. - Value nexP = ADDI(curP, c2); - builder.create(loc, nexP, sPtrBuf, c1); - } } } loopSeqStack.pop_back(); @@ -1297,11 +1314,11 @@ // Pushes sliced levels to build correct LoopInfo. bool unReduc = isAffineIdxUnRedCond(denseLoopCond); SliceInfo &info = sliceStack[tid].back(); + // Pushes sliced dense loop info to tell LoopEmitter how to exit it. + sliceInfo.emplace_back(tid, lvl, /*fullyReduced=*/!unReduc); if (unReduc) { - // Pushes sliced dense loop info to tell LoopEmitter how to exit it. - sliceInfo.emplace_back(tid, lvl, /*fullyReduced=*/false); - // Update the slice information as we enter the new loop. assert(*info.slicedOnLvl == lvl); + // Update the slice information as we enter the new loop. info.minCrd = info.offset = iv; info.isNonEmpty = constantI1(builder, loc, true); levelReducedDep[tid][lvl]++; @@ -1331,27 +1348,31 @@ } } -void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc, +void LoopEmitter::exitForLoop(RewriterBase &builder, Location loc, MutableArrayRef reduc) { const LoopInfo &loopInfo = loopStack.back(); for (auto [tid, lvl, reduced] : loopInfo.sliceDrivenInfo) { - SliceInfo &info = sliceStack[tid].back(); - assert(isDenseDLT(lvlTypes[tid][lvl])); - assert(*info.slicedOnLvl == lvl && !reduced); - (void)reduced; - // Resets slices pointers as the resolved slices are invalidated after we - // moves forward to the next slice. - invalidateSliceIterIdx(rewriter, loc, tid, lvl); - info.minCrd = info.offset = info.isNonEmpty = Value(); - levelReducedDep[tid][lvl]--; + if (!reduced) { + SliceInfo &info = sliceStack[tid].back(); + assert(isDenseDLT(lvlTypes[tid][lvl])); + assert(*info.slicedOnLvl == lvl); + (void)reduced; + // Resets slices pointers as the resolved slices are invalidated after we + // moves forward to the next slice. + invalidateSliceIterIdx(builder, loc, tid, lvl); + info.minCrd = info.offset = info.isNonEmpty = Value(); + levelReducedDep[tid][lvl]--; + } else { + forwardsReducedSliceLevelTreeIt(builder, loc, tid, lvl, C_IDX(1)); + } } if (auto forOp = llvm::dyn_cast(loopInfo.loop)) { if (!reduc.empty()) { assert(reduc.size() == forOp.getNumResults()); - rewriter.create(loc, reduc); + builder.create(loc, reduc); } // Exit the loop. - rewriter.setInsertionPointAfter(forOp); + builder.setInsertionPointAfter(forOp); // In-place update reduction variables. for (unsigned i = 0, e = forOp.getResults().size(); i < e; i++) reduc[i] = forOp.getResult(i); @@ -1387,22 +1408,22 @@ assert(numUsers == 1); #endif // NDEBUG - rewriter.setInsertionPointAfter(redExp); - auto redOp = rewriter.create(loc, curVal); + builder.setInsertionPointAfter(redExp); + auto redOp = builder.create(loc, curVal); // Attach to the reduction op. Block *redBlock = &redOp.getRegion().getBlocks().front(); - rewriter.setInsertionPointToEnd(redBlock); - Operation *newRed = rewriter.clone(*redExp); + builder.setInsertionPointToEnd(redBlock); + Operation *newRed = builder.clone(*redExp); // Replaces arguments of the reduction expression by using the block // arguments from scf.reduce. - rewriter.updateRootInPlace( + builder.updateRootInPlace( newRed, [&]() { newRed->setOperands(redBlock->getArguments()); }); // Erases the out-dated reduction expression. - rewriter.eraseOp(redExp); - rewriter.setInsertionPointToEnd(redBlock); - rewriter.create(loc, newRed->getResult(0)); + builder.eraseOp(redExp); + builder.setInsertionPointToEnd(redBlock); + builder.create(loc, newRed->getResult(0)); } - rewriter.setInsertionPointAfter(parOp); + builder.setInsertionPointAfter(parOp); // In-place update reduction variables. for (unsigned i = 0, e = parOp.getResults().size(); i < e; i++) reduc[i] = parOp.getResult(i); @@ -1421,6 +1442,54 @@ } } +void LoopEmitter::forwardsReducedSliceLevelTreeIt(OpBuilder &builder, + Location loc, TensorId tid, + Level rootLvl, Value fcnt) { + auto stt = getSparseTensorType(tensors[tid]); + + // Find a [Lvl, leafLvl) range, and all level in between are fully reduced + // level (but not resolved). Since we forwards a iterator at higher level of + // the tree, the subtree need to be pruned. + Level leafLvl = rootLvl + 1; + while (leafLvl < stt.getLvlRank() && !dependentLvlMap[tid][leafLvl].empty()) { + assert(depFullyReduced(tid, leafLvl)); + leafLvl++; + } + + Level curLvl = rootLvl + 1; + // Prunes all denses subtree. + while (curLvl < leafLvl && isDenseDLT(lvlTypes[tid][curLvl])) { + fcnt = MULI(sliceSizes[tid][curLvl].back(), fcnt); + curLvl++; + } + + Value nxPosPtr = nullptr; + if (curLvl < leafLvl) { + assert(!isDenseDLT(lvlTypes[tid][curLvl])); + Value sPosBuf = slicePosBuffer[tid][curLvl].back(); + Value fPosPtr = MULI(fcnt, C_IDX(kSliceIterWidth)); // forward ptr + Value pPosPtr = loadSlicePosPtr(builder, loc, sPosBuf); // previous ptr + Value cPosPtr = ADDI(fPosPtr, pPosPtr); // current ptr + updateSlicePosPtr(builder, loc, sPosBuf, cPosPtr); + // dumpIndexMemRef(builder, loc, sPosBuf); + // Loads the position pointer start for next level. + nxPosPtr = genIndexLoad(builder, loc, sPosBuf, ADDI(cPosPtr, C_IDX(1))); + curLvl++; + } + + // TODO: This is not always needed (only needed when the level is can be + // skipped without traversing the child levels). + for (; curLvl < leafLvl; curLvl++) { + assert(nxPosPtr); + if (!isDenseDLT(lvlTypes[tid][curLvl])) { + nxPosPtr = MULI(nxPosPtr, C_IDX(kSliceIterWidth)); + Value sPosBuf = slicePosBuffer[tid][curLvl].back(); + updateSlicePosPtr(builder, loc, sPosBuf, nxPosPtr); + nxPosPtr = genIndexLoad(builder, loc, sPosBuf, ADDI(nxPosPtr, C_IDX(1))); + } + } +} + void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc, MutableArrayRef reduc) { const LoopInfo &loopInfo = loopStack.back(); @@ -1448,17 +1517,25 @@ continue; } + Value forwarded = nullptr; if (loopInfo.trivialTidLvls.empty() && loopInfo.sliceDrivenInfo.size() == 1) { // Forwards the position iterator. operands.push_back(ADDI(posits[tid][lvl], one)); + forwarded = constantI1(builder, loc, true); } else { const Value pos = posits[tid][lvl]; const Value nxPos = ADDI(posits[tid][lvl], one); - Value cmp = CMPI(eq, coords[tid][lvl], iv); - operands.push_back(SELECT(cmp, nxPos, pos)); + forwarded = CMPI(eq, coords[tid][lvl], iv); + operands.push_back(SELECT(forwarded, nxPos, pos)); + } + { + OpBuilder::InsertionGuard guard(builder); + auto ifOp = builder.create(loc, TypeRange{}, forwarded, + /*else=*/false); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + forwardsReducedSliceLevelTreeIt(builder, loc, tid, lvl, one); } - // The coordinate is invalid now. coords[tid][lvl] = nullptr; @@ -1656,7 +1733,6 @@ } // Generates a loop nest that traverse all the unresolved levels in between. -// TODO: it can only handle all compressed tensors. // // for(int i = 0; i < slicePos.size(); i+=2) { // loopLo = slicePos[i]; @@ -1683,6 +1759,15 @@ OpBuilder::InsertPoint ip; SmallVector innerArgs(userReduc.begin(), userReduc.end()); scf::ForOp outerMost = nullptr; // the outtermost loop. + + // Wrap body builder and insert a extr counting instruction at the end. + auto wrapped = [bodyBuilder](OpBuilder &builder, Location loc, Value iv, + MutableArrayRef reduc) { + bodyBuilder(builder, loc, iv, reduc.drop_back()); + // Increments the counter. + reduc.back() = ADDI(reduc.back(), C_IDX(1)); + }; + if (firstResLvl.has_value()) { // Overwrite position when the first level is fully resolved. pos = posits[firstResLvl->first][firstResLvl->second]; @@ -1692,13 +1777,18 @@ Level firstLvl = *frontSlice.slicedOnLvl; if (!lvlFullyResolved(tid, firstLvl)) { if (isCompressedDLT(lvlTypes[tid][firstLvl])) { + // An extra counter that tracks how many segments are there in the child + // compressed level. + innerArgs.push_back(c0); + // Overrides the user-provided builder. + bodyBuilder = wrapped; unsigned depth = frontSlice.depth - 1; Value offset = frontSlice.offset; Value sPtrBuf = slicePosBuffer[tid][firstLvl][depth]; Value mSz = genIndexLoad(builder, loc, sPtrBuf, c0); // memSize outerMost = builder.create( - loc, c2, mSz, c2, innerArgs, - [this, c1, tid, firstLvl, offset, sPtrBuf, &ip, &pos, + loc, c2, mSz, C_IDX(kSliceIterWidth), innerArgs, + [this, c1, c2, tid, firstLvl, offset, sPtrBuf, &ip, &pos, &innerArgs](OpBuilder &builder, Location loc, Value iv, ValueRange iterArgs) { // generate traversal for each level. @@ -1715,6 +1805,9 @@ innerArgs.assign(reduc.begin(), reduc.end()); }) .second; + // Marks downs the pPtr for next level. + builder.create(loc, itArgs.back(), sPtrBuf, + ADDI(iv, c2).getResult()); YIELD(itArgs); }); } else if (isDenseDLT(lvlTypes[tid][firstLvl])) { @@ -1855,8 +1948,7 @@ TensorId tid, Level lvl) { Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2); unsigned depth = levelReducedDep[tid][lvl]; - Value size = sliceSizes[tid][lvl][depth]; - // Dense slice begin is trivial + Value size = sliceSizes[tid][lvl][depth]; // Dense slice begin is trivial if (isDenseDLT(lvlTypes[tid][lvl])) { sliceStack[tid].emplace_back(c0, c0, constantI1(builder, loc, false), lvl, depth + 1); @@ -1902,9 +1994,8 @@ ValueRange result = genUnResolvedSliceTreeTraverse( builder, loc, tid, unResSlices, firstResLvl, reduc, - [this, c1, c2, tid, lvl, sPtrBuf](OpBuilder &builder, Location loc, - Value iv, - MutableArrayRef reduc) { + [this, c1, tid, lvl, sPtrBuf](OpBuilder &builder, Location loc, Value iv, + MutableArrayRef reduc) { Value &nonEmpty = reduc[0]; Value &minCrd = reduc[1]; Value &curMemSz = reduc[2]; @@ -1942,8 +2033,8 @@ builder.create(loc, sPLo, sPtrBuf, curMemSz); Value nxtMemSize = ADDI(curMemSz, c1); builder.create(loc, sPHi, sPtrBuf, nxtMemSize); - // curMemSize += 2 - curMemSz = ADDI(curMemSz, c2); + // curMemSize += kSliceIterWidth + curMemSz = ADDI(curMemSz, C_IDX(kSliceIterWidth)); }); Value isNonEmpty = result[0]; @@ -1962,6 +2053,7 @@ Value c1 = C_IDX(1), c2 = C_IDX(2); if (depFullyReduced(tid, lvl)) { + // Do not need to prepare for slice driven loop on dense level after it is // fully reduced. if (isDenseDLT(lvlTypes[tid][lvl])) @@ -1969,8 +2061,9 @@ // If constraints on the tensor is fully resolved. We do not need to // generates slice begin any more, instead we fall back to TACO-based // algorithm to (co)iterates over the slice. + // dumpIndexMemRef(builder, loc, slicePosBuffer[tid][lvl].back()); Value pLoPtr = - genIndexLoad(builder, loc, slicePosBuffer[tid][lvl].back(), c1); + loadSlicePosPtr(builder, loc, slicePosBuffer[tid][lvl].back()); pLoPtr = ADDI(pLoPtr, c2); Value pHiPtr = ADDI(pLoPtr, c1); posits[tid][lvl] = @@ -2022,10 +2115,10 @@ Value sz = *(sliceSizes[tid][lvl].rbegin() + depth - 1); bufSize = MULI(bufSize, sz); } - // For a pair of [pLo, pHi]. Note that we can not compress pHi because + // For a tuple of [pLo, pHi, N]. Note that we can not compress pHi because // slice creates segments in the index buffer so that the pHi for the // current level is no longer the pLo for the next level. - bufSize = MULI(bufSize, c2); + bufSize = MULI(bufSize, C_IDX(kSliceIterWidth)); // Additional two metadata {memSize, idx} at head. bufSize = ADDI(bufSize, c2); llvm::for_each( @@ -2049,8 +2142,7 @@ TensorId tid, Level lvl) { for (unsigned i = 0; i <= lvl; i++) { if (!isDenseDLT(lvlTypes[tid][i]) && !dependentLvlMap[tid][i].empty()) { - builder.create(loc, C_IDX(0), - slicePosBuffer[tid][i].back(), C_IDX(1)); + updateSlicePosPtr(builder, loc, slicePosBuffer[tid][i].back(), C_IDX(0)); } } } @@ -2103,7 +2195,7 @@ YIELD(reduc); // else /*minCrd == offset*/ { - // for (i = 0; i < slicePos.size(); i+=2) { + // for (i = 0; i < slicePos.size(); i+=kSliceIterWidth) { // if (crd[pos[slicePos[i]]] == minCrd) { // slicePos[i]++; // } @@ -2119,7 +2211,7 @@ reduc[1] = constantI1(builder, loc, false); // isNonEmpty auto loopArgs = static_cast(reduc).drop_back(); auto forOp = scf::buildLoopNest( - builder, loc, pSt, mSz, c2, loopArgs, + builder, loc, pSt, mSz, C_IDX(kSliceIterWidth), loopArgs, [this, tid, lvl, c1, sPtrBuf, &info](OpBuilder &builder, Location loc, ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector { diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/dual_sparse_conv_2d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/dual_sparse_conv_2d.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/dual_sparse_conv_2d.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/dual_sparse_conv_2d.mlir @@ -2,7 +2,7 @@ // DEFINE: %{compile} = mlir-opt %s --sparse-compiler=%{option} // DEFINE: %{run} = mlir-cpu-runner \ // DEFINE: -e entry -entry-point-result=void \ -// DEFINE: -shared-libs=%mlir_c_runner_utils | \ +// DEFINE: -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils | \ // DEFINE: FileCheck %s // // RUN: %{compile} | %{run} @@ -152,21 +152,20 @@ : tensor<6x6xi32>, vector<6x6xi32> vector.print %v : vector<6x6xi32> - // FIXME: DCSR still wrong // // Should be the same as dense output - // C_HECK: ( ( 0, 0, -1, -6, -1, 6 ), - // C_HECK-SAME: ( -1, 0, 1, 0, 1, 0 ), - // C_HECK-SAME: ( 0, -1, 1, 0, 0, 0 ), - // C_HECK-SAME: ( -1, 0, 0, 0, 0, 0 ), - // C_HECK-SAME: ( 0, 0, 3, 6, -3, -6 ), - // C_HECK-SAME: ( 2, -1, 3, 0, -3, 0 ) ) + // CHECK: ( ( 0, 0, -1, -6, -1, 6 ), + // CHECK-SAME: ( -1, 0, 1, 0, 1, 0 ), + // CHECK-SAME: ( 0, -1, 1, 0, 0, 0 ), + // CHECK-SAME: ( -1, 0, 0, 0, 0, 0 ), + // CHECK-SAME: ( 0, 0, 3, 6, -3, -6 ), + // CHECK-SAME: ( 2, -1, 3, 0, -3, 0 ) ) // - // %all_sparse_DCSR = sparse_tensor.convert %2 - // : tensor<6x6xi32, #DCSR> to tensor<6x6xi32> - // %v2 = vector.transfer_read %all_sparse_DCSR[%c0, %c0], %i0 - // : tensor<6x6xi32>, vector<6x6xi32> - // vector.print %v2 : vector<6x6xi32> + %all_sparse_DCSR = sparse_tensor.convert %2 + : tensor<6x6xi32, #DCSR> to tensor<6x6xi32> + %v2 = vector.transfer_read %all_sparse_DCSR[%c0, %c0], %i0 + : tensor<6x6xi32>, vector<6x6xi32> + vector.print %v2 : vector<6x6xi32> // // Should be the same as dense output diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d.mlir @@ -2,7 +2,7 @@ // DEFINE: %{compile} = mlir-opt %s --sparse-compiler=%{option} // DEFINE: %{run} = mlir-cpu-runner \ // DEFINE: -e entry -entry-point-result=void \ -// DEFINE: -shared-libs=%mlir_c_runner_utils | \ +// DEFINE: -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils | \ // DEFINE: FileCheck %s // // RUN: %{compile} | %{run}