diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h @@ -0,0 +1,182 @@ +//===- SparseTensorStorageLayout.h ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines utilities for the sparse memory layout. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORSTORAGELAYOUT_H_ +#define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORSTORAGELAYOUT_H_ + +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" + +namespace mlir { +namespace sparse_tensor { + +///===----------------------------------------------------------------------===// +/// The sparse tensor storage scheme for a tensor is organized as a single +/// compound type with the following fields. Note that every memref with ? size +/// actually behaves as a "vector", i.e. the stored size is the capacity and the +/// used size resides in the storage_specifier struct. +/// +/// struct { +/// ; per-level l: +/// ; if dense: +/// +/// ; if compresed: +/// memref positions-l ; positions for sparse level l +/// memref coordinates-l ; coordinates for sparse level l +/// ; if singleton: +/// memref coordinates-l ; coordinates for singleton level l +/// +/// memref values ; values +/// +/// struct sparse_tensor.storage_specifier { +/// array lvlSizes ; sizes/cardinalities for each level +/// array memSizes; ; sizes/lengths for each data memref +/// } +/// }; +/// +/// In addition, for a "trailing COO region", defined as a compressed level +/// followed by one or more singleton levels, the default SOA storage that +/// is inherent to the TACO format is optimized into an AOS storage where +/// all coordinates of a stored element appear consecutively. In such cases, +/// a special operation (sparse_tensor.coordinates_buffer) must be used to +/// access the AOS coordinates array. In the code below, the method +/// `getCOOStart` is used to find the start of the "trailing COO region". +/// +/// If the sparse tensor is a slice (produced by `tensor.extract_slice` +/// operation), instead of allocating a new sparse tensor for it, it reuses the +/// same sets of MemRefs but attaching a additional set of slicing-metadata for +/// per-dimension slice offset and stride. +/// +/// Examples. +/// +/// #CSR storage of 2-dim matrix yields +/// memref ; positions-1 +/// memref ; coordinates-1 +/// memref ; values +/// struct<(array<2 x i64>, array<3 x i64>)>) ; lvl0, lvl1, 3xsizes +/// +/// #COO storage of 2-dim matrix yields +/// memref, ; positions-0, essentially +/// [0,sz] memref ; AOS coordinates storage +/// memref ; values +/// struct<(array<2 x i64>, array<3 x i64>)>) ; lvl0, lvl1, 3xsizes +/// +/// Slice on #COO storage of 2-dim matrix yields +/// ;; Inherited from the original sparse tensors +/// memref, ; positions-0, essentially +/// [0,sz] memref ; AOS coordinates storage +/// memref ; values +/// struct<(array<2 x i64>, array<3 x i64>, ; lvl0, lvl1, 3xsizes +/// ;; Extra slicing-metadata +/// array<2 x i64>, array<2 x i64>)>) ; dim offset, dim stride. +/// +///===----------------------------------------------------------------------===// + +enum class SparseTensorFieldKind : uint32_t { + StorageSpec = 0, + PosMemRef = static_cast(StorageSpecifierKind::PosMemSize), + CrdMemRef = static_cast(StorageSpecifierKind::CrdMemSize), + ValMemRef = static_cast(StorageSpecifierKind::ValMemSize) +}; + +inline StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind) { + assert(kind != SparseTensorFieldKind::StorageSpec); + return static_cast(kind); +} + +inline SparseTensorFieldKind toFieldKind(StorageSpecifierKind kind) { + assert(kind != StorageSpecifierKind::LvlSize); + return static_cast(kind); +} + +/// The type of field indices. This alias is to help code be more +/// self-documenting; unfortunately it is not type-checked, so it only +/// provides documentation rather than doing anything to prevent mixups. +using FieldIndex = unsigned; + +/// Provides methods to access fields of a sparse tensor with the given +/// encoding. +class StorageLayout { +public: + // TODO: Functions/methods marked with [NUMFIELDS] might should use + // `FieldIndex` for their return type, via the same reasoning for why + // `Dimension`/`Level` are used both for identifiers and ranks. + explicit StorageLayout(const SparseTensorType &stt) : enc(stt.getEncoding()) { + assert(enc); + } + explicit StorageLayout(SparseTensorEncodingAttr enc) : enc(enc) { + assert(enc); + } + + /// For each field that will be allocated for the given sparse tensor + /// encoding, calls the callback with the corresponding field index, + /// field kind, level, and level-type (the last two are only for level + /// memrefs). The field index always starts with zero and increments + /// by one between each callback invocation. Ideally, all other methods + /// should rely on this function to query a sparse tensor fields instead + /// of relying on ad-hoc index computation. + void foreachField( + llvm::function_ref) + const; + + /// + /// Getters: get the field index for required field. + /// + + FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind, + std::optional lvl) const { + return getFieldIndexAndStride(kind, lvl).first; + } + + /// Gets the total number of fields for the given sparse tensor encoding. + unsigned getNumFields() const; + + /// Gets the total number of data fields (coordinate arrays, position + /// arrays, and a value array) for the given sparse tensor encoding. + unsigned getNumDataFields() const; + + std::pair + getFieldIndexAndStride(SparseTensorFieldKind kind, + std::optional lvl) const; + +private: + const SparseTensorEncodingAttr enc; +}; + +// TODO: See note [NUMFIELDS]. +inline unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) { + return StorageLayout(enc).getNumFields(); +} + +// TODO: See note [NUMFIELDS]. +inline unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) { + return StorageLayout(enc).getNumDataFields(); +} + +inline void foreachFieldInSparseTensor( + SparseTensorEncodingAttr enc, + llvm::function_ref + callback) { + return StorageLayout(enc).foreachField(callback); +} + +void foreachFieldAndTypeInSparseTensor( + SparseTensorType, + llvm::function_ref); +} // namespace sparse_tensor +} // namespace mlir + +#endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORSTORAGELAYOUT_H_ diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -9,6 +9,7 @@ #include #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -41,6 +42,137 @@ return getRankedTensorType(t).getRank(); } +//===----------------------------------------------------------------------===// +// sparse_tensor::StorageLayout +//===----------------------------------------------------------------------===// +static constexpr Level kInvalidLevel = -1u; +static constexpr FieldIndex kDataFieldStartingIdx = 0; + +void StorageLayout::foreachField( + llvm::function_ref + callback) const { +#define RETURN_ON_FALSE(fidx, kind, lvl, dlt) \ + if (!(callback(fidx, kind, lvl, dlt))) \ + return; + + const auto lvlTypes = enc.getLvlTypes(); + const Level lvlRank = enc.getLvlRank(); + const Level cooStart = getCOOStart(enc); + const Level end = cooStart == lvlRank ? cooStart : cooStart + 1; + FieldIndex fieldIdx = kDataFieldStartingIdx; + // Per-level storage. + for (Level l = 0; l < end; l++) { + const auto dlt = lvlTypes[l]; + if (isCompressedDLT(dlt) || isCompressedWithHiDLT(dlt)) { + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, dlt); + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt); + } else if (isSingletonDLT(dlt)) { + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt); + } else { + assert(isDenseDLT(dlt)); // no fields + } + } + // The values array. + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel, + DimLevelType::Undef); + + // Put metadata at the end. + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel, + DimLevelType::Undef); + +#undef RETURN_ON_FALSE +} + +void sparse_tensor::foreachFieldAndTypeInSparseTensor( + SparseTensorType stt, + llvm::function_ref + callback) { + assert(stt.hasEncoding()); + // Construct the basic types. + const Type crdType = stt.getCrdType(); + const Type posType = stt.getPosType(); + const Type eltType = stt.getElementType(); + + const Type specType = StorageSpecifierType::get(stt.getEncoding()); + // memref positions + const Type posMemType = MemRefType::get({ShapedType::kDynamic}, posType); + // memref coordinates + const Type crdMemType = MemRefType::get({ShapedType::kDynamic}, crdType); + // memref values + const Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType); + + StorageLayout(stt).foreachField( + [specType, posMemType, crdMemType, valMemType, + callback](FieldIndex fieldIdx, SparseTensorFieldKind fieldKind, + Level lvl, DimLevelType dlt) -> bool { + switch (fieldKind) { + case SparseTensorFieldKind::StorageSpec: + return callback(specType, fieldIdx, fieldKind, lvl, dlt); + case SparseTensorFieldKind::PosMemRef: + return callback(posMemType, fieldIdx, fieldKind, lvl, dlt); + case SparseTensorFieldKind::CrdMemRef: + return callback(crdMemType, fieldIdx, fieldKind, lvl, dlt); + case SparseTensorFieldKind::ValMemRef: + return callback(valMemType, fieldIdx, fieldKind, lvl, dlt); + }; + llvm_unreachable("unrecognized field kind"); + }); +} + +unsigned StorageLayout::getNumFields() const { + unsigned numFields = 0; + foreachField([&numFields](FieldIndex, SparseTensorFieldKind, Level, + DimLevelType) -> bool { + numFields++; + return true; + }); + return numFields; +} + +unsigned StorageLayout::getNumDataFields() const { + unsigned numFields = 0; // one value memref + foreachField([&numFields](FieldIndex fidx, SparseTensorFieldKind, Level, + DimLevelType) -> bool { + if (fidx >= kDataFieldStartingIdx) + numFields++; + return true; + }); + numFields -= 1; // the last field is StorageSpecifier + assert(numFields == getNumFields() - kDataFieldStartingIdx - 1); + return numFields; +} + +std::pair +StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind, + std::optional lvl) const { + FieldIndex fieldIdx = kInvalidLevel; + unsigned stride = 1; + if (kind == SparseTensorFieldKind::CrdMemRef) { + assert(lvl.has_value()); + const Level cooStart = getCOOStart(enc); + const Level lvlRank = enc.getLvlRank(); + if (lvl.value() >= cooStart && lvl.value() < lvlRank) { + lvl = cooStart; + stride = lvlRank - cooStart; + } + } + foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx, + SparseTensorFieldKind fKind, Level fLvl, + DimLevelType dlt) -> bool { + if ((lvl && fLvl == lvl.value() && kind == fKind) || + (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) { + fieldIdx = fIdx; + // Returns false to break the iteration. + return false; + } + return true; + }); + assert(fieldIdx != kInvalidLevel); + return std::pair(fieldIdx, stride); +} + //===----------------------------------------------------------------------===// // TensorDialect Attribute Methods. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt @@ -10,7 +10,7 @@ SparseTensorConversion.cpp SparseTensorPasses.cpp SparseTensorRewriting.cpp - SparseTensorStorageLayout.cpp + SparseTensorDescriptor.cpp SparseVectorization.cpp Sparsification.cpp SparsificationAndBufferizationPass.cpp diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// #include "CodegenUtils.h" -#include "SparseTensorStorageLayout.h" +#include "SparseTensorDescriptor.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp @@ -7,9 +7,11 @@ //===----------------------------------------------------------------------===// #include "CodegenUtils.h" -#include "SparseTensorStorageLayout.h" +#include "mlir/Conversion/LLVMCommon/StructBuilder.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" + #include using namespace mlir; @@ -262,7 +264,8 @@ std::optional lvl; if (op.getLevel()) lvl = (*op.getLevel()); - unsigned idx = layout.getMemRefFieldIndex(op.getSpecifierKind(), lvl); + unsigned idx = + layout.getMemRefFieldIndex(toFieldKind(op.getSpecifierKind()), lvl); Value v = Base::onMemSize(rewriter, op, spec, idx); rewriter.replaceOp(op, v); return success(); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -16,7 +16,7 @@ //===----------------------------------------------------------------------===// #include "CodegenUtils.h" -#include "SparseTensorStorageLayout.h" +#include "SparseTensorDescriptor.h" #include "llvm/Support/FormatVariadic.h" diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h rename from mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h rename to mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h @@ -13,8 +13,8 @@ #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_ #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_ -#include "mlir/Conversion/LLVMCommon/StructBuilder.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Transforms/DialectConversion.h" @@ -27,199 +27,8 @@ // layout scheme during "direct code generation" (i.e. when sparsification // generates the buffers as part of actual IR, in constrast with the library // approach where data structures are hidden behind opaque pointers). -// -// The sparse tensor storage scheme for a rank-dimensional tensor is organized -// as a single compound type with the following fields. Note that every memref -// with ? size actually behaves as a "vector", i.e. the stored size is the -// capacity and the used size resides in the storage_specifier struct. -// -// struct { -// ; per-level l: -// ; if dense: -// -// ; if compresed: -// memref positions-l ; positions for sparse level l -// memref coordinates-l ; coordinates for sparse level l -// ; if singleton: -// memref coordinates-l ; coordinates for singleton level l -// -// memref values ; values -// -// struct sparse_tensor.storage_specifier { -// array lvlSizes ; sizes/cardinalities for each level -// array memSizes; ; sizes/lengths for each data memref -// } -// }; -// -// In addition, for a "trailing COO region", defined as a compressed level -// followed by one or more singleton levels, the default SOA storage that -// is inherent to the TACO format is optimized into an AOS storage where -// all coordinates of a stored element appear consecutively. In such cases, -// a special operation (sparse_tensor.coordinates_buffer) must be used to -// access the AOS coordinates array. In the code below, the method `getCOOStart` -// is used to find the start of the "trailing COO region". -// -// If the sparse tensor is a slice (produced by `tensor.extract_slice` -// operation), instead of allocating a new sparse tensor for it, it reuses the -// same sets of MemRefs but attaching a additional set of slicing-metadata for -// per-dimension slice offset and stride. -// -// Examples. -// -// #CSR storage of 2-dim matrix yields -// memref ; positions-1 -// memref ; coordinates-1 -// memref ; values -// struct<(array<2 x i64>, array<3 x i64>)>) ; lvl0, lvl1, 3xsizes -// -// #COO storage of 2-dim matrix yields -// memref, ; positions-0, essentially [0,sz] -// memref ; AOS coordinates storage -// memref ; values -// struct<(array<2 x i64>, array<3 x i64>)>) ; lvl0, lvl1, 3xsizes -// -// Slice on #COO storage of 2-dim matrix yields -// ;; Inherited from the original sparse tensors -// memref, ; positions-0, essentially [0,sz] -// memref ; AOS coordinates storage -// memref ; values -// struct<(array<2 x i64>, array<3 x i64>, ; lvl0, lvl1, 3xsizes -// ;; Extra slicing-metadata -// array<2 x i64>, array<2 x i64>)>) ; dim offset, dim stride. -// //===----------------------------------------------------------------------===// -enum class SparseTensorFieldKind : uint32_t { - StorageSpec = 0, - PosMemRef = 1, - CrdMemRef = 2, - ValMemRef = 3 -}; - -static_assert(static_cast(SparseTensorFieldKind::PosMemRef) == - static_cast(StorageSpecifierKind::PosMemSize)); -static_assert(static_cast(SparseTensorFieldKind::CrdMemRef) == - static_cast(StorageSpecifierKind::CrdMemSize)); -static_assert(static_cast(SparseTensorFieldKind::ValMemRef) == - static_cast(StorageSpecifierKind::ValMemSize)); - -/// The type of field indices. This alias is to help code be more -/// self-documenting; unfortunately it is not type-checked, so it only -/// provides documentation rather than doing anything to prevent mixups. -using FieldIndex = unsigned; - -// TODO: Functions/methods marked with [NUMFIELDS] might should use -// `FieldIndex` for their return type, via the same reasoning for why -// `Dimension`/`Level` are used both for identifiers and ranks. - -/// For each field that will be allocated for the given sparse tensor -/// encoding, calls the callback with the corresponding field index, -/// field kind, level, and level-type (the last two are only for level -/// memrefs). The field index always starts with zero and increments -/// by one between each callback invocation. Ideally, all other methods -/// should rely on this function to query a sparse tensor fields instead -/// of relying on ad-hoc index computation. -void foreachFieldInSparseTensor( - SparseTensorEncodingAttr, - llvm::function_ref); - -/// Same as above, except that it also builds the Type for the corresponding -/// field. -void foreachFieldAndTypeInSparseTensor( - SparseTensorType, - llvm::function_ref); - -/// Gets the total number of fields for the given sparse tensor encoding. -// TODO: See note [NUMFIELDS]. -unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc); - -/// Gets the total number of data fields (coordinate arrays, position -/// arrays, and a value array) for the given sparse tensor encoding. -// TODO: See note [NUMFIELDS]. -unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc); - -inline StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind) { - assert(kind != SparseTensorFieldKind::StorageSpec); - return static_cast(kind); -} - -inline SparseTensorFieldKind toFieldKind(StorageSpecifierKind kind) { - assert(kind != StorageSpecifierKind::LvlSize); - return static_cast(kind); -} - -/// Provides methods to access fields of a sparse tensor with the given -/// encoding. -class StorageLayout { -public: - explicit StorageLayout(SparseTensorEncodingAttr enc) : enc(enc) {} - - /// - /// Getters: get the field index for required field. - /// - - FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind, - std::optional lvl) const { - return getFieldIndexAndStride(kind, lvl).first; - } - - FieldIndex getMemRefFieldIndex(StorageSpecifierKind kind, - std::optional lvl) const { - return getMemRefFieldIndex(toFieldKind(kind), lvl); - } - - // TODO: See note [NUMFIELDS]. - static unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) { - return sparse_tensor::getNumFieldsFromEncoding(enc); - } - - static void foreachFieldInSparseTensor( - const SparseTensorEncodingAttr enc, - llvm::function_ref - callback) { - return sparse_tensor::foreachFieldInSparseTensor(enc, callback); - } - - std::pair - getFieldIndexAndStride(SparseTensorFieldKind kind, - std::optional lvl) const { - FieldIndex fieldIdx = -1u; - unsigned stride = 1; - if (kind == SparseTensorFieldKind::CrdMemRef) { - assert(lvl.has_value()); - const Level cooStart = getCOOStart(enc); - const Level lvlRank = enc.getLvlRank(); - if (lvl.value() >= cooStart && lvl.value() < lvlRank) { - lvl = cooStart; - stride = lvlRank - cooStart; - } - } - foreachFieldInSparseTensor( - enc, - [lvl, kind, &fieldIdx](FieldIndex fIdx, SparseTensorFieldKind fKind, - Level fLvl, DimLevelType dlt) -> bool { - if ((lvl && fLvl == lvl.value() && kind == fKind) || - (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) { - fieldIdx = fIdx; - // Returns false to break the iteration. - return false; - } - return true; - }); - assert(fieldIdx != -1u); - return std::pair(fieldIdx, stride); - } - -private: - SparseTensorEncodingAttr enc; -}; - class SparseTensorSpecifier { public: explicit SparseTensorSpecifier(Value specifier) @@ -249,10 +58,12 @@ template class SparseTensorDescriptorImpl { protected: + // TODO: Functions/methods marked with [NUMFIELDS] might should use + // `FieldIndex` for their return type, via the same reasoning for why + // `Dimension`/`Level` are used both for identifiers and ranks. SparseTensorDescriptorImpl(SparseTensorType stt, ValueArrayRef fields) - : rType(stt), fields(fields) { - assert(stt.hasEncoding() && - getNumFieldsFromEncoding(stt.getEncoding()) == getNumFields()); + : rType(stt), fields(fields), layout(stt) { + assert(layout.getNumFields() == getNumFields()); // We should make sure the class is trivially copyable (and should be small // enough) such that we can pass it by value. static_assert(std::is_trivially_copyable_v< @@ -263,7 +74,6 @@ FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind, std::optional lvl) const { // Delegates to storage layout. - StorageLayout layout(rType.getEncoding()); return layout.getMemRefFieldIndex(kind, lvl); } @@ -336,7 +146,6 @@ } std::pair getCrdMemRefIndexAndStride(Level lvl) const { - StorageLayout layout(rType.getEncoding()); return layout.getFieldIndexAndStride(SparseTensorFieldKind::CrdMemRef, lvl); } @@ -352,6 +161,7 @@ protected: SparseTensorType rType; ValueArrayRef fields; + StorageLayout layout; }; /// Uses ValueRange for immutable descriptors. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp rename from mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp rename to mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "SparseTensorStorageLayout.h" +#include "SparseTensorDescriptor.h" #include "CodegenUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -116,117 +116,3 @@ /*size=*/ValueRange{size}, /*step=*/ValueRange{stride}); } - -//===----------------------------------------------------------------------===// -// Public methods. -//===----------------------------------------------------------------------===// - -constexpr FieldIndex kDataFieldStartingIdx = 0; - -void sparse_tensor::foreachFieldInSparseTensor( - const SparseTensorEncodingAttr enc, - llvm::function_ref - callback) { - assert(enc); - -#define RETURN_ON_FALSE(fidx, kind, dim, dlt) \ - if (!(callback(fidx, kind, dim, dlt))) \ - return; - - const auto lvlTypes = enc.getLvlTypes(); - const Level lvlRank = enc.getLvlRank(); - const Level cooStart = getCOOStart(enc); - const Level end = cooStart == lvlRank ? cooStart : cooStart + 1; - FieldIndex fieldIdx = kDataFieldStartingIdx; - // Per-dimension storage. - for (Level l = 0; l < end; l++) { - // Dimension level types apply in order to the reordered dimension. - // As a result, the compound type can be constructed directly in the given - // order. - const auto dlt = lvlTypes[l]; - if (isCompressedDLT(dlt) || isCompressedWithHiDLT(dlt)) { - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, dlt); - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt); - } else if (isSingletonDLT(dlt)) { - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt); - } else { - assert(isDenseDLT(dlt)); // no fields - } - } - // The values array. - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, -1u, - DimLevelType::Undef); - - // Put metadata at the end. - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::StorageSpec, -1u, - DimLevelType::Undef); - -#undef RETURN_ON_FALSE -} - -void sparse_tensor::foreachFieldAndTypeInSparseTensor( - SparseTensorType stt, - llvm::function_ref - callback) { - assert(stt.hasEncoding()); - // Construct the basic types. - const Type crdType = stt.getCrdType(); - const Type posType = stt.getPosType(); - const Type eltType = stt.getElementType(); - - const Type metaDataType = StorageSpecifierType::get(stt.getEncoding()); - // memref positions - const Type posMemType = MemRefType::get({ShapedType::kDynamic}, posType); - // memref coordinates - const Type crdMemType = MemRefType::get({ShapedType::kDynamic}, crdType); - // memref values - const Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType); - - foreachFieldInSparseTensor( - stt.getEncoding(), - [metaDataType, posMemType, crdMemType, valMemType, - callback](FieldIndex fieldIdx, SparseTensorFieldKind fieldKind, - Level lvl, DimLevelType dlt) -> bool { - switch (fieldKind) { - case SparseTensorFieldKind::StorageSpec: - return callback(metaDataType, fieldIdx, fieldKind, lvl, dlt); - case SparseTensorFieldKind::PosMemRef: - return callback(posMemType, fieldIdx, fieldKind, lvl, dlt); - case SparseTensorFieldKind::CrdMemRef: - return callback(crdMemType, fieldIdx, fieldKind, lvl, dlt); - case SparseTensorFieldKind::ValMemRef: - return callback(valMemType, fieldIdx, fieldKind, lvl, dlt); - }; - llvm_unreachable("unrecognized field kind"); - }); -} - -unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) { - unsigned numFields = 0; - foreachFieldInSparseTensor(enc, - [&numFields](FieldIndex, SparseTensorFieldKind, - Level, DimLevelType) -> bool { - numFields++; - return true; - }); - return numFields; -} - -unsigned -sparse_tensor::getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) { - unsigned numFields = 0; // one value memref - foreachFieldInSparseTensor(enc, - [&numFields](FieldIndex fidx, - SparseTensorFieldKind, Level, - DimLevelType) -> bool { - if (fidx >= kDataFieldStartingIdx) - numFields++; - return true; - }); - numFields -= 1; // the last field is MetaData field - assert(numFields == - getNumFieldsFromEncoding(enc) - kDataFieldStartingIdx - 1); - return numFields; -} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2403,6 +2403,7 @@ srcs = ["lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp"], hdrs = [ "include/mlir/Dialect/SparseTensor/IR/SparseTensor.h", + "include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h", "include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h", ], includes = ["include"],