diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -18,6 +18,49 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +//===----------------------------------------------------------------------===// +// +// Type aliases to help code be more self-documenting. Unfortunately +// these are not type-checked, so they only provide documentation rather +// than doing anything to prevent mixups. +// +// We must include these here (rather than in "SparseTensorType.h") +// because they are used by methods declared in the tablegen files. +// +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace sparse_tensor { + +/// The type of dimension identifiers, and dimension-ranks. We use the +/// same type for both identifiers and ranks because the latter are used +/// mainly for ordering-comparisons against the former (just like how the +/// one-past-the-end iterators are used). +using Dimension = uint64_t; + +/// The type of level identifiers, and level-ranks. We use the same +/// type for both identifiers and ranks because the latter are used +/// mainly for ordering-comparisons against the former (just like how +/// the one-past-the-end iterators are used). +using Level = uint64_t; + +/// The type for individual components of a compile-time shape. We avoid +/// calling this "size" because we use the term "sizes" to indicate the +/// actual run-time sizes, whereas this type also allows the value +/// `ShapedType::kDynamic`. +using DynSize = int64_t; + +/// The type for individual components of a compile-time shape which +/// are known not to be `ShapedType::kDynamic`. +using StaticSize = int64_t; + +} // namespace sparse_tensor +} // namespace mlir + +//===----------------------------------------------------------------------===// +// TableGen-defined classes +//===----------------------------------------------------------------------===// + // We must include Enums.h.inc before AttrDefs.h.inc due to dependency between // StorageSpecifierKindAttr and StorageSpeciferKind Enum. @@ -35,6 +78,10 @@ #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.h.inc" +//===----------------------------------------------------------------------===// +// Additional convenience methods. +//===----------------------------------------------------------------------===// + namespace mlir { namespace sparse_tensor { @@ -54,14 +101,14 @@ /// Returns null-attribute for any type without an encoding. SparseTensorEncodingAttr getSparseTensorEncoding(Type type); -/// Returns true iff the given type is a type for a COO tensor with the last -/// dimension level type being unique. +/// Returns true iff the given type is a COO type where the last level +/// is unique. bool isUniqueCOOType(TensorType tp); -/// Returns the starting dimension for a trailing COO region that spans across -/// at least two dimensions. If no such COO region is found, returns the rank -/// of the tensor. -unsigned getCOOStart(SparseTensorEncodingAttr enc); +/// Returns the starting level for a trailing COO region that spans +/// at least two levels. If no such COO region is found, then returns +/// the level-rank. +Level getCOOStart(SparseTensorEncodingAttr enc); /// Helpers to setup a COO type. RankedTensorType getCOOFromTypeWithOrdering(RankedTensorType src, @@ -69,88 +116,33 @@ RankedTensorType getCOOFromType(RankedTensorType src, bool ordered); -// -// Dimension level types. -// - -// MSVC does not allow this function to be constexpr, because -// `SparseTensorEncodingAttr::operator bool` isn't declared constexpr. -// And therefore all functions calling it cannot be constexpr either. -// TODO: since Clang does allow these to be constexpr, perhaps we should -// define a macro to abstract over `inline` vs `constexpr` annotations. -inline DimLevelType getDimLevelType(SparseTensorEncodingAttr enc, uint64_t d) { - if (enc) { - auto types = enc.getDimLevelType(); - assert(d < types.size() && "Dimension out of bounds"); - return types[d]; - } - return DimLevelType::Dense; // unannotated tensor is dense -} - -inline DimLevelType getDimLevelType(RankedTensorType type, uint64_t d) { - return getDimLevelType(getSparseTensorEncoding(type), d); -} - -/// Convenience function to test for dense dimension (0 <= d < rank). -inline bool isDenseDim(RankedTensorType type, uint64_t d) { - return isDenseDLT(getDimLevelType(type, d)); -} - -/// Convenience function to test for compressed dimension (0 <= d < rank). -inline bool isCompressedDim(RankedTensorType type, uint64_t d) { - return isCompressedDLT(getDimLevelType(type, d)); -} - -/// Convenience function to test for singleton dimension (0 <= d < rank). -inline bool isSingletonDim(RankedTensorType type, uint64_t d) { - return isSingletonDLT(getDimLevelType(type, d)); -} - -/// Convenience function to test for dense dimension (0 <= d < rank). -inline bool isDenseDim(SparseTensorEncodingAttr enc, uint64_t d) { - return isDenseDLT(getDimLevelType(enc, d)); -} - -/// Convenience function to test for compressed dimension (0 <= d < rank). -inline bool isCompressedDim(SparseTensorEncodingAttr enc, uint64_t d) { - return isCompressedDLT(getDimLevelType(enc, d)); -} - -/// Convenience function to test for singleton dimension (0 <= d < rank). -inline bool isSingletonDim(SparseTensorEncodingAttr enc, uint64_t d) { - return isSingletonDLT(getDimLevelType(enc, d)); -} - -// -// Dimension level properties. -// - -/// Convenience function to test for ordered property in the -/// given dimension (0 <= d < rank). -inline bool isOrderedDim(RankedTensorType type, uint64_t d) { - return isOrderedDLT(getDimLevelType(type, d)); -} - -/// Convenience function to test for unique property in the -/// given dimension (0 <= d < rank). -inline bool isUniqueDim(RankedTensorType type, uint64_t d) { - return isUniqueDLT(getDimLevelType(type, d)); -} - // // Reordering. // -uint64_t toOrigDim(SparseTensorEncodingAttr enc, uint64_t d); -uint64_t toStoredDim(SparseTensorEncodingAttr enc, uint64_t d); +// This CPP guard is to disable deprecation warnings for the LLVM +// build-bot, while making it easy to re-enable it for local development. +#if 0 +#define DEPRECATED \ + LLVM_DEPRECATED("The toOrigDim/toStoredDim functions are deprecated " \ + "because they only work for permutations; therefore any " \ + "code using them cannot support non-permutations.", \ + "") +#else +#define DEPRECATED +#endif -/// Convenience method to translate the given stored dimension -/// to the original dimension (0 <= d < rank). -uint64_t toOrigDim(RankedTensorType type, uint64_t d); +/// [deprecated] Convenience method to translate the given level to the +/// corresponding dimension. Requires: `0 <= l < lvlRank`. +DEPRECATED Dimension toOrigDim(SparseTensorEncodingAttr enc, Level l); +DEPRECATED Dimension toOrigDim(RankedTensorType type, Level l); -/// Convenience method to translate the given original dimension -/// to the stored dimension (0 <= d < rank). -uint64_t toStoredDim(RankedTensorType type, uint64_t d); +/// [deprecated] Convenience method to translate the given dimension to +/// the corresponding level. Requires: `0 <= d < dimRank`. +DEPRECATED Level toStoredDim(SparseTensorEncodingAttr enc, Dimension d); +DEPRECATED Level toStoredDim(RankedTensorType type, Dimension d); + +#undef DEPRECATED } // namespace sparse_tensor } // namespace mlir diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -260,22 +260,45 @@ /// reset to the default/identity. SparseTensorEncodingAttr withoutOrdering() const; - /// Return true if every level is dense in the encoding. + /// Returns true if every level is dense. Also returns true for + /// the null encoding (since dense-tensors are always all-dense). bool isAllDense() const; - /// Return true if the encoding has an identity dimension ordering. + /// Returns true if every level is ordered. Also returns true for + /// the null encoding (since dense-tensors are always all-ordered). + bool isAllOrdered() const; + + /// Returns true if the encoding has an identity dimension ordering. + /// Also returns true for the null encoding (since dense-tensors + /// always have the identity ordering). bool hasIdDimOrdering() const; + /// Returns the number of storage levels. Asserts that the encoding + /// is non-null (since there is no fixed result that's valid for + /// every dense-tensor). + ::mlir::sparse_tensor::Level getLvlRank() const; + + /// Safely looks up the level-type for the requested level. (Returns + /// `DimLevelType::Dense` for the null encoding, since dense-tensors + /// are always all-dense.) + ::mlir::sparse_tensor::DimLevelType getLvlType(::mlir::sparse_tensor::Level l) const; + + bool isDenseLvl(::mlir::sparse_tensor::Level l) const { return isDenseDLT(getLvlType(l)); } + bool isCompressedLvl(::mlir::sparse_tensor::Level l) const { return isCompressedDLT(getLvlType(l)); } + bool isSingletonLvl(::mlir::sparse_tensor::Level l) const { return isSingletonDLT(getLvlType(l)); } + bool isOrderedLvl(::mlir::sparse_tensor::Level l) const { return isOrderedDLT(getLvlType(l)); } + bool isUniqueLvl(::mlir::sparse_tensor::Level l) const { return isUniqueDLT(getLvlType(l)); } + bool isSlice() const { return !getDimSlices().empty(); } - std::optional getStaticDimSliceOffset(unsigned dim) const; - std::optional getStaticDimSliceSize(unsigned dim) const; - std::optional getStaticDimSliceStride(unsigned dim) const; - std::optional getStaticLvlSliceOffset(unsigned lvl) const; - std::optional getStaticLvlSliceSize(unsigned lvl) const; - std::optional getStaticLvlSliceStride(unsigned lvl) const; + std::optional getStaticDimSliceOffset(::mlir::sparse_tensor::Dimension dim) const; + std::optional getStaticDimSliceSize(::mlir::sparse_tensor::Dimension dim) const; + std::optional getStaticDimSliceStride(::mlir::sparse_tensor::Dimension dim) const; + std::optional getStaticLvlSliceOffset(::mlir::sparse_tensor::Level lvl) const; + std::optional getStaticLvlSliceSize(::mlir::sparse_tensor::Level lvl) const; + std::optional getStaticLvlSliceStride(::mlir::sparse_tensor::Level lvl) const; }]; let genVerifyDecl = 1; diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h @@ -0,0 +1,232 @@ +//===- SparseTensorType.h - Wrapper around RankedTensorType -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header defines the `SparseTensorType` wrapper class. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORTYPE_H_ +#define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORTYPE_H_ + +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" + +namespace mlir { +namespace sparse_tensor { + +//===----------------------------------------------------------------------===// +/// A wrapper around `RankedTensorType`, which has three goals: +/// +/// (1) To provide a uniform API for querying aspects of sparse-tensor +/// types; in particular, to make the "dimension" vs "level" distinction +/// overt (i.e., explicit everywhere). Thus, throughout the sparse-compiler +/// this class should be preferred over using `RankedTensorType` or +/// `ShapedType` directly, since the methods of the latter do not make +/// the "dimension" vs "level" distinction overt. +/// +/// (2) To provide a uniform abstraction over both sparse-tensor +/// types (i.e., `RankedTensorType` with `SparseTensorEncodingAttr`) +/// and dense-tensor types (i.e., `RankedTensorType` without an encoding). +/// That is, we want to manipulate dense-tensor types using the same API +/// that we use for manipulating sparse-tensor types; both to keep the +/// "dimension" vs "level" distinction overt, and to avoid needing to +/// handle certain cases specially in the sparse-compiler. +/// +/// (3) To provide uniform handling of "defaults". In particular +/// this means that dense-tensors should always return the same answers +/// as sparse-tensors with a default encoding. But it additionally means +/// that the answers should be normalized, so that there's no way to +/// distinguish between non-provided data (which is filled in by default) +/// vs explicitly-provided data which equals the defaults. +/// +class SparseTensorType { +public: + // We memoize `lvlRank` and `dim2lvl` to avoid repeating the + // conditionals throughout the rest of the class. + SparseTensorType(RankedTensorType rtp) + : rtp(rtp), enc(getSparseTensorEncoding(rtp)), + lvlRank(enc ? enc.getLvlRank() : getDimRank()), + dim2lvl(enc.hasIdDimOrdering() ? AffineMap() : enc.getDimOrdering()) { + assert((!isIdentity() || getDimRank() == lvlRank) && "Rank mismatch"); + } + + SparseTensorType(ShapedType stp, SparseTensorEncodingAttr enc) + : SparseTensorType( + RankedTensorType::get(stp.getShape(), stp.getElementType(), enc)) {} + + /// Constructs a new `SparseTensorType` with the same dimension-shape + /// and element type, but with the encoding replaced by the given encoding. + SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const { + return SparseTensorType(rtp, newEnc); + } + + /// Constructs a new `SparseTensorType` with the same dimension-shape + /// and element type, but with the encoding replaced by + /// `getEncoding().withoutOrdering()`. + SparseTensorType withoutOrdering() const { + return withEncoding(enc.withoutOrdering()); + } + + /// Allow implicit conversion to `RankedTensorType`, `ShapedType`, + /// and `Type`. These are implicit to help alleviate the impedance + /// mismatch for code that has not been converted to use `SparseTensorType` + /// directly. Once more of the sparse compiler has been converted to + /// using `SparseTensorType`, we may want to make these explicit instead. + /// + /// WARNING: This user-defined-conversion method causes overload + /// ambiguity whenever passing a `SparseTensorType` directly to a + /// function which is overloaded to accept either `Type` or `TypeRange`. + /// In particular, this includes `RewriterBase::replaceOpWithNewOp` + /// and `OpBuilder::create` whenever the `OpTy::build` is overloaded + /// thus. This happens because the `TypeRange(T&&)` ctor is implicit + /// as well, and there's no SFINAE we can add to this method that would + /// block subsequent application of that ctor. The only way to fix the + /// overload ambiguity is to avoid *implicit* conversion at the callsite: + /// e.g., by using `static_cast` to make the conversion explicit, by + /// assigning the `SparseTensorType` to a temporary variable of the + /// desired type, etc. + // + // NOTE: We implement this as a single templated user-defined-conversion + // function to avoid ambiguity problems when the desired result is `Type` + // (since both `RankedTensorType` and `ShapedType` can be implicitly + // converted to `Type`). + template >> + /*implicit*/ operator T() const { + return rtp; + } + + /// Explicitly convert to `RankedTensorType`. This method is + /// a convenience for resolving overload-ambiguity issues with + /// implicit conversion. + RankedTensorType getRankedTensorType() const { return rtp; } + + MLIRContext *getContext() const { return rtp.getContext(); } + + Type getElementType() const { return rtp.getElementType(); } + + /// Returns the encoding (or the null-attribute for dense-tensors). + SparseTensorEncodingAttr getEncoding() const { return enc; } + + /// Returns true for tensors which have an encoding, and false for + /// those which do not. Therefore tensors with an all-dense encoding + /// return true. + bool hasEncoding() const { return static_cast(enc); } + + /// Returns true for tensors where every level is dense. + /// (This is always true for dense-tensors.) + bool isAllDense() const { return enc.isAllDense(); } + + /// Returns true for tensors where every level is ordered. + /// (This is always true for dense-tensors.) + bool isAllOrdered() const { return enc.isAllOrdered(); } + + /// Returns true if the dimToLvl mapping is the identity. + /// (This is always true for dense-tensors.) + bool isIdentity() const { return !dim2lvl; } + + /// Returns the dimToLvl mapping (or the null-map for the identity). + AffineMap getDimToLvlMap() const { return dim2lvl; } + + /// Returns the dimToLvl mapping, where the identity map is expanded out + /// into a full `AffineMap`. This method is provided as a convenience, + /// but for most purposes other methods (`isIdentity`, `getDimToLvlMap`, + /// etc) will be more helpful. + AffineMap getExpandedDimToLvlMap() const { + return dim2lvl + ? dim2lvl + : AffineMap::getMultiDimIdentityMap(getDimRank(), getContext()); + } + + /// Returns the dimension-rank. + Dimension getDimRank() const { return rtp.getRank(); } + + /// Returns the level-rank. + Level getLvlRank() const { return lvlRank; } + + /// Returns the dimension-shape. + ArrayRef getDimShape() const { return rtp.getShape(); } + + /// Safely looks up the requested dimension-DynSize. If you intend + /// to check the result with `ShapedType::isDynamic`, then see the + /// `getStaticDimSize` method instead. + DynSize getDynamicDimSize(Dimension d) const { + assert(d < getDimRank() && "Dimension is out of bounds"); + return getDimShape()[d]; + } + + /// Safely looks up the requested dimension-size, mapping dynamic + /// sizes to `std::nullopt`. + std::optional getStaticDimSize(Dimension d) const { + const DynSize sh = getDynamicDimSize(d); + return ShapedType::isDynamic(sh) ? std::nullopt + : std::optional(sh); + } + + /// Returns true if no dimension has dynamic size. + bool hasStaticDimShape() const { return rtp.hasStaticShape(); } + + /// Returns true if any dimension has dynamic size. + bool hasDynamicDimShape() const { return !hasStaticDimShape(); } + + /// Returns true if the given dimension has dynamic size. If you + /// intend to call `getDynamicDimSize` based on the result, then see + /// the `getStaticDimSize` method instead. + bool isDynamicDim(Dimension d) const { + // We don't use `rtp.isDynamicDim(d)` because we want the + // OOB error message to be consistent with `getDynamicDimSize`. + return ShapedType::isDynamic(getDynamicDimSize(d)); + } + + /// Returns the number of dimensions which have dynamic sizes. + /// The return type is `int64_t` to maintain consistency with + /// `ShapedType::Trait::getNumDynamicDims`. + int64_t getNumDynamicDims() const { return rtp.getNumDynamicDims(); } + + DimLevelType getLvlType(Level l) const { + // This OOB check is for dense-tensors, since this class knows + // their lvlRank (whereas STEA::getLvlType will/can only check + // OOB for sparse-tensors). + assert(l < lvlRank && "Level out of bounds"); + return enc.getLvlType(l); + } + + // We can't just delegate these, since we want to use this class's + // `getLvlType` method instead of STEA's. + bool isDenseLvl(Level l) const { return isDenseDLT(getLvlType(l)); } + bool isCompressedLvl(Level l) const { return isCompressedDLT(getLvlType(l)); } + bool isSingletonLvl(Level l) const { return isSingletonDLT(getLvlType(l)); } + bool isOrderedLvl(Level l) const { return isOrderedDLT(getLvlType(l)); } + bool isUniqueLvl(Level l) const { return isUniqueDLT(getLvlType(l)); } + + /// Returns the index-overhead bitwidth, defaulting to zero. + unsigned getIndexBitWidth() const { return enc ? enc.getIndexBitWidth() : 0; } + + /// Returns the pointer-overhead bitwidth, defaulting to zero. + unsigned getPointerBitWidth() const { + return enc ? enc.getPointerBitWidth() : 0; + } + +private: + // These two must be const, to ensure coherence of the memoized fields. + const RankedTensorType rtp; + const SparseTensorEncodingAttr enc; + // Memoized to avoid frequent redundant conditionals. + const Level lvlRank; + const AffineMap dim2lvl; +}; + +/// Convenience method to abbreviate wrapping `getRankedTensorType`. +template +inline SparseTensorType getSparseTensorType(T t) { + return SparseTensorType(getRankedTensorType(t)); +} + +} // namespace sparse_tensor +} // namespace mlir + +#endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORTYPE_H_ diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -70,13 +70,13 @@ } intptr_t mlirSparseTensorEncodingGetNumDimLevelTypes(MlirAttribute attr) { - return unwrap(attr).cast().getDimLevelType().size(); + return unwrap(attr).cast().getLvlRank(); } MlirSparseTensorDimLevelType -mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t pos) { +mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t lvl) { return static_cast( - unwrap(attr).cast().getDimLevelType()[pos]); + unwrap(attr).cast().getLvlType(lvl)); } int mlirSparseTensorEncodingAttrGetPointerBitWidth(MlirAttribute attr) { diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -9,6 +9,7 @@ #include #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Builders.h" @@ -33,8 +34,10 @@ // Additional convenience methods. //===----------------------------------------------------------------------===// +/// Gets the dimension-rank of the type of some `T`. (In particular +/// this is only used for `Value` and `TypedValue`.) template -static inline int64_t getTypeRank(T t) { +static inline Dimension getDimRank(T t) { return getRankedTensorType(t).getRank(); } @@ -132,40 +135,59 @@ } bool SparseTensorEncodingAttr::isAllDense() const { - return llvm::all_of(getDimLevelType(), isDenseDLT); + return !getImpl() || llvm::all_of(getDimLevelType(), isDenseDLT); +} + +bool SparseTensorEncodingAttr::isAllOrdered() const { + return !getImpl() || llvm::all_of(getDimLevelType(), isOrderedDLT); } bool SparseTensorEncodingAttr::hasIdDimOrdering() const { - return !getDimOrdering() || getDimOrdering().isIdentity(); + return !getImpl() || !getDimOrdering() || getDimOrdering().isIdentity(); +} + +Level SparseTensorEncodingAttr::getLvlRank() const { + assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); + return getDimLevelType().size(); +} + +DimLevelType SparseTensorEncodingAttr::getLvlType(Level l) const { + if (!getImpl()) + return DimLevelType::Dense; + assert(l < getLvlRank() && "Level is out of bounds"); + return getDimLevelType()[l]; } std::optional -SparseTensorEncodingAttr::getStaticDimSliceOffset(unsigned dim) const { +SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const { return getDimSlices()[dim].getStaticOffset(); } std::optional -SparseTensorEncodingAttr::getStaticDimSliceSize(unsigned dim) const { +SparseTensorEncodingAttr::getStaticDimSliceSize(Dimension dim) const { return getDimSlices()[dim].getStaticSize(); } std::optional -SparseTensorEncodingAttr::getStaticDimSliceStride(unsigned dim) const { +SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const { return getDimSlices()[dim].getStaticStride(); } std::optional -SparseTensorEncodingAttr::getStaticLvlSliceOffset(unsigned lvl) const { +SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const { + // FIXME: `toOrigDim` is deprecated. return getStaticDimSliceOffset(toOrigDim(*this, lvl)); } std::optional -SparseTensorEncodingAttr::getStaticLvlSliceSize(unsigned lvl) const { +SparseTensorEncodingAttr::getStaticLvlSliceSize(Level lvl) const { + // FIXME: `toOrigDim` is deprecated. return getStaticDimSliceSize(toOrigDim(*this, lvl)); } std::optional -SparseTensorEncodingAttr::getStaticLvlSliceStride(unsigned lvl) const { +SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const { + // FIXME: `toOrigDim` is deprecated. return getStaticDimSliceStride(toOrigDim(*this, lvl)); } @@ -296,11 +318,9 @@ void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { // Print the struct-like storage in dictionary fashion. printer << "<{ dimLevelType = [ "; - for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) { - printer << "\"" << toMLIRString(getDimLevelType()[i]) << "\""; - if (i != e - 1) - printer << ", "; - } + llvm::interleaveComma(getDimLevelType(), printer, [&](DimLevelType dlt) { + printer << "\"" << toMLIRString(dlt) << "\""; + }); printer << " ]"; // Print remaining members only for non-default values. if (!hasIdDimOrdering()) @@ -334,11 +354,19 @@ return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth; if (!acceptBitWidth(indexBitWidth)) return emitError() << "unexpected index bitwidth: " << indexBitWidth; + // Before we can check that the level-rank is consistent/coherent + // across all fields, we need to define it. The source-of-truth for + // the `getLvlRank` method is the length of the level-types array, + // since it must always be provided and have full rank; therefore we + // use that same source-of-truth here. + const Level lvlRank = dimLevelType.size(); + if (lvlRank == 0) + return emitError() << "expected a non-empty array for level types"; if (dimOrdering) { if (!dimOrdering.isPermutation()) return emitError() << "expected a permutation affine map for dimension ordering"; - if (dimOrdering.getNumResults() != dimLevelType.size()) + if (dimOrdering.getNumResults() != lvlRank) return emitError() << "unexpected mismatch in ordering and dimension " "level types size"; } @@ -347,11 +375,11 @@ return emitError() << "unexpected higher ordering mapping from " << higherOrdering.getNumDims() << " to " << higherOrdering.getNumResults(); - if (higherOrdering.getNumResults() != dimLevelType.size()) + if (higherOrdering.getNumResults() != lvlRank) return emitError() << "unexpected mismatch in higher ordering and " "dimension level types size"; } - if (!dimSlices.empty() && dimSlices.size() != dimLevelType.size()) { + if (!dimSlices.empty() && dimSlices.size() != lvlRank) { return emitError() << "unexpected mismatch in dimension slices and " "dimension level type size"; } @@ -364,32 +392,28 @@ } LogicalResult SparseTensorEncodingAttr::verifyEncoding( - ArrayRef shape, Type elementType, + ArrayRef dimShape, Type elementType, function_ref emitError) const { - // Check structural integrity. + // Check structural integrity. In particular, this ensures that the + // level-rank is coherent across all the fields. RETURN_FAILURE_IF_FAILED(verify( emitError, getDimLevelType(), getDimOrdering(), getHigherOrdering(), getPointerBitWidth(), getIndexBitWidth(), getDimSlices())) - // Check integrity with tensor type specifics. Dimension ordering is optional, - // but we always should have dimension level types for the full rank. - unsigned size = shape.size(); - if (size == 0) + // Check integrity with tensor type specifics. In particular, we + // need only check that the dimension-rank of the tensor agrees with + // the dimension-rank of the encoding. + const Dimension dimRank = dimShape.size(); + if (dimRank == 0) return emitError() << "expected non-scalar sparse tensor"; - if (getHigherOrdering()) { - if (getHigherOrdering().getNumDims() != size) - return emitError() << "expected an affine map of size " << size - << " for higher ordering"; - + if (const auto higherOrdering = getHigherOrdering()) { + if (higherOrdering.getNumDims() != dimRank) + return emitError() << "expected an affine map with " << dimRank + << " dimensions for higher ordering"; // TODO: verification of higher ordering contents - - size = getHigherOrdering().getNumResults(); // higher-order size! - } - if (getDimOrdering() && getDimOrdering().getNumResults() != size) - return emitError() << "expected an affine map of size " << size - << " for dimension ordering"; - if (getDimLevelType().size() != size) - return emitError() << "expected an array of size " << size + } else if (dimRank != getLvlRank()) { + return emitError() << "expected an array of size " << dimRank << " for dimension level types"; + } return success(); } @@ -407,69 +431,72 @@ } /// Returns true iff the given sparse tensor encoding attribute has a trailing -/// COO region starting at the given dimension. -static bool isCOOType(SparseTensorEncodingAttr enc, uint64_t s, bool isUnique) { - uint64_t rank = enc.getDimLevelType().size(); - assert(s < rank && "Dimension out of bounds"); - if (!isCompressedDim(enc, s)) +/// COO region starting at the given level. +static bool isCOOType(SparseTensorEncodingAttr enc, Level startLvl, + bool isUnique) { + if (!enc || !enc.isCompressedLvl(startLvl)) return false; - - for (uint64_t i = s + 1; i < rank; ++i) - if (!isSingletonDim(enc, i)) + const Level lvlRank = enc.getLvlRank(); + for (Level l = startLvl + 1; l < lvlRank; ++l) + if (!enc.isSingletonLvl(l)) return false; - - // If isUnique is true, then make sure that the last dimension level is - // unique, that is, rank == 1 (unique the only compressed) and rank > 1 + // If isUnique is true, then make sure that the last level is unique, + // that is, lvlRank == 1 (unique the only compressed) and lvlRank > 1 // (unique on the last singleton). - return !isUnique || isUniqueDLT(getDimLevelType(enc, rank - 1)); + return !isUnique || enc.isUniqueLvl(lvlRank - 1); } bool mlir::sparse_tensor::isUniqueCOOType(TensorType tp) { - SparseTensorEncodingAttr enc = getSparseTensorEncoding(tp); - return enc && isCOOType(enc, 0, /*isUnique=*/true); + return isCOOType(getSparseTensorEncoding(tp), 0, /*isUnique=*/true); } -unsigned mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) { - const unsigned rank = enc.getDimLevelType().size(); - // We only consider COO region with at least two dimensions for the purpose +Level mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) { + // We only consider COO region with at least two levels for the purpose // of AOS storage optimization. - if (rank > 1) - for (unsigned r = 0; r < rank - 1; r++) - if (isCOOType(enc, r, /*isUnique=*/false)) - return r; - - return rank; + const Level lvlRank = enc.getLvlRank(); + if (lvlRank > 1) + for (Level l = 0; l < lvlRank - 1; l++) + if (isCOOType(enc, l, /*isUnique=*/false)) + return l; + return lvlRank; } // Helpers to setup a COO type. -RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType src, - AffineMap ordering, +RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt, + AffineMap lvlPerm, bool ordered) { - auto *ctx = src.getContext(); - auto rank = src.getRank(); - SmallVector dims; + const SparseTensorType src(rtt); + // The dim-rank of the source `RankedTensorType` is used as the lvl-rank + // of the result `RankedTensorType`. This follows from the fact that the + // result's encoding has the default higher-ordering (hence the result's + // lvl-rank equals its dim-rank). We don't need to assert that `lvlRank` + // agrees with the size of `lvlPerm` because that will be verified by + // `STEA::get`. + const Level lvlRank = src.getDimRank(); + SmallVector lvlTypes; - // An unordered and non-unique compressed dim at beginning. - // If this is also the last dimension, then it is unique. - dims.push_back(*getDimLevelType(LevelFormat::Compressed, ordered, rank == 1)); - if (rank > 1) { + // An unordered and non-unique compressed level at beginning. + // If this is also the last level, then it is unique. + lvlTypes.push_back( + *getDimLevelType(LevelFormat::Compressed, ordered, lvlRank == 1)); + if (lvlRank > 1) { // TODO: it is actually ordered at the level for ordered input. // Followed by unordered non-unique n-2 singleton levels. - std::fill_n(std::back_inserter(dims), rank - 2, + std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2, *getDimLevelType(LevelFormat::Singleton, ordered, false)); - // Ends by a unique singleton level unless the tensor rank is 1. - dims.push_back(*getDimLevelType(LevelFormat::Singleton, ordered, true)); + // Ends by a unique singleton level unless the lvlRank is 1. + lvlTypes.push_back(*getDimLevelType(LevelFormat::Singleton, ordered, true)); } - SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(src); // TODO: Maybe pick the bitwidth based on input/output tensors (probably the // largest one among them) in the original operation instead of using the // default value. - unsigned pointerBitWidth = encSrc ? encSrc.getPointerBitWidth() : 0; - unsigned indexBitWidth = encSrc ? encSrc.getIndexBitWidth() : 0; - auto enc = SparseTensorEncodingAttr::get(ctx, dims, ordering, AffineMap(), - pointerBitWidth, indexBitWidth); - return RankedTensorType::get(src.getShape(), src.getElementType(), enc); + unsigned pointerBitWidth = src.getPointerBitWidth(); + unsigned indexBitWidth = src.getIndexBitWidth(); + auto enc = SparseTensorEncodingAttr::get(src.getContext(), lvlTypes, lvlPerm, + AffineMap(), pointerBitWidth, + indexBitWidth); + return RankedTensorType::get(src.getDimShape(), src.getElementType(), enc); } RankedTensorType sparse_tensor::getCOOFromType(RankedTensorType src, @@ -479,20 +506,24 @@ ordered); } -uint64_t mlir::sparse_tensor::toOrigDim(SparseTensorEncodingAttr enc, - uint64_t d) { +// TODO: Remove this definition once all use-sites have been fixed to +// properly handle non-permutations. +Dimension mlir::sparse_tensor::toOrigDim(SparseTensorEncodingAttr enc, + Level l) { if (enc) { auto order = enc.getDimOrdering(); if (order) { assert(order.isPermutation()); - return order.getDimPosition(d); + return order.getDimPosition(l); } } - return d; + return l; } -uint64_t mlir::sparse_tensor::toStoredDim(SparseTensorEncodingAttr enc, - uint64_t d) { +// TODO: Remove this definition once all use-sites have been fixed to +// properly handle non-permutations. +Level mlir::sparse_tensor::toStoredDim(SparseTensorEncodingAttr enc, + Dimension d) { if (enc) { auto order = enc.getDimOrdering(); if (order) { @@ -506,13 +537,18 @@ return d; } -uint64_t mlir::sparse_tensor::toOrigDim(RankedTensorType type, uint64_t d) { - assert(d < static_cast(type.getRank())); - return toOrigDim(getSparseTensorEncoding(type), d); +// TODO: Remove this definition once all use-sites have been fixed to +// properly handle non-permutations. +Dimension mlir::sparse_tensor::toOrigDim(RankedTensorType type, Level l) { + const auto enc = getSparseTensorEncoding(type); + assert(l < enc.getLvlRank()); + return toOrigDim(enc, l); } -uint64_t mlir::sparse_tensor::toStoredDim(RankedTensorType type, uint64_t d) { - assert(d < static_cast(type.getRank())); +// TODO: Remove this definition once all use-sites have been fixed to +// properly handle non-permutations. +Level mlir::sparse_tensor::toStoredDim(RankedTensorType type, Dimension d) { + assert(d < static_cast(type.getRank())); return toStoredDim(getSparseTensorEncoding(type), d); } @@ -554,6 +590,8 @@ return IntegerType::get(getContext(), std::max(idxBitWidth, ptrBitWidth)); } +// FIXME: see note [CLARIFY_DIM_LVL] in +// "lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h" Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind, std::optional dim) const { if (kind != StorageSpecifierKind::ValMemSize) @@ -565,6 +603,8 @@ return getSizesType(); } +// FIXME: see note [CLARIFY_DIM_LVL] in +// "lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h" Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind, std::optional dim) const { return getFieldType(kind, dim ? std::optional(dim.value().getZExtValue()) @@ -575,8 +615,8 @@ // SparseTensorDialect Operations. //===----------------------------------------------------------------------===// -static LogicalResult isInBounds(uint64_t dim, Value tensor) { - return success(dim < static_cast(getTypeRank(tensor))); +static LogicalResult dimIsInBounds(Dimension dim, Value tensor) { + return success(dim < getDimRank(tensor)); } static LogicalResult isMatchingWidth(Value result, unsigned width) { @@ -585,26 +625,25 @@ } static LogicalResult verifySparsifierGetterSetter( - StorageSpecifierKind mdKind, std::optional dim, + StorageSpecifierKind mdKind, std::optional lvl, TypedValue md, Operation *op) { - if (mdKind == StorageSpecifierKind::ValMemSize && dim) { + if (mdKind == StorageSpecifierKind::ValMemSize && lvl) { return op->emitError( - "redundant dimension argument for querying value memory size"); + "redundant level argument for querying value memory size"); } - auto enc = md.getType().getEncoding(); - ArrayRef dlts = enc.getDimLevelType(); - unsigned rank = dlts.size(); + const auto enc = md.getType().getEncoding(); + const Level lvlRank = enc.getLvlRank(); if (mdKind != StorageSpecifierKind::ValMemSize) { - if (!dim) - return op->emitError("missing dimension argument"); + if (!lvl) + return op->emitError("missing level argument"); - unsigned d = dim.value().getZExtValue(); - if (d >= rank) - return op->emitError("requested dimension out of bound"); + const Level l = lvl.value().getZExtValue(); + if (l >= lvlRank) + return op->emitError("requested level out of bound"); - if (mdKind == StorageSpecifierKind::PtrMemSize && isSingletonDLT(dlts[d])) + if (mdKind == StorageSpecifierKind::PtrMemSize && enc.isSingletonLvl(l)) return op->emitError( "requested pointer memory size on a singleton level"); } @@ -612,7 +651,7 @@ } LogicalResult NewOp::verify() { - if (getExpandSymmetry() && getTypeRank(getResult()) != 2) + if (getExpandSymmetry() && getDimRank(getResult()) != 2) return emitOpError("expand_symmetry can only be used for 2D tensors"); return success(); } @@ -670,7 +709,7 @@ // Accept size matches between the source and the destination type // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). - for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) + for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++) if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic) return emitError("unexpected conversion mismatch in dimension ") << d; return success(); @@ -692,7 +731,8 @@ LogicalResult ToPointersOp::verify() { auto e = getSparseTensorEncoding(getTensor().getType()); - if (failed(isInBounds(getDimension().getZExtValue(), getTensor()))) + // FIXME: there seems to be some dim/lvl confusion here. + if (failed(dimIsInBounds(getDimension().getZExtValue(), getTensor()))) return emitError("requested pointers dimension out of bounds"); if (failed(isMatchingWidth(getResult(), e.getPointerBitWidth()))) return emitError("unexpected type for pointers"); @@ -701,7 +741,8 @@ LogicalResult ToIndicesOp::verify() { auto e = getSparseTensorEncoding(getTensor().getType()); - if (failed(isInBounds(getDimension().getZExtValue(), getTensor()))) + // FIXME: there seems to be some dim/lvl confusion here. + if (failed(dimIsInBounds(getDimension().getZExtValue(), getTensor()))) return emitError("requested indices dimension out of bounds"); if (failed(isMatchingWidth(getResult(), e.getIndexBitWidth()))) return emitError("unexpected type for indices"); @@ -710,7 +751,7 @@ LogicalResult ToIndicesBufferOp::verify() { auto e = getSparseTensorEncoding(getTensor().getType()); - if (getCOOStart(e) >= e.getDimLevelType().size()) + if (getCOOStart(e) >= e.getLvlRank()) return emitError("expected sparse tensor with a COO region"); return success(); } @@ -846,58 +887,56 @@ } LogicalResult ConcatenateOp::verify() { - auto dstTp = getRankedTensorType(*this); - uint64_t concatDim = getDimension().getZExtValue(); - unsigned rank = dstTp.getRank(); + const auto dstTp = getSparseTensorType(*this); + const Dimension concatDim = getDimension().getZExtValue(); + const Dimension dimRank = dstTp.getDimRank(); if (getInputs().size() <= 1) return emitError("Need at least two tensors to concatenate."); - for (auto type : getInputs().getTypes()) { - auto shape = type.cast().getShape(); - for (auto dim : shape) { - if (ShapedType::isDynamic(dim)) - return emitError("Only statically-sized input tensors are supported."); - } - } - - if (concatDim >= rank) + if (concatDim >= dimRank) return emitError(llvm::formatv( - "Failed to concatentate tensors with rank={0} on dimension={1}.", rank, - concatDim)); + "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})", + concatDim, dimRank)); - for (size_t i = 0, e = getInputs().size(); i < e; i++) { - const auto inputRank = getTypeRank(getInputs()[i]); - if (inputRank != rank) + for (const auto &it : llvm::enumerate(getInputs())) { + const auto i = it.index(); + const auto srcTp = getSparseTensorType(it.value()); + if (srcTp.hasDynamicDimShape()) + return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i)); + const Dimension srcDimRank = srcTp.getDimRank(); + if (srcDimRank != dimRank) return emitError( - llvm::formatv("The input tensor ${0} has a different rank (rank={1}) " + llvm::formatv("Input tensor ${0} has a different rank (rank={1}) " "from the output tensor (rank={2}).", - i, inputRank, rank)); + i, srcDimRank, dimRank)); } - for (unsigned i = 0; i < rank; i++) { - const auto dstDim = dstTp.getShape()[i]; - if (i == concatDim) { - if (!ShapedType::isDynamic(dstDim)) { - // If we reach here, all inputs should have static shapes. - unsigned sumDim = 0; - for (auto src : getInputs()) - sumDim += getRankedTensorType(src).getShape()[i]; + for (Dimension d = 0; d < dimRank; d++) { + const DynSize dstSh = dstTp.getDimShape()[d]; + if (d == concatDim) { + if (!ShapedType::isDynamic(dstSh)) { + // If we reach here, then all inputs have static shapes. So we + // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)` + // to avoid redundant assertions in the loop. + StaticSize sumSz = 0; + for (const auto src : getInputs()) + sumSz += getSparseTensorType(src).getDimShape()[d]; // If all dimension are statically known, the sum of all the input // dimensions should be equal to the output dimension. - if (sumDim != dstDim) + if (sumSz != dstSh) return emitError( "The concatenation dimension of the output tensor should be the " "sum of all the concatenation dimensions of the input tensors."); } } else { - int64_t prev = dstDim; - for (auto src : getInputs()) { - const auto d = getRankedTensorType(src).getShape()[i]; - if (!ShapedType::isDynamic(prev) && d != prev) + DynSize prev = dstSh; + for (const auto src : getInputs()) { + const auto sh = getSparseTensorType(src).getDimShape()[d]; + if (!ShapedType::isDynamic(prev) && sh != prev) return emitError("All dimensions (expect for the concatenating one) " "should be equal."); - prev = d; + prev = sh; } } } @@ -906,7 +945,7 @@ } LogicalResult InsertOp::verify() { - if (getTypeRank(getTensor()) != static_cast(getIndices().size())) + if (getDimRank(getTensor()) != static_cast(getIndices().size())) return emitOpError("incorrect number of indices"); return success(); } @@ -926,7 +965,8 @@ } LogicalResult CompressOp::verify() { - if (getTypeRank(getTensor()) != 1 + static_cast(getIndices().size())) + if (getDimRank(getTensor()) != + 1 + static_cast(getIndices().size())) return emitOpError("incorrect number of indices"); return success(); } @@ -947,37 +987,34 @@ // Builds foreach body. if (!bodyBuilder) return; - auto rtp = getRankedTensorType(tensor); - int64_t rank = rtp.getRank(); + const auto stt = getSparseTensorType(tensor); + const Dimension dimRank = stt.getDimRank(); - SmallVector blockArgTypes; - // Starts with n index. - std::fill_n(std::back_inserter(blockArgTypes), rank, builder.getIndexType()); + // Starts with `dimRank`-many indices. + SmallVector blockArgTypes(dimRank, builder.getIndexType()); // Followed by one value. - blockArgTypes.push_back(rtp.getElementType()); - // Followed by reduction variable. + blockArgTypes.push_back(stt.getElementType()); + // Followed by the reduction variables. blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end()); - SmallVector blockArgLocs; - std::fill_n(std::back_inserter(blockArgLocs), blockArgTypes.size(), - tensor.getLoc()); + SmallVector blockArgLocs(blockArgTypes.size(), tensor.getLoc()); OpBuilder::InsertionGuard guard(builder); auto ®ion = *result.regions.front(); Block *bodyBlock = builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); bodyBuilder(builder, result.location, - bodyBlock->getArguments().slice(0, rank), - bodyBlock->getArguments()[rank], - bodyBlock->getArguments().drop_front(rank + 1)); + bodyBlock->getArguments().slice(0, dimRank), + bodyBlock->getArguments()[dimRank], + bodyBlock->getArguments().drop_front(dimRank + 1)); } LogicalResult ForeachOp::verify() { - auto t = getRankedTensorType(getTensor()); - auto args = getBody()->getArguments(); + const auto t = getSparseTensorType(getTensor()); + const Dimension dimRank = t.getDimRank(); + const auto args = getBody()->getArguments(); - if (static_cast(t.getRank()) + 1 + getInitArgs().size() != - args.size()) + if (static_cast(dimRank) + 1 + getInitArgs().size() != args.size()) return emitError("Unmatched number of arguments in the block"); if (getNumResults() != getInitArgs().size()) @@ -986,18 +1023,20 @@ if (getResultTypes() != getInitArgs().getTypes()) return emitError("Mismatch in types of init arguments and results"); + // Cannot mark this const, because the getters aren't. auto yield = cast(getBody()->getTerminator()); if (yield.getNumOperands() != getNumResults() || yield.getOperands().getTypes() != getResultTypes()) return emitError("Mismatch in types of yield values and results"); - for (int64_t i = 0, e = t.getRank(); i < e; i++) - if (args[i].getType() != IndexType::get(getContext())) + const auto iTp = IndexType::get(getContext()); + for (Dimension d = 0; d < dimRank; d++) + if (args[d].getType() != iTp) emitError( - llvm::formatv("Expecting Index type for argument at index {0}", i)); + llvm::formatv("Expecting Index type for argument at index {0}", d)); - auto elemTp = t.getElementType(); - auto valueTp = args[t.getRank()].getType(); + const auto elemTp = t.getElementType(); + const auto valueTp = args[dimRank].getType(); if (elemTp != valueTp) emitError(llvm::formatv("Unmatched element type between input tensor and " "block argument, expected:{0}, got: {1}", @@ -1036,13 +1075,13 @@ bool checkEleType = true) -> LogicalResult { for (Value opnd : operands) { auto mtp = getMemRefType(opnd); - int64_t dim = mtp.getShape()[0]; + const DynSize sh = mtp.getShape()[0]; // We can't check the size of dynamic dimension at compile-time, but all // xs and ys should have a dimension not less than n at runtime. - if (n && !ShapedType::isDynamic(dim) && dim < n.value()) + if (n && !ShapedType::isDynamic(sh) && sh < n.value()) return emitError(llvm::formatv("xs and ys need to have a dimension >= n" ": {0} < {1}", - dim, n.value())); + sh, n.value())); if (checkEleType && xtp != mtp.getElementType()) return emitError("mismatch xs element types"); @@ -1072,12 +1111,13 @@ ny = nyAttr.getInt(); } - auto checkDim = [&](Value v, uint64_t min, const char *message) { - auto tp = getMemRefType(v); - int64_t dim = tp.getShape()[0]; - if (!ShapedType::isDynamic(dim) && dim < (int64_t)min) { - emitError(llvm::formatv("{0} got {1} < {2}", message, dim, min)); - } + // FIXME: update the types of variables used in expressions bassed as + // the `minSize` argument, to avoid implicit casting at the callsites + // of this lambda. + const auto checkDim = [&](Value v, StaticSize minSize, const char *message) { + const DynSize sh = getMemRefType(v).getShape()[0]; + if (!ShapedType::isDynamic(sh) && sh < minSize) + emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize)); }; checkDim(getXy(), n * (nx + ny), "Expected dimension(xy) >= n * (nx + ny)"); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp @@ -10,7 +10,9 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" + #include using namespace mlir; @@ -114,10 +116,9 @@ OpOperand *lhs = linalgOp.getDpsInitOperand(0); unsigned tensor = lhs->getOperandNumber(); - auto enc = getSparseTensorEncoding(lhs->get().getType()); // An non-annotated output tensor is assumed dense, and becomes a random // access n-dim memref. Admissible since insertions cannot occur. - if (!enc || enc.isAllDense()) + if (getSparseTensorType(lhs->get()).isAllDense()) return true; // A tensor expression with a sparse output tensor that changes its values diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -228,8 +228,7 @@ /// to match the shape of the corresponding dense tensor to support direct /// access of the buffer through indices. Value reshapeValuesToLevels(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, - const SmallVectorImpl &dimSizes, + SparseTensorEncodingAttr enc, ValueRange dimSizes, Value valuesBuffer, Value idxBuffer); //===----------------------------------------------------------------------===// @@ -345,13 +344,13 @@ } /// Infers the result type and generates ToPointersOp. -Value genToPointers(OpBuilder &builder, Location loc, Value tensor, uint64_t d); +Value genToPointers(OpBuilder &builder, Location loc, Value tensor, Level lvl); -/// Infers the result type and generates ToIndicesOp. If the dim is within a COO +/// Infers the result type and generates ToIndicesOp. If the lvl is within a COO /// region, the result type is a memref with unknown stride and offset. /// Otherwise, the result type is a memref without any specified layout. -Value genToIndices(OpBuilder &builder, Location loc, Value tensor, uint64_t d, - uint64_t cooStart); +Value genToIndices(OpBuilder &builder, Location loc, Value tensor, Level lvl, + Level cooStart); /// Infers the result type and generates ToValuesOp. Value genToValues(OpBuilder &builder, Location loc, Value tensor); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -236,7 +236,7 @@ void mlir::sparse_tensor::genReshapeDstShape( Location loc, PatternRewriter &rewriter, SmallVectorImpl &dstShape, - ArrayRef srcShape, ArrayRef staticDstShape, + ArrayRef srcShape, ArrayRef staticDstShape, ArrayRef reassociation) { // Collapse shape. if (reassociation.size() < srcShape.size()) { @@ -269,7 +269,7 @@ if (staticDstShape[j] == ShapedType::kDynamic) { // The expanded dimension has dynamic size. We compute the dimension // by dividing srcDim by the product of the static dimensions. - int64_t product = 1; + StaticSize product = 1; for (unsigned k = start; k < start + map.size(); k++) { if (staticDstShape[k] != ShapedType::kDynamic) { product *= staticDstShape[k]; @@ -483,9 +483,9 @@ void mlir::sparse_tensor::sizesFromSrc(OpBuilder &builder, SmallVectorImpl &sizes, Location loc, Value src) { - unsigned rank = src.getType().cast().getRank(); - for (unsigned i = 0; i < rank; i++) - sizes.push_back(linalg::createOrFoldDimOp(builder, loc, src, i)); + const Dimension dimRank = getSparseTensorType(src).getDimRank(); + for (Dimension d = 0; d < dimRank; d++) + sizes.push_back(linalg::createOrFoldDimOp(builder, loc, src, d)); } Operation *mlir::sparse_tensor::getTop(Operation *op) { @@ -532,9 +532,20 @@ } void sparse_tensor::storeIndices(OpBuilder &builder, Location loc, - unsigned rank, Value ind, ValueRange ivs, + unsigned size, Value ind, ValueRange ivs, unsigned offsetDim, Value offset) { - for (unsigned i = 0; i < rank; i++) { +#ifndef NDEBUG + const auto memTp = ind.getType().cast(); + (void)memTp; + assert(memTp.getRank() == 1); + const DynSize memSh = memTp.getDimSize(0); + (void)memSh; + assert(ShapedType::isDynamic(memSh) || memSh == static_cast(size)); + assert(ivs.size() == static_cast(size)); + assert(offsetDim < size); +#endif // NDEBUG + + for (unsigned i = 0; i < size; i++) { Value idx = ivs[i]; if (offsetDim == i && offset) idx = builder.create(loc, idx, offset); @@ -543,44 +554,47 @@ } } -Value sparse_tensor::reshapeValuesToLevels( - OpBuilder &builder, Location loc, SparseTensorEncodingAttr enc, - const SmallVectorImpl &dimSizes, Value valuesBuffer, - Value idxBuffer) { - // Use the dstIdx to store the level sizes. - unsigned rank = enc.getDimLevelType().size(); +Value sparse_tensor::reshapeValuesToLevels(OpBuilder &builder, Location loc, + SparseTensorEncodingAttr enc, + ValueRange dimSizes, + Value valuesBuffer, + Value idxBuffer) { + // Use the `idxBuffer` to store the level sizes. + const Level lvlRank = enc.getLvlRank(); SmallVector lvlSizes; - for (unsigned i = 0; i < dimSizes.size(); i++) - lvlSizes.push_back(dimSizes[toOrigDim(enc, i)]); - storeIndices(builder, loc, rank, idxBuffer, lvlSizes); + lvlSizes.reserve(lvlRank); + for (Level l = 0; l < lvlRank; l++) + // FIXME: `toOrigDim` is deprecated. + lvlSizes.push_back(dimSizes[toOrigDim(enc, l)]); + storeIndices(builder, loc, lvlRank, idxBuffer, lvlSizes); // The memref ReshapeOp requires the sizes buffer to have a static // shape. - idxBuffer = builder.create( - loc, MemRefType::get({rank}, builder.getIndexType()), idxBuffer); - SmallVector shape(rank, ShapedType::kDynamic); - Type elemTp = getMemRefType(valuesBuffer).getElementType(); - return builder.create(loc, MemRefType::get(shape, elemTp), - valuesBuffer, idxBuffer); + const auto iTp = builder.getIndexType(); + const SmallVector idxBufferShape{static_cast(lvlRank)}; + const auto idxBufferTp = MemRefType::get(idxBufferShape, iTp); + idxBuffer = builder.create(loc, idxBufferTp, idxBuffer); + const SmallVector resShape(lvlRank, ShapedType::kDynamic); + const Type elemTp = getMemRefType(valuesBuffer).getElementType(); + const auto resTp = MemRefType::get(resShape, elemTp); + return builder.create(loc, resTp, valuesBuffer, idxBuffer); } Value sparse_tensor::genToPointers(OpBuilder &builder, Location loc, - Value tensor, uint64_t d) { - RankedTensorType srcTp = getRankedTensorType(tensor); - SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp); - Type ptrTp = get1DMemRefType(getPointerOverheadType(builder, encSrc), - /*withLayout=*/false); - return builder.create(loc, ptrTp, tensor, - builder.getIndexAttr(d)); + Value tensor, Level lvl) { + const auto srcTp = getSparseTensorType(tensor); + const Type ptrTp = getPointerOverheadType(builder, srcTp.getEncoding()); + const Type memTp = get1DMemRefType(ptrTp, /*withLayout=*/false); + return builder.create(loc, memTp, tensor, + builder.getIndexAttr(lvl)); } Value sparse_tensor::genToIndices(OpBuilder &builder, Location loc, - Value tensor, uint64_t d, uint64_t cooStart) { - RankedTensorType srcTp = getRankedTensorType(tensor); - SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp); - Type indTp = get1DMemRefType(getIndexOverheadType(builder, encSrc), - /*withLayout=*/d >= cooStart); - return builder.create(loc, indTp, tensor, - builder.getIndexAttr(d)); + Value tensor, Level lvl, Level cooStart) { + const auto srcTp = getSparseTensorType(tensor); + const Type idxTp = getIndexOverheadType(builder, srcTp.getEncoding()); + const Type memTp = get1DMemRefType(idxTp, /*withLayout=*/lvl >= cooStart); + return builder.create(loc, memTp, tensor, + builder.getIndexAttr(lvl)); } Value sparse_tensor::genToValues(OpBuilder &builder, Location loc, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -166,37 +166,41 @@ // For every tensor, find lower and upper bound on dimensions, set the // same bounds on loop indices, and obtain dense or sparse buffer(s). for (size_t t = 0, e = tensors.size(); t < e; t++) { - auto tensor = tensors[t]; - auto rtp = tensor.getType().dyn_cast(); + const auto tensor = tensors[t]; + const auto rtp = tensor.getType().dyn_cast(); if (!rtp) // Skips only scalar, zero ranked tensor still need to be bufferized and // (probably) filled with zeros by users. continue; - auto rank = rtp.getRank(); - auto shape = rtp.getShape(); - auto enc = getSparseTensorEncoding(rtp); - uint64_t cooStart = enc ? getCOOStart(enc) : rank; - // Scan all dimensions of current tensor. - for (int64_t d = 0; d < rank; d++) { + // FIXME: the definition of `lvlRank` looks more like a dim-rank; + // but the variable is used as a level everywhere below, which + // suggests there may be some dim/lvl confusion going on here. + const Level lvlRank = rtp.getRank(); + const auto shape = rtp.getShape(); + const auto enc = getSparseTensorEncoding(rtp); + const Level cooStart = enc ? getCOOStart(enc) : lvlRank; + // Scan all levels of current tensor. + for (Level l = 0; l < lvlRank; l++) { // This should be called only once at beginning. - assert(!ptrBuffer[t][d] && !idxBuffer[t][d] && !highs[t][d]); + assert(!ptrBuffer[t][l] && !idxBuffer[t][l] && !highs[t][l]); + const auto dlt = dimTypes[t][l]; // Handle sparse storage schemes. - if (isCompressedDLT(dimTypes[t][d])) { + if (isCompressedDLT(dlt)) { // Generate sparse primitives to obtains pointer and indices. - ptrBuffer[t][d] = genToPointers(builder, loc, tensor, d); - idxBuffer[t][d] = genToIndices(builder, loc, tensor, d, cooStart); - } else if (isSingletonDLT(dimTypes[t][d])) { + ptrBuffer[t][l] = genToPointers(builder, loc, tensor, l); + idxBuffer[t][l] = genToIndices(builder, loc, tensor, l, cooStart); + } else if (isSingletonDLT(dlt)) { // Singleton dimension, fetch indices. - idxBuffer[t][d] = genToIndices(builder, loc, tensor, d, cooStart); + idxBuffer[t][l] = genToIndices(builder, loc, tensor, l, cooStart); } else { // Dense dimension, nothing to fetch. - assert(isDenseDLT(dimTypes[t][d])); + assert(isDenseDLT(dlt)); } // Find upper bound in current dimension. - unsigned p = toOrigDim(enc, d); - Value up = mlir::linalg::createOrFoldDimOp(builder, loc, tensor, p); - highs[t][d] = up; + // FIXME: `toOrigDim` is deprecated + const Dimension d = toOrigDim(enc, l); + highs[t][l] = mlir::linalg::createOrFoldDimOp(builder, loc, tensor, d); } // Perform the required bufferization. Dense inputs materialize diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp @@ -24,11 +24,11 @@ static SmallVector getSpecifierFields(StorageSpecifierType tp) { MLIRContext *ctx = tp.getContext(); auto enc = tp.getEncoding(); - unsigned rank = enc.getDimLevelType().size(); + const Level lvlRank = enc.getLvlRank(); SmallVector result; auto indexType = tp.getSizesType(); - auto dimSizes = LLVM::LLVMArrayType::get(ctx, indexType, rank); + auto dimSizes = LLVM::LLVMArrayType::get(ctx, indexType, lvlRank); auto memSizes = LLVM::LLVMArrayType::get(ctx, indexType, getNumDataFieldsFromEncoding(enc)); result.push_back(dimSizes); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -24,9 +24,12 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SparseTensor/IR/Enums.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/FormatVariadic.h" + #include using namespace mlir; @@ -104,74 +107,77 @@ /// Gets the dimension size for the given sparse tensor at the given /// original dimension 'dim'. static Value sizeFromTensorAtDim(OpBuilder &builder, Location loc, - SparseTensorDescriptor desc, unsigned dim) { - RankedTensorType rtp = desc.getTensorType(); + SparseTensorDescriptor desc, Dimension dim) { + const SparseTensorType stt(desc.getRankedTensorType()); // Access into static dimension can query original type directly. // Note that this is typically already done by DimOp's folding. - auto shape = rtp.getShape(); - if (!ShapedType::isDynamic(shape[dim])) - return constantIndex(builder, loc, shape[dim]); + if (auto sz = stt.getStaticDimSize(dim)) + return constantIndex(builder, loc, *sz); // Any other query can consult the dimSizes array at field DimSizesIdx, // accounting for the reordering applied to the sparse storage. - return desc.getDimSize(builder, loc, toStoredDim(rtp, dim)); + // FIXME: `toStoredDim` is deprecated. + const Level lvl = toStoredDim(stt, dim); + // FIXME: this method seems to get *level* sizes, but the name is confusing + return desc.getDimSize(builder, loc, lvl); } // Gets the dimension size at the given stored level 'lvl', either as a // constant for a static size, or otherwise dynamically through memSizes. static Value sizeFromTensorAtLvl(OpBuilder &builder, Location loc, - SparseTensorDescriptor desc, unsigned lvl) { + SparseTensorDescriptor desc, Level lvl) { + // FIXME: `toOrigDim` is deprecated. return sizeFromTensorAtDim(builder, loc, desc, - toOrigDim(desc.getTensorType(), lvl)); + toOrigDim(desc.getRankedTensorType(), lvl)); } static void createPushback(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, - SparseTensorFieldKind kind, - std::optional dim, Value value, - Value repeat = Value()) { - Type etp = desc.getMemRefElementType(kind, dim); - Value field = desc.getMemRefField(kind, dim); + SparseTensorFieldKind kind, std::optional lvl, + Value value, Value repeat = Value()) { + Type etp = desc.getMemRefElementType(kind, lvl); + Value field = desc.getMemRefField(kind, lvl); StorageSpecifierKind specFieldKind = toSpecifierKind(kind); auto pushBackOp = builder.create( - loc, desc.getSpecifierField(builder, loc, specFieldKind, dim), field, + loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl), field, toType(builder, loc, value, etp), repeat); - desc.setMemRefField(kind, dim, pushBackOp.getOutBuffer()); - desc.setSpecifierField(builder, loc, specFieldKind, dim, + desc.setMemRefField(kind, lvl, pushBackOp.getOutBuffer()); + desc.setSpecifierField(builder, loc, specFieldKind, lvl, pushBackOp.getNewSize()); } /// Generates code that allocates a sparse storage scheme for given rank. static void allocSchemeForRank(OpBuilder &builder, Location loc, - MutSparseTensorDescriptor desc, unsigned r0) { - RankedTensorType rtp = desc.getTensorType(); - unsigned rank = rtp.getShape().size(); + MutSparseTensorDescriptor desc, Level startLvl) { + const SparseTensorType stt(desc.getRankedTensorType()); Value linear = constantIndex(builder, loc, 1); - for (unsigned r = r0; r < rank; r++) { - if (isCompressedDim(rtp, r)) { + const Level lvlRank = stt.getLvlRank(); + for (Level l = startLvl; l < lvlRank; l++) { + const auto dlt = stt.getLvlType(l); + if (isCompressedDLT(dlt)) { // Append linear x pointers, initialized to zero. Since each compressed // dimension initially already has a single zero entry, this maintains // the desired "linear + 1" length property at all times. - Type ptrType = getSparseTensorEncoding(rtp).getPointerType(); + Type ptrType = stt.getEncoding().getPointerType(); Value ptrZero = constantZero(builder, loc, ptrType); - createPushback(builder, loc, desc, SparseTensorFieldKind::PtrMemRef, r, + createPushback(builder, loc, desc, SparseTensorFieldKind::PtrMemRef, l, ptrZero, linear); return; } - if (isSingletonDim(rtp, r)) { + if (isSingletonDLT(dlt)) { return; // nothing to do } // Keep compounding the size, but nothing needs to be initialized // at this level. We will eventually reach a compressed level or // otherwise the values array for the from-here "all-dense" case. - assert(isDenseDim(rtp, r)); - Value size = sizeFromTensorAtLvl(builder, loc, desc, r); + assert(isDenseDLT(dlt)); + Value size = sizeFromTensorAtLvl(builder, loc, desc, l); linear = builder.create(loc, linear, size); } // Reached values array so prepare for an insertion. - Value valZero = constantZero(builder, loc, rtp.getElementType()); + Value valZero = constantZero(builder, loc, stt.getElementType()); createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef, std::nullopt, valZero, linear); } @@ -193,43 +199,40 @@ /// for all dynamic memrefs, the memory size is really the capacity of /// the "vector", while the actual size resides in the sizes array. /// -/// TODO: for efficiency, we will need heuristis to make educated guesses +/// TODO: for efficiency, we will need heuristics to make educated guesses /// on the required capacities (see heuristic variable). /// -static void createAllocFields(OpBuilder &builder, Location loc, Type type, - ValueRange dynSizes, bool enableInit, - SmallVectorImpl &fields, Value sizeHint) { - RankedTensorType rtp = type.cast(); - SparseTensorEncodingAttr enc = getSparseTensorEncoding(rtp); - +static void createAllocFields(OpBuilder &builder, Location loc, + SparseTensorType stt, ValueRange dynSizes, + bool enableInit, SmallVectorImpl &fields, + Value sizeHint) { // Build original sizes. - SmallVector sizes; - auto shape = rtp.getShape(); - unsigned rank = shape.size(); - for (unsigned r = 0, o = 0; r < rank; r++) { - if (ShapedType::isDynamic(shape[r])) - sizes.push_back(dynSizes[o++]); - else - sizes.push_back(constantIndex(builder, loc, shape[r])); - } + assert((dynSizes.size() == static_cast(stt.getNumDynamicDims())) && + "Got wrong number of dynamic sizes"); + const Dimension dimRank = stt.getDimRank(); + SmallVector dimSizes; + dimSizes.reserve(dimRank); + unsigned i = 0; // cumulative index into `dynSizes`. + for (const DynSize sh : stt.getDimShape()) + dimSizes.push_back(ShapedType::isDynamic(sh) + ? dynSizes[i++] + : constantIndex(builder, loc, sh)); // Set up some heuristic sizes. We try to set the initial // size based on available information. Otherwise we just // initialize a few elements to start the reallocation chain. // TODO: refine this Value ptrHeuristic, idxHeuristic, valHeuristic; - if (enc.isAllDense()) { - Value linear = sizes[0]; - for (unsigned r = 1; r < rank; r++) { - linear = builder.create(loc, linear, sizes[r]); - } - valHeuristic = linear; + if (stt.isAllDense()) { + valHeuristic = dimSizes[0]; + for (const Value sz : ArrayRef{dimSizes}.drop_front()) + valHeuristic = builder.create(loc, valHeuristic, sz); } else if (sizeHint) { - if (getCOOStart(enc) == 0) { + if (getCOOStart(stt.getEncoding()) == 0) { ptrHeuristic = constantIndex(builder, loc, 2); idxHeuristic = builder.create( - loc, constantIndex(builder, loc, rank), sizeHint); // AOS - } else if (rank == 2 && isDenseDim(rtp, 0) && isCompressedDim(rtp, 1)) { + loc, constantIndex(builder, loc, dimRank), sizeHint); // AOS + } else if (dimRank == 2 && stt.isDenseLvl(0) && stt.isCompressedLvl(1)) { ptrHeuristic = builder.create( loc, sizeHint, constantIndex(builder, loc, 1)); idxHeuristic = sizeHint; @@ -243,15 +246,15 @@ } foreachFieldAndTypeInSparseTensor( - rtp, - [&builder, &fields, rtp, loc, ptrHeuristic, idxHeuristic, valHeuristic, - enableInit](Type fType, unsigned fIdx, SparseTensorFieldKind fKind, - unsigned /*dim*/, DimLevelType /*dlt*/) -> bool { + stt, + [&builder, &fields, stt, loc, ptrHeuristic, idxHeuristic, valHeuristic, + enableInit](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind, + Level /*lvl*/, DimLevelType /*dlt*/) -> bool { assert(fields.size() == fIdx); Value field; switch (fKind) { case SparseTensorFieldKind::StorageSpec: - field = SparseTensorSpecifier::getInitValue(builder, loc, rtp); + field = SparseTensorSpecifier::getInitValue(builder, loc, stt); break; case SparseTensorFieldKind::PtrMemRef: case SparseTensorFieldKind::IdxMemRef: @@ -270,65 +273,66 @@ return true; }); - MutSparseTensorDescriptor desc(rtp, fields); + MutSparseTensorDescriptor desc(stt, fields); // Initialize the storage scheme to an empty tensor. Initialized memSizes // to all zeros, sets the dimSizes to known values and gives all pointer // fields an initial zero entry, so that it is easier to maintain the // "linear + 1" length property. Value ptrZero = - constantZero(builder, loc, getSparseTensorEncoding(rtp).getPointerType()); - for (unsigned r = 0; r < rank; r++) { - unsigned ro = toOrigDim(rtp, r); + constantZero(builder, loc, stt.getEncoding().getPointerType()); + for (Level lvlRank = stt.getLvlRank(), l = 0; l < lvlRank; l++) { // Fills dim sizes array. - desc.setDimSize(builder, loc, r, sizes[ro]); - + // FIXME: this method seems to set *level* sizes, but the name is confusing + // FIXME: `toOrigDim` is deprecated. + desc.setDimSize(builder, loc, l, dimSizes[toOrigDim(stt, l)]); // Pushes a leading zero to pointers memref. - if (isCompressedDim(rtp, r)) { - createPushback(builder, loc, desc, SparseTensorFieldKind::PtrMemRef, r, + if (stt.isCompressedLvl(l)) + createPushback(builder, loc, desc, SparseTensorFieldKind::PtrMemRef, l, ptrZero); - } } allocSchemeForRank(builder, loc, desc, /*rank=*/0); } /// Helper method that generates block specific to compressed case: /// -/// plo = pointers[d][pos[d-1]] -/// phi = pointers[d][pos[d-1]+1] -/// msz = indices[d].size() +/// plo = pointers[l][pos[l-1]] +/// phi = pointers[l][pos[l-1]+1] +/// msz = indices[l].size() /// if (plo < phi) { -/// present = indices[d][phi-1] == i[d] +/// present = indices[l][phi-1] == i[l] /// } else { // first insertion /// present = false -/// pointers[d][pos[d-1]] = msz +/// pointers[l][pos[l-1]] = msz /// } /// if (present) { // index already present /// next = phi-1 /// } else { -/// indices[d].push_back(i[d]) -/// pointers[d][pos[d-1]+1] = msz+1 +/// indices[l].push_back(i[l]) +/// pointers[l][pos[l-1]+1] = msz+1 /// next = msz -/// +/// /// } -/// pos[d] = next +/// pos[l] = next static Value genCompressed(OpBuilder &builder, Location loc, - MutSparseTensorDescriptor desc, - SmallVectorImpl &indices, Value value, - Value pos, unsigned d) { - RankedTensorType rtp = desc.getTensorType(); - unsigned rank = rtp.getShape().size(); + MutSparseTensorDescriptor desc, ValueRange indices, + Value value, Value pos, Level lvl) { + const SparseTensorType stt(desc.getRankedTensorType()); + const Level lvlRank = stt.getLvlRank(); + assert(lvl < lvlRank && "Level is out of bounds"); + assert(indices.size() == static_cast(lvlRank) && + "Level-rank mismatch"); SmallVector types; Type indexType = builder.getIndexType(); Type boolType = builder.getIntegerType(1); unsigned idxIndex; unsigned idxStride; - std::tie(idxIndex, idxStride) = desc.getIdxMemRefIndexAndStride(d); + std::tie(idxIndex, idxStride) = desc.getIdxMemRefIndexAndStride(lvl); Value one = constantIndex(builder, loc, 1); Value pp1 = builder.create(loc, pos, one); - Value plo = genLoad(builder, loc, desc.getPtrMemRef(d), pos); - Value phi = genLoad(builder, loc, desc.getPtrMemRef(d), pp1); - Value msz = desc.getIdxMemSize(builder, loc, d); + Value plo = genLoad(builder, loc, desc.getPtrMemRef(lvl), pos); + Value phi = genLoad(builder, loc, desc.getPtrMemRef(lvl), pp1); + Value msz = desc.getIdxMemSize(builder, loc, lvl); Value idxStrideC; if (idxStride > 1) { idxStrideC = constantIndex(builder, loc, idxStride); @@ -349,14 +353,13 @@ : phim1); Value eq = builder.create(loc, arith::CmpIPredicate::eq, toType(builder, loc, crd, indexType), - indices[d]); + indices[lvl]); builder.create(loc, eq); builder.setInsertionPointToStart(&ifOp1.getElseRegion().front()); - if (d > 0) - genStore(builder, loc, msz, desc.getPtrMemRef(d), pos); + if (lvl > 0) + genStore(builder, loc, msz, desc.getPtrMemRef(lvl), pos); builder.create(loc, constantI1(builder, loc, false)); builder.setInsertionPointAfter(ifOp1); - Value p = ifOp1.getResult(0); // If present construct. Note that for a non-unique dimension level, we // simply set the condition to false and rely on CSE/DCE to clean up the IR. // @@ -365,8 +368,8 @@ for (unsigned i = 0, e = desc.getNumFields(); i < e; i++) types.push_back(desc.getField(i).getType()); types.push_back(indexType); - if (!isUniqueDim(rtp, d)) - p = constantI1(builder, loc, false); + const Value p = stt.isUniqueLvl(lvl) ? ifOp1.getResult(0) + : constantI1(builder, loc, false); scf::IfOp ifOp2 = builder.create(loc, types, p, /*else*/ true); // If present (fields unaffected, update next to phim1). builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); @@ -380,12 +383,12 @@ // If !present (changes fields, update next). builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); Value mszp1 = builder.create(loc, msz, one); - genStore(builder, loc, mszp1, desc.getPtrMemRef(d), pp1); - createPushback(builder, loc, desc, SparseTensorFieldKind::IdxMemRef, d, - indices[d]); + genStore(builder, loc, mszp1, desc.getPtrMemRef(lvl), pp1); + createPushback(builder, loc, desc, SparseTensorFieldKind::IdxMemRef, lvl, + indices[lvl]); // Prepare the next dimension "as needed". - if ((d + 1) < rank) - allocSchemeForRank(builder, loc, desc, d + 1); + if ((lvl + 1) < lvlRank) + allocSchemeForRank(builder, loc, desc, lvl + 1); desc.getFields().push_back(msz); builder.create(loc, desc.getFields()); @@ -412,52 +415,52 @@ /// static void genInsertBody(OpBuilder &builder, ModuleOp module, func::FuncOp func, RankedTensorType rtp) { - OpBuilder::InsertionGuard insertionGuard(builder); - Block *entryBlock = func.addEntryBlock(); + const OpBuilder::InsertionGuard insertionGuard(builder); + Block *const entryBlock = func.addEntryBlock(); builder.setInsertionPointToStart(entryBlock); - - Location loc = func.getLoc(); - ValueRange args = entryBlock->getArguments(); - unsigned rank = rtp.getShape().size(); + const ValueRange args = entryBlock->getArguments(); + const Location loc = func.getLoc(); + const SparseTensorType stt(rtp); + const Level lvlRank = stt.getLvlRank(); // Construct fields and indices arrays from parameters. - ValueRange tmp = args.drop_back(rank + 1); - SmallVector fields(tmp.begin(), tmp.end()); + SmallVector fields = llvm::to_vector(args.drop_back(lvlRank + 1)); MutSparseTensorDescriptor desc(rtp, fields); - tmp = args.take_back(rank + 1).drop_back(); - SmallVector indices(tmp.begin(), tmp.end()); + const SmallVector indices = + llvm::to_vector(args.take_back(lvlRank + 1).drop_back()); Value value = args.back(); Value pos = constantZero(builder, loc, builder.getIndexType()); - // Generate code for every dimension. - for (unsigned d = 0; d < rank; d++) { - if (isCompressedDim(rtp, d)) { + // Generate code for every level. + for (Level l = 0; l < lvlRank; l++) { + const auto dlt = stt.getLvlType(l); + if (isCompressedDLT(dlt)) { // Create: // if (!present) { - // indices[d].push_back(i[d]) - // + // indices[l].push_back(i[l]) + // // } - // pos[d] = indices.size() - 1 - // - pos = genCompressed(builder, loc, desc, indices, value, pos, d); - } else if (isSingletonDim(rtp, d)) { + // pos[l] = indices.size() - 1 + // + pos = genCompressed(builder, loc, desc, indices, value, pos, l); + } else if (isSingletonDLT(dlt)) { // Create: - // indices[d].push_back(i[d]) - // pos[d] = pos[d-1] - // - createPushback(builder, loc, desc, SparseTensorFieldKind::IdxMemRef, d, - indices[d]); + // indices[l].push_back(i[l]) + // pos[l] = pos[l-1] + // + createPushback(builder, loc, desc, SparseTensorFieldKind::IdxMemRef, l, + indices[l]); } else { - assert(isDenseDim(rtp, d)); + assert(isDenseDLT(dlt)); // Construct the new position as: - // pos[d] = size * pos[d-1] + i[d] - // - Value size = sizeFromTensorAtLvl(builder, loc, desc, d); + // pos[l] = size * pos[l-1] + i[l] + // + Value size = sizeFromTensorAtLvl(builder, loc, desc, l); Value mult = builder.create(loc, size, pos); - pos = builder.create(loc, mult, indices[d]); + pos = builder.create(loc, mult, indices[l]); } } // Reached the actual value append/insert. - if (!isDenseDim(rtp, rank - 1)) + if (!stt.isDenseLvl(lvlRank - 1)) createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef, std::nullopt, value); else @@ -476,26 +479,24 @@ // The mangled name of the function has this format: // ____ // __ - RankedTensorType rtp = desc.getTensorType(); + const SparseTensorType stt(desc.getRankedTensorType()); SmallString<32> nameBuffer; llvm::raw_svector_ostream nameOstream(nameBuffer); nameOstream << namePrefix; - unsigned rank = rtp.getShape().size(); - assert(rank == indices.size()); - for (unsigned d = 0; d < rank; d++) { - nameOstream << toMLIRString(getDimLevelType(rtp, d)) << "_"; - } + assert(static_cast(stt.getLvlRank()) == indices.size()); + const Level lvlRank = stt.getLvlRank(); + for (Level l = 0; l < lvlRank; l++) + nameOstream << toMLIRString(stt.getLvlType(l)) << "_"; // Static dim sizes are used in the generated code while dynamic sizes are // loaded from the dimSizes buffer. This is the reason for adding the shape // to the function name. - for (auto d : rtp.getShape()) - nameOstream << d << "_"; - SparseTensorEncodingAttr enc = getSparseTensorEncoding(rtp); + for (const auto sh : stt.getDimShape()) + nameOstream << sh << "_"; // Permutation information is also used in generating insertion. - if (enc.getDimOrdering() && !enc.getDimOrdering().isIdentity()) - nameOstream << enc.getDimOrdering() << "_"; - nameOstream << rtp.getElementType() << "_"; - nameOstream << enc.getIndexBitWidth() << "_" << enc.getPointerBitWidth(); + if (!stt.isIdentity()) + nameOstream << stt.getDimToLvlMap() << "_"; + nameOstream << stt.getElementType() << "_"; + nameOstream << stt.getIndexBitWidth() << "_" << stt.getPointerBitWidth(); // Look up the function. ModuleOp module = insertPoint->getParentOfType(); @@ -504,8 +505,8 @@ auto func = module.lookupSymbol(result.getAttr()); // Construct parameters for fields and indices. - SmallVector operands(desc.getFields().begin(), desc.getFields().end()); - operands.append(indices.begin(), indices.end()); + SmallVector operands = llvm::to_vector(desc.getFields()); + operands.append(indices); operands.push_back(value); Location loc = insertPoint.getLoc(); @@ -519,7 +520,7 @@ FunctionType::get(context, ValueRange(operands).getTypes(), ValueRange(desc.getFields()).getTypes())); func.setPrivate(); - createFunc(builder, module, func, rtp); + createFunc(builder, module, func, stt); } // Generate a call to perform the insertion and update `fields` with values @@ -533,20 +534,21 @@ /// Generations insertion finalization code. static void genEndInsert(OpBuilder &builder, Location loc, SparseTensorDescriptor desc) { - RankedTensorType rtp = desc.getTensorType(); - unsigned rank = rtp.getShape().size(); - for (unsigned d = 0; d < rank; d++) { - if (isCompressedDim(rtp, d)) { + const SparseTensorType stt(desc.getRankedTensorType()); + const Level lvlRank = stt.getLvlRank(); + for (Level l = 0; l < lvlRank; l++) { + const auto dlt = stt.getLvlType(l); + if (isCompressedDLT(dlt)) { // Compressed dimensions need a pointer cleanup for all entries // that were not visited during the insertion pass. // // TODO: avoid cleanup and keep compressed scheme consistent at all // times? // - if (d > 0) { - Type ptrType = getSparseTensorEncoding(rtp).getPointerType(); - Value ptrMemRef = desc.getPtrMemRef(d); - Value hi = desc.getPtrMemSize(builder, loc, d); + if (l > 0) { + Type ptrType = stt.getEncoding().getPointerType(); + Value ptrMemRef = desc.getPtrMemRef(l); + Value hi = desc.getPtrMemSize(builder, loc, l); Value zero = constantIndex(builder, loc, 0); Value one = constantIndex(builder, loc, 1); // Vector of only one, but needed by createFor's prototype. @@ -570,7 +572,7 @@ builder.setInsertionPointAfter(loop); } } else { - assert(isDenseDim(rtp, d) || isSingletonDim(rtp, d)); + assert(isDenseDLT(dlt) || isSingletonDLT(dlt)); } } } @@ -704,18 +706,25 @@ LogicalResult matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - RankedTensorType resType = op.getType(); - auto enc = getSparseTensorEncoding(resType); - if (!enc) + const auto resType = getSparseTensorType(op); + if (!resType.hasEncoding()) return failure(); if (op.getCopy()) return rewriter.notifyMatchFailure(op, "tensor copy not implemented"); // Construct allocation for each field. - Location loc = op.getLoc(); - Value sizeHint = op.getSizeHint(); + const Location loc = op.getLoc(); + const Value sizeHint = op.getSizeHint(); + const ValueRange dynSizes = adaptor.getDynamicSizes(); + const size_t found = dynSizes.size(); + const int64_t expected = resType.getNumDynamicDims(); + if (found != static_cast(expected)) + return rewriter.notifyMatchFailure( + op, llvm::formatv( + "Got wrong number of dynamic sizes: Found={0}, Expected={1}", + found, expected)); SmallVector fields; - createAllocFields(rewriter, loc, resType, adaptor.getOperands(), + createAllocFields(rewriter, loc, resType, dynSizes, enableBufferInitialization, fields, sizeHint); // Replace operation with resulting memrefs. rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields)); @@ -779,7 +788,7 @@ return failure(); Location loc = op->getLoc(); auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); - auto srcType = getRankedTensorType(op.getTensor()); + const auto srcType = getSparseTensorType(op.getTensor()); Type eltType = srcType.getElementType(); Type boolType = rewriter.getIntegerType(1); Type idxType = rewriter.getIndexType(); @@ -788,11 +797,12 @@ // Determine the size for access expansion (always the innermost stored // dimension size, translated back to original dimension). Note that we // recursively rewrite the new DimOp on the **original** tensor. - unsigned innerDim = toOrigDim(srcType, srcType.getRank() - 1); - auto sz = sizeFromTensorAtDim(rewriter, loc, desc, innerDim); + // FIXME: `toOrigDim` is deprecated. + const Dimension innerDim = toOrigDim(srcType, srcType.getLvlRank() - 1); + const auto sz = sizeFromTensorAtDim(rewriter, loc, desc, innerDim); // Generate a memref for `sz` elements of type `t`. - auto genAlloc = [&](Type t) { - auto memTp = MemRefType::get({ShapedType::kDynamic}, t); + const auto genAlloc = [&](Type t) { + const auto memTp = MemRefType::get({ShapedType::kDynamic}, t); return rewriter.create(loc, memTp, ValueRange{sz}); }; // Allocate temporary buffers for values/filled-switch and added. @@ -834,14 +844,13 @@ Value filled = adaptor.getFilled(); Value added = adaptor.getAdded(); Value count = adaptor.getCount(); - RankedTensorType dstType = desc.getTensorType(); + const SparseTensorType dstType(desc.getRankedTensorType()); Type eltType = dstType.getElementType(); // Prepare indices. SmallVector indices(adaptor.getIndices()); - // If the innermost dimension is ordered, we need to sort the indices + // If the innermost level is ordered, we need to sort the indices // in the "added" array prior to applying the compression. - unsigned rank = dstType.getShape().size(); - if (isOrderedDim(dstType, rank - 1)) + if (dstType.isOrderedLvl(dstType.getLvlRank() - 1)) rewriter.create(loc, count, ValueRange{added}, ValueRange{}, SparseTensorSortKind::HybridQuickSort); // While performing the insertions, we also need to reset the elements @@ -1065,7 +1074,7 @@ matchAndRewrite(PackOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto rtp = op.getResult().getType().cast(); + const auto rtp = getRankedTensorType(op.getResult()); assert(isUniqueCOOType(rtp)); SmallVector fields; @@ -1074,8 +1083,8 @@ foreachFieldAndTypeInSparseTensor( rtp, [&rewriter, &fields, &op, rtp, - loc](Type fType, unsigned fIdx, SparseTensorFieldKind fKind, - unsigned /*dim*/, DimLevelType /*dlt*/) -> bool { + loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind, + Level /*lvl*/, DimLevelType /*dlt*/) -> bool { assert(fields.size() == fIdx); auto enc = getSparseTensorEncoding(rtp); Value field; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -25,6 +25,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/Enums.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" @@ -84,21 +85,20 @@ /// Looks up a level-size by returning a statically-computed constant /// (when possible), or by calling `genLvlSizeCall` (when dynamic). static Value createOrFoldLvlCall(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, ShapedType stp, - Value tensor, unsigned lvl) { + SparseTensorType stt, Value tensor, + Level lvl) { // Only sparse tensors have "levels" to query. - assert(enc); - auto dimOrder = enc.getDimOrdering(); + assert(stt.hasEncoding()); // TODO: The following implementation only handles permutations; // we'll need to generalize this to handle arbitrary AffineExpr. // // There's no need to assert `isPermutation` here: because // `getDimPosition` checks that the expr isa `AffineDimExpr`, // which is all we care about (for supporting permutations). - unsigned dim = dimOrder ? dimOrder.getDimPosition(lvl) : lvl; - auto s = stp.getShape()[dim]; - if (s != ShapedType::kDynamic) - return constantIndex(builder, loc, s); + const Dimension dim = + stt.isIdentity() ? lvl : stt.getDimToLvlMap().getDimPosition(lvl); + if (const auto sz = stt.getStaticDimSize(dim)) + return constantIndex(builder, loc, *sz); // If we cannot statically compute the size from the shape, then we // must dynamically query it. (In principle we could also dynamically // compute it, but since we already did so to construct the `tensor` @@ -111,89 +111,82 @@ /// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes /// of dense tensors). static Value createOrFoldDimCall(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, ShapedType stp, - Value tensor, unsigned dim) { - auto s = stp.getShape()[dim]; - if (s != ShapedType::kDynamic) - return constantIndex(builder, loc, s); - if (enc) + SparseTensorType stt, Value tensor, + Dimension dim) { + if (const auto sz = stt.getStaticDimSize(dim)) + return constantIndex(builder, loc, *sz); + if (stt.hasEncoding()) return genDimSizeCall(builder, loc, tensor, dim); return linalg::createOrFoldDimOp(builder, loc, tensor, dim); } /// Populates the array with the dimension-sizes of the given tensor. -static void fillDimSizes(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, ShapedType stp, +static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl &out) { - unsigned dimRank = stp.getRank(); + const Dimension dimRank = stt.getDimRank(); + out.clear(); out.reserve(dimRank); - for (unsigned d = 0; d < dimRank; d++) - out.push_back(createOrFoldDimCall(builder, loc, enc, stp, tensor, d)); + for (Dimension d = 0; d < dimRank; d++) + out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d)); } /// Returns an array with the dimension-sizes of the given tensor. static SmallVector getDimSizes(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, - ShapedType stp, Value tensor) { + SparseTensorType stt, Value tensor) { SmallVector out; - fillDimSizes(builder, loc, enc, stp, tensor, out); + fillDimSizes(builder, loc, stt, tensor, out); return out; } -/// Populates the array with the dimension-shape of the given `ShapedType`, -/// where dynamic sizes are represented by zero. -static void fillDimShape(OpBuilder &builder, Location loc, ShapedType stp, +/// Populates the array with the dimension-shape of the given +/// `SparseTensorType`, where dynamic sizes are represented by zero. +static void fillDimShape(OpBuilder &builder, Location loc, SparseTensorType stt, SmallVectorImpl &out) { - auto shape = stp.getShape(); - unsigned dimRank = stp.getRank(); - out.reserve(dimRank); - for (unsigned d = 0; d < dimRank; d++) { - auto s = shape[d] == ShapedType::kDynamic ? 0 : shape[d]; + out.clear(); + out.reserve(stt.getDimRank()); + for (const DynSize sh : stt.getDimShape()) { + const auto s = ShapedType::isDynamic(sh) ? 0 : sh; out.push_back(constantIndex(builder, loc, s)); } } -/// Returns an array with the dimension-shape of the given `ShapedType`, +/// Returns an array with the dimension-shape of the given `SparseTensorType`, /// where dynamic sizes are represented by zero. static SmallVector getDimShape(OpBuilder &builder, Location loc, - ShapedType stp) { + SparseTensorType stt) { SmallVector out; - fillDimShape(builder, loc, stp, out); + fillDimShape(builder, loc, stt, out); return out; } /// Populates the given sizes array for concatenation from type (for static /// sizes) and from an already-converted opaque pointer source (for dynamic /// sizes). -static void concatSizesFromInputs(OpBuilder &builder, - SmallVectorImpl &sizes, Location loc, - ShapedType dstTp, ValueRange srcs, - unsigned dim) { - auto dstShape = dstTp.getShape(); +static void concatDimSizesFromInputs(OpBuilder &builder, Location loc, + SparseTensorType dstTp, ValueRange srcs, + Dimension dim, + SmallVectorImpl &dimSizes) { + assert(dim < dstTp.getDimRank() && "Dimension is out of bounds"); + dimSizes.clear(); - auto srcTp = srcs[0].getType().cast(); - auto srcEnc = getSparseTensorEncoding(srcTp); // We first fills the sizes from an input tensor, and then // compute the size of the concatenation dimension if necessary. - if (srcEnc) + const auto srcTp = getSparseTensorType(srcs[0]); + if (srcTp.hasEncoding()) // Reuses sizes from an arbitrary input tensor is fine. - fillDimSizes(builder, loc, srcEnc, srcTp, srcs[0], sizes); + fillDimSizes(builder, loc, srcTp, srcs[0], dimSizes); else - sizesFromSrc(builder, sizes, loc, srcs[0]); + sizesFromSrc(builder, dimSizes, loc, srcs[0]); - // Sum up on the `dim` if the dimension is dynamic. - if (dstShape[dim] != ShapedType::kDynamic) { + if (const auto sz = dstTp.getStaticDimSize(dim)) { // Faithfully take the static size. - sizes[dim] = constantIndex(builder, loc, dstShape[dim]); + dimSizes[dim] = constantIndex(builder, loc, *sz); } else { - // Else, compute the shape dynamically. - for (size_t i = 1, sz = srcs.size(); i < sz; i++) { - auto srcTp = srcs[i].getType().cast(); - auto encSrc = getSparseTensorEncoding(srcTp); - Value srcSz = - createOrFoldDimCall(builder, loc, encSrc, srcTp, srcs[i], dim); - // Sum up all the sizes. - sizes[dim] = builder.create(loc, sizes[dim], srcSz); + // Else, dynamically compute the size. + for (const auto src : srcs.drop_front()) { + const auto srcTp = getSparseTensorType(src); + Value srcSz = createOrFoldDimCall(builder, loc, srcTp, src, dim); + dimSizes[dim] = builder.create(loc, dimSizes[dim], srcSz); } } } @@ -209,11 +202,10 @@ /// Generates a temporary buffer for the level-types of the given encoding. static Value genLvlTypesBuffer(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc) { + SparseTensorType stt) { SmallVector lvlTypes; - auto dlts = enc.getDimLevelType(); - lvlTypes.reserve(dlts.size()); - for (auto dlt : dlts) + lvlTypes.reserve(stt.getLvlRank()); + for (const auto dlt : stt.getEncoding().getDimLevelType()) lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt)); return allocaBuffer(builder, loc, lvlTypes); } @@ -235,8 +227,7 @@ /// MLIR buffers as needed, and returning `this` for method chaining. /// This method does not set the action and pointer arguments, since /// those are handled by `genNewCall` instead. - NewCallParams &genBuffers(SparseTensorEncodingAttr enc, ValueRange sizes, - ShapedType stp); + NewCallParams &genBuffers(SparseTensorType stt, ValueRange dimSizes); /// (Re)sets the C++ template type parameters, and returns `this` /// for method chaining. This is already done as part of `genBuffers`, @@ -246,12 +237,12 @@ // // TODO: This is only ever used by sparse2sparse-viaCOO `ConvertOp`; // is there a better way to handle that than this one-off setter method? - NewCallParams &setTemplateTypes(SparseTensorEncodingAttr enc, - ShapedType stp) { + NewCallParams &setTemplateTypes(SparseTensorType stt) { + const auto enc = stt.getEncoding(); params[kParamPtrTp] = constantPointerTypeEncoding(builder, loc, enc); params[kParamIndTp] = constantIndexTypeEncoding(builder, loc, enc); params[kParamValTp] = - constantPrimaryTypeEncoding(builder, loc, stp.getElementType()); + constantPrimaryTypeEncoding(builder, loc, stt.getElementType()); return *this; } @@ -308,15 +299,16 @@ // TODO: see the note at `_mlir_ciface_newSparseTensor` about how // the meaning of the various arguments (e.g., "sizes" vs "shapes") // is inconsistent between the different actions. -NewCallParams &NewCallParams::genBuffers(SparseTensorEncodingAttr enc, - ValueRange dimSizes, ShapedType stp) { - const unsigned lvlRank = enc.getDimLevelType().size(); - const unsigned dimRank = stp.getRank(); +NewCallParams &NewCallParams::genBuffers(SparseTensorType stt, + ValueRange dimSizes) { + const Level lvlRank = stt.getLvlRank(); + const Dimension dimRank = stt.getDimRank(); // Sparsity annotations. - params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, enc); + params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt); // Dimension-sizes array of the enveloping tensor. Useful for either // verification of external data, or for construction of internal data. - assert(dimSizes.size() == dimRank && "Dimension-rank mismatch"); + assert(dimSizes.size() == static_cast(dimRank) && + "Dimension-rank mismatch"); params[kParamDimSizes] = allocaBuffer(builder, loc, dimSizes); // The level-sizes array must be passed as well, since for arbitrary // dim2lvl mappings it cannot be trivially reconstructed at runtime. @@ -329,29 +321,31 @@ // `operator[]` assignment. We preinitialize `lvl2dim` for code symmetry. SmallVector dim2lvl(dimRank); SmallVector lvl2dim(lvlRank); - auto dimOrder = enc.getDimOrdering(); - if (dimOrder) { + if (!stt.isIdentity()) { + const auto dimOrder = stt.getDimToLvlMap(); assert(dimOrder.isPermutation()); - for (unsigned l = 0; l < lvlRank; l++) { + for (Level l = 0; l < lvlRank; l++) { // The `d`th source variable occurs in the `l`th result position. - uint64_t d = dimOrder.getDimPosition(l); + const Dimension d = dimOrder.getDimPosition(l); dim2lvl[d] = constantIndex(builder, loc, l); lvl2dim[l] = constantIndex(builder, loc, d); lvlSizes[l] = dimSizes[d]; } } else { - assert(dimRank == lvlRank && "Rank mismatch"); - for (unsigned i = 0; i < lvlRank; i++) { - dim2lvl[i] = lvl2dim[i] = constantIndex(builder, loc, i); - lvlSizes[i] = dimSizes[i]; + // The `SparseTensorType` ctor already ensures `dimRank == lvlRank` + // when `isIdentity`; so no need to re-assert it here. + for (Level l = 0; l < lvlRank; l++) { + dim2lvl[l] = lvl2dim[l] = constantIndex(builder, loc, l); + lvlSizes[l] = dimSizes[l]; } } params[kParamLvlSizes] = allocaBuffer(builder, loc, lvlSizes); params[kParamLvl2Dim] = allocaBuffer(builder, loc, lvl2dim); - params[kParamDim2Lvl] = - dimOrder ? allocaBuffer(builder, loc, dim2lvl) : params[kParamLvl2Dim]; + params[kParamDim2Lvl] = stt.isIdentity() + ? params[kParamLvl2Dim] + : allocaBuffer(builder, loc, dim2lvl); // Secondary and primary types encoding. - setTemplateTypes(enc, stp); + setTemplateTypes(stt); // Finally, make note that initialization is complete. assert(isInitialized() && "Initialization failed"); // And return `this` for method chaining. @@ -441,8 +435,7 @@ /// given target `dimTypes`. static bool canUseDirectConversion(ArrayRef dimTypes) { bool alreadyCompressed = false; - for (uint64_t rank = dimTypes.size(), r = 0; r < rank; r++) { - const DimLevelType dlt = dimTypes[r]; + for (const auto dlt : dimTypes) { if (isCompressedDLT(dlt)) { if (alreadyCompressed) return false; // Multiple compressed dimensions not yet supported. @@ -467,13 +460,14 @@ TensorType dstTp, TensorType srcTp, Value dstIdx, Value srcIdx, ArrayRef dstShape, ArrayRef srcShape) { - unsigned dstRank = dstTp.getRank(); - unsigned srcRank = srcTp.getRank(); + const Dimension dstRank = dstTp.getRank(); + const Dimension srcRank = srcTp.getRank(); SmallVector srcIndices; - for (unsigned i = 0; i < srcRank; i++) { + srcIndices.reserve(srcRank); + for (Dimension d = 0; d < srcRank; d++) { Value idx = rewriter.create( - loc, srcIdx, constantIndex(rewriter, loc, i)); + loc, srcIdx, constantIndex(rewriter, loc, d)); srcIndices.push_back(idx); } @@ -481,9 +475,9 @@ translateIndicesArray(rewriter, loc, reassociation, srcIndices, srcShape, dstShape, dstIndices); - for (unsigned i = 0; i < dstRank; i++) - rewriter.create(loc, dstIndices[i], dstIdx, - constantIndex(rewriter, loc, i)); + for (Dimension d = 0; d < dstRank; d++) + rewriter.create(loc, dstIndices[d], dstIdx, + constantIndex(rewriter, loc, d)); } /// Generate code for a general sparse to sparse reshaping operation. @@ -505,37 +499,34 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) { Location loc = op.getLoc(); - auto srcTp = getRankedTensorType(op.getSrc()); - auto dstTp = getRankedTensorType(op.getResult()); - auto encSrc = getSparseTensorEncoding(srcTp); - auto encDst = getSparseTensorEncoding(dstTp); - if (!encDst || !encSrc) + const auto srcTp = getSparseTensorType(op.getSrc()); + const auto dstTp = getSparseTensorType(op.getResult()); + if (!srcTp.hasEncoding() || !dstTp.hasEncoding()) return failure(); Type elemTp = srcTp.getElementType(); assert(elemTp == dstTp.getElementType() && "reshape should not change element type"); // Start an iterator over the source tensor (in original index order). - const auto noPerm = encSrc.withoutOrdering(); SmallVector srcDimSizes = - getDimSizes(rewriter, loc, encSrc, srcTp, adaptor.getSrc()); + getDimSizes(rewriter, loc, srcTp, adaptor.getSrc()); NewCallParams params(rewriter, loc); - Value iter = params.genBuffers(noPerm, srcDimSizes, srcTp) + Value iter = params.genBuffers(srcTp.withoutOrdering(), srcDimSizes) .genNewCall(Action::kToIterator, adaptor.getSrc()); // Start a new COO for the destination tensor. SmallVector dstDimSizes; - if (dstTp.hasStaticShape()) + if (dstTp.hasStaticDimShape()) // Static "shapes" are in fact "sizes". fillDimShape(rewriter, loc, dstTp, dstDimSizes); else genReshapeDstShape(loc, rewriter, dstDimSizes, srcDimSizes, - dstTp.getShape(), op.getReassociationIndices()); - Value coo = params.genBuffers(encDst, dstDimSizes, dstTp) - .genNewCall(Action::kEmptyCOO); + dstTp.getDimShape(), op.getReassociationIndices()); + Value coo = + params.genBuffers(dstTp, dstDimSizes).genNewCall(Action::kEmptyCOO); Value dstPerm = params.getDim2LvlMap(); // Construct a while loop over the iterator. Type iTp = rewriter.getIndexType(); - Value srcIdx = genAlloca(rewriter, loc, srcTp.getRank(), iTp); - Value dstIdx = genAlloca(rewriter, loc, dstTp.getRank(), iTp); + Value srcIdx = genAlloca(rewriter, loc, srcTp.getDimRank(), iTp); + Value dstIdx = genAlloca(rewriter, loc, dstTp.getDimRank(), iTp); Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); SmallVector noArgs; SmallVector noTypes; @@ -571,23 +562,22 @@ // TODO: rename to `genSparseIterationLoop`? static void genSparseCOOIterationLoop( ConversionPatternRewriter &rewriter, Location loc, Value t, - RankedTensorType tensorTp, + SparseTensorType stt, function_ref bodyBuilder) { - auto enc = getSparseTensorEncoding(tensorTp); - assert(enc && "Generating Sparse Tensor COO Loop on a Dense Tensor!"); - - unsigned rank = tensorTp.getRank(); - Type elemTp = tensorTp.getElementType(); + assert(stt.hasEncoding() && + "Generating Sparse Tensor COO Loop on a Dense Tensor!"); + const Dimension dimRank = stt.getDimRank(); + const Type elemTp = stt.getElementType(); // Start an iterator over the tensor (in original index order). - const auto noPerm = enc.withoutOrdering(); - SmallVector dimSizes = getDimSizes(rewriter, loc, noPerm, tensorTp, t); + const auto noPerm = stt.withoutOrdering(); + SmallVector dimSizes = getDimSizes(rewriter, loc, noPerm, t); Value iter = NewCallParams(rewriter, loc) - .genBuffers(noPerm, dimSizes, tensorTp) + .genBuffers(noPerm, dimSizes) .genNewCall(Action::kToIterator, t); // Construct a while loop over the iterator. - Value srcIdx = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); + Value srcIdx = genAlloca(rewriter, loc, dimRank, rewriter.getIndexType()); Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); SmallVector noArgs; SmallVector noTypes; @@ -599,8 +589,8 @@ Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); rewriter.setInsertionPointToStart(after); - bool hasDenseDim = llvm::any_of( - enc.getDimLevelType(), [](DimLevelType dlt) { return isDenseDLT(dlt); }); + const bool hasDenseDim = + llvm::any_of(stt.getEncoding().getDimLevelType(), isDenseDLT); if (hasDenseDim) { Value elemV = rewriter.create(loc, elemPtr); Value isZero = genIsNonzero(rewriter, loc, elemV); @@ -633,12 +623,12 @@ // reduce code repetition! static void genDenseTensorIterationLoop( ConversionPatternRewriter &rewriter, Location loc, Value t, - RankedTensorType tensorTp, + SparseTensorType stt, function_ref bodyBuilder) { - assert(!getSparseTensorEncoding(tensorTp) && + assert(!stt.hasEncoding() && "Generating Dense Tensor Loop on a Sparse Tensor!"); - unsigned rank = tensorTp.getRank(); + const Dimension dimRank = stt.getDimRank(); Value zero = constantIndex(rewriter, loc, 0); Value one = constantIndex(rewriter, loc, 1); @@ -647,9 +637,9 @@ SmallVector st; // Fill out loop iteration information. - for (unsigned i = 0; i < rank; i++) { + for (Dimension d = 0; d < dimRank; d++) { lo.push_back(zero); - hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, t, i)); + hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, t, d)); st.push_back(one); } @@ -686,10 +676,9 @@ LogicalResult matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto stp = op.getSource().getType().cast(); + const auto stt = getSparseTensorType(op.getSource()); // Only rewrite sparse DimOp. - auto enc = getSparseTensorEncoding(stp); - if (!enc) + if (!stt.hasEncoding()) return failure(); // Only rewrite DimOp with constant index. std::optional dim = op.getConstantIndex(); @@ -698,7 +687,7 @@ // Generate the call. Value src = adaptor.getOperands()[0]; rewriter.replaceOp( - op, createOrFoldDimCall(rewriter, op->getLoc(), enc, stp, src, *dim)); + op, createOrFoldDimCall(rewriter, op->getLoc(), stt, src, *dim)); return success(); } }; @@ -741,21 +730,19 @@ matchAndRewrite(NewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto stp = op.getType().cast(); - auto enc = getSparseTensorEncoding(stp); - if (!enc) + const auto stt = getSparseTensorType(op); + if (!stt.hasEncoding()) return failure(); - const unsigned dimRank = stp.getRank(); - const unsigned lvlRank = enc.getDimLevelType().size(); + const Dimension dimRank = stt.getDimRank(); + const Level lvlRank = stt.getLvlRank(); // Construct the dimShape. - const auto dimShape = stp.getShape(); - SmallVector dimShapeValues = getDimShape(rewriter, loc, stp); + SmallVector dimShapeValues = getDimShape(rewriter, loc, stt); Value dimShapeBuffer = allocaBuffer(rewriter, loc, dimShapeValues); // Allocate `SparseTensorReader` and perform all initial setup that // does not depend on lvlSizes (nor dim2lvl, lvl2dim, etc). Type opaqueTp = getOpaquePointerType(rewriter); Value valTp = - constantPrimaryTypeEncoding(rewriter, loc, stp.getElementType()); + constantPrimaryTypeEncoding(rewriter, loc, stt.getElementType()); Value reader = createFuncCall(rewriter, loc, "createCheckedSparseTensorReader", opaqueTp, @@ -773,7 +760,7 @@ // // FIXME: reduce redundancy vs `NewCallParams::genBuffers`. Value dimSizesBuffer; - if (!stp.hasStaticShape()) { + if (stt.hasDynamicDimShape()) { Type indexTp = rewriter.getIndexType(); auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp); dimSizesBuffer = @@ -784,22 +771,23 @@ Value lvlSizesBuffer; Value lvl2dimBuffer; Value dim2lvlBuffer; - if (auto dimOrder = enc.getDimOrdering()) { + if (!stt.isIdentity()) { + const auto dimOrder = stt.getDimToLvlMap(); assert(dimOrder.isPermutation() && "Got non-permutation"); // We preinitialize `dim2lvlValues` since we need random-access writing. // And we preinitialize the others for stylistic consistency. SmallVector lvlSizeValues(lvlRank); SmallVector lvl2dimValues(lvlRank); SmallVector dim2lvlValues(dimRank); - for (unsigned l = 0; l < lvlRank; l++) { + for (Level l = 0; l < lvlRank; l++) { // The `d`th source variable occurs in the `l`th result position. - uint64_t d = dimOrder.getDimPosition(l); + Dimension d = dimOrder.getDimPosition(l); Value lvl = constantIndex(rewriter, loc, l); Value dim = constantIndex(rewriter, loc, d); dim2lvlValues[d] = lvl; lvl2dimValues[l] = dim; lvlSizeValues[l] = - (dimShape[d] == ShapedType::kDynamic) + stt.isDynamicDim(d) ? rewriter.create(loc, dimSizesBuffer, dim) : dimShapeValues[d]; } @@ -807,11 +795,12 @@ lvl2dimBuffer = allocaBuffer(rewriter, loc, lvl2dimValues); dim2lvlBuffer = allocaBuffer(rewriter, loc, dim2lvlValues); } else { - assert(dimRank == lvlRank && "Rank mismatch"); + // The `SparseTensorType` ctor already ensures `dimRank == lvlRank` + // when `isIdentity`; so no need to re-assert it here. SmallVector iotaValues; iotaValues.reserve(lvlRank); - for (unsigned i = 0; i < lvlRank; i++) - iotaValues.push_back(constantIndex(rewriter, loc, i)); + for (Level l = 0; l < lvlRank; l++) + iotaValues.push_back(constantIndex(rewriter, loc, l)); lvlSizesBuffer = dimSizesBuffer ? dimSizesBuffer : dimShapeBuffer; dim2lvlBuffer = lvl2dimBuffer = allocaBuffer(rewriter, loc, iotaValues); } @@ -819,11 +808,11 @@ SmallVector params{ reader, lvlSizesBuffer, - genLvlTypesBuffer(rewriter, loc, enc), + genLvlTypesBuffer(rewriter, loc, stt), lvl2dimBuffer, dim2lvlBuffer, - constantPointerTypeEncoding(rewriter, loc, enc), - constantIndexTypeEncoding(rewriter, loc, enc), + constantPointerTypeEncoding(rewriter, loc, stt.getEncoding()), + constantIndexTypeEncoding(rewriter, loc, stt.getEncoding()), valTp}; Value tensor = createFuncCall(rewriter, loc, "newSparseTensorFromReader", opaqueTp, params, EmitCInterface::On) @@ -848,27 +837,25 @@ return rewriter.notifyMatchFailure(op, "sparse tensor copy not implemented"); Location loc = op.getLoc(); - RankedTensorType resType = op.getType(); - auto enc = getSparseTensorEncoding(resType); - if (!enc) + const auto stt = getSparseTensorType(op); + if (!stt.hasEncoding()) return failure(); // Gather all dimension sizes as SSA values. - SmallVector sizes; - unsigned int operandCtr = 0; - for (int64_t i = 0; i < resType.getRank(); ++i) { - if (resType.isDynamicDim(i)) { - sizes.push_back(adaptor.getOperands()[operandCtr++]); - } else { - sizes.push_back( - rewriter.create(loc, op.getStaticSize(i))); - } + const Dimension dimRank = stt.getDimRank(); + SmallVector dimSizes; + dimSizes.reserve(dimRank); + unsigned operandCtr = 0; + for (Dimension d = 0; d < dimRank; ++d) { + dimSizes.push_back( + stt.isDynamicDim(d) + ? adaptor.getOperands()[operandCtr++] + : constantIndex(rewriter, loc, op.getStaticSize(d))); } // Generate the call to construct empty tensor. The sizes are // explicitly defined by the arguments to the alloc operator. - rewriter.replaceOp(op, - NewCallParams(rewriter, loc) - .genBuffers(enc, sizes, resType.cast()) - .genNewCall(Action::kEmpty)); + rewriter.replaceOp(op, NewCallParams(rewriter, loc) + .genBuffers(stt, dimSizes) + .genNewCall(Action::kEmpty)); return success(); } }; @@ -887,27 +874,30 @@ LogicalResult matchAndRewrite(ConvertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - auto resType = getRankedTensorType(op); - auto srcType = getRankedTensorType(op.getSource()); - auto encDst = getSparseTensorEncoding(resType); - auto encSrc = getSparseTensorEncoding(srcType); - Value src = adaptor.getOperands()[0]; - if (encDst && encSrc) { + const Location loc = op->getLoc(); + const auto srcTp = getSparseTensorType(op.getSource()); + const auto dstTp = getSparseTensorType(op); + if (!srcTp.hasEncoding() && !dstTp.hasEncoding()) + return failure(); + + const Dimension dimRank = srcTp.getDimRank(); + const Type elemTp = srcTp.getElementType(); + const Value src = adaptor.getOperands()[0]; + if (srcTp.hasEncoding() && dstTp.hasEncoding()) { + const auto srcEnc = srcTp.getEncoding(); + const auto dstEnc = dstTp.getEncoding(); // This is a sparse => sparse conversion, which is handled as follows: // t = src->toCOO(); ; src to COO in dst order // dst = newSparseTensor(t) // Using the coordinate scheme as an intermediate does not always // yield the fastest conversion but avoids the need for a full // O(N^2) conversion matrix. - if (encDst == encSrc) { + if (dstEnc == srcEnc) { rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast return success(); } NewCallParams params(rewriter, loc); - ShapedType stp = srcType.cast(); - SmallVector dimSizes = - getDimSizes(rewriter, loc, encSrc, stp, src); + SmallVector dimSizes = getDimSizes(rewriter, loc, srcTp, src); bool useDirectConversion; switch (options.sparseToSparseStrategy) { case SparseToSparseConversionStrategy::kViaCOO: @@ -915,37 +905,39 @@ break; case SparseToSparseConversionStrategy::kDirect: useDirectConversion = true; - assert(canUseDirectConversion(encDst.getDimLevelType()) && + assert(canUseDirectConversion(dstEnc.getDimLevelType()) && "Unsupported target for direct sparse-to-sparse conversion"); break; case SparseToSparseConversionStrategy::kAuto: - useDirectConversion = canUseDirectConversion(encDst.getDimLevelType()); + useDirectConversion = canUseDirectConversion(dstEnc.getDimLevelType()); break; } if (useDirectConversion) { - rewriter.replaceOp(op, params.genBuffers(encDst, dimSizes, stp) - .genNewCall(Action::kSparseToSparse, src)); + rewriter.replaceOp( + op, params.genBuffers(srcTp.withEncoding(dstEnc), dimSizes) + .genNewCall(Action::kSparseToSparse, src)); } else { // use via-COO conversion. // Set up encoding with right mix of src and dst so that the two // method calls can share most parameters, while still providing // the correct sparsity information to either of them. - auto enc = SparseTensorEncodingAttr::get( - op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(), - encDst.getHigherOrdering(), encSrc.getPointerBitWidth(), - encSrc.getIndexBitWidth()); + const auto mixedEnc = SparseTensorEncodingAttr::get( + op->getContext(), dstEnc.getDimLevelType(), dstEnc.getDimOrdering(), + dstEnc.getHigherOrdering(), srcEnc.getPointerBitWidth(), + srcEnc.getIndexBitWidth()); // TODO: This is the only place where `kToCOO` (or `kToIterator`) // is called with a non-identity permutation. Is there any clean // way to push the permutation over to the `kFromCOO` side instead? - Value coo = params.genBuffers(enc, dimSizes, stp) + Value coo = params.genBuffers(srcTp.withEncoding(mixedEnc), dimSizes) .genNewCall(Action::kToCOO, src); - Value dst = params.setTemplateTypes(encDst, stp) + Value dst = params.setTemplateTypes(srcTp.withEncoding(dstEnc)) .genNewCall(Action::kFromCOO, coo); - genDelCOOCall(rewriter, loc, stp.getElementType(), coo); + genDelCOOCall(rewriter, loc, elemTp, coo); rewriter.replaceOp(op, dst); } return success(); } - if (!encDst && encSrc) { + if (srcTp.hasEncoding() && !dstTp.hasEncoding()) { + const auto srcEnc = srcTp.getEncoding(); // This is sparse => dense conversion, which is handled as follows: // dst = new Tensor(0); // iter = new SparseTensorIterator(src); @@ -953,26 +945,24 @@ // dst[elem.indices] = elem.value; // } // delete iter; - const unsigned rank = resType.getRank(); - const Type elemTp = resType.getElementType(); + // // Fabricate a no-permutation encoding for NewCallParams // The pointer/index types must be those of `src`. // The dimLevelTypes aren't actually used by Action::kToIterator. - encDst = SparseTensorEncodingAttr::get( + const auto dstEnc = SparseTensorEncodingAttr::get( op->getContext(), - SmallVector(rank, DimLevelType::Dense), AffineMap(), - AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); - SmallVector dimSizes = - getDimSizes(rewriter, loc, encSrc, srcType, src); + SmallVector(dimRank, DimLevelType::Dense), AffineMap(), + AffineMap(), srcEnc.getPointerBitWidth(), srcEnc.getIndexBitWidth()); + SmallVector dimSizes = getDimSizes(rewriter, loc, srcTp, src); Value iter = NewCallParams(rewriter, loc) - .genBuffers(encDst, dimSizes, resType) + .genBuffers(dstTp.withEncoding(dstEnc), dimSizes) .genNewCall(Action::kToIterator, src); - Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); + Value ind = genAlloca(rewriter, loc, dimRank, rewriter.getIndexType()); Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); Block *insertionBlock = rewriter.getInsertionBlock(); // TODO: Dense buffers should be allocated/deallocated via the callback // in BufferizationOptions. - Value dst = allocDenseTensor(rewriter, loc, resType, dimSizes); + Value dst = allocDenseTensor(rewriter, loc, dstTp, dimSizes); SmallVector noArgs; SmallVector noTypes; auto whileOp = rewriter.create(loc, noTypes, noArgs); @@ -982,12 +972,13 @@ rewriter.create(loc, cond, before->getArguments()); Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); rewriter.setInsertionPointToStart(after); - SmallVector ivs = loadIndices(rewriter, loc, rank, ind); + SmallVector ivs = loadIndices(rewriter, loc, dimRank, ind); insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, ivs); rewriter.create(loc); rewriter.setInsertionPointAfter(whileOp); genDelIteratorCall(rewriter, loc, elemTp, iter); - rewriter.replaceOpWithNewOp(op, resType, dst); + rewriter.replaceOpWithNewOp( + op, dstTp.getRankedTensorType(), dst); // Deallocate the buffer. if (bufferization::allocationDoesNotEscape(op->getOpResult(0))) { rewriter.setInsertionPoint(insertionBlock->getTerminator()); @@ -995,10 +986,7 @@ } return success(); } - if (!encDst && !encSrc) { - // dense => dense - return failure(); - } + assert(!srcTp.hasEncoding() && dstTp.hasEncoding()); // This is a dense => sparse conversion or a sparse constant in COO => // sparse conversion, which is handled as follows: // t = newSparseCOO() @@ -1025,30 +1013,27 @@ // Also note that the code below only generates the "new" ops and // the loop-nest per se; whereas the entire body of the innermost // loop is generated by genAddElt(). - ShapedType stp = resType.cast(); - unsigned rank = stp.getRank(); - SmallVector sizes; - sizesFromSrc(rewriter, sizes, loc, src); + SmallVector dimSizes; + sizesFromSrc(rewriter, dimSizes, loc, src); NewCallParams params(rewriter, loc); Value coo = - params.genBuffers(encDst, sizes, stp).genNewCall(Action::kEmptyCOO); - Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); + params.genBuffers(dstTp, dimSizes).genNewCall(Action::kEmptyCOO); + Value ind = genAlloca(rewriter, loc, dimRank, rewriter.getIndexType()); Value perm = params.getDim2LvlMap(); - Type eltType = stp.getElementType(); - Value elemPtr = genAllocaScalar(rewriter, loc, eltType); + Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); genDenseTensorOrSparseConstantIterLoop( - rewriter, loc, src, rank, - [&](OpBuilder &builder, Location loc, Value val, ValueRange indices) { - for (unsigned i = 0; i < rank; i++) { - Value idx = constantIndex(builder, loc, i); - builder.create(loc, indices[i], ind, idx); + rewriter, loc, src, dimRank, + [&](OpBuilder &builder, Location loc, Value val, ValueRange ivs) { + for (Dimension d = 0; d < dimRank; d++) { + Value dim = constantIndex(builder, loc, d); + builder.create(loc, ivs[d], ind, dim); } builder.create(loc, val, elemPtr); - genAddEltCall(builder, loc, eltType, coo, elemPtr, ind, perm); + genAddEltCall(builder, loc, elemTp, coo, elemPtr, ind, perm); }); // Final call to construct sparse tensor storage. Value dst = params.genNewCall(Action::kFromCOO, coo); - genDelCOOCall(rewriter, loc, eltType, coo); + genDelCOOCall(rewriter, loc, elemTp, coo); rewriter.replaceOp(op, dst); return success(); } @@ -1066,8 +1051,7 @@ LogicalResult matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto enc = getSparseTensorEncoding(op.getTensor().getType()); - if (!enc) + if (!getSparseTensorType(op.getTensor()).hasEncoding()) return failure(); StringRef name = "delSparseTensor"; createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(), @@ -1190,14 +1174,14 @@ // index order. All values are passed by reference through stack // allocated memrefs. Location loc = op->getLoc(); - auto tp = getRankedTensorType(op.getTensor()); - auto elemTp = tp.getElementType(); - unsigned rank = tp.getRank(); - auto mref = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); + const auto stt = getSparseTensorType(op.getTensor()); + const auto elemTp = stt.getElementType(); + const Dimension dimRank = stt.getDimRank(); + auto mref = genAlloca(rewriter, loc, dimRank, rewriter.getIndexType()); auto vref = genAllocaScalar(rewriter, loc, elemTp); - for (unsigned i = 0; i < rank; i++) - rewriter.create(loc, adaptor.getIndices()[i], mref, - constantIndex(rewriter, loc, i)); + for (Dimension d = 0; d < dimRank; d++) + rewriter.create(loc, adaptor.getIndices()[d], mref, + constantIndex(rewriter, loc, d)); rewriter.create(loc, adaptor.getValue(), vref); SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)}; createFuncCall(rewriter, loc, name, {}, {adaptor.getTensor(), mref, vref}, @@ -1215,18 +1199,15 @@ matchAndRewrite(ExpandOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - auto srcType = getRankedTensorType(op.getTensor()); - Type eltType = srcType.getElementType(); + const auto srcTp = getSparseTensorType(op.getTensor()); + Type eltType = srcTp.getElementType(); Type boolType = rewriter.getIntegerType(1); Type idxType = rewriter.getIndexType(); // All initialization should be done on entry of the loop nest. rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp()); // Get the cardinality of valid coordinates for the innermost level. - auto srcEnc = getSparseTensorEncoding(srcType); - unsigned lvlRank = - srcEnc ? srcEnc.getDimLevelType().size() : srcType.getRank(); - Value sz = createOrFoldLvlCall(rewriter, loc, srcEnc, srcType, - adaptor.getTensor(), lvlRank - 1); + Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(), + srcTp.getLvlRank() - 1); // Allocate temporary buffers for values, filled-switch, and indices. // We do not use stack buffers for this, since the expanded size may // be rather large (as it envelops a single expanded dense dimension). @@ -1269,13 +1250,13 @@ Value added = adaptor.getAdded(); Value count = adaptor.getCount(); Value tensor = adaptor.getTensor(); - auto tp = getRankedTensorType(op.getTensor()); - Type elemTp = tp.getElementType(); - unsigned rank = tp.getRank(); - auto mref = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); - for (unsigned i = 0; i < rank - 1; i++) - rewriter.create(loc, adaptor.getIndices()[i], mref, - constantIndex(rewriter, loc, i)); + const auto stt = getSparseTensorType(op.getTensor()); + const Type elemTp = stt.getElementType(); + const Dimension dimRank = stt.getDimRank(); + auto mref = genAlloca(rewriter, loc, dimRank, rewriter.getIndexType()); + for (Dimension d = 0; d < dimRank - 1; d++) + rewriter.create(loc, adaptor.getIndices()[d], mref, + constantIndex(rewriter, loc, d)); SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)}; createFuncCall(rewriter, loc, name, {}, {tensor, mref, values, filled, added, count}, @@ -1323,34 +1304,33 @@ // a[ adjustForOffset(elem.indices) ] = elem.value // return a Location loc = op.getLoc(); - auto dstTp = getRankedTensorType(op); - auto encDst = getSparseTensorEncoding(dstTp); - Type elemTp = dstTp.getElementType(); - uint64_t concatDim = op.getDimension().getZExtValue(); - unsigned rank = dstTp.getRank(); + const auto dstTp = getSparseTensorType(op); + const auto dstEnc = dstTp.getEncoding(); + const Type elemTp = dstTp.getElementType(); + const Dimension concatDim = op.getDimension().getZExtValue(); + const Dimension dimRank = dstTp.getDimRank(); Value dst; // destination tensor Value dstPerm; // destination tensor permutation (if sparse out) // A pointer to the value being inserted (if dense => sparse) Value elemPtr; - // Memory that holds the COO for destination tensor (if sparse out) - Value dstIdx; + // Memory that holds the dim-indices for destination tensor (if sparse out) + Value dstInd; // The offset applied to the dimenstion to be concated (starting from 0) Value offset = constantIndex(rewriter, loc, 0); - SmallVector sizes; - NewCallParams params(rewriter, loc); - concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), - concatDim); + SmallVector dimSizes; + concatDimSizesFromInputs(rewriter, loc, dstTp, op.getInputs(), concatDim, + dimSizes); - bool allDense = false; + NewCallParams params(rewriter, loc); + const bool allDense = dstTp.hasEncoding() && dstTp.isAllDense(); Value dstTensor; - if (encDst) { - allDense = encDst.isAllDense(); + if (dstTp.hasEncoding()) { // Start a new COO or an initialized annotated all dense sparse tensor. - dst = params.genBuffers(encDst, sizes, dstTp) + dst = params.genBuffers(dstTp, dimSizes) .genNewCall(allDense ? Action::kEmpty : Action::kEmptyCOO); - dstIdx = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); + dstInd = genAlloca(rewriter, loc, dimRank, rewriter.getIndexType()); if (allDense) { dstTensor = dst; // Get the values buffer for the sparse tensor and reshape it to the @@ -1358,8 +1338,9 @@ dst = genValuesCall(rewriter, loc, MemRefType::get({ShapedType::kDynamic}, elemTp), {dst}); - // Use the dstIdx to store the level sizes. - dst = reshapeValuesToLevels(rewriter, loc, encDst, sizes, dst, dstIdx); + // Use the dstInd to store the level sizes. + dst = + reshapeValuesToLevels(rewriter, loc, dstEnc, dimSizes, dst, dstInd); } else { dstPerm = params.getDim2LvlMap(); elemPtr = genAllocaScalar(rewriter, loc, elemTp); @@ -1367,74 +1348,73 @@ } else { // TODO: Dense buffers should be allocated/deallocated via the callback // in BufferizationOptions. - dst = allocDenseTensor(rewriter, loc, dstTp, sizes); + dst = allocDenseTensor(rewriter, loc, dstTp, dimSizes); } - auto dimIdx2LvlIdx = [&](ValueRange dIdx) -> SmallVector { - SmallVector lIdx; - for (unsigned i = 0; i < dIdx.size(); i++) - lIdx.push_back(dIdx[toOrigDim(encDst, i)]); - return lIdx; + const Level lvlRank = dstTp.getLvlRank(); + const auto dimIvs2LvlIvs = [&](ValueRange dimIvs) -> SmallVector { + SmallVector lvlIvs; + lvlIvs.reserve(lvlRank); + for (Level l = 0; l < lvlRank; l++) + // FIXME: `toOrigDim` is deprecated + lvlIvs.push_back(dimIvs[toOrigDim(dstEnc, l)]); + return lvlIvs; }; - for (auto it : llvm::zip(op.getInputs(), adaptor.getInputs())) { + for (const auto &it : llvm::zip(op.getInputs(), adaptor.getInputs())) { Value orignalOp = std::get<0>(it); // Input (with encoding) from Op Value adaptedOp = std::get<1>(it); // Input (type converted) from adaptor - auto srcTp = getRankedTensorType(orignalOp); - auto encSrc = getSparseTensorEncoding(srcTp); - if (encSrc) { + const auto srcTp = getSparseTensorType(orignalOp); + if (srcTp.hasEncoding()) { genSparseCOOIterationLoop( rewriter, loc, adaptedOp, srcTp, [&](OpBuilder &builder, Location loc, Value idx, Value elemPtr) -> void { - SmallVector dimInd = - loadIndices(builder, loc, rank, idx, concatDim, offset); - if (encDst && !allDense) { + SmallVector dimIvs = + loadIndices(builder, loc, dimRank, idx, concatDim, offset); + if (dstTp.hasEncoding() && !allDense) { // Case: sparse => sparse, except for annotated all dense. - storeIndices(builder, loc, rank, dstIdx, dimInd); - genAddEltCall(builder, loc, elemTp, dst, elemPtr, dstIdx, + storeIndices(builder, loc, dimRank, dstInd, dimIvs); + genAddEltCall(builder, loc, elemTp, dst, elemPtr, dstInd, dstPerm); } else { // Case: sparse => dense, or annotated all dense. - SmallVector lvlInd; - if (allDense) - lvlInd = dimIdx2LvlIdx(dimInd); - else - lvlInd = dimInd; - insertScalarIntoDenseTensor(builder, loc, elemPtr, dst, lvlInd); + const auto lvlIvs = allDense ? dimIvs2LvlIvs(dimIvs) : dimIvs; + insertScalarIntoDenseTensor(builder, loc, elemPtr, dst, lvlIvs); } }); } else { genDenseTensorIterationLoop( rewriter, loc, adaptedOp, srcTp, - [&](OpBuilder &builder, Location loc, ValueRange idx) -> void { - if (encDst && !allDense) { + [&](OpBuilder &builder, Location loc, ValueRange dimIvs) -> void { + if (dstTp.hasEncoding() && !allDense) { // Case: dense => sparse, except for annotated all dense. - storeIndices(builder, loc, rank, dstIdx, idx, concatDim, + storeIndices(builder, loc, dimRank, dstInd, dimIvs, concatDim, offset); - Value val = genValueForDense(builder, loc, adaptedOp, idx); + Value val = genValueForDense(builder, loc, adaptedOp, dimIvs); builder.create(loc, val, elemPtr); - genAddEltCall(builder, loc, elemTp, dst, elemPtr, dstIdx, + genAddEltCall(builder, loc, elemTp, dst, elemPtr, dstInd, dstPerm); } else { // Case: dense => dense, or annotated all dense. - Value val = genValueForDense(builder, loc, adaptedOp, idx); - SmallVector lvlInd(idx); + Value val = genValueForDense(builder, loc, adaptedOp, dimIvs); + // Despite the name, this isn't actually level-ivs until + // after the `dimIvs2LvlIvs` call. + SmallVector lvlIvs(dimIvs); // Apply offset. - lvlInd[concatDim] = builder.create( - loc, lvlInd[concatDim], offset); + lvlIvs[concatDim] = builder.create( + loc, lvlIvs[concatDim], offset); if (allDense) - lvlInd = dimIdx2LvlIdx(lvlInd); - builder.create(loc, val, dst, lvlInd); + lvlIvs = dimIvs2LvlIvs(lvlIvs); + builder.create(loc, val, dst, lvlIvs); } }); } // Accumulate offset. // TODO: avoid calling sparseDimSize multiple times by caching the result! - Value curDim = createOrFoldDimCall(rewriter, loc, encSrc, srcTp, - adaptedOp, concatDim); - + Value curDim = + createOrFoldDimCall(rewriter, loc, srcTp, adaptedOp, concatDim); offset = rewriter.create(loc, offset, curDim); } - if (encDst) { + if (dstTp.hasEncoding()) { if (!allDense) { // In sparse output case, the destination holds the COO. Value coo = dst; @@ -1446,7 +1426,8 @@ } rewriter.replaceOp(op, dst); } else { - rewriter.replaceOpWithNewOp(op, dstTp, dst); + rewriter.replaceOpWithNewOp( + op, dstTp.getRankedTensorType(), dst); } return success(); } @@ -1459,30 +1440,25 @@ LogicalResult matchAndRewrite(OutOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - ShapedType srcType = op.getTensor().getType().cast(); + const Location loc = op->getLoc(); + const auto srcTp = getSparseTensorType(op.getTensor()); // Convert to default permuted COO. Value src = adaptor.getOperands()[0]; - auto encSrc = getSparseTensorEncoding(srcType); - SmallVector dimSizes = - getDimSizes(rewriter, loc, encSrc, srcType, src); - const auto enc = encSrc.withoutOrdering(); + SmallVector dimSizes = getDimSizes(rewriter, loc, srcTp, src); Value coo = NewCallParams(rewriter, loc) - .genBuffers(enc, dimSizes, srcType) + .genBuffers(srcTp.withoutOrdering(), dimSizes) .genNewCall(Action::kToCOO, src); // Then output the tensor to external file with indices in the externally // visible lexicographic index order. A sort is required if the source was // not in that order yet (note that the sort can be dropped altogether if // external format does not care about the order at all, but here we assume // it does). - Value sort = constantI1(rewriter, loc, - encSrc.getDimOrdering() && - !encSrc.getDimOrdering().isIdentity()); + const Value sort = constantI1(rewriter, loc, !srcTp.isIdentity()); SmallVector outParams{coo, adaptor.getOperands()[1], sort}; - Type eltType = srcType.getElementType(); - SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(eltType)}; + const Type elemTp = srcTp.getElementType(); + SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(elemTp)}; createFuncCall(rewriter, loc, name, {}, outParams, EmitCInterface::Off); - genDelCOOCall(rewriter, loc, eltType, coo); + genDelCOOCall(rewriter, loc, elemTp, coo); rewriter.eraseOp(op); return success(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineMap.h" @@ -42,32 +43,15 @@ // Helper to detect a sparse tensor type operand. static bool isSparseTensor(OpOperand *op) { - if (auto enc = getSparseTensorEncoding(op->get().getType())) { - if (llvm::is_contained(enc.getDimLevelType(), DimLevelType::Compressed)) - return true; - } - return false; -} - -static bool isAllDimOrdered(RankedTensorType rtp) { - if (auto enc = getSparseTensorEncoding(rtp)) - return llvm::all_of(enc.getDimLevelType(), isOrderedDLT); - - return true; + auto enc = getSparseTensorEncoding(op->get().getType()); + return enc && + llvm::is_contained(enc.getDimLevelType(), DimLevelType::Compressed); } static bool hasSameDimOrdering(RankedTensorType rtp1, RankedTensorType rtp2) { assert(rtp1.getRank() == rtp2.getRank()); - AffineMap idMap = - AffineMap::getMultiDimIdentityMap(rtp1.getRank(), rtp1.getContext()); - - auto enc1 = getSparseTensorEncoding(rtp1); - auto enc2 = getSparseTensorEncoding(rtp2); - - auto order1 = (enc1 && enc1.getDimOrdering()) ? enc1.getDimOrdering() : idMap; - auto order2 = (enc2 && enc2.getDimOrdering()) ? enc2.getDimOrdering() : idMap; - - return order1 == order2; + return SparseTensorType(rtp1).getDimToLvlMap() == + SparseTensorType(rtp2).getDimToLvlMap(); } // Helper method to find zero/uninitialized allocation. @@ -424,9 +408,10 @@ ValueRange reduc) { SmallVector srcIndices; SmallVector dstIndices; - for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) { - uint64_t dim = toStoredDim(encSrc, i); - srcIndices.push_back(args[dim]); + for (Dimension d = 0, dimRank = srcTp.getRank(); d < dimRank; d++) { + // FIXME: `toStoredDim` is deprecated + Level lvl = toStoredDim(encSrc, d); + srcIndices.push_back(args[lvl]); } translateIndicesArray(builder, loc, op.getReassociationIndices(), srcIndices, srcSizes, dstSizes, dstIndices); @@ -486,9 +471,10 @@ using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ConcatenateOp op, PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - auto dstTp = getRankedTensorType(op); - uint64_t conDim = op.getDimension().getZExtValue(); + const Location loc = op.getLoc(); + const auto dstTp = getSparseTensorType(op); + const Dimension dimRank = dstTp.getDimRank(); + const Dimension conDim = op.getDimension().getZExtValue(); SmallVector sizes; concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim); @@ -505,14 +491,16 @@ // foreach in %s2 : insert d0, d1 + size(s1), %tmp // foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp // %t = convert_to_dest_tensor(%tmp) - SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp); + // + // NOTE: this cannot be `const` because it will be changed when + // `needTmpCOO`, but that's buried in the conditional below and + // thus not easily extracted. + auto encDst = dstTp.getEncoding(); Value dst; // Destination tensor for inserting source tensor values. bool needTmpCOO = true; - bool allDense = false; + const bool allDense = dstTp.hasEncoding() && dstTp.isAllDense(); Value annotatedDenseDst; - int64_t rank = dstTp.getRank(); - if (encDst) { - allDense = encDst.isAllDense(); + if (dstTp.hasEncoding()) { bool allOrdered = false; // When concatenating on dimension 0, and all inputs are sorted and have // an identity dimOrdering, the concatenate will generate coords in @@ -521,16 +509,12 @@ // in all input/output buffers, and all input/output buffers have the same // dimOrdering, the tmp COO buffer is still unnecessary (e.g, concatenate // CSC matrices along column). - if (!allDense && conDim == 0 && encDst.hasIdDimOrdering()) { + if (!allDense && conDim == 0 && dstTp.isIdentity()) { for (auto i : op.getInputs()) { - auto rtp = getRankedTensorType(i); - auto srcEnc = getSparseTensorEncoding(rtp); - if (isAllDimOrdered(rtp) && (!srcEnc || srcEnc.hasIdDimOrdering())) { - allOrdered = true; - continue; - } - allOrdered = false; - break; + const auto stt = getSparseTensorType(i); + allOrdered = stt.isAllOrdered() && stt.isIdentity(); + if (!allOrdered) + break; } } @@ -547,8 +531,9 @@ // Create a view of the values buffer to match the unannotated dense // tensor. Value valuesBuffer = genToValues(rewriter, loc, dst); - Value idxBuffer = genAlloca( - rewriter, loc, rank, rewriter.getIndexType(), /*staticShape=*/true); + Value idxBuffer = + genAlloca(rewriter, loc, dimRank, rewriter.getIndexType(), + /*staticShape=*/true); annotatedDenseDst = dst; dst = reshapeValuesToLevels(rewriter, loc, encDst, sizes, valuesBuffer, idxBuffer); @@ -571,13 +556,14 @@ loc, input, initArgs, [&](OpBuilder &builder, Location loc, ValueRange args, Value v, ValueRange reduc) { - SmallVector indices(rank, Value()); - for (int64_t i = 0; i < rank; i++) { - Value idx = args[i]; - if (i == static_cast(conDim)) + SmallVector indices(dstTp.getLvlRank()); + for (Dimension d = 0; d < dimRank; d++) { + Value idx = args[d]; + if (d == conDim) // Transform coordinates for the concatenating dim. idx = builder.create(loc, idx, offset); - indices[toStoredDim(encDst, i)] = idx; + // FIXME: `toStoredDim` is deprecated + indices[toStoredDim(encDst, d)] = idx; } if (encDst && !allDense) { Value cond = genIsNonzero(rewriter, loc, v); @@ -599,31 +585,34 @@ // Accumulates the offset. Note that only static-shaped inputs are allowed // by concatenate op verifier, which saves us from computing the offset // dynamically. - int64_t d = getRankedTensorType(input).getShape()[conDim]; - assert(!ShapedType::isDynamic(d)); - offset = rewriter.create(loc, offset, - constantIndex(rewriter, loc, d)); + const auto sh = getSparseTensorType(input).getStaticDimSize(conDim); + assert(sh.has_value()); + offset = rewriter.create( + loc, offset, constantIndex(rewriter, loc, *sh)); if (encDst && !allDense) { dst = foreachOp.getResult(0); initArgs[0] = dst; } } + // Temp variable to avoid needing to call `getRankedTensorType` + // in the three use-sites below. + const RankedTensorType dstRTT = dstTp; if (encDst) { if (!allDense) { dst = rewriter.create(loc, dst, true); if (needTmpCOO) { Value tmpCoo = dst; - dst = rewriter.create(loc, dstTp, tmpCoo).getResult(); + dst = rewriter.create(loc, dstRTT, tmpCoo).getResult(); rewriter.create(loc, tmpCoo); } } else { - dst = rewriter.create(loc, dstTp, annotatedDenseDst) + dst = rewriter.create(loc, dstRTT, annotatedDenseDst) .getResult(); } rewriter.replaceOp(op, dst); } else { - rewriter.replaceOpWithNewOp(op, dstTp, dst); + rewriter.replaceOpWithNewOp(op, dstRTT, dst); } return success(); } @@ -675,7 +664,7 @@ PatternRewriter &rewriter) const { Location loc = op.getLoc(); Value src = op.getSource(); - auto dstTp = getRankedTensorType(op); + const auto dstTp = getSparseTensorType(op); SmallVector sizes; sizesFromSrc(rewriter, sizes, loc, src); SmallVector dynSizes; @@ -688,16 +677,16 @@ } } - SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp); + const auto encDst = dstTp.getEncoding(); // We don't need a temporary COO tensor if the destination has an identity // ordering. Otherwise, we use the destination ordering for the temporary // COO tensor. // TODO: enhance foreachOp to take ordering to remove the need of a // temporary COO tensor here. - RankedTensorType bufferTp = encDst.hasIdDimOrdering() - ? dstTp - : getUnorderedCOOFromTypeWithOrdering( - dstTp, encDst.getDimOrdering()); + const RankedTensorType bufferTp = dstTp.isIdentity() + ? dstTp.getRankedTensorType() + : getUnorderedCOOFromTypeWithOrdering( + dstTp, dstTp.getDimToLvlMap()); auto buffer = rewriter.create(loc, bufferTp, dynSizes).getResult(); auto foreachOp = rewriter.create( @@ -705,10 +694,11 @@ [&](OpBuilder &builder, Location loc, ValueRange indices, Value v, ValueRange reduc) { Value input = reduc.front(); - uint64_t rank = dstTp.getRank(); - SmallVector indicesArray(rank, Value()); - for (uint64_t i = 0; i < rank; i++) - indicesArray[toStoredDim(encDst, i)] = indices[i]; + const Dimension dimRank = dstTp.getDimRank(); + SmallVector indicesArray(dimRank); + for (Dimension d = 0; d < dimRank; d++) + // FIXME: `toStoredDim` is deprecated + indicesArray[toStoredDim(encDst, d)] = indices[d]; if (fromSparseConst) { input = builder.create(loc, v, input, indicesArray); } else { @@ -729,7 +719,8 @@ rewriter.setInsertionPointAfter(op); src = rewriter.create(loc, foreachOp.getResult(0), true); if (bufferTp != dstTp) { - rewriter.replaceOpWithNewOp(op, dstTp, src); + rewriter.replaceOpWithNewOp(op, dstTp.getRankedTensorType(), + src); rewriter.create(loc, src); } else { rewriter.replaceOp(op, src); @@ -782,15 +773,22 @@ // insert element to dst LogicalResult sparse2SparseRewrite(ConvertOp op, PatternRewriter &rewriter) const { - Location loc = op->getLoc(); + const Location loc = op->getLoc(); + // These two variables cannot be `const` because they're conditionally + // changed below. Ideally we'd use `SparseTensorType` for `srcRTT`; + // however that class's copy-ctor is implicitly deleted. Value src = op.getSource(); - RankedTensorType srcTp = getRankedTensorType(src); - RankedTensorType dstTp = getRankedTensorType(op); - SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp); - int64_t rank = dstTp.getRank(); + auto srcRTT = getRankedTensorType(src); + const auto dstTp = getSparseTensorType(op); + const auto encDst = dstTp.getEncoding(); + const Level dstLvlRank = dstTp.getLvlRank(); + const Dimension dimRank = dstTp.getDimRank(); + // This assertion should be guaranteed by validity of the op, + // but just for paranoia's sake. + assert(srcRTT.getRank() == dimRank); SmallVector srcSizes; - sizesForTensor(rewriter, srcSizes, loc, srcTp, src); + sizesForTensor(rewriter, srcSizes, loc, srcRTT, src); Value tmpCoo = Value(); Value nnz = rewriter.create(loc, src); // We need a tmp COO buffer if and only if @@ -798,28 +796,31 @@ // 2. the src tensor is not ordered in the same way as the target // tensor (e.g., src tensor is not ordered or src tensor haves a different // dimOrdering). - if (!isUniqueCOOType(srcTp) && - !(isAllDimOrdered(srcTp) && hasSameDimOrdering(srcTp, dstTp))) { + if (!isUniqueCOOType(srcRTT) && !(SparseTensorType(srcRTT).isAllOrdered() && + hasSameDimOrdering(srcRTT, dstTp))) { // Construct a COO tensor from the src tensor. // TODO: there may be cases for which more efficiently without // going through an intermediate COO, such as cases that only change // the overhead types. SmallVector dynSrcSizes; - getDynamicSizes(srcTp, srcSizes, dynSrcSizes); - srcTp = - getUnorderedCOOFromTypeWithOrdering(srcTp, encDst.getDimOrdering()); + getDynamicSizes(srcRTT, srcSizes, dynSrcSizes); + srcRTT = + getUnorderedCOOFromTypeWithOrdering(srcRTT, dstTp.getDimToLvlMap()); + // Ensure that mutating `srcRTT` didn't invalidate `dimRank`. + assert(srcRTT.getRank() == dimRank); tmpCoo = rewriter - .create(loc, srcTp, dynSrcSizes, Value(), + .create(loc, srcRTT, dynSrcSizes, Value(), /*sizeHint=*/nnz, Attribute()) .getResult(); auto foreachOp = rewriter.create( loc, src, tmpCoo, [&](OpBuilder &builder, Location loc, ValueRange args, Value v, ValueRange reduc) { - SmallVector dstIndices(srcTp.getRank(), Value()); - for (int64_t i = 0; i < rank; i++) { - uint64_t dim = toStoredDim(encDst, i); - dstIndices[dim] = args[i]; + SmallVector dstIndices(dstLvlRank); + for (Dimension d = 0; d < dimRank; d++) { + // FIXME: `toStoredDim` is deprecated + Level l = toStoredDim(encDst, d); + dstIndices[l] = args[d]; } auto t = builder.create(loc, v, reduc.front(), dstIndices); @@ -828,30 +829,36 @@ src = rewriter.create(loc, foreachOp.getResult(0), true); } + // Now that the conditional is done, we can use `SparseTensorType`. + const SparseTensorType srcTp(srcRTT); + // Only need to sort if the srcTp is not already sorted (we faithfully take // the guarantee from the sparse tensor encoding). - if (!isAllDimOrdered(srcTp)) { + if (!srcTp.isAllOrdered()) { // Retrieve the values-array. Value y = genToValues(rewriter, loc, src); - SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp); + const auto encSrc = srcTp.getEncoding(); // Sort the COO tensor so that its elements are ordered via increasing // indices for the storage ordering of the dst tensor. Use SortCoo if the // COO tensor has the same dim ordering as the dst tensor. - if (rank > 1 && hasSameDimOrdering(srcTp, dstTp)) { + if (dimRank > 1 && hasSameDimOrdering(srcTp, dstTp)) { MemRefType indTp = get1DMemRefType(getIndexOverheadType(rewriter, encSrc), /*withLayout=*/false); Value xs = rewriter.create(loc, indTp, src); rewriter.create( - loc, nnz, xs, ValueRange{y}, rewriter.getIndexAttr(rank), + loc, nnz, xs, ValueRange{y}, rewriter.getIndexAttr(dimRank), rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort); } else { // Gather the indices-arrays in the dst tensor storage order. - SmallVector xs(rank, Value()); - for (int64_t i = 0; i < rank; i++) { - uint64_t orgDim = toOrigDim(encSrc, i); - xs[toStoredDim(encDst, orgDim)] = - genToIndices(rewriter, loc, src, i, /*cooStart=*/0); + SmallVector xs(dstLvlRank); + const Level srcLvlRank = srcTp.getLvlRank(); + for (Level srcLvl = 0; srcLvl < srcLvlRank; srcLvl++) { + // FIXME: `toOrigDim` is deprecated + Dimension dim = toOrigDim(encSrc, srcLvl); + // FIXME: `toStoredDim` is deprecated + Level dstLvl = toStoredDim(encDst, dim); + xs[dstLvl] = genToIndices(rewriter, loc, src, srcLvl, /*cooStart=*/0); } rewriter.create(loc, nnz, xs, ValueRange{y}, SparseTensorSortKind::HybridQuickSort); @@ -862,17 +869,19 @@ SmallVector dynDstSizes; getDynamicSizes(dstTp, srcSizes, dynDstSizes); Value dst = rewriter - .create(loc, dstTp, dynDstSizes, Value(), + .create(loc, dstTp.getRankedTensorType(), + dynDstSizes, Value(), /*sizeHint=*/nnz, Attribute()) .getResult(); - SmallVector indices(srcTp.getRank(), Value()); + SmallVector indices(dstLvlRank); auto foreachOp = rewriter.create( loc, src, dst, [&](OpBuilder &builder, Location loc, ValueRange args, Value v, ValueRange reduc) { - for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) { - uint64_t dim = toStoredDim(encDst, i); - indices[dim] = args[i]; + for (Dimension d = 0; d < dimRank; d++) { + // FIXME: `toStoredDim` is deprecated + Level l = toStoredDim(encDst, d); + indices[l] = args[d]; } auto t = builder.create(loc, v, reduc.front(), indices); builder.create(loc, t); @@ -889,7 +898,7 @@ // codegen. rewriter.setInsertionPointAfter(op); auto t = rewriter.create(loc, foreachOp.getResult(0), true); - rewriter.replaceOpWithNewOp(op, dstTp, t); + rewriter.replaceOpWithNewOp(op, dstTp.getRankedTensorType(), t); return success(); } }; @@ -905,8 +914,8 @@ auto loc = op.getLoc(); Value input = op.getTensor(); SmallVector reduc = op.getInitArgs(); - auto rtp = getRankedTensorType(input); - int64_t rank = rtp.getRank(); + const auto stt = getSparseTensorType(input); + const Dimension dimRank = stt.getDimRank(); // Special-case: for each over a sparse constant uses its own rewriting // rule. @@ -917,24 +926,24 @@ } // Otherwise, use loop emitter to generate loops. - auto enc = getSparseTensorEncoding(rtp); + const auto enc = stt.getEncoding(); // 1. Generates loop for the sparse input. LoopEmitter loopEmitter( ValueRange{input}, StringAttr::get(getContext(), ForeachOp::getOperationName())); loopEmitter.initializeLoopEmit(rewriter, loc); - for (int64_t i = 0; i < rank; i++) { + for (Dimension d = 0; d < dimRank; d++) { // TODO: provide utility function for loop sequences that only contains // one for loop? - loopEmitter.enterNewLoopSeq(rewriter, loc, 0, static_cast(i)); + loopEmitter.enterNewLoopSeq(rewriter, loc, 0, static_cast(d)); // Note that reduc will be taken care of by loop emitter and get updated // in place. - loopEmitter.enterLoopOverTensorAtDim(rewriter, loc, 0, i, reduc); + loopEmitter.enterLoopOverTensorAtDim(rewriter, loc, 0, d, reduc); } SmallVector coords; - coords.reserve(rank); + coords.reserve(dimRank); loopEmitter.getCoordinateArray(coords); Value vals = loopEmitter.getValBuffer()[0]; @@ -949,8 +958,9 @@ // Remap coordinates. SmallVector args; - for (int64_t i = 0; i < rank; i++) { - Value actual = coords[toStoredDim(enc, i)]; + for (Dimension d = 0; d < dimRank; d++) { + // FIXME: `toStoredDim` is deprecated + Value actual = coords[toStoredDim(enc, d)]; args.push_back(actual); } // Remap value. @@ -972,7 +982,7 @@ rewriter.mergeBlockBefore(srcBlock, &*rewriter.getInsertionPoint(), args); } - for (int64_t i = 0; i < rank; i++) { + for (Dimension d = 0; d < dimRank; d++) { // Link the reduction chain. Note that loop emitter update the reducValue // in place. loopEmitter.exitCurrentLoop(rewriter, loc, reducValue); @@ -992,9 +1002,9 @@ LogicalResult matchAndRewrite(NewOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto dstTp = getRankedTensorType(op.getResult()); - SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp); - if (!encDst) + const auto dstTp = getSparseTensorType(op.getResult()); + const auto encDst = dstTp.getEncoding(); + if (!dstTp.hasEncoding()) return failure(); // Create a sparse tensor reader. @@ -1006,17 +1016,17 @@ // Allocate a temporary buffer for storing dimension sizes and indices. Type indexTp = rewriter.getIndexType(); - uint64_t rank = dstTp.getRank(); - Value dimSizes = genAlloca(rewriter, loc, rank, indexTp); + const Dimension dimRank = dstTp.getDimRank(); + Value dimSizes = genAlloca(rewriter, loc, dimRank, indexTp); // If the result tensor has dynamic dimensions, get the dynamic sizes from // the sparse tensor reader. SmallVector dynSizesArray; - if (!dstTp.hasStaticShape()) { + if (dstTp.hasDynamicDimShape()) { createFuncCall(rewriter, loc, "copySparseTensorReaderDimSizes", {}, {reader, dimSizes}, EmitCInterface::On) .getResult(0); - ArrayRef dstShape = dstTp.getShape(); + ArrayRef dstShape = dstTp.getRankedTensorType().getShape(); for (auto &d : llvm::enumerate(dstShape)) { if (d.value() == ShapedType::kDynamic) { dynSizesArray.push_back(rewriter.create( @@ -1038,7 +1048,7 @@ {indexTp}, {reader}, EmitCInterface::Off) .getResult(0); RankedTensorType cooTp = - getUnorderedCOOFromTypeWithOrdering(dstTp, encDst.getDimOrdering()); + getUnorderedCOOFromTypeWithOrdering(dstTp, dstTp.getDimToLvlMap()); Value cooBuffer = rewriter .create(loc, cooTp, dynSizesArray, Value(), @@ -1047,7 +1057,7 @@ // The verifier ensures only 2D tensors can have the expandSymmetry flag. Value symmetric; - if (rank == 2 && op.getExpandSymmetry()) { + if (dimRank == 2 && op.getExpandSymmetry()) { symmetric = createFuncCall(rewriter, loc, "getSparseTensorReaderIsSymmetric", {rewriter.getI1Type()}, {reader}, EmitCInterface::Off) @@ -1066,10 +1076,11 @@ Value indices = dimSizes; // Reuse the indices memref to store indices. createFuncCall(rewriter, loc, getNextFuncName, {}, {reader, indices, value}, EmitCInterface::On); - SmallVector indicesArray(rank, Value()); - for (uint64_t i = 0; i < rank; i++) { - indicesArray[toStoredDim(encDst, i)] = rewriter.create( - loc, indices, constantIndex(rewriter, loc, i)); + SmallVector indicesArray(dimRank); + for (Dimension d = 0; d < dimRank; d++) { + // FIXME: `toStoredDim` is deprecated + indicesArray[toStoredDim(encDst, d)] = rewriter.create( + loc, indices, constantIndex(rewriter, loc, d)); } Value v = rewriter.create(loc, value); Value t = rewriter.create(loc, v, forOp.getRegionIterArg(0), @@ -1098,7 +1109,8 @@ createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader}, EmitCInterface::Off); cooBuffer = rewriter.create(loc, cooBuffer, true); - Value newOp = rewriter.replaceOpWithNewOp(op, dstTp, cooBuffer); + Value newOp = rewriter.replaceOpWithNewOp( + op, dstTp.getRankedTensorType(), cooBuffer); // Release the unordered COO tensor buffer. rewriter.setInsertionPointAfterValue(newOp); @@ -1118,18 +1130,18 @@ Value nnz = rewriter.create(loc, src); // Allocate a temporary buffer for storing dimension sizes and indices. - auto srcTp = getRankedTensorType(src); - uint64_t rank = srcTp.getRank(); + const auto srcTp = getSparseTensorType(src); + const Dimension dimRank = srcTp.getDimRank(); Type indexTp = rewriter.getIndexType(); - Value dimSizes = genAlloca(rewriter, loc, rank, indexTp); + Value dimSizes = genAlloca(rewriter, loc, dimRank, indexTp); // Generate code to calculate dimension size values and store the values to // the buffer. SmallVector dims; sizesForTensor(rewriter, dims, loc, srcTp, src); - for (uint64_t i = 0; i < rank; i++) { - rewriter.create(loc, dims[i], dimSizes, - constantIndex(rewriter, loc, i)); + for (Dimension d = 0; d < dimRank; d++) { + rewriter.create(loc, dims[d], dimSizes, + constantIndex(rewriter, loc, d)); } // Create a sparse tensor writer and output meta data. @@ -1138,7 +1150,7 @@ createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp}, {op.getDest()}, EmitCInterface::Off) .getResult(0); - Value rankValue = constantIndex(rewriter, loc, rank); + Value rankValue = constantIndex(rewriter, loc, dimRank); createFuncCall(rewriter, loc, "outSparseTensorWriterMetaData", {}, {writer, rankValue, nnz, dimSizes}, EmitCInterface::On); @@ -1153,9 +1165,9 @@ loc, src, std::nullopt, [&](OpBuilder &builder, Location loc, ValueRange args, Value v, ValueRange reduc) { - for (uint64_t i = 0; i < rank; i++) { - rewriter.create(loc, args[i], indices, - constantIndex(builder, loc, i)); + for (Dimension d = 0; d < dimRank; d++) { + rewriter.create(loc, args[d], indices, + constantIndex(builder, loc, d)); } rewriter.create(loc, v, value); SmallVector operands{writer, rankValue, indices, value}; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h @@ -15,6 +15,7 @@ #include "mlir/Conversion/LLVMCommon/StructBuilder.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Transforms/DialectConversion.h" @@ -33,14 +34,14 @@ // capacity and the used size resides in the storage_specifier struct. // // struct { -// ; per-dimension d: +// ; per-level l: // ; if dense: // // ; if compresed: -// memref pointers-d ; pointers for sparse dim d -// memref indices-d ; indices for sparse dim d +// memref pointers-l ; pointers for sparse level l +// memref indices-l ; indices for sparse level l // ; if singleton: -// memref indices-d ; indices for singleton dim d +// memref indices-l ; indices for singleton level l // // memref values ; values // @@ -50,13 +51,13 @@ // } // }; // -// In addition, for a "trailing COO region", defined as a compressed -// dimension followed by one ore more singleton dimensions, the default -// SOA storage that is inherent to the TACO format is optimized into an -// AOS storage where all indices of a stored element appear consecutively. -// In such cases, a special operation (sparse_tensor.indices_buffer) must -// be used to access the AOS index array. In the code below, the method -// `getCOOStart` is used to find the start of the "trailing COO region". +// In addition, for a "trailing COO region", defined as a compressed level +// followed by one ore more singleton levels, the default SOA storage that +// is inherent to the TACO format is optimized into an AOS storage where +// all indices of a stored element appear consecutively. In such cases, +// a special operation (sparse_tensor.indices_buffer) must be used to +// access the AOS index array. In the code below, the method `getCOOStart` +// is used to find the start of the "trailing COO region". // // Examples. // @@ -64,13 +65,13 @@ // memref ; pointers-1 // memref ; indices-1 // memref ; values -// struct<(array<2 x i64>, array<3 x i64>)>) ; dim0, dim1, 3xsizes +// struct<(array<2 x i64>, array<3 x i64>)>) ; lvl0, lvl1, 3xsizes // // #COO storage of 2-dim matrix yields // memref, ; pointers-0, essentially [0,sz] // memref ; AOS index storage // memref ; values -// struct<(array<2 x i64>, array<3 x i64>)>) ; dim0, dim1, 3xsizes +// struct<(array<2 x i64>, array<3 x i64>)>) ; lvl0, lvl1, 3xsizes // //===----------------------------------------------------------------------===// @@ -88,6 +89,15 @@ static_assert(static_cast(SparseTensorFieldKind::ValMemRef) == static_cast(StorageSpecifierKind::ValMemSize)); +/// The type of field indices. This alias is to help code be more +/// self-documenting; unfortunately it is not type-checked, so it only +/// provides documentation rather than doing anything to prevent mixups. +using FieldIndex = unsigned; + +// TODO: Functions/methods marked with [NUMFIELDS] might should use +// `FieldIndex` for their return type, via the same reasoning for why +// `Dimension`/`Level` are used both for identifiers and ranks. + /// For each field that will be allocated for the given sparse tensor encoding, /// calls the callback with the corresponding field index, field kind, dimension /// (for sparse tensor level memrefs) and dimlevelType. @@ -97,25 +107,26 @@ /// tensor fields instead of relying on ad-hoc index computation. void foreachFieldInSparseTensor( SparseTensorEncodingAttr, - llvm::function_ref); + llvm::function_ref); /// Same as above, except that it also builds the Type for the corresponding /// field. void foreachFieldAndTypeInSparseTensor( - RankedTensorType, - llvm::function_ref); /// Gets the total number of fields for the given sparse tensor encoding. +// TODO: See note [NUMFIELDS]. unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc); /// Gets the total number of data fields (index arrays, pointer arrays, and a /// value array) for the given sparse tensor encoding. +// TODO: See note [NUMFIELDS]. unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc); inline StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind) { @@ -138,47 +149,48 @@ /// Getters: get the field index for required field. /// - unsigned getMemRefFieldIndex(SparseTensorFieldKind kind, - std::optional dim) const { - return getFieldIndexAndStride(kind, dim).first; + FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind, + std::optional lvl) const { + return getFieldIndexAndStride(kind, lvl).first; } - unsigned getMemRefFieldIndex(StorageSpecifierKind kind, - std::optional dim) const { - return getMemRefFieldIndex(toFieldKind(kind), dim); + FieldIndex getMemRefFieldIndex(StorageSpecifierKind kind, + std::optional lvl) const { + return getMemRefFieldIndex(toFieldKind(kind), lvl); } + // TODO: See note [NUMFIELDS]. static unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) { return sparse_tensor::getNumFieldsFromEncoding(enc); } static void foreachFieldInSparseTensor( const SparseTensorEncodingAttr enc, - llvm::function_ref callback) { return sparse_tensor::foreachFieldInSparseTensor(enc, callback); } - std::pair + std::pair getFieldIndexAndStride(SparseTensorFieldKind kind, - std::optional dim) const { - unsigned fieldIdx = -1u; + std::optional lvl) const { + FieldIndex fieldIdx = -1u; unsigned stride = 1; if (kind == SparseTensorFieldKind::IdxMemRef) { - assert(dim.has_value()); - unsigned cooStart = getCOOStart(enc); - unsigned rank = enc.getDimLevelType().size(); - if (dim.value() >= cooStart && dim.value() < rank) { - dim = cooStart; - stride = rank - cooStart; + assert(lvl.has_value()); + const Level cooStart = getCOOStart(enc); + const Level lvlRank = enc.getLvlRank(); + if (lvl.value() >= cooStart && lvl.value() < lvlRank) { + lvl = cooStart; + stride = lvlRank - cooStart; } } foreachFieldInSparseTensor( enc, - [dim, kind, &fieldIdx](unsigned fIdx, SparseTensorFieldKind fKind, - unsigned fDim, DimLevelType dlt) -> bool { - if ((dim && fDim == dim.value() && kind == fKind) || + [lvl, kind, &fieldIdx](FieldIndex fIdx, SparseTensorFieldKind fKind, + Level fLvl, DimLevelType dlt) -> bool { + if ((lvl && fLvl == lvl.value() && kind == fKind) || (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) { fieldIdx = fIdx; // Returns false to break the iteration. @@ -187,13 +199,27 @@ return true; }); assert(fieldIdx != -1u); - return std::pair(fieldIdx, stride); + return std::pair(fieldIdx, stride); } private: SparseTensorEncodingAttr enc; }; +// FIXME: Functions/methods marked with [CLARIFY_DIM_LVL] require +// clarification on whether their "dim" argument should actually +// be `Level` or `Dimension`. In particular, it's unclear whether +// `StorageSpecifierKind::DimSize` actually means to refer to dimension-sizes +// vs level-sizes. If it's the latter (which seems unlikely), then all the +// noted functions should use the `Level` type alias. If it's the former, +// then the functions which specifically use `DimSize` should be changed +// to use the `Dimension` type alias; however, the functions which take +// an unknown `StorageSpecifierKind` must be adjusted to ensure that they +// correctly interpret the "dim" argument since the interpretation depends +// on the `StorageSpecifierKind` value. Since wrengr couldn't figure this +// out from context, Peiming or Bixia should review these functions and +// update them as appropriate. + class SparseTensorSpecifier { public: explicit SparseTensorSpecifier(Value specifier) @@ -201,18 +227,21 @@ // Undef value for dimension sizes, all zero value for memory sizes. static Value getInitValue(OpBuilder &builder, Location loc, - RankedTensorType rtp); + SparseTensorType stt); /*implicit*/ operator Value() { return specifier; } + // FIXME: see note [CLARIFY_DIM_LVL]. Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional dim); + // FIXME: see note [CLARIFY_DIM_LVL]. void setSpecifierField(OpBuilder &builder, Location loc, Value v, StorageSpecifierKind kind, std::optional dim); + // FIXME: see note [CLARIFY_DIM_LVL]. Type getFieldType(StorageSpecifierKind kind, std::optional dim) { return specifier.getType().getFieldType(kind, dim); } @@ -229,11 +258,10 @@ template class SparseTensorDescriptorImpl { protected: - SparseTensorDescriptorImpl(Type tp, ValueArrayRef fields) - : rType(tp.cast()), fields(fields) { - assert(getSparseTensorEncoding(tp) && - getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) == - fields.size()); + SparseTensorDescriptorImpl(SparseTensorType stt, ValueArrayRef fields) + : rType(stt), fields(fields) { + assert(stt.hasEncoding() && + getNumFieldsFromEncoding(stt.getEncoding()) == getNumFields()); // We should make sure the class is trivially copyable (and should be small // enough) such that we can pass it by value. static_assert(std::is_trivially_copyable_v< @@ -241,19 +269,21 @@ } public: - unsigned getMemRefFieldIndex(SparseTensorFieldKind kind, - std::optional dim) const { + FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind, + std::optional lvl) const { // Delegates to storage layout. - StorageLayout layout(getSparseTensorEncoding(rType)); - return layout.getMemRefFieldIndex(kind, dim); + StorageLayout layout(rType.getEncoding()); + return layout.getMemRefFieldIndex(kind, lvl); } + // TODO: See note [NUMFIELDS]. unsigned getNumFields() const { return fields.size(); } /// /// Getters: get the value for required field. /// + // FIXME: see note [CLARIFY_DIM_LVL]. Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional dim) const { @@ -261,12 +291,13 @@ return md.getSpecifierField(builder, loc, kind, dim); } + // FIXME: see note [CLARIFY_DIM_LVL]. Value getDimSize(OpBuilder &builder, Location loc, unsigned dim) const { return getSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim); } - Value getPtrMemRef(unsigned ptrDim) const { - return getMemRefField(SparseTensorFieldKind::PtrMemRef, ptrDim); + Value getPtrMemRef(Level lvl) const { + return getMemRefField(SparseTensorFieldKind::PtrMemRef, lvl); } Value getValMemRef() const { @@ -274,23 +305,23 @@ } Value getMemRefField(SparseTensorFieldKind kind, - std::optional dim) const { - return getField(getMemRefFieldIndex(kind, dim)); + std::optional lvl) const { + return getField(getMemRefFieldIndex(kind, lvl)); } - Value getMemRefField(unsigned fidx) const { + Value getMemRefField(FieldIndex fidx) const { assert(fidx < fields.size() - 1); return getField(fidx); } - Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const { + Value getPtrMemSize(OpBuilder &builder, Location loc, Level lvl) const { return getSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize, - dim); + lvl); } - Value getIdxMemSize(OpBuilder &builder, Location loc, unsigned dim) const { + Value getIdxMemSize(OpBuilder &builder, Location loc, Level lvl) const { return getSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize, - dim); + lvl); } Value getValMemSize(OpBuilder &builder, Location loc) const { @@ -299,54 +330,46 @@ } Type getMemRefElementType(SparseTensorFieldKind kind, - std::optional dim) const { - return getMemRefField(kind, dim) - .getType() - .template cast() - .getElementType(); + std::optional lvl) const { + return getMemRefType(getMemRefField(kind, lvl)).getElementType(); } - Value getField(unsigned fidx) const { + Value getField(FieldIndex fidx) const { assert(fidx < fields.size()); return fields[fidx]; } ValueRange getMemRefFields() const { - ValueRange ret = fields; // Drop the last metadata fields. - return ret.slice(0, fields.size() - 1); + return fields.drop_back(); } - std::pair - getIdxMemRefIndexAndStride(unsigned idxDim) const { - StorageLayout layout(getSparseTensorEncoding(rType)); - return layout.getFieldIndexAndStride(SparseTensorFieldKind::IdxMemRef, - idxDim); + std::pair getIdxMemRefIndexAndStride(Level lvl) const { + StorageLayout layout(rType.getEncoding()); + return layout.getFieldIndexAndStride(SparseTensorFieldKind::IdxMemRef, lvl); } Value getAOSMemRef() const { - auto enc = getSparseTensorEncoding(rType); - unsigned cooStart = getCOOStart(enc); - assert(cooStart < enc.getDimLevelType().size()); + const Level cooStart = getCOOStart(rType.getEncoding()); + assert(cooStart < rType.getLvlRank()); return getMemRefField(SparseTensorFieldKind::IdxMemRef, cooStart); } - RankedTensorType getTensorType() const { return rType; } + RankedTensorType getRankedTensorType() const { return rType; } ValueArrayRef getFields() const { return fields; } protected: - RankedTensorType rType; + SparseTensorType rType; ValueArrayRef fields; }; /// Uses ValueRange for immutable descriptors. class SparseTensorDescriptor : public SparseTensorDescriptorImpl { public: - SparseTensorDescriptor(Type tp, ValueRange buffers) - : SparseTensorDescriptorImpl(tp, buffers) {} + SparseTensorDescriptor(SparseTensorType stt, ValueRange buffers) + : SparseTensorDescriptorImpl(stt, buffers) {} - Value getIdxMemRefOrView(OpBuilder &builder, Location loc, - unsigned idxDim) const; + Value getIdxMemRefOrView(OpBuilder &builder, Location loc, Level lvl) const; }; /// Uses SmallVectorImpl & for mutable descriptors. @@ -359,8 +382,9 @@ class MutSparseTensorDescriptor : public SparseTensorDescriptorImpl &> { public: - MutSparseTensorDescriptor(Type tp, SmallVectorImpl &buffers) - : SparseTensorDescriptorImpl &>(tp, buffers) {} + MutSparseTensorDescriptor(SparseTensorType stt, + SmallVectorImpl &buffers) + : SparseTensorDescriptorImpl &>(stt, buffers) {} // Allow implicit type conversion from mutable descriptors to immutable ones // (but not vice versa). @@ -373,21 +397,22 @@ /// required field. /// - void setMemRefField(SparseTensorFieldKind kind, std::optional dim, + void setMemRefField(SparseTensorFieldKind kind, std::optional lvl, Value v) { - fields[getMemRefFieldIndex(kind, dim)] = v; + fields[getMemRefFieldIndex(kind, lvl)] = v; } - void setMemRefField(unsigned fidx, Value v) { + void setMemRefField(FieldIndex fidx, Value v) { assert(fidx < fields.size() - 1); fields[fidx] = v; } - void setField(unsigned fidx, Value v) { + void setField(FieldIndex fidx, Value v) { assert(fidx < fields.size()); fields[fidx] = v; } + // FIXME: see note [CLARIFY_DIM_LVL]. void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional dim, Value v) { @@ -401,14 +426,15 @@ std::nullopt, v); } - void setIdxMemSize(OpBuilder &builder, Location loc, unsigned dim, Value v) { - setSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize, dim, v); + void setIdxMemSize(OpBuilder &builder, Location loc, Level lvl, Value v) { + setSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize, lvl, v); } - void setPtrMemSize(OpBuilder &builder, Location loc, unsigned dim, Value v) { - setSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize, dim, v); + void setPtrMemSize(OpBuilder &builder, Location loc, Level lvl, Value v) { + setSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize, lvl, v); } + // FIXME: see note [CLARIFY_DIM_LVL]. void setDimSize(OpBuilder &builder, Location loc, unsigned dim, Value v) { setSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim, v); } @@ -428,19 +454,21 @@ inline Value genTuple(OpBuilder &builder, Location loc, SparseTensorDescriptor desc) { - return genTuple(builder, loc, desc.getTensorType(), desc.getFields()); + return genTuple(builder, loc, desc.getRankedTensorType(), desc.getFields()); } inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) { auto tuple = getTuple(tensor); - return SparseTensorDescriptor(tuple.getResultTypes()[0], tuple.getInputs()); + SparseTensorType stt(tuple.getResultTypes()[0].cast()); + return SparseTensorDescriptor(stt, tuple.getInputs()); } inline MutSparseTensorDescriptor getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl &fields) { auto tuple = getTuple(tensor); fields.assign(tuple.getInputs().begin(), tuple.getInputs().end()); - return MutSparseTensorDescriptor(tuple.getResultTypes()[0], fields); + SparseTensorType stt(tuple.getResultTypes()[0].cast()); + return MutSparseTensorDescriptor(stt, fields); } } // namespace sparse_tensor diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp @@ -36,16 +36,19 @@ return IntegerAttr::get(IndexType::get(ctx), dim.value()); } +// This is only ever called from `SparseTensorTypeToBufferConverter`, +// which is why the first argument is `RankedTensorType` rather than +// `SparseTensorType`. static std::optional convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl &fields) { - auto enc = getSparseTensorEncoding(rtp); - if (!enc) + const SparseTensorType stt(rtp); + if (!stt.hasEncoding()) return std::nullopt; foreachFieldAndTypeInSparseTensor( - rtp, - [&fields](Type fieldType, unsigned fieldIdx, - SparseTensorFieldKind /*fieldKind*/, unsigned /*dim*/, + stt, + [&fields](Type fieldType, FieldIndex fieldIdx, + SparseTensorFieldKind /*fieldKind*/, Level /*lvl*/, DimLevelType /*dlt*/) -> bool { assert(fieldIdx == fields.size()); fields.push_back(fieldType); @@ -60,9 +63,7 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() { addConversion([](Type type) { return type; }); - addConversion([&](RankedTensorType rtp, SmallVectorImpl &fields) { - return convertSparseTensorType(rtp, fields); - }); + addConversion(convertSparseTensorType); // Required by scf.for 1:N type conversion. addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp, @@ -81,9 +82,9 @@ //===----------------------------------------------------------------------===// Value SparseTensorSpecifier::getInitValue(OpBuilder &builder, Location loc, - RankedTensorType rtp) { + SparseTensorType stt) { return builder.create( - loc, StorageSpecifierType::get(getSparseTensorEncoding(rtp))); + loc, StorageSpecifierType::get(stt.getEncoding())); } Value SparseTensorSpecifier::getSpecifierField(OpBuilder &builder, Location loc, @@ -110,34 +111,30 @@ //===----------------------------------------------------------------------===// Value sparse_tensor::SparseTensorDescriptor::getIdxMemRefOrView( - OpBuilder &builder, Location loc, unsigned idxDim) const { - auto enc = getSparseTensorEncoding(rType); - unsigned cooStart = getCOOStart(enc); - unsigned idx = idxDim >= cooStart ? cooStart : idxDim; - Value buffer = getMemRefField(SparseTensorFieldKind::IdxMemRef, idx); - if (idxDim >= cooStart) { - unsigned rank = enc.getDimLevelType().size(); - Value stride = constantIndex(builder, loc, rank - cooStart); - Value size = getIdxMemSize(builder, loc, cooStart); - size = builder.create(loc, size, stride); - buffer = builder.create( - loc, buffer, - /*offset=*/ValueRange{constantIndex(builder, loc, idxDim - cooStart)}, - /*size=*/ValueRange{size}, - /*step=*/ValueRange{stride}); - } - return buffer; + OpBuilder &builder, Location loc, Level idxLvl) const { + const Level cooStart = getCOOStart(rType.getEncoding()); + if (idxLvl < cooStart) + return getMemRefField(SparseTensorFieldKind::IdxMemRef, idxLvl); + + Value stride = constantIndex(builder, loc, rType.getLvlRank() - cooStart); + Value size = getIdxMemSize(builder, loc, cooStart); + size = builder.create(loc, size, stride); + return builder.create( + loc, getMemRefField(SparseTensorFieldKind::IdxMemRef, cooStart), + /*offset=*/ValueRange{constantIndex(builder, loc, idxLvl - cooStart)}, + /*size=*/ValueRange{size}, + /*step=*/ValueRange{stride}); } //===----------------------------------------------------------------------===// // Public methods. //===----------------------------------------------------------------------===// -constexpr uint64_t kDataFieldStartingIdx = 0; +constexpr FieldIndex kDataFieldStartingIdx = 0; void sparse_tensor::foreachFieldInSparseTensor( const SparseTensorEncodingAttr enc, - llvm::function_ref callback) { assert(enc); @@ -146,23 +143,22 @@ if (!(callback(idx, kind, dim, dlt))) \ return; - unsigned rank = enc.getDimLevelType().size(); - unsigned end = getCOOStart(enc); - if (end != rank) - end += 1; - static_assert(kDataFieldStartingIdx == 0); - unsigned fieldIdx = kDataFieldStartingIdx; + const auto lvlTypes = enc.getDimLevelType(); + const Level lvlRank = enc.getLvlRank(); + const Level cooStart = getCOOStart(enc); + const Level end = cooStart == lvlRank ? cooStart : cooStart + 1; + FieldIndex fieldIdx = kDataFieldStartingIdx; // Per-dimension storage. - for (unsigned r = 0; r < end; r++) { + for (Level l = 0; l < end; l++) { // Dimension level types apply in order to the reordered dimension. // As a result, the compound type can be constructed directly in the given // order. - auto dlt = getDimLevelType(enc, r); + const auto dlt = lvlTypes[l]; if (isCompressedDLT(dlt)) { - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PtrMemRef, r, dlt); - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PtrMemRef, l, dlt); + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, l, dlt); } else if (isSingletonDLT(dlt)) { - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, l, dlt); } else { assert(isDenseDLT(dlt)); // no fields } @@ -179,16 +175,16 @@ } void sparse_tensor::foreachFieldAndTypeInSparseTensor( - RankedTensorType rType, - llvm::function_ref callback) { - auto enc = getSparseTensorEncoding(rType); + const auto enc = stt.getEncoding(); assert(enc); // Construct the basic types. Type idxType = enc.getIndexType(); Type ptrType = enc.getPointerType(); - Type eltType = rType.getElementType(); + Type eltType = stt.getElementType(); Type metaDataType = StorageSpecifierType::get(enc); // memref pointers @@ -201,17 +197,17 @@ foreachFieldInSparseTensor( enc, [metaDataType, ptrMemType, idxMemType, valMemType, - callback](unsigned fieldIdx, SparseTensorFieldKind fieldKind, - unsigned dim, DimLevelType dlt) -> bool { + callback](FieldIndex fieldIdx, SparseTensorFieldKind fieldKind, + Level lvl, DimLevelType dlt) -> bool { switch (fieldKind) { case SparseTensorFieldKind::StorageSpec: - return callback(metaDataType, fieldIdx, fieldKind, dim, dlt); + return callback(metaDataType, fieldIdx, fieldKind, lvl, dlt); case SparseTensorFieldKind::PtrMemRef: - return callback(ptrMemType, fieldIdx, fieldKind, dim, dlt); + return callback(ptrMemType, fieldIdx, fieldKind, lvl, dlt); case SparseTensorFieldKind::IdxMemRef: - return callback(idxMemType, fieldIdx, fieldKind, dim, dlt); + return callback(idxMemType, fieldIdx, fieldKind, lvl, dlt); case SparseTensorFieldKind::ValMemRef: - return callback(valMemType, fieldIdx, fieldKind, dim, dlt); + return callback(valMemType, fieldIdx, fieldKind, lvl, dlt); }; llvm_unreachable("unrecognized field kind"); }); @@ -220,8 +216,8 @@ unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) { unsigned numFields = 0; foreachFieldInSparseTensor(enc, - [&numFields](unsigned, SparseTensorFieldKind, - unsigned, DimLevelType) -> bool { + [&numFields](FieldIndex, SparseTensorFieldKind, + Level, DimLevelType) -> bool { numFields++; return true; }); @@ -232,8 +228,9 @@ sparse_tensor::getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) { unsigned numFields = 0; // one value memref foreachFieldInSparseTensor(enc, - [&numFields](unsigned fidx, SparseTensorFieldKind, - unsigned, DimLevelType) -> bool { + [&numFields](FieldIndex fidx, + SparseTensorFieldKind, Level, + DimLevelType) -> bool { if (fidx >= kDataFieldStartingIdx) numFields++; return true; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -26,6 +26,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/SparseTensor/Utils/Merger.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -225,18 +226,20 @@ static unsigned getNumCompoundAffineOnSparseDims(AffineMap affineMap, Value tensor) { unsigned num = 0; - auto enc = getSparseTensorEncoding(tensor.getType()); + const auto enc = getSparseTensorEncoding(tensor.getType()); if (enc) { - ArrayRef exps = affineMap.getResults(); - for (unsigned rank = 0; rank < exps.size(); rank++) { - auto aidx = toOrigDim(enc, rank); - auto affine = exps[aidx]; - if (!affine.isa()) - if (!isDenseDLT(getDimLevelType(enc, rank))) - num++; + const ArrayRef exps = affineMap.getResults(); + const Level lvlRank = enc.getLvlRank(); + assert(static_cast(exps.size()) == lvlRank); + for (Level l = 0; l < lvlRank; l++) { + // FIXME: `toOrigDim` is deprecated. + const Dimension d = toOrigDim(enc, l); + // FIXME: there's some dim/lvl confusion here; since `d` isn't + // guaranteed to be in bounds (for non-permutations). + if (!exps[d].isa() && !enc.isDenseLvl(l)) + num++; } } - return num; } @@ -252,10 +255,8 @@ static bool hasCompoundAffineOnSparseOut(linalg::GenericOp op) { OpOperand *out = op.getDpsInitOperand(0); - auto enc = getSparseTensorEncoding(out->get().getType()); - if (!enc || enc.isAllDense()) + if (getSparseTensorType(out->get()).isAllDense()) return false; - return getNumCompoundAffineOnSparseDims(op.getMatchingIndexingMap(out), out->get()); } @@ -269,16 +270,18 @@ bool annotated = false; unsigned filterLdx = env.merger().getFilterLoopStartingIdx(); for (OpOperand &t : env.op()->getOpOperands()) { - auto map = env.op().getMatchingIndexingMap(&t); - auto enc = getSparseTensorEncoding(t.get().getType()); + const auto map = env.op().getMatchingIndexingMap(&t); + const auto enc = getSparseTensorEncoding(t.get().getType()); if (enc) annotated = true; - assert(map.getNumResults() == env.op().getRank(&t)); - for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { - unsigned tensor = t.getOperandNumber(); - AffineExpr a = map.getResult(toOrigDim(enc, d)); - if (!findAffine(env.merger(), tensor, d, a, getDimLevelType(enc, d), - filterLdx)) + const Level lvlRank = map.getNumResults(); + assert(!enc || lvlRank == enc.getLvlRank()); + assert(env.op().getRank(&t) == lvlRank); + for (Level l = 0; l < lvlRank; l++) { + const unsigned tensor = t.getOperandNumber(); + // FIXME: `toOrigDim` is deprecated. + const AffineExpr a = map.getResult(toOrigDim(enc, l)); + if (!findAffine(env.merger(), tensor, l, a, enc.getLvlType(l), filterLdx)) return false; // inadmissible affine expression } } @@ -440,15 +443,15 @@ OpOperand *skip = nullptr) { // Set up an n x n from/to adjacency matrix of the iteration graph // for the implicit loop indices i_0 .. i_n-1. - unsigned n = env.merger().getNumLoops(); + const unsigned n = env.merger().getNumLoops(); std::vector> adjM(n, std::vector(n, false)); std::vector inDegree(n, 0); // in-degree of each node. - auto iteratorTypes = env.op().getIteratorTypesArray(); + const auto iteratorTypes = env.op().getIteratorTypesArray(); // Iterate over the indexing maps of every tensor in the tensor expression. for (OpOperand &t : env.op()->getOpOperands()) { // Get map and encoding. - auto map = env.op().getMatchingIndexingMap(&t); - auto enc = getSparseTensorEncoding(t.get().getType()); + const auto map = env.op().getMatchingIndexingMap(&t); + const auto enc = getSparseTensorEncoding(t.get().getType()); assert(map.getNumDims() + getNumCompoundAffineOnSparseDims(env.op()) == n); // Skip dense tensor constraints when not requested. if (!(mask & SortMask::kIncludeDense) && !enc) @@ -457,16 +460,18 @@ // by default) puts an ordering constraint on the loop indices. For // example, the tensor expresion A_ijk forces the ordering i < j < k // on the loop indices if no explicit dimension ordering is given. - for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { - AffineExpr ta = map.getResult(toOrigDim(enc, d)); + const Level lvlRank = map.getNumResults(); + assert(!enc || lvlRank == enc.getLvlRank()); + for (Level l = 0; l < lvlRank; l++) { + // FIXME: `toOrigDim` is deprecated. + AffineExpr ta = map.getResult(toOrigDim(enc, l)); std::optional tldx = - env.merger().getLoopIdx(t.getOperandNumber(), d); + env.merger().getLoopIdx(t.getOperandNumber(), l); // Filter loops should be constructed after all the dependent loops, // i.e., d0 + d1 < filter_loop(d0 + d1) if (tldx && env.merger().isFilterLoop(*tldx)) { - assert(!ta.isa() && - !isDenseDLT(getDimLevelType(enc, d))); + assert(!ta.isa() && !isDenseDLT(enc.getLvlType(l))); addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt, tldx); // Now that the ordering of affine expression is captured by filter @@ -481,10 +486,11 @@ if (&t == skip) continue; - if (d > 0) { - AffineExpr fa = map.getResult(toOrigDim(enc, d - 1)); + if (l > 0) { + // FIXME: `toOrigDim` is deprecated. + AffineExpr fa = map.getResult(toOrigDim(enc, l - 1)); std::optional fldx = - env.merger().getLoopIdx(t.getOperandNumber(), d - 1); + env.merger().getLoopIdx(t.getOperandNumber(), l - 1); // Applying order constraints on every pair of dimExpr between two // compound affine expressions can sometime too strict: @@ -576,8 +582,11 @@ /// Generates index for load/store on sparse tensor. static Value genIndex(CodegenEnv &env, OpOperand *t) { auto map = env.op().getMatchingIndexingMap(t); - auto enc = getSparseTensorEncoding(t->get().getType()); - AffineExpr a = map.getResult(toOrigDim(enc, map.getNumResults() - 1)); + const auto stt = getSparseTensorType(t->get()); + const Level lvlRank = stt.getLvlRank(); + assert(static_cast(map.getNumResults()) == lvlRank); + // FIXME: `toOrigDim` is deprecated. + AffineExpr a = map.getResult(toOrigDim(stt.getEncoding(), lvlRank - 1)); assert(a.getKind() == AffineExprKind::DimId); unsigned idx = a.cast().getPosition(); return env.getLoopIdxValue(idx); @@ -589,15 +598,16 @@ linalg::GenericOp op = env.op(); unsigned tensor = t->getOperandNumber(); auto map = op.getMatchingIndexingMap(t); - auto enc = getSparseTensorEncoding(t->get().getType()); - unsigned rank = map.getNumResults(); - if (enc) { + const auto stt = getSparseTensorType(t->get()); + if (stt.hasEncoding()) { Value pidx = env.emitter().getPidxs()[tensor].back(); assert(pidx); args.push_back(pidx); // position index } else { - for (unsigned d = 0; d < rank; d++) { - AffineExpr a = map.getResult(d); + const Level lvlRank = stt.getLvlRank(); + assert(static_cast(map.getNumResults()) == lvlRank); + for (Level l = 0; l < lvlRank; l++) { + AffineExpr a = map.getResult(l); args.push_back(env.emitter().genAffine(builder, a, op.getLoc())); } } @@ -861,11 +871,14 @@ linalg::GenericOp op = env.op(); OpOperand &t = op->getOpOperand(env.exp(exp).tensor); auto map = op.getMatchingIndexingMap(&t); - auto enc = getSparseTensorEncoding(t.get().getType()); - for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { - AffineExpr a = map.getResult(toOrigDim(enc, d)); + const auto stt = getSparseTensorType(t.get()); + const Level lvlRank = stt.getLvlRank(); + assert(static_cast(map.getNumResults()) == lvlRank); + for (Level l = 0; l < lvlRank; l++) { + // FIXME: `toOrigDim` is deprecated. + AffineExpr a = map.getResult(toOrigDim(stt.getEncoding(), l)); std::optional sldx = - env.merger().getLoopIdx(t.getOperandNumber(), d); + env.merger().getLoopIdx(t.getOperandNumber(), l); if (sldx && env.merger().isFilterLoop(*sldx)) { if (!env.getLoopIdxValue(*sldx)) // The filter loops has not been constructed. @@ -1002,6 +1015,7 @@ OpOperand *t = &op->getOpOperand(tid); auto enc = getSparseTensorEncoding(t->get().getType()); // Retrieves the affine expression for the filter loop. + // FIXME: `toOrigDim` is deprecated. AffineExpr a = op.getMatchingIndexingMap(t).getResult(toOrigDim(enc, dim)); return env.emitter().enterFilterLoopOverTensorAtDim(builder, loc, tid, @@ -1192,20 +1206,22 @@ static void genConstantDenseAddressFromLevel(CodegenEnv &env, OpBuilder &builder, unsigned tid, - unsigned lvl) { + Level lvl) { // TODO: Handle affine expression on output tensor. linalg::GenericOp op = env.op(); assert(tid < op.getNumDpsInputs()); OpOperand *input = op.getDpsInputOperands()[tid]; ArrayRef affines = op.getMatchingIndexingMap(input).getResults(); - auto enc = getSparseTensorEncoding(input->get().getType()); + const auto enc = getSparseTensorEncoding(input->get().getType()); if (enc) { - for (unsigned i = lvl, e = affines.size(); i < e; i++) { - AffineExpr affine = affines[toOrigDim(enc, i)]; - if (isDenseDLT(getDimLevelType(enc, i)) && - affine.isa()) + const Level lvlRank = enc.getLvlRank(); + assert(affines.size() == static_cast(lvlRank)); + for (Level l = lvl; l < lvlRank; l++) { + // FIXME: `toOrigDim` is deprecated. + AffineExpr affine = affines[toOrigDim(enc, l)]; + if (enc.isDenseLvl(l) && affine.isa()) env.emitter().genDenseAffineAddressAtCurLevel( - builder, op.getLoc(), input->getOperandNumber(), i, affine); + builder, op.getLoc(), input->getOperandNumber(), l, affine); else return; // break on first non-dense non-constant level } @@ -1262,20 +1278,21 @@ // We only handle affine expression on input tensors (for now). return; OpOperand *operand = &op->getOpOperand(tid); - auto enc = getSparseTensorEncoding(operand->get().getType()); + const auto stt = getSparseTensorType(operand->get()); // Non-annotated dense tensors requires no special handling. - if (!enc) + if (!stt.hasEncoding()) return; ArrayRef affines = op.getMatchingIndexingMap(operand).getResults(); - assert(affines.size() == enc.getDimLevelType().size()); - for (unsigned i = 0, e = affines.size(); i < e; i++) { - AffineExpr exp = affines[toOrigDim(enc, i)]; + const Level lvlRank = stt.getLvlRank(); + assert(affines.size() == static_cast(lvlRank)); + for (Level l = 0; l < lvlRank; l++) { + // FIXME: `toOrigDim` is deprecated. + AffineExpr exp = affines[toOrigDim(stt.getEncoding(), l)]; // Skip simple affine expression and non dense dimensions (which has // it own filter loop). - if (exp.isa() || - !isDenseDLT(getDimLevelType(enc, i))) + if (exp.isa() || !stt.isDenseLvl(l)) continue; // Constant affine expression are handled in genLoop @@ -1292,7 +1309,7 @@ // might be accepting out-of-order access between consecutive // dense levels. affineTids.push_back(tid); - affineDims.push_back(i); + affineDims.push_back(l); exps.push_back(exp); } } diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -227,7 +227,7 @@ #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 { - // expected-error@+1 {{redundant dimension argument for querying value memory size}} + // expected-error@+1 {{redundant level argument for querying value memory size}} %0 = sparse_tensor.storage_specifier.get %arg0 val_mem_sz at 0 : !sparse_tensor.storage_specifier<#SparseVector> to i64 return %0 : i64 @@ -238,7 +238,7 @@ #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 { - // expected-error@+1 {{missing dimension argument}} + // expected-error@+1 {{missing level argument}} %0 = sparse_tensor.storage_specifier.get %arg0 idx_mem_sz : !sparse_tensor.storage_specifier<#SparseVector> to i64 return %0 : i64 @@ -249,7 +249,7 @@ #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 { - // expected-error@+1 {{requested dimension out of bound}} + // expected-error@+1 {{requested level out of bound}} %0 = sparse_tensor.storage_specifier.get %arg0 dim_sz at 1 : !sparse_tensor.storage_specifier<#SparseVector> to i64 return %0 : i64 @@ -654,7 +654,7 @@ func.func @invalid_concat_dim(%arg0: tensor<2x4xf64, #DC>, %arg1: tensor<3x4xf64, #DC>, %arg2: tensor<4x4xf64, #DC>) -> tensor<9x4xf64, #DC> { - // expected-error@+1 {{Failed to concatentate tensors with rank=2 on dimension=4}} + // expected-error@+1 {{Concat-dimension is out of bounds for dimension-rank (4 >= 2)}} %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 4 : index} : tensor<2x4xf64, #DC>, tensor<3x4xf64, #DC>, @@ -670,7 +670,7 @@ func.func @invalid_concat_rank_mismatch(%arg0: tensor<2xf64, #C>, %arg1: tensor<3x4xf64, #DC>, %arg2: tensor<4x4x4xf64, #DCC>) -> tensor<9x4xf64, #DC> { - // expected-error@+1 {{The input tensor $0 has a different rank (rank=1) from the output tensor (rank=2)}} + // expected-error@+1 {{Input tensor $0 has a different rank (rank=1) from the output tensor (rank=2)}} %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index} : tensor<2xf64, #C>, tensor<3x4xf64, #DC>, @@ -684,7 +684,7 @@ func.func @invalid_concat_size_mismatch_dyn(%arg0: tensor, %arg1: tensor<5x4xf64, #DC>, %arg2: tensor<4x4xf64, #DC>) -> tensor<9x4xf64, #DC> { - // expected-error@+1 {{Only statically-sized input tensors are supported.}} + // expected-error@+1 {{Input tensor $0 has dynamic shape}} %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index} : tensor, tensor<5x4xf64, #DC>, diff --git a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir --- a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir @@ -1,7 +1,8 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics +// expected-error@+1 {{expected a non-empty array for level types}} #a = #sparse_tensor.encoding<{dimLevelType = []}> -func.func private @scalar(%arg0: tensor) -> () // expected-error {{expected non-scalar sparse tensor}} +func.func private @scalar(%arg0: tensor) -> () // ----- @@ -35,7 +36,8 @@ // ----- -#a = #sparse_tensor.encoding<{dimOrdering = affine_map<(i,j) -> (i,i)>}> // expected-error {{expected a permutation affine map for dimension ordering}} +// expected-error@+1 {{expected a permutation affine map for dimension ordering}} +#a = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"], dimOrdering = affine_map<(i,j) -> (i,i)>}> func.func private @tensor_no_permutation(%arg0: tensor<16x32xf32, #a>) -> () // ----- diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2175,7 +2175,10 @@ cc_library( name = "SparseTensorDialect", srcs = ["lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp"], - hdrs = ["include/mlir/Dialect/SparseTensor/IR/SparseTensor.h"], + hdrs = [ + "include/mlir/Dialect/SparseTensor/IR/SparseTensor.h", + "include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h", + ], includes = ["include"], deps = [ ":ArithDialect",