diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -217,6 +217,20 @@ Location loc, RewriterBase &rewriter, SparseElementsAttr attr, function_ref, Value)> callback); +/// Converts the vector indices and store it into the memory pointed by +/// `ind`, apply (optional) `offset` on `offsetDim`. +void storeIndices(OpBuilder &builder, Location loc, unsigned rank, Value ind, + ValueRange ivs, unsigned offsetDim = 0, + Value offset = Value()); + +/// Reshapes the linear values buffer for an annotated all dense sparse tensor +/// to match the shape of the corresponding dense tensor to support direct +/// access of the buffer through indices. +Value reshapeValuesToLevels(OpBuilder &builder, Location loc, + SparseTensorEncodingAttr enc, + const SmallVectorImpl &dimSizes, + Value valuesBuffer, Value idxBuffer); + //===----------------------------------------------------------------------===// // Inlined constant generators. // diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -526,6 +526,38 @@ } } +void sparse_tensor::storeIndices(OpBuilder &builder, Location loc, + unsigned rank, Value ind, ValueRange ivs, + unsigned offsetDim, Value offset) { + for (unsigned i = 0; i < rank; i++) { + Value idx = ivs[i]; + if (offsetDim == i && offset) + idx = builder.create(loc, idx, offset); + builder.create(loc, idx, ind, + constantIndex(builder, loc, i)); + } +} + +Value sparse_tensor::reshapeValuesToLevels( + OpBuilder &builder, Location loc, SparseTensorEncodingAttr enc, + const SmallVectorImpl &dimSizes, Value valuesBuffer, + Value idxBuffer) { + // Use the dstIdx to store the level sizes. + unsigned rank = enc.getDimLevelType().size(); + SmallVector lvlSizes; + for (unsigned i = 0; i < dimSizes.size(); i++) + lvlSizes.push_back(dimSizes[toOrigDim(enc, i)]); + storeIndices(builder, loc, rank, idxBuffer, lvlSizes); + // The memref ReshapeOp requires the sizes buffer to have a static + // shape. + idxBuffer = builder.create( + loc, MemRefType::get({rank}, builder.getIndexType()), idxBuffer); + SmallVector shape(rank, ShapedType::kDynamic); + Type elemTp = valuesBuffer.getType().cast().getElementType(); + return builder.create(loc, MemRefType::get(shape, elemTp), + valuesBuffer, idxBuffer); +} + Value sparse_tensor::genToPointers(OpBuilder &builder, Location loc, Value tensor, uint64_t d) { RankedTensorType srcTp = tensor.getType().cast(); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -428,20 +428,6 @@ return ivs; } -/// Converts the vector indices and store it into the memory pointed by -/// `ind`, apply (optional) `offset` on `offsetDim`. -static void storeIndices(OpBuilder &builder, Location loc, unsigned rank, - Value ind, ValueRange ivs, unsigned offsetDim = 0, - Value offset = Value()) { - for (unsigned i = 0; i < rank; i++) { - Value idx = ivs[i]; - if (offsetDim == i && offset) - idx = builder.create(loc, idx, offset); - builder.create(loc, idx, ind, - constantIndex(builder, loc, i)); - } -} - /// Inserts a value stored in `elemPtr` into a dense tensor created by /// allocDenseTensor(). static void insertScalarIntoDenseTensor(OpBuilder &builder, Location loc, @@ -1375,19 +1361,8 @@ dst = genValuesCall(rewriter, loc, MemRefType::get({ShapedType::kDynamic}, elemTp), {dst}); - // Use the dstIdx to store the level sizes. - SmallVector lvlSizes; - for (unsigned i = 0; i < sizes.size(); i++) - lvlSizes.push_back(sizes[toOrigDim(encDst, i)]); - storeIndices(rewriter, loc, rank, dstIdx, lvlSizes); - // The memref ReshapeOp requires the sizes buffer to have a static - // shape. - Value typedBuffer = rewriter.create( - loc, MemRefType::get({rank}, rewriter.getIndexType()), dstIdx); - SmallVector shape(rank, ShapedType::kDynamic); - dst = rewriter.create( - loc, MemRefType::get(shape, elemTp), dst, typedBuffer); + dst = reshapeValuesToLevels(rewriter, loc, encDst, sizes, dst, dstIdx); } else { dstPerm = params.getDim2LvlMap(); elemPtr = genAllocaScalar(rewriter, loc, elemTp);