diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -336,6 +336,9 @@ /// Infers the result type and generates ToValuesOp. Value genToValues(OpBuilder &builder, Location loc, Value tensor); +/// Generates code to retrieve the values size for the sparse tensor. +Value genValMemSize(OpBuilder &builder, Location loc, Value tensor); + } // namespace sparse_tensor } // namespace mlir diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "CodegenUtils.h" +#include "SparseTensorStorageLayout.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -551,4 +552,11 @@ Type valTp = get1DMemRefType(srcTp.getElementType(), /*withLayout=*/false); return builder.create(loc, valTp, tensor); +} + +Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc, + Value tensor) { + SmallVector fields; + auto desc = getMutDescriptorFromTensorTuple(tensor, fields); + return desc.getValMemSize(builder, loc); } \ No newline at end of file diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -967,9 +967,9 @@ LogicalResult matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // Query memSizes for the actually stored values size. - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); - rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc())); + // Query memSizes for the actually stored values. + rewriter.replaceOp( + op, genValMemSize(rewriter, op.getLoc(), adaptor.getTensor())); return success(); } }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h @@ -154,7 +154,7 @@ /// instead relies on this class to access the right value for the right field. template class SparseTensorDescriptorImpl { -private: +protected: // Uses ValueRange for immuatable descriptors; uses SmallVectorImpl & // for mutable descriptors. // Using SmallVector for mutable descriptor allows users to reuse it as a tmp @@ -220,21 +220,6 @@ return getSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim); } - Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const { - return getSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize, - dim); - } - - Value getIdxMemSize(OpBuilder &builder, Location loc, unsigned dim) const { - return getSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize, - dim); - } - - Value getValMemSize(OpBuilder &builder, Location loc) const { - return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize, - std::nullopt); - } - Value getPtrMemRef(unsigned ptrDim) const { return getMemRefField(SparseTensorFieldKind::PtrMemRef, ptrDim); } @@ -262,25 +247,70 @@ return fields[fidx]; } + ValueRange getMemRefFields() const { + ValueRange ret = fields; + // Drop the last metadata fields. + return ret.slice(0, fields.size() - 1); + } + + Type getMemRefElementType(SparseTensorFieldKind kind, + Optional dim) const { + return getMemRefField(kind, dim) + .getType() + .template cast() + .getElementType(); + } + + RankedTensorType getTensorType() const { return rType; } + ValueArrayRef getFields() const { return fields; } + +protected: + RankedTensorType rType; + ValueArrayRef fields; +}; + +class MutSparseTensorDescriptor : public SparseTensorDescriptorImpl { +public: + MutSparseTensorDescriptor(Type tp, ValueArrayRef buffers) + : SparseTensorDescriptorImpl(tp, buffers) {} + + /// + /// Getters: get the value for required field. + /// + + Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const { + return getSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize, + dim); + } + + Value getIdxMemSize(OpBuilder &builder, Location loc, unsigned dim) const { + return getSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize, + dim); + } + + Value getValMemSize(OpBuilder &builder, Location loc) const { + return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize, + std::nullopt); + } + /// /// Setters: update the value for required field (only enabled for /// MutSparseTensorDescriptor). /// template - void setMemRefField(SparseTensorFieldKind kind, Optional dim, - std::enable_if_t v) { + void setMemRefField(SparseTensorFieldKind kind, Optional dim, T v) { fields[getMemRefFieldIndex(kind, dim)] = v; } template - void setMemRefField(unsigned fidx, std::enable_if_t v) { + void setMemRefField(unsigned fidx, T v) { assert(fidx < fields.size() - 1); fields[fidx] = v; } template - void setField(unsigned fidx, std::enable_if_t v) { + void setField(unsigned fidx, T v) { assert(fidx < fields.size()); fields[fidx] = v; } @@ -288,42 +318,19 @@ template void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, Optional dim, - std::enable_if_t v) { + T v) { SparseTensorSpecifier md(fields.back()); md.setSpecifierField(builder, loc, v, kind, dim); fields.back() = md; } template - void setDimSize(OpBuilder &builder, Location loc, unsigned dim, - std::enable_if_t v) { + void setDimSize(OpBuilder &builder, Location loc, unsigned dim, T v) { setSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim, v); } - - ValueRange getMemRefFields() const { - ValueRange ret = fields; - // drop the last metadata fields - return ret.slice(0, fields.size() - 1); - } - - Type getMemRefElementType(SparseTensorFieldKind kind, - Optional dim) const { - return getMemRefField(kind, dim) - .getType() - .template cast() - .getElementType(); - } - - RankedTensorType getTensorType() const { return rType; } - ValueArrayRef getFields() const { return fields; } - -private: - RankedTensorType rType; - ValueArrayRef fields; }; using SparseTensorDescriptor = SparseTensorDescriptorImpl; -using MutSparseTensorDescriptor = SparseTensorDescriptorImpl; /// Returns the "tuple" value of the adapted tensor. inline UnrealizedConversionCastOp getTuple(Value tensor) {