diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -593,7 +593,5 @@ Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc, Value tensor) { - SmallVector fields; - auto desc = getMutDescriptorFromTensorTuple(tensor, fields); - return desc.getValMemSize(builder, loc); -} \ No newline at end of file + return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc); +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -517,7 +517,7 @@ /// Generations insertion finalization code. static void genEndInsert(OpBuilder &builder, Location loc, - MutSparseTensorDescriptor desc) { + SparseTensorDescriptor desc) { RankedTensorType rtp = desc.getTensorType(); unsigned rank = rtp.getShape().size(); for (unsigned d = 0; d < rank; d++) { @@ -727,8 +727,7 @@ // Replace the sparse tensor deallocation with field deallocations. Location loc = op.getLoc(); - SmallVector fields; - auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); for (auto input : desc.getMemRefFields()) // Deallocate every buffer used to store the sparse tensor handler. rewriter.create(loc, input); @@ -746,8 +745,7 @@ matchAndRewrite(LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Prepare descriptor. - SmallVector fields; - auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); // Generate optional insertion finalization code. if (op.getHasInserts()) genEndInsert(rewriter, op.getLoc(), desc); @@ -957,8 +955,7 @@ // Replace the requested pointer access with corresponding field. // The cast_op is inserted by type converter to intermix 1:N type // conversion. - SmallVector fields; - auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); rewriter.replaceOp(op, desc.getAOSMemRef()); return success(); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h @@ -202,20 +202,9 @@ /// field in a consistent way. /// Users should not make assumption on how a sparse tensor is laid out but /// instead relies on this class to access the right value for the right field. -template +template class SparseTensorDescriptorImpl { protected: - // Uses ValueRange for immuatable descriptors; uses SmallVectorImpl & - // for mutable descriptors. - // Using SmallVector for mutable descriptor allows users to reuse it as a tmp - // buffers to append value for some special cases, though users should be - // responsible to restore the buffer to legal states after their use. It is - // probably not a clean way, but it is the most efficient way to avoid copying - // the fields into another SmallVector. If a more clear way is wanted, we - // should change it to MutableArrayRef instead. - using ValueArrayRef = typename std::conditional &, - ValueRange>::type; - SparseTensorDescriptorImpl(Type tp, ValueArrayRef fields) : rType(tp.cast()), fields(fields) { assert(getSparseTensorEncoding(tp) && @@ -223,8 +212,8 @@ fields.size()); // We should make sure the class is trivially copyable (and should be small // enough) such that we can pass it by value. - static_assert( - std::is_trivially_copyable_v>); + static_assert(std::is_trivially_copyable_v< + SparseTensorDescriptorImpl>); } public: @@ -262,12 +251,12 @@ Value getMemRefField(SparseTensorFieldKind kind, std::optional dim) const { - return fields[getMemRefFieldIndex(kind, dim)]; + return getField(getMemRefFieldIndex(kind, dim)); } Value getMemRefField(unsigned fidx) const { assert(fidx < fields.size() - 1); - return fields[fidx]; + return getField(fidx); } Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const { @@ -293,6 +282,31 @@ .getElementType(); } + Value getField(unsigned fidx) const { + assert(fidx < fields.size()); + return fields[fidx]; + } + + ValueRange getMemRefFields() const { + ValueRange ret = fields; + // Drop the last metadata fields. + return ret.slice(0, fields.size() - 1); + } + + std::pair + getIdxMemRefIndexAndStride(unsigned idxDim) const { + StorageLayout layout(getSparseTensorEncoding(rType)); + return layout.getFieldIndexAndStride(SparseTensorFieldKind::IdxMemRef, + idxDim); + } + + Value getAOSMemRef() const { + auto enc = getSparseTensorEncoding(rType); + unsigned cooStart = getCOOStart(enc); + assert(cooStart < enc.getDimLevelType().size()); + return getMemRefField(SparseTensorFieldKind::IdxMemRef, cooStart); + } + RankedTensorType getTensorType() const { return rType; } ValueArrayRef getFields() const { return fields; } @@ -301,25 +315,32 @@ ValueArrayRef fields; }; -class MutSparseTensorDescriptor : public SparseTensorDescriptorImpl { +/// Uses ValueRange for immuatable descriptors; +class SparseTensorDescriptor : public SparseTensorDescriptorImpl { public: - MutSparseTensorDescriptor(Type tp, ValueArrayRef buffers) - : SparseTensorDescriptorImpl(tp, buffers) {} + SparseTensorDescriptor(Type tp, ValueRange buffers) + : SparseTensorDescriptorImpl(tp, buffers) {} - Value getField(unsigned fidx) const { - assert(fidx < fields.size()); - return fields[fidx]; - } + Value getIdxMemRefOrView(OpBuilder &builder, Location loc, + unsigned idxDim) const; +}; - ValueRange getMemRefFields() const { - ValueRange ret = fields; - // Drop the last metadata fields. - return ret.slice(0, fields.size() - 1); - } +/// Uses SmallVectorImpl & for mutable descriptors. +/// Using SmallVector for mutable descriptor allows users to reuse it as a +/// tmp buffers to append value for some special cases, though users should +/// be responsible to restore the buffer to legal states after their use. It +/// is probably not a clean way, but it is the most efficient way to avoid +/// copying the fields into another SmallVector. If a more clear way is +/// wanted, we should change it to MutableArrayRef instead. +class MutSparseTensorDescriptor + : public SparseTensorDescriptorImpl &> { +public: + MutSparseTensorDescriptor(Type tp, SmallVectorImpl &buffers) + : SparseTensorDescriptorImpl &>(tp, buffers) {} /// - /// Setters: update the value for required field (only enabled for - /// MutSparseTensorDescriptor). + /// Adds additional setters for mutable descriptor, update the value for + /// required field. /// void setMemRefField(SparseTensorFieldKind kind, std::optional dim, @@ -348,29 +369,6 @@ void setDimSize(OpBuilder &builder, Location loc, unsigned dim, Value v) { setSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim, v); } - - std::pair - getIdxMemRefIndexAndStride(unsigned idxDim) const { - StorageLayout layout(getSparseTensorEncoding(rType)); - return layout.getFieldIndexAndStride(SparseTensorFieldKind::IdxMemRef, - idxDim); - } - - Value getAOSMemRef() const { - auto enc = getSparseTensorEncoding(rType); - unsigned cooStart = getCOOStart(enc); - assert(cooStart < enc.getDimLevelType().size()); - return getMemRefField(SparseTensorFieldKind::IdxMemRef, cooStart); - } -}; - -class SparseTensorDescriptor : public SparseTensorDescriptorImpl { -public: - SparseTensorDescriptor(Type tp, ValueArrayRef buffers) - : SparseTensorDescriptorImpl(tp, buffers) {} - - Value getIdxMemRefOrView(OpBuilder &builder, Location loc, - unsigned idxDim) const; }; /// Returns the "tuple" value of the adapted tensor. @@ -385,6 +383,11 @@ .getResult(0); } +inline Value genTuple(OpBuilder &builder, Location loc, + SparseTensorDescriptor desc) { + return genTuple(builder, loc, desc.getTensorType(), desc.getFields()); +} + inline Value genTuple(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc) { return genTuple(builder, loc, desc.getTensorType(), desc.getFields());