diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h @@ -89,9 +89,11 @@ TensorLevel makeTensorLevel(TensorId t, Level l) const { // Make sure LoopEmitter, GenericOp, and Merger agree on the number of - // tensors. Merger has one more synthetic tensor for loop invariants. - assert(loopEmitter.getNumTensors() == linalgOp->getNumOperands() && - loopEmitter.getNumTensors() == latticeMerger.getNumTensors() - 1); + // tensors. + assert(loopEmitter.getNumManifestTensors() == linalgOp->getNumOperands() && + loopEmitter.getNumTensors() == latticeMerger.getNumTensors() && + loopEmitter.getOutTensorId() == latticeMerger.getOutTensorID() && + loopEmitter.getSynTensorId() == latticeMerger.getSynTensorID()); return loopEmitter.makeTensorLevel(t, l); } std::pair unpackTensorLevel(TensorLevel tl) const { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h @@ -192,20 +192,32 @@ } /// Gets the total number of tensors that loopEmitter is operating on. - unsigned getNumTensors() const { return tensors.size(); } + unsigned getNumManifestTensors() const { return tensors.size(); } + + /// Gets the total number of tensors that loopEmitter is operating on. + unsigned getNumTensors() const { + // Manifest tensors with one synthetic tensors at the end. + return getNumManifestTensors() + 1; + } /// Gets the TensorId for synthetic tensor. TensorId getSynTensorId() const { return tensors.size(); } + /// Gets the TensorId for output tensor. + TensorId getOutTensorId() const { + assert(hasOutput); + return getNumManifestTensors() - 1; + } + /// Compresses a TensorId and Level into a TensorLevel. TensorLevel makeTensorLevel(TensorId t, Level l) const { // TODO: getNumTensor() should include synthetic tensor. - return l * (getNumTensors() + 1) + t; + return l * (getNumTensors()) + t; } /// De-compresses a TensorLevel back to a pair of TensorId and Level. std::pair unpackTensorLevel(TensorLevel tidLvl) const { - unsigned nt = getNumTensors() + 1; + unsigned nt = getNumTensors(); return std::make_pair(tidLvl % nt, tidLvl / nt); } @@ -323,10 +335,10 @@ Location loc, Value crd, TensorId tid, Level lvl); - bool isSynTensor(TensorId tid) const { return tid == getNumTensors(); } + bool isSynTensor(TensorId tid) const { return tid == getSynTensorId(); } bool isOutputTensor(TensorId tid) const { - return hasOutput && tid == getNumTensors() - 1; + return hasOutput && tid == getOutTensorId(); } bool isSparseOutput(TensorId tid) const { @@ -414,8 +426,8 @@ /// TODO: why not do this computation when we first store the reassoc, /// instead of doing it every time we look it up? SmallVector getCollapseReassociation(TensorId tid, Level dstLvl) { - assert(tid < getNumTensors() + 1 && "Invalid TensorId"); - assert(collapseReassoc.size() == getNumTensors() + 1); + assert(tid < getNumTensors() && "Invalid TensorId"); + assert(collapseReassoc.size() == getNumTensors()); if (const auto reassoc = collapseReassoc[tid]) { assert(!isSynTensor(tid) && !isOutputTensor(tid) && "Output/Synthetic tensor should not have reassociation"); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -235,8 +235,9 @@ const unsigned numManifestTensors = ts.size(); const unsigned synTensorId = numManifestTensors; const unsigned numTensors = numManifestTensors + 1; - + // tensors array (len == numManifestTensor). this->tensors.assign(ts.begin(), ts.end()); + // Arrays with len == numTensor. this->lvlTypes.assign(numTensors, std::vector()); this->lvlSizes.assign(numTensors, std::vector()); this->highs.assign(numTensors, std::vector()); @@ -355,13 +356,14 @@ void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc, LoopEmitter::OutputUpdater updater) { - // For every tensor: + // For every manifest tensor: // * get the values buffer. // * For every level: // * get the positions and coordinates buffers // * get/compute the level-size, which is also used as the upper-bound // on positions. - for (TensorId t = 0, numTensors = getNumTensors(); t < numTensors; t++) { + for (TensorId t = 0, numTensors = getNumManifestTensors(); t < numTensors; + t++) { const Value tensor = tensors[t]; const auto rtp = dyn_cast(tensor.getType()); if (!rtp)