diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h @@ -194,14 +194,17 @@ /// Gets the total number of tensors that loopEmitter is operating on. unsigned getNumTensors() const { return tensors.size(); } + /// Gets the TensorId for synthetic tensor. + TensorId getSynTensorId() const { return tensors.size(); } + /// Compresses a TensorId and Level into a TensorLevel. TensorLevel makeTensorLevel(TensorId t, Level l) const { - return l * getNumTensors() + t; + return l * (getNumTensors() + 1) + t; } /// De-compresses a TensorLevel back to a pair of TensorId and Level. std::pair unpackTensorLevel(TensorLevel tidLvl) const { - unsigned nt = getNumTensors(); + unsigned nt = getNumTensors() + 1; return std::make_pair(tidLvl % nt, tidLvl / nt); } @@ -319,6 +322,8 @@ Location loc, Value crd, TensorId tid, Level lvl); + bool isSynTensor(TensorId tid) const { return tid == getNumTensors(); } + bool isOutputTensor(TensorId tid) const { return hasOutput && tid == getNumTensors() - 1; } @@ -408,9 +413,11 @@ /// TODO: why not do this computation when we first store the reassoc, /// instead of doing it every time we look it up? SmallVector getCollapseReassociation(TensorId tid, Level dstLvl) { - assert(tid < getNumTensors() && "Invalid TensorId"); - assert(collapseReassoc.size() == getNumTensors()); + assert(tid < getNumTensors() + 1 && "Invalid TensorId"); + assert(collapseReassoc.size() == getNumTensors() + 1); if (const auto reassoc = collapseReassoc[tid]) { + assert(!isSynTensor(tid) && !isOutputTensor(tid) && + "Output/Synthetic tensor should not have reassociation"); // TODO: store the dstLvlRank in the LoopEmitter so that we can // check `dstLvl < dstLvlRank` at the top; and only here need to // assert that `reassoc.size() == dstLvlRank`. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -232,7 +232,10 @@ this->hasOutput = hasOutput; this->isSparseOut = isSparseOut; - const unsigned numTensors = ts.size(); + const unsigned numRealTensor = ts.size(); + const unsigned synTensorId = numRealTensor; + const unsigned numTensors = numRealTensor + 1; + this->tensors.assign(ts.begin(), ts.end()); this->lvlTypes.assign(numTensors, std::vector()); this->lvlSizes.assign(numTensors, std::vector()); @@ -265,33 +268,43 @@ // Initialize nested types of `TensorId`-indexed fields. for (TensorId tid = 0; tid < numTensors; tid++) { - const Value t = tensors[tid]; - // a scalar or 0-dimension tensors - if (isZeroRankedTensorOrScalar(t.getType())) - continue; - - auto rtp = getRankedTensorType(t); - if (auto reshape = t.getDefiningOp(); - isUniqueCOOType(rtp) && reshape) { - // TODO: Supports more kinds of sparse tensors. - // FIXME: We should instead lower reshape operations on sparse tensors to - // view change. - collapseReassoc[tid] = reshape.getReassociation(); - rtp = reshape.getSrcType(); - // Overwrites the tensor to the source tensor of reshape operations. - tensors[tid] = reshape.getSrc(); - } - const SparseTensorType stt(rtp); - const Level lvlRank = stt.getLvlRank(); - // We always treat sparse output tensor as dense so that we always iterate - // it based on lvl size. - if (stt.hasEncoding() && !(isOutputTensor(tid) && isSparseOut)) { - const auto enc = stt.getEncoding(); - isSparseSlices[tid] = enc.isSlice(); - for (auto lvlTp : enc.getLvlTypes()) - lvlTypes[tid].push_back(lvlTp); - } else { + Level lvlRank; + if (tid == synTensorId) { + // Synthetic tensor (conceptually) is an all-dense tensor with rank equal + // to the total number of loops (each level can potentially be mapped to + // one of the loop being generated). + lvlRank = numLoops; lvlTypes[tid].assign(lvlRank, DimLevelType::Dense); + } else { + const Value t = tensors[tid]; + // a scalar or 0-dimension tensors + if (isZeroRankedTensorOrScalar(t.getType())) + continue; + + auto rtp = getRankedTensorType(t); + if (auto reshape = t.getDefiningOp(); + isUniqueCOOType(rtp) && reshape) { + // TODO: Supports more kinds of sparse tensors. + // FIXME: We should instead lower reshape operations on sparse tensors + // to view change. + collapseReassoc[tid] = reshape.getReassociation(); + rtp = reshape.getSrcType(); + // Overwrites the tensor to the source tensor of reshape operations. + tensors[tid] = reshape.getSrc(); + } + const SparseTensorType stt(rtp); + lvlRank = stt.getLvlRank(); + + // We always treat sparse output tensor as dense so that we always iterate + // it based on lvl size. + if (stt.hasEncoding() && !(isOutputTensor(tid) && isSparseOut)) { + const auto enc = stt.getEncoding(); + isSparseSlices[tid] = enc.isSlice(); + for (auto lvlTp : enc.getLvlTypes()) + lvlTypes[tid].push_back(lvlTp); + } else { + lvlTypes[tid].assign(lvlRank, DimLevelType::Dense); + } } // Initialize using empty value. @@ -314,7 +327,7 @@ sliceStack[tid].emplace_back(/*minCrd=*/Value(), /*offset=*/Value(), /*isNonEmpty*/ Value(), std::nullopt, 0); - if (dimGetter) { + if (dimGetter && !isSynTensor(tid)) { auto reassoc = collapseReassoc[tid]; Level dstRank = reassoc ? reassoc.size() : lvlRank; for (Level l = 0; l < dstRank; l++) { @@ -461,15 +474,28 @@ assert(loopSeqStack.size() == loopStack.size()); // Prepares for all the tensors used in the current loop sequence. std::vector> slicedTids; + + bool hasSynTensor = false; + std::optional> loopBoundDefLevel = std::nullopt; for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) { if (!dependentLvlMap[tid][lvl].empty()) { bool fullyRed = genSliceBegin(builder, loc, tid, lvl); slicedTids.emplace_back(tid, lvl, fullyRed); } else { - prepareLoopOverTensorAtLvl(builder, loc, tid, lvl); + if (isSynTensor(tid)) { + hasSynTensor = true; + } else { + loopBoundDefLevel = std::make_pair(tid, lvl); + prepareLoopOverTensorAtLvl(builder, loc, tid, lvl); + } } } + if (hasSynTensor && loopBoundDefLevel.has_value()) { + // TODO: compute the loopBound for index reduction by d - sum(unres_lvls). + highs[getSynTensorId()][getCurrentDepth()] = + lvlSizes[loopBoundDefLevel->first][loopBoundDefLevel->second]; + } // Universal Index starts from 0. loopSeqStack.emplace_back(C_IDX(0), std::move(slicedTids)); } @@ -1137,6 +1163,9 @@ // output tensor unconditionally, since they may not appear in the lattice, // but may be needed for linearized codegen. for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) { + if (isSynTensor(tid)) + continue; + if (isDenseDLT(lvlTypes[tid][lvl])) { // Slice-driven dense level should have be handled already. if (!dependentLvlMap[tid][lvl].empty()) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1488,8 +1488,15 @@ std::optional lvl, DimLevelType dlt, bool isIdxReduc) { assert(env.merger().loop(b) == idx); - if (isDenseDLT(dlt) || isUndefDLT(dlt)) + if (isDenseDLT(dlt) || isUndefDLT(dlt)) { + if (tid == env.merger().getSynTensorID()) { + // Needs loop emitter to set up loop bounds for synthetic tensor too if + // there is a loop condition imposed on the synthetic tensor. + tidLvls.push_back( + env.makeTensorLevel(tid, env.emitter().getCurrentDepth())); + } needsUniv = true; + } if (isCompressedDLT(dlt) || isSingletonDLT(dlt) || isCompressedWithHiDLT(dlt) || isIdxReduc) { // Only when this is a index reduction loop, can the dlt be undefined. @@ -1573,13 +1580,24 @@ // iterate based on the level of output tensor. E.g., this // could be a synthetic tensor (for invariants and sparse // output tensor). - // out[i][j] = invariant; or a broadcast - // out[i][j] = in[i] (j is undef for input) - tid = outTid; - lvl = outLvl; - // Skips invalid lvl (e.g., when this is a zero ranked tensor). - if (!lvl) - return; + if (env.isReduc() && env.merger().getSynTensorID() == tid) { + // Coiterating with an invariant, and this is a reduction loop + // e.g., out = prod(in[i][j] op invariant); + // In this case, we can not infer the loop bound from output + // (whose level is reduced). Instead we use the synthetic tensor + // to infer the bound. + // The level of the synthetic tensor is the current loop depth; + // the rank of the synthetic tensor equals to number of loops. + lvl = env.emitter().getCurrentDepth(); + } else { + // or a broadcast + // out[i][j] = in[i] (j is undef for input) + tid = outTid; + lvl = outLvl; + // Skips invalid lvl (e.g., when this is a zero ranked tensor). + if (!lvl) + return; + } } hasNonUnique = !isUniqueDLT(dlt) || hasNonUnique; tidLvls.push_back(env.makeTensorLevel(tid, *lvl)); @@ -1669,7 +1687,8 @@ auto allTidLvls = llvm::concat(tidLvls, llvm::make_first_range(affineTidLvls)); for (auto [tid, lvl] : env.unpackTensorLevelRange(allTidLvls)) { - if (tid != env.merger().getOutTensorID()) + if (tid != env.merger().getOutTensorID() && + tid != env.merger().getSynTensorID()) genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1); } @@ -1796,7 +1815,7 @@ } else { // To rematerialize an non-annotated tensor, simply load it // from the bufferized value. - Value val = env.emitter().getValBuffer().back(); // value array + Value val = env.emitter().getValBuffer()[env.merger().getOutTensorID()]; rewriter.replaceOpWithNewOp(op, resType, val); } } diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions_prod.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions_prod.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions_prod.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions_prod.mlir @@ -140,7 +140,9 @@ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 4.0 ]> : tensor<32xf32> - // Convert constants to annotated tensors. + // Convert constants to annotated tensors. Note that this + // particular conversion only stores nonzero elements, + // so we will have no explicit zeros, only implicit zeros. %d0_i32 = sparse_tensor.convert %c_0_i32 : tensor<32xi32> to tensor<32xi32, #DV> %d0_f32 = sparse_tensor.convert %c_0_f32 @@ -158,6 +160,10 @@ %s1_f32 = sparse_tensor.convert %c_1_f32 : tensor<32xf32> to tensor<32xf32, #SV> + // Special case, construct a sparse vector with an explicit zero. + %v0 = arith.constant sparse< [ [1] ], [ 0 ] > : tensor<32xi32> + %s0 = sparse_tensor.convert %v0: tensor<32xi32> to tensor<32xi32, #SV> + // Call the kernels. %0 = call @prod_dreduction_i32(%d0_i32, %ri) : (tensor<32xi32, #DV>, tensor) -> tensor %1 = call @prod_dreduction_f32(%d0_f32, %rf) : (tensor<32xf32, #DV>, tensor) -> tensor @@ -167,19 +173,23 @@ %5 = call @prod_dreduction_f32(%d1_f32, %rf) : (tensor<32xf32, #DV>, tensor) -> tensor %6 = call @prod_sreduction_i32(%s1_i32, %ri) : (tensor<32xi32, #SV>, tensor) -> tensor %7 = call @prod_sreduction_f32(%s1_f32, %rf) : (tensor<32xf32, #SV>, tensor) -> tensor + %8 = call @prod_sreduction_i32(%s0, %ri) : (tensor<32xi32, #SV>, tensor) -> tensor // Verify results. Note that the custom reduction gave permission // to treat an explicit vs implicit zero differently to compute the - // full product reduction. A "standard" product reduction would - // have to return 0 for any implicit zero occurrence too. + // full product reduction over stored elements. A "standard" product + // reduction would have to return 0 for any implicit zero occurrence + // too. An explicit zero nullifies the product, though, as requested. // // CHECK: 0 + // CHECK: 0 // CHECK: 3087 // CHECK: 14 // CHECK: 3087 // CHECK: 168 // CHECK: 3087 // CHECK: 168 + // CHECK: 0 // call @dump_i32(%0) : (tensor) -> () call @dump_f32(%1) : (tensor) -> () @@ -189,6 +199,7 @@ call @dump_f32(%5) : (tensor) -> () call @dump_i32(%6) : (tensor) -> () call @dump_f32(%7) : (tensor) -> () + call @dump_i32(%8) : (tensor) -> () // Release the resources. bufferization.dealloc_tensor %d0_i32 : tensor<32xi32, #DV> @@ -199,6 +210,7 @@ bufferization.dealloc_tensor %d1_f32 : tensor<32xf32, #DV> bufferization.dealloc_tensor %s1_i32 : tensor<32xi32, #SV> bufferization.dealloc_tensor %s1_f32 : tensor<32xf32, #SV> + bufferization.dealloc_tensor %s0 : tensor<32xi32, #SV> return }