diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -593,7 +593,5 @@ Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc, Value tensor) { - SmallVector fields; - auto desc = getMutDescriptorFromTensorTuple(tensor, fields); - return desc.getValMemSize(builder, loc); -} \ No newline at end of file + return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc); +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -102,11 +102,9 @@ } /// Gets the dimension size for the given sparse tensor at the given -/// original dimension 'dim'. Returns std::nullopt if no sparse encoding is -/// attached to the given tensor type. -static std::optional -sizeFromTensorAtDim(OpBuilder &builder, Location loc, - const SparseTensorDescriptor &desc, unsigned dim) { +/// original dimension 'dim'. +static Value sizeFromTensorAtDim(OpBuilder &builder, Location loc, + SparseTensorDescriptor desc, unsigned dim) { RankedTensorType rtp = desc.getTensorType(); // Access into static dimension can query original type directly. // Note that this is typically already done by DimOp's folding. @@ -119,17 +117,12 @@ return desc.getDimSize(builder, loc, toStoredDim(rtp, dim)); } -// Gets the dimension size at the given stored dimension 'd', either as a +// Gets the dimension size at the given stored level 'lvl', either as a // constant for a static size, or otherwise dynamically through memSizes. -Value sizeAtStoredDim(OpBuilder &builder, Location loc, - MutSparseTensorDescriptor desc, unsigned d) { - RankedTensorType rtp = desc.getTensorType(); - unsigned dim = toOrigDim(rtp, d); - auto shape = rtp.getShape(); - if (!ShapedType::isDynamic(shape[dim])) - return constantIndex(builder, loc, shape[dim]); - - return desc.getDimSize(builder, loc, d); +static Value sizeFromTensorAtLvl(OpBuilder &builder, Location loc, + SparseTensorDescriptor desc, unsigned lvl) { + return sizeFromTensorAtDim(builder, loc, desc, + toOrigDim(desc.getTensorType(), lvl)); } static void createPushback(OpBuilder &builder, Location loc, @@ -174,7 +167,7 @@ // at this level. We will eventually reach a compressed level or // otherwise the values array for the from-here "all-dense" case. assert(isDenseDim(rtp, r)); - Value size = sizeAtStoredDim(builder, loc, desc, r); + Value size = sizeFromTensorAtLvl(builder, loc, desc, r); linear = builder.create(loc, linear, size); } // Reached values array so prepare for an insertion. @@ -436,7 +429,7 @@ // Construct the new position as: // pos[d] = size * pos[d-1] + i[d] // - Value size = sizeAtStoredDim(builder, loc, desc, d); + Value size = sizeFromTensorAtLvl(builder, loc, desc, d); Value mult = builder.create(loc, size, pos); pos = builder.create(loc, mult, indices[d]); } @@ -517,7 +510,7 @@ /// Generations insertion finalization code. static void genEndInsert(OpBuilder &builder, Location loc, - MutSparseTensorDescriptor desc) { + SparseTensorDescriptor desc) { RankedTensorType rtp = desc.getTensorType(); unsigned rank = rtp.getShape().size(); for (unsigned d = 0; d < rank; d++) { @@ -654,10 +647,7 @@ auto desc = getDescriptorFromTensorTuple(adaptor.getSource()); auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *index); - if (!sz) - return failure(); - - rewriter.replaceOp(op, *sz); + rewriter.replaceOp(op, sz); return success(); } }; @@ -727,8 +717,7 @@ // Replace the sparse tensor deallocation with field deallocations. Location loc = op.getLoc(); - SmallVector fields; - auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); for (auto input : desc.getMemRefFields()) // Deallocate every buffer used to store the sparse tensor handler. rewriter.create(loc, input); @@ -746,8 +735,7 @@ matchAndRewrite(LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Prepare descriptor. - SmallVector fields; - auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); // Generate optional insertion finalization code. if (op.getHasInserts()) genEndInsert(rewriter, op.getLoc(), desc); @@ -780,11 +768,10 @@ // recursively rewrite the new DimOp on the **original** tensor. unsigned innerDim = toOrigDim(srcType, srcType.getRank() - 1); auto sz = sizeFromTensorAtDim(rewriter, loc, desc, innerDim); - assert(sz); // This for sure is a sparse tensor // Generate a memref for `sz` elements of type `t`. auto genAlloc = [&](Type t) { auto memTp = MemRefType::get({ShapedType::kDynamic}, t); - return rewriter.create(loc, memTp, ValueRange{*sz}); + return rewriter.create(loc, memTp, ValueRange{sz}); }; // Allocate temporary buffers for values/filled-switch and added. // We do not use stack buffers for this, since the expanded size may @@ -957,8 +944,7 @@ // Replace the requested pointer access with corresponding field. // The cast_op is inserted by type converter to intermix 1:N type // conversion. - SmallVector fields; - auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); rewriter.replaceOp(op, desc.getAOSMemRef()); return success(); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h @@ -202,20 +202,9 @@ /// field in a consistent way. /// Users should not make assumption on how a sparse tensor is laid out but /// instead relies on this class to access the right value for the right field. -template +template class SparseTensorDescriptorImpl { protected: - // Uses ValueRange for immuatable descriptors; uses SmallVectorImpl & - // for mutable descriptors. - // Using SmallVector for mutable descriptor allows users to reuse it as a tmp - // buffers to append value for some special cases, though users should be - // responsible to restore the buffer to legal states after their use. It is - // probably not a clean way, but it is the most efficient way to avoid copying - // the fields into another SmallVector. If a more clear way is wanted, we - // should change it to MutableArrayRef instead. - using ValueArrayRef = typename std::conditional &, - ValueRange>::type; - SparseTensorDescriptorImpl(Type tp, ValueArrayRef fields) : rType(tp.cast()), fields(fields) { assert(getSparseTensorEncoding(tp) && @@ -223,8 +212,8 @@ fields.size()); // We should make sure the class is trivially copyable (and should be small // enough) such that we can pass it by value. - static_assert( - std::is_trivially_copyable_v>); + static_assert(std::is_trivially_copyable_v< + SparseTensorDescriptorImpl>); } public: @@ -262,12 +251,12 @@ Value getMemRefField(SparseTensorFieldKind kind, std::optional dim) const { - return fields[getMemRefFieldIndex(kind, dim)]; + return getField(getMemRefFieldIndex(kind, dim)); } Value getMemRefField(unsigned fidx) const { assert(fidx < fields.size() - 1); - return fields[fidx]; + return getField(fidx); } Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const { @@ -293,6 +282,31 @@ .getElementType(); } + Value getField(unsigned fidx) const { + assert(fidx < fields.size()); + return fields[fidx]; + } + + ValueRange getMemRefFields() const { + ValueRange ret = fields; + // Drop the last metadata fields. + return ret.slice(0, fields.size() - 1); + } + + std::pair + getIdxMemRefIndexAndStride(unsigned idxDim) const { + StorageLayout layout(getSparseTensorEncoding(rType)); + return layout.getFieldIndexAndStride(SparseTensorFieldKind::IdxMemRef, + idxDim); + } + + Value getAOSMemRef() const { + auto enc = getSparseTensorEncoding(rType); + unsigned cooStart = getCOOStart(enc); + assert(cooStart < enc.getDimLevelType().size()); + return getMemRefField(SparseTensorFieldKind::IdxMemRef, cooStart); + } + RankedTensorType getTensorType() const { return rType; } ValueArrayRef getFields() const { return fields; } @@ -301,25 +315,38 @@ ValueArrayRef fields; }; -class MutSparseTensorDescriptor : public SparseTensorDescriptorImpl { +/// Uses ValueRange for immuatable descriptors; +class SparseTensorDescriptor : public SparseTensorDescriptorImpl { public: - MutSparseTensorDescriptor(Type tp, ValueArrayRef buffers) - : SparseTensorDescriptorImpl(tp, buffers) {} + SparseTensorDescriptor(Type tp, ValueRange buffers) + : SparseTensorDescriptorImpl(tp, buffers) {} - Value getField(unsigned fidx) const { - assert(fidx < fields.size()); - return fields[fidx]; - } + Value getIdxMemRefOrView(OpBuilder &builder, Location loc, + unsigned idxDim) const; +}; - ValueRange getMemRefFields() const { - ValueRange ret = fields; - // Drop the last metadata fields. - return ret.slice(0, fields.size() - 1); +/// Uses SmallVectorImpl & for mutable descriptors. +/// Using SmallVector for mutable descriptor allows users to reuse it as a +/// tmp buffers to append value for some special cases, though users should +/// be responsible to restore the buffer to legal states after their use. It +/// is probably not a clean way, but it is the most efficient way to avoid +/// copying the fields into another SmallVector. If a more clear way is +/// wanted, we should change it to MutableArrayRef instead. +class MutSparseTensorDescriptor + : public SparseTensorDescriptorImpl &> { +public: + MutSparseTensorDescriptor(Type tp, SmallVectorImpl &buffers) + : SparseTensorDescriptorImpl &>(tp, buffers) {} + + // Allow implicit type conversion from mutable descriptors to immutable ones + // (but not vice versa). + /*implicit*/ operator SparseTensorDescriptor() const { + return SparseTensorDescriptor(rType, fields); } /// - /// Setters: update the value for required field (only enabled for - /// MutSparseTensorDescriptor). + /// Adds additional setters for mutable descriptor, update the value for + /// required field. /// void setMemRefField(SparseTensorFieldKind kind, std::optional dim, @@ -348,29 +375,6 @@ void setDimSize(OpBuilder &builder, Location loc, unsigned dim, Value v) { setSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim, v); } - - std::pair - getIdxMemRefIndexAndStride(unsigned idxDim) const { - StorageLayout layout(getSparseTensorEncoding(rType)); - return layout.getFieldIndexAndStride(SparseTensorFieldKind::IdxMemRef, - idxDim); - } - - Value getAOSMemRef() const { - auto enc = getSparseTensorEncoding(rType); - unsigned cooStart = getCOOStart(enc); - assert(cooStart < enc.getDimLevelType().size()); - return getMemRefField(SparseTensorFieldKind::IdxMemRef, cooStart); - } -}; - -class SparseTensorDescriptor : public SparseTensorDescriptorImpl { -public: - SparseTensorDescriptor(Type tp, ValueArrayRef buffers) - : SparseTensorDescriptorImpl(tp, buffers) {} - - Value getIdxMemRefOrView(OpBuilder &builder, Location loc, - unsigned idxDim) const; }; /// Returns the "tuple" value of the adapted tensor. @@ -386,7 +390,7 @@ } inline Value genTuple(OpBuilder &builder, Location loc, - MutSparseTensorDescriptor desc) { + SparseTensorDescriptor desc) { return genTuple(builder, loc, desc.getTensorType(), desc.getFields()); }