diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -56,6 +56,13 @@ /// for the corresponding `SmallVector` object. using LatSetId = unsigned; +/// A pair of level and its corresponding DimLevelType of a tensor. +using LvlDLTPair = std::pair; + +/// A pair of loop id and its coefficients. E.g., for affine expression in the +/// affine map `2 * d0`, loop id = 0, coefficient = 2. +using LoopCoeffPair = std::pair; + /// Tensor expression. Represents an MLIR expression in tensor index notation. struct TensorExp final { enum class Kind; @@ -509,22 +516,22 @@ /// Establishes the two-way map that i <-> . void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl, - DimLevelType dlt) { + DimLevelType dlt, unsigned coefficient) { assert(isValidLoopId(i) && isValidLevel(t, lvl)); - assert(!loopToDependencies[i][t].has_value()); // must be the first def - loopToDependencies[i][t] = std::make_pair(lvl, dlt); - levelToDependentLoop[t][lvl].push_back(i); + assert(!loopToUnresolvedLvls[i][t].has_value()); // must be the first def + loopToUnresolvedLvls[i][t] = std::make_pair(lvl, dlt); + levelToDependentLoop[t][lvl].emplace_back(i, coefficient); } /// Whether the loop has dependent slice. bool hasDependentLvl(LoopId i, TensorId t) { assert(isValidTensorId(t) && isValidLoopId(i)); - return loopToDependencies[i][t].has_value(); + return loopToUnresolvedLvls[i][t].has_value(); } /// Returns the list of loop indices which appear in the non-trivial index /// expression on t_l, e.g., A[i+j] => {i, j} - std::vector &getDependentLoops(TensorId t, Level lvl) { + std::vector &getDependentLoops(TensorId t, Level lvl) { assert(isValidLevel(t, lvl)); return levelToDependentLoop[t][lvl]; } @@ -541,7 +548,7 @@ const TensorId t = tensor(b); const LoopId i = loop(b); assert(isValidTensorId(t) && isValidLoopId(i)); - return loopToDependencies[i][t].has_value(); + return loopToUnresolvedLvls[i][t].has_value(); } /// Checks whether the TensorLoopId represents a sparse tensor level contains @@ -556,12 +563,12 @@ Level getLoopDependentLevel(TensorLoopId b) const { assert(isLvlWithNonTrivialIdxExp(b)); - return loopToDependencies[loop(b)][tensor(b)]->first; + return loopToUnresolvedLvls[loop(b)][tensor(b)]->first; } DimLevelType getLoopDependentLevelType(TensorLoopId b) const { assert(isLvlWithNonTrivialIdxExp(b)); - return loopToDependencies[loop(b)][tensor(b)]->second; + return loopToUnresolvedLvls[loop(b)][tensor(b)]->second; } /// Convenience getters to immediately access the stored nodes. @@ -715,13 +722,13 @@ /// It is currently only set for non-trivial index expressions. /// E.g., A[i+j] => i and j will have dependencies {A0, dlt(A0)} to indicate /// that i and j are used in the non-trivial index expression on A0. - std::vector>>> - loopToDependencies; + std::vector>> loopToUnresolvedLvls; /// The inverse map of ldxToDependencies from tensor level -> dependent loop - /// E.g., A[i+j], we have A0 => {i, j}, to indicate that A0 uses both {i, j} - /// to compute its indices. - std::vector>> levelToDependentLoop; + /// E.g., A[2i+j], we have A0 => {(2, i), (1, j)}, to indicate that A0 uses + /// both {i, j} to compute its indices and the coefficients on the loop id are + /// 2 and 1 respectively. + std::vector>> levelToDependentLoop; /// Map from a loop to the [tid, lvl] pair that defines the loop boundary. std::vector> loopBounds; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h @@ -96,6 +96,9 @@ loopEmitter.getSynTensorId() == latticeMerger.getSynTensorID()); return loopEmitter.makeTensorLevel(t, l); } + TensorLevel makeTensorLevel(std::pair tlPair) const { + return makeTensorLevel(tlPair.first, tlPair.second); + } std::pair unpackTensorLevel(TensorLevel tl) const { return loopEmitter.unpackTensorLevel(tl); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp @@ -29,16 +29,16 @@ } /// Makes target array's elements sorted according to the `order` array. -static void sortArrayBasedOnOrder(std::vector &target, +static void sortArrayBasedOnOrder(std::vector &target, ArrayRef order) { std::sort(target.begin(), target.end(), - [&order](const LoopId &l, const LoopId &r) { + [&order](const LoopCoeffPair &l, const LoopCoeffPair &r) { assert(std::addressof(l) == std::addressof(r) || l != r); int idxL = -1, idxR = -1; for (int i = 0, e = order.size(); i < e; i++) { - if (order[i] == l) + if (order[i] == l.first) idxL = i; - if (order[i] == r) + if (order[i] == r.first) idxR = i; } assert(idxL >= 0 && idxR >= 0); @@ -104,13 +104,17 @@ /*isSparseOut=*/sparseOut != nullptr, topSort, // TODO: compute the map and pass it to loop emitter directly instead of // passing in a callback. - [this](TensorId t, Level lvl) -> std::vector> { - // Translates from a list of loop index to a list of [tid, dim] pair. - std::vector rLoops = this->merger().getDependentLoops(t, lvl); - std::vector> ret; + /*dependentLvlGetter=*/ + [this](TensorId t, + Level lvl) -> std::vector> { + // Translates from a list of loop indices to a list of [tid, lvl] pair. + std::vector &rLoops = merger().getDependentLoops(t, lvl); + std::vector> ret; ret.reserve(rLoops.size()); - for (LoopId l : rLoops) - ret.emplace_back(this->merger().getLoopDefiningLvl(l)); + for (auto [loop, coeff] : rLoops) { + TensorLevel tl = makeTensorLevel(merger().getLoopDefiningLvl(loop)); + ret.emplace_back(tl, coeff); + }; return ret; }); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h @@ -84,18 +84,22 @@ using SynTensorBoundSetter = function_ref; - // Map from [tid, dim] to a list of dependent [tid, dim] for affine expression - // index on sparse tensors. - // E.g., for affine index (d0 + d1), it depends on two [tid, dim] that defines - // d0 and d1 (for affine expression reduction). + // Map from [tid, lvl] to a list of dependent [tidlvl, coeffecient] for + // subscript expressions on sparse tensors. + // + // E.g., for affine index (2 * d0 + d1), it depends on two tidlvls that + // defines d0 and d1 (for affine expression reduction) and uses 2 and 1 for + // cofficients on d0, d1 respectively. // If the list is empty, it means that there is no affine expression on the - // input [tid, dim]. + // input [tid, lvl]. + // // NOTE: The caller is responsible to ensure that the order of the returned // list to be consistent with the topological order of the iteration graph, // otherwise the loop emitter might reduce a wrong dependent index variable // when generating slice-driven loops. using DependentLvlGetter = - function_ref>(TensorId, Level)>; + function_ref>(TensorId, + Level)>; LoopEmitter() = default; @@ -335,9 +339,9 @@ // Whether this is the tensor that has not yet been sliced. bool isInitialTensor() const { return !slicedOnLvl.has_value(); } - Value minCrd; // the minimum coordinate of the slice. - Value offset; // the offset of the current slice. - Value isNonEmpty; // whether the slice is empty. + Value minCrd; // the minimum coordinate of the slice. + Value offset; // the *absolute* offset of the current slice. + Value isNonEmpty; // whether the slice is empty. std::optional slicedOnLvl; // the level on which the slice is done unsigned depth; // the depth (relative to dependentDimMap[tid][lvl]). }; @@ -645,10 +649,12 @@ bool genSliceBegin(OpBuilder &builder, Location loc, TensorId tid, Level lvl); /// Generates code to get the next non-empty slices of tid on lvl. - void genSliceNextInduction(OpBuilder &builder, Location loc, - const Operation *whileOp, TensorId tid, Level lvl, - SmallVectorImpl &operands, - unsigned &retIdx); + /// Returns a tuple of values for (see + /// SliceInfo) respectively. + std::tuple genSliceNextInduction(OpBuilder &builder, + Location loc, + TensorId tid, + Level lvl); /// A optional string attribute that should be attached to the loop /// generated by loop emitter, it might help following passes to identify @@ -707,9 +713,9 @@ std::vector> sliceOffsets; std::vector> sliceStrides; - // Map from [tid, level] to a list of dependent [tid, level]. - // See comments for `DependentDimGetter`. - std::vector>>> + // Map from [tid, level] to a list of dependent [tidlevel, coefficient]. + // See comments for `DependentLvlGetter`. + std::vector>>> dependentLvlMap; // The cached position buffer for the slices, they serve the same purpose as @@ -718,8 +724,9 @@ // to avoid iteration from the beginning. std::vector>> slicePosBuffer; - // The cached size for each slices. - std::vector>> sliceSizes; + // The (size, stride) for each conceptual slice used for index reduction + // loops. + std::vector>>> sliceMeta; // The number of reduced dependencies on a tensor level so far. std::vector> levelReducedDep; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -264,14 +264,11 @@ Value LoopEmitter::genSparseCrd(OpBuilder &builder, Location loc, TensorId tid, Level lvl) { - Value crd = C_IDX(0); // A load on the coordinates array yields the coordinate. const Value mem = coordinatesBuffers[tid][lvl]; /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header. const Value pos = posits[tid][lvl]; - const Value off = genIndexLoad(builder, loc, mem, pos); - // Linearized the coordinates within the same collapse reassociation. - crd = ADDI(crd, off); + const Value crd = genIndexLoad(builder, loc, mem, pos); return crd; } @@ -317,9 +314,10 @@ // Index-reduction related fields. this->dependentLvlMap.assign( - numTensors, std::vector>>()); + numTensors, std::vector>>()); this->slicePosBuffer.assign(numTensors, std::vector>()); - this->sliceSizes.assign(numTensors, std::vector>()); + this->sliceMeta.assign( + numTensors, std::vector>>()); this->sliceStack.assign(numTensors, std::vector()); this->levelReducedDep.assign(numTensors, std::vector()); @@ -367,10 +365,10 @@ // Slice-driven loops related initialization. levelReducedDep[tid].assign(lvlRank, 0); - dependentLvlMap[tid].assign(lvlRank, - std::vector>()); + dependentLvlMap[tid].assign( + lvlRank, std::vector>()); slicePosBuffer[tid].assign(lvlRank, std::vector()); - sliceSizes[tid].assign(lvlRank, std::vector()); + sliceMeta[tid].assign(lvlRank, std::vector>()); sliceStack[tid].emplace_back(/*minCrd=*/Value(), /*offset=*/Value(), /*isNonEmpty*/ Value(), std::nullopt, 0); @@ -380,8 +378,8 @@ unsigned depends = dependentLvlMap[tid][l].size(); if (depends == 0) continue; - // We need `depends - 1` slices to fully the affine expression. - sliceSizes[tid][l].assign(depends - 1, nullptr); + sliceMeta[tid][l].assign(depends, std::make_pair(nullptr, 0)); + // We need `depends - 1` slices to fully reduce the affine expression. slicePosBuffer[tid][l].assign(depends - 1, nullptr); } } @@ -502,15 +500,20 @@ Level lvlRank = SparseTensorType(rtp).getLvlRank(); for (Level lvl = 0; lvl < lvlRank; lvl++) { if (!dependentLvlMap[t][lvl].empty()) { - ArrayRef> depLvls = dependentLvlMap[t][lvl]; + ArrayRef> depLvls = + dependentLvlMap[t][lvl]; // Needs at least two operands to form a non-trivial affine expression. - assert(depLvls.size() > 1); + assert(depLvls.size() == sliceMeta[t][lvl].size()); Value size = c0; - for (unsigned e = depLvls.size() - 1; e >= 1; e--) { - auto [dt, dd] = depLvls[e]; - size = ADDI(size, lvlSizes[dt][dd]); - sliceSizes[t][lvl][e - 1] = size; + for (int e = depLvls.size() - 1; e >= 0; e--) { + auto [dt, dl] = unpackTensorLevel(depLvls[e].first); + unsigned stride = depLvls[e].second; + Value stridedSize = lvlSizes[dt][dl]; + if (stride != 1) + stridedSize = MULI(stridedSize, C_IDX(stride)); + size = ADDI(size, stridedSize); + sliceMeta[t][lvl][e] = std::make_pair(size, stride); } } } @@ -729,8 +732,9 @@ // crdHi is a loop invariant, hosit the computation outside the loop. if (llvm::isa_and_nonnull(loop)) builder.setInsertionPoint(loop); - crdHi = ADDI(getMostRecentSliceOnLvl(tid, lvl).offset, - sliceSizes[tid][lvl].back()); + auto [size, stride] = sliceMeta[tid][lvl].back(); + assert(stride == 1 && "Not yet implemented"); + crdHi = ADDI(getMostRecentSliceOnLvl(tid, lvl).offset, size); } assert(crdHi); return genSparseReducedAffineCond(builder, loc, @@ -984,7 +988,7 @@ if (sparseConds.size() > 1) return false; - // We also need a while loop for levels with affine index expression for + // We also need a while loop for levels with affine index expression and // non-unique levels when deduplication is required. if (sparseConds.size() == 1) { auto [tid, lvl] = unpackTensorLevel(sparseConds.back().first); @@ -1042,7 +1046,9 @@ if (isDenseCond(loopCondKind) && isAffineIdxCond(loopCondKind)) { bool unReduc = isAffineIdxUnRedCond(loopCondKind); assert(unReduc == !depFullyReduced(tid, lvl)); - hi = sliceSizes[tid][lvl][sliceStack[tid].back().depth - 1]; + auto [size, stride] = sliceMeta[tid][lvl][sliceStack[tid].back().depth]; + assert(stride == 1 && "Not yet implemented"); + hi = size; if (unReduc) { // Adjust for loop hi for dense slice-driven loop. hi = SUBI(lvlSizes[tid][lvl], hi); @@ -1215,6 +1221,8 @@ SliceInfo &info = sliceStack[tid].back(); // Pushes sliced dense loop info to tell LoopEmitter how to exit it. sliceInfo.emplace_back(tid, lvl, /*fullyReduced=*/!unReduc); + // FIXME: The offset and position iterator need to be adjusted when the + // slice is strided. if (unReduc) { assert(*info.slicedOnLvl == lvl); // Update the slice information as we enter the new loop. @@ -1361,7 +1369,9 @@ while (curLvl < leafLvl && isDenseDLT(lvlTypes[tid][curLvl])) { // One step forward in parent level results in forwarding `slice.size` step // in child dense level. - fcnt = MULI(sliceSizes[tid][curLvl].back(), fcnt); + auto [size, stride] = sliceMeta[tid][curLvl].back(); + assert(stride == 1 && "Not yet implemented"); + fcnt = MULI(size, fcnt); curLvl++; } @@ -1420,7 +1430,18 @@ // TODO: support coiterating multiple slices assert(loopInfo.trivialTidLvls.empty() && loopInfo.sliceDrivenInfo.size() == 1); - genSliceNextInduction(builder, loc, whileOp, tid, lvl, operands, o); + auto [nxNonEmpty, nxMinCrd, nxAbsOffset] = + genSliceNextInduction(builder, loc, tid, lvl); + // Update while loop induction operands. + operands.push_back(nxNonEmpty); + operands.push_back(nxMinCrd); + operands.push_back(nxAbsOffset); + + // Update the slice stack. + SliceInfo &info = sliceStack[tid].back(); + info.isNonEmpty = whileOp.getResult(o++); + info.minCrd = whileOp.getResult(o++); + info.offset = whileOp.getResult(o++); continue; } @@ -1566,7 +1587,10 @@ Value size, TensorId tid, Level lvl, ValueRange userReduc, LoopBodyBuilder bodyBuilder) { Value c1 = C_IDX(1); - Value sliceHi = ADDI(offset, sliceSizes[tid][lvl].back()); + auto [sliceSz, stride] = sliceMeta[tid][lvl].back(); + assert(stride == 1 && "Not yet implemented"); + Value sliceHi = ADDI(offset, sliceSz); + SmallVector reduc{posLo}; // loop lower bounds const unsigned numMetaReduc = reduc.size(); @@ -1663,6 +1687,8 @@ reduc.back() = ADDI(reduc.back(), C_IDX(1)); }; + // FIXME: Need special handling when the previous unresolved slice is strided: + // We probably need to filter out coordinates that is not on stride. if (firstResLvl.has_value()) { // Overwrite position when the first level is fully resolved. pos = posits[firstResLvl->first][firstResLvl->second]; @@ -1694,10 +1720,13 @@ // non-consecutive segments. builder.create(loc, iterArgs.back(), sPtrBuf, ADDI(iv, c2).getResult()); + + auto [size, stride] = sliceMeta[tid][firstLvl].back(); + assert(stride == 1 && "Not yet implemented"); ValueRange itArgs = genSliceLvlTraverseLoop( - builder, loc, loopLo, loopHi, offset, - sliceSizes[tid][firstLvl].back(), tid, firstLvl, iterArgs, + builder, loc, loopLo, loopHi, offset, size, tid, firstLvl, + iterArgs, [&](OpBuilder &builder, Location, Value iv, MutableArrayRef reduc) { ip = builder.saveInsertionPoint(); @@ -1710,8 +1739,9 @@ } else if (isDenseDLT(lvlTypes[tid][firstLvl])) { assert(firstLvl == 0); // This must be the first level. Value lb = frontSlice.offset; - Value sliceSz = - sliceSizes[tid][*frontSlice.slicedOnLvl][frontSlice.depth - 1]; + auto [sliceSz, stride] = + sliceMeta[tid][*frontSlice.slicedOnLvl][frontSlice.depth]; + assert(stride == 1 && "Not yet implemented"); Value ub = ADDI(lb, sliceSz); outerMost = builder.create( loc, lb, ub, c1, innerArgs, @@ -1735,7 +1765,8 @@ Level sliceLvl = *slice->slicedOnLvl; assert(isDenseDLT(lvlTypes[tid][sliceLvl])); Value offset = slice->offset; - Value sliceSz = sliceSizes[tid][sliceLvl][slice->depth - 1]; + auto [sliceSz, stride] = sliceMeta[tid][sliceLvl][slice->depth]; + assert(stride == 1 && "Not yet implemented"); lbs.push_back(offset); ubs.push_back(ADDI(offset, sliceSz)); steps.push_back(c1); @@ -1788,7 +1819,8 @@ lvl, /*depth=*/1); return; } - Value size = sliceSizes[tid][lvl][0]; + auto [nxSz, stride] = sliceMeta[tid][lvl][1]; + assert(stride == 1 && "Not yet implemented"); Value sPtrBuf = slicePosBuffer[tid][lvl][0]; Value pHi, pLo; if (lvl == 0) { @@ -1816,7 +1848,7 @@ Value minCrd = genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], pLo); // FIXME: We need the relative offset related to the base slice. - Value absOffset = offsetFromMinCoord(builder, loc, minCrd, size, isNonEmpty); + Value absOffset = offsetFromMinCoord(builder, loc, minCrd, nxSz, isNonEmpty); sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, lvl, /*depth=*/1); } @@ -1845,7 +1877,12 @@ TensorId tid, Level lvl) { Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2); unsigned depth = levelReducedDep[tid][lvl]; - Value size = sliceSizes[tid][lvl][depth]; // Dense slice begin is trivial + // TODO: handle case when the current slice stride is not one. + assert(sliceMeta[tid][lvl][depth].second == 1 && "Not yet implemented"); + + // The remaining slice size after reduction. + Value remSz = sliceMeta[tid][lvl][depth + 1].first; + // Dense slice begin is trivial if (isDenseDLT(lvlTypes[tid][lvl])) { sliceStack[tid].emplace_back(c0, c0, constantI1(builder, loc, false), lvl, depth + 1); @@ -1941,7 +1978,7 @@ builder.create(loc, result[2], sPtrBuf, c0); builder.create(loc, c0, sPtrBuf, c1); // FIXME: we need the relative offset related to the base slice. - Value absOffset = offsetFromMinCoord(builder, loc, minCrd, size, isNonEmpty); + Value absOffset = offsetFromMinCoord(builder, loc, minCrd, remSz, isNonEmpty); sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, lvl, depth + 1); } @@ -2005,9 +2042,12 @@ // TODO: Maybe using allocaScopeOp inside the loop to resolve the issue? for (Level curLevel = lvl; curLevel >= 1 && !lvlFullyResolved(tid, curLevel - 1); curLevel--) { - auto depth = remDepOnLevel(tid, curLevel - 1); - assert(sliceSizes[tid][lvl].size() >= depth); - Value sz = *(sliceSizes[tid][lvl].rbegin() + depth - 1); + // We only handle cases when all the previously unresolved levels are + // fully reduced. + assert(depFullyReduced(tid, curLevel - 1)); + assert(!sliceMeta[tid][curLevel - 1].empty()); + auto [sz, stride] = sliceMeta[tid][curLevel - 1].back(); + assert(stride == 1 && "Not yet implemented"); bufSize = MULI(bufSize, sz); } // For a triple of [pLo, pHi, pPtr]. Note that we can not compress pHi @@ -2042,18 +2082,15 @@ } } -void LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc, - const Operation *op, TensorId tid, - Level lvl, - SmallVectorImpl &operands, - unsigned &retIdx) { +std::tuple +LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc, + TensorId tid, Level lvl) { if (!isCompressedDLT(lvlTypes[tid][lvl])) llvm_unreachable("TODO"); // else generate code to compute next non empty slice. Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2); - auto whileOp = llvm::cast(op); SliceInfo &info = sliceStack[tid].back(); assert(info.slicedOnLvl == lvl); // @@ -2182,9 +2219,12 @@ builder.setInsertionPointAfter(forOp.loops.front()); // minOffset = minCrd + 1 >= size ? minCrd + 1 - size : c0 Value tmp = ADDI(forOp.results.front(), c1); - Value minOffset = SUBI(tmp, sliceSizes[tid][lvl][info.depth - 1]); - Value p = CMPI(uge, tmp, sliceSizes[tid][lvl][info.depth - 1]); + auto [size, stride] = sliceMeta[tid][lvl][info.depth]; + assert(stride == 1 && "Not yet implemented"); + Value minOffset = SUBI(tmp, size); + Value p = CMPI(uge, tmp, size); minOffset = SELECT(p, minOffset, c0); + SmallVector yields; yields.assign(forOp.results.begin(), forOp.results.end()); yields.push_back(minOffset); @@ -2200,7 +2240,9 @@ Value maxPred = CMPI(ugt, minOffset, nxOffset); Value nextAbsOffset = SELECT(maxPred, minOffset, nxOffset); - Value sliceUB = ADDI(nextAbsOffset, sliceSizes[tid][lvl][info.depth - 1]); + auto [size, stride] = sliceMeta[tid][lvl][info.depth]; + assert(stride == 1 && "Not yet implemented"); + Value sliceUB = ADDI(nextAbsOffset, size); // FIXME: this only works if there is only one parent. assert(info.depth - 1 == 0); @@ -2211,15 +2253,7 @@ assert(info.depth - 1 == 0); Value nextRelOffset = nextAbsOffset; nextRelOffset = SELECT(nextNonEmpty, nextRelOffset, c0); - - operands.push_back(nextNonEmpty); - operands.push_back(nextMinCrd); - operands.push_back(nextAbsOffset); // we push the absolute offset. - - // Update the slice stack. - info.isNonEmpty = whileOp.getResult(retIdx++); - info.minCrd = whileOp.getResult(retIdx++); - info.offset = whileOp.getResult(retIdx++); + return std::make_tuple(nextNonEmpty, nextMinCrd, nextAbsOffset); } #undef CMPI diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -282,10 +282,14 @@ /// /// TODO: constant should be easy to handle. static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl, - AffineExpr a, DimLevelType dlt, - bool isSubExp = false) { + AffineExpr a, DimLevelType dlt, bool isSubExp = false, + int64_t coefficient = 1) { switch (a.getKind()) { case AffineExprKind::DimId: { + // Only allow positive coefficients on AffineDimExpr. + if (coefficient <= 0) + return false; + const LoopId ldx = merger.makeLoopId(a.cast().getPosition()); if (!isUndefDLT(merger.getLvlType(tensor, ldx))) return false; // used more than once, e.g., A[i][i] @@ -293,8 +297,10 @@ // TODO: Generalizes the following two cases. A[i] (with trivial index // expression) can be treated as a special affine index expression. We do // not necessarily need to differentiate them. - if (!isSubExp) + if (!isSubExp) { + assert(coefficient == 1); merger.setLevelAndType(tensor, ldx, lvl, dlt); + } if (isSubExp) { // The current loops appears in more than one affine expressions on the @@ -312,14 +318,26 @@ // else increase min(d0_1, d0_2). return false; } - merger.setLoopDependentTensorLevel(ldx, tensor, lvl, dlt); + merger.setLoopDependentTensorLevel(ldx, tensor, lvl, dlt, coefficient); } return true; } case AffineExprKind::Constant: - case AffineExprKind::Mul: - // TODO: Support Mul and Constant AffineExp for slice-based codegen - return false; + // TODO: Support Constant AffineExp for slice-based codegen + case AffineExprKind::Mul: { + // TODO: Support index expression like `2 * d0`, we now only support more + // complicated cases like `2 * d0 + d1`. + if (!isSubExp) + return false; + auto binOp = a.cast(); + auto lhs = binOp.getLHS(), rhs = binOp.getRHS(); + if (rhs.isa()) + std::swap(lhs, rhs); + // Must be in form of `constant * d`. + assert(lhs.isa() && rhs.isa()); + int64_t coefficient = lhs.cast().getValue(); + return findDepIdxSet(merger, tensor, lvl, rhs, dlt, isSubExp, coefficient); + } case AffineExprKind::Add: { auto binOp = a.cast(); return findDepIdxSet(merger, tensor, lvl, binOp.getLHS(), dlt, true) && diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -232,11 +232,11 @@ std::vector>(numLoops, std::nullopt)), lvlToLoop(numTensors, std::vector>(maxLvlRank, std::nullopt)), - loopToDependencies( - numLoops, std::vector>>( - numTensors, std::nullopt)), - levelToDependentLoop(numTensors, std::vector>( - maxLvlRank, std::vector())), + loopToUnresolvedLvls(numLoops, std::vector>( + numTensors, std::nullopt)), + levelToDependentLoop(numTensors, + std::vector>( + maxLvlRank, std::vector())), loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {} //===----------------------------------------------------------------------===//