diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -434,6 +434,8 @@ return hasOutput && tid == tensors.size() - 1; } + bool isSparseOutput(size_t tid) { return isOutputTensor(tid) && isSparseOut; } + /// Setups [lo, hi] for iterating tensor[dim], it assumes that tensor[0 /// ...dims-1] has already been setup. void prepareLoopOverTensorAtDim(OpBuilder &builder, Location loc, size_t tid, @@ -462,6 +464,7 @@ // Whether the loop emitter needs to treat the last tensor as the output // tensor. bool hasOutput; + bool isSparseOut; /// Input and (optional) output tensors. std::vector tensors; /// The dim type array for each tensor. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -97,10 +97,11 @@ SparseTensorLoopEmitter::SparseTensorLoopEmitter(ValueRange tensors, bool hasOutput, bool isSparseOut) - : hasOutput(hasOutput), tensors(tensors.begin(), tensors.end()), - dimTypes(tensors.size()), pidxs(tensors.size()), coord(tensors.size()), - highs(tensors.size()), ptrBuffer(tensors.size()), - idxBuffer(tensors.size()), valBuffer(tensors.size()), loopStack() { + : hasOutput(hasOutput), isSparseOut(isSparseOut), + tensors(tensors.begin(), tensors.end()), dimTypes(tensors.size()), + pidxs(tensors.size()), coord(tensors.size()), highs(tensors.size()), + ptrBuffer(tensors.size()), idxBuffer(tensors.size()), + valBuffer(tensors.size()), loopStack() { for (size_t tid = 0, e = tensors.size(); tid < e; tid++) { auto t = tensors[tid]; // a scalar or 0-dimension tensors @@ -246,7 +247,7 @@ coord[tid][dim] = iv; // generate pidx for dense dim (pidx = i * sz + j) auto enc = getSparseTensorEncoding(tensors[tid].getType()); - if (enc) + if (enc && !isSparseOutput(tid)) pidxs[tid][dim] = genAddress(builder, loc, tid, dim, iv); } @@ -353,7 +354,7 @@ pidxs[tid][dim] = min; // generate pidx for dense dim (pidx = i * sz + j) auto enc = getSparseTensorEncoding(tensors[tid].getType()); - if (enc) + if (enc && !isSparseOutput(tid)) pidxs[tid][dim] = genAddress(builder, loc, tid, dim, min); } // NOTE: we can also prepares for next dim here in advance @@ -419,7 +420,7 @@ for (auto [tid, dim] : llvm::zip(tids, dims)) { assert(isDenseDLT(dimTypes[tid][dim])); auto enc = getSparseTensorEncoding(tensors[tid].getType()); - if (enc) { + if (enc && !isSparseOutput(tid)) { bool validPidx = dim == 0 || pidxs[tid][dim - 1]; if (!validPidx) { // We might not find the pidx for the sparse output tensor as it is diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1130,13 +1130,13 @@ assert(all.test(b)); assert(merger.index(b) == idx); if (isUndefDLT(merger.getDimLevelType(b))) { - // This could be a synthetic tensor (for invariants and sparse output - // tensor). - // In both cases, we mean to generate loops over output tensor. - // e.g., - // out[i][j] = invariant; - if (merger.getSynTensorID() == tid) - tid = merger.getOutTensorID(); + // An undefined dlt in the lattices, we probably mean to iterate based + // on the dim of output tensor. + // E.g., this could be a synthetic tensor (for invariants and sparse + // output tensor). + // out[i][j] = invariant; or a broadcast + // out[i][j] = int[i] + tid = merger.getOutTensorID(); } auto dim = codegen.loopIdxToDim[tid][idx]; if (dim != INVALID_ID) {