diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h @@ -78,6 +78,9 @@ /// initializing the loop emitter (e.g., to fill a dense output with zeros). using OutputUpdater = function_ref; + using SynTensorBoundSetter = + function_ref; + // Map from [tid, dim] to a list of dependent [tid, dim] for affine expression // index on sparse tensors. // E.g., for affine index (d0 + d1), it depends on two [tid, dim] that defines @@ -114,7 +117,8 @@ /// Starts a loop emitting session by generating all the buffers needed /// for iterating over the tensors. void initializeLoopEmit(OpBuilder &builder, Location loc, - OutputUpdater updater = nullptr); + OutputUpdater updater = nullptr, + SynTensorBoundSetter synSetter = nullptr); /// Generates code to compute an affine expression whose variables are /// `LoopId`s (i.e., `a.cast().getPosition()` is a valid diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -376,8 +376,15 @@ loopIdToOrd[topSort[n]] = n; } -void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc, - LoopEmitter::OutputUpdater updater) { +void LoopEmitter::initializeLoopEmit( + OpBuilder &builder, Location loc, LoopEmitter::OutputUpdater updater, + LoopEmitter::SynTensorBoundSetter synSetter) { + + // For every synthetic tensor, set the high bound by calling the callback. + if (synSetter) + for (unsigned i = 0, e = highs[getSynTensorId()].size(); i < e; i++) + highs[getSynTensorId()][i] = synSetter(builder, loc, i); + // For every manifest tensor: // * get the values buffer. // * For every level: @@ -534,27 +541,17 @@ // Prepares for all the tensors used in the current loop sequence. std::vector> slicedTids; - bool hasSynTensor = false; - std::optional> loopBoundDefLevel = std::nullopt; for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) { if (!dependentLvlMap[tid][lvl].empty()) { bool fullyRed = genSliceBegin(builder, loc, tid, lvl); slicedTids.emplace_back(tid, lvl, fullyRed); } else { - if (isSynTensor(tid)) { - hasSynTensor = true; - } else { - loopBoundDefLevel = std::make_pair(tid, lvl); + if (!isSynTensor(tid)) { prepareLoopOverTensorAtLvl(builder, loc, tid, lvl); } } } - if (hasSynTensor && loopBoundDefLevel.has_value()) { - // TODO: compute the loopBound for index reduction by d - sum(unres_lvls). - highs[getSynTensorId()][getCurrentDepth()] = - lvlSizes[loopBoundDefLevel->first][loopBoundDefLevel->second]; - } // Universal Index starts from 0. loopSeqStack.emplace_back(C_IDX(0), std::move(slicedTids)); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -832,6 +832,21 @@ Location loc = op.getLoc(); assert(op.getNumOperands() == op.getNumDpsInputs() + 1); + SmallVector loopRange = + llvm::cast(op.getOperation()) + .createLoopRanges(builder, loc); + + assert(loopRange.size() == env.merger().getStartingFilterLoopId()); + SmallVector sortedRange; + for (unsigned i = 0, e = env.topSortSize(); i < e; i++) { + LoopId ldx = env.topSortAt(i); + // FIXME: Gets rid of filter loops since we have a better algorithm to deal + // with affine index expression. + if (ldx < env.merger().getStartingFilterLoopId()) { + sortedRange.push_back(loopRange[ldx]); + } + } + env.emitter().initializeLoopEmit( builder, loc, /// Generates buffer for the output tensor. @@ -865,6 +880,16 @@ ValueRange{init}); } return init; + }, + [&sortedRange, &env](OpBuilder &b, Location loc, Level l) { + assert(l < env.topSortSize()); + // FIXME: Remove filter loop since we have a better algorithm to + // deal with affine index expression. + if (l >= env.merger().getStartingFilterLoopId()) + return Value(); + + return mlir::getValueOrCreateConstantIndexOp(b, loc, + sortedRange[l].size); }); } @@ -1587,7 +1612,9 @@ // iterate based on the level of output tensor. E.g., this // could be a synthetic tensor (for invariants and sparse // output tensor). - if (env.isReduc() && env.merger().getSynTensorID() == tid) { + auto itType = env.op().getIteratorTypesArray()[ldx]; + if (linalg::isReductionIterator(itType) && + env.merger().getSynTensorID() == tid) { // Coiterating with an invariant, and this is a reduction loop // e.g., out = prod(in[i][j] op invariant); // In this case, we can not infer the loop bound from output @@ -1662,7 +1689,14 @@ tidLvls.push_back(env.makeTensorLevel(outTid, *outLvl)); } - assert(numloopCond > 0); + if (numloopCond == 0) { + // Corner cases where the loop bound is defined by a *unused* operand, in + // this case, we just generate a dense "fake" loop by iterating over the + // synthetic tensor. + tidLvls.push_back(env.makeTensorLevel(env.merger().getSynTensorID(), + env.emitter().getCurrentDepth())); + numloopCond++; + } // If we just need to one loop conditions and the conditions is not imposed on // non-unique level, the loop can be generated by a for loop. return numloopCond == 1 && !hasNonUnique;