diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -968,7 +968,8 @@ matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Query memSizes for the actually stored values size. - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + SmallVector fields; + auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields); rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc())); return success(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h @@ -153,7 +153,7 @@ /// instead relies on this class to access the right value for the right field. template class SparseTensorDescriptorImpl { -private: +protected: // Uses ValueRange for immuatable descriptors; uses SmallVectorImpl & // for mutable descriptors. // Using SmallVector for mutable descriptor allows users to reuse it as a tmp @@ -219,21 +219,6 @@ return getSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim); } - Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const { - return getSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize, - dim); - } - - Value getIdxMemSize(OpBuilder &builder, Location loc, unsigned dim) const { - return getSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize, - dim); - } - - Value getValMemSize(OpBuilder &builder, Location loc) const { - return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize, - std::nullopt); - } - Value getPtrMemRef(unsigned ptrDim) const { return getMemRefField(SparseTensorFieldKind::PtrMemRef, ptrDim); } @@ -261,25 +246,68 @@ return fields[fidx]; } + ValueRange getMemRefFields() const { + ValueRange ret = fields; + // drop the last metadata fields + return ret.slice(0, fields.size() - 1); + } + + Type getMemRefElementType(SparseTensorFieldKind kind, + Optional dim) const { + return getMemRefField(kind, dim) + .getType() + .template cast() + .getElementType(); + } + + RankedTensorType getTensorType() const { return rType; } + ValueArrayRef getFields() const { return fields; } + +protected: + RankedTensorType rType; + ValueArrayRef fields; +}; + +class MutSparseTensorDescriptor : public SparseTensorDescriptorImpl { +public: + MutSparseTensorDescriptor(Type tp, ValueArrayRef buffers) + : SparseTensorDescriptorImpl(tp, buffers) {} + + /// + /// Getters: get the value for required field. + /// + + Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const { + return getSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize, + dim); + } + + Value getIdxMemSize(OpBuilder &builder, Location loc, unsigned dim) const { + return getSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize, + dim); + } + + Value getValMemSize(OpBuilder &builder, Location loc) const { + return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize, + std::nullopt); + } + /// /// Setters: update the value for required field (only enabled for /// MutSparseTensorDescriptor). /// template - void setMemRefField(SparseTensorFieldKind kind, Optional dim, - std::enable_if_t v) { + void setMemRefField(SparseTensorFieldKind kind, Optional dim, T v) { fields[getMemRefFieldIndex(kind, dim)] = v; } - template - void setMemRefField(unsigned fidx, std::enable_if_t v) { + template void setMemRefField(unsigned fidx, T v) { assert(fidx < fields.size() - 1); fields[fidx] = v; } - template - void setField(unsigned fidx, std::enable_if_t v) { + template void setField(unsigned fidx, T v) { assert(fidx < fields.size()); fields[fidx] = v; } @@ -287,42 +315,19 @@ template void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, Optional dim, - std::enable_if_t v) { + T v) { SparseTensorSpecifier md(fields.back()); md.setSpecifierField(builder, loc, v, kind, dim); fields.back() = md; } template - void setDimSize(OpBuilder &builder, Location loc, unsigned dim, - std::enable_if_t v) { + void setDimSize(OpBuilder &builder, Location loc, unsigned dim, T v) { setSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim, v); } - - ValueRange getMemRefFields() const { - ValueRange ret = fields; - // drop the last metadata fields - return ret.slice(0, fields.size() - 1); - } - - Type getMemRefElementType(SparseTensorFieldKind kind, - Optional dim) const { - return getMemRefField(kind, dim) - .getType() - .template cast() - .getElementType(); - } - - RankedTensorType getTensorType() const { return rType; } - ValueArrayRef getFields() const { return fields; } - -private: - RankedTensorType rType; - ValueArrayRef fields; }; using SparseTensorDescriptor = SparseTensorDescriptorImpl; -using MutSparseTensorDescriptor = SparseTensorDescriptorImpl; /// Returns the "tuple" value of the adapted tensor. inline UnrealizedConversionCastOp getTuple(Value tensor) {