diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -124,7 +124,7 @@ } bool SparseTensorEncodingAttr::hasIdDimOrdering() const { - return !getDimOrdering() || getDimOrdering().isIdentity(); + return !getImpl() || !getDimOrdering() || getDimOrdering().isIdentity(); } std::optional diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ BufferizableOpInterfaceImpl.cpp CodegenEnv.cpp CodegenUtils.cpp + DimLvlMapping.cpp LoopEmitter.cpp SparseBufferRewriting.cpp SparseStorageSpecifierToLLVM.cpp @@ -42,3 +43,12 @@ MLIRTransforms MLIRVectorDialect ) + +# To make sure we adhere to the style guide: +# +check_cxx_compiler_flag(-Wweak-vtables + COMPILER_SUPPORTS_WARNING_WEAK_VTABLES) +if(COMPILER_SUPPORTS_WARNING_WEAK_VTABLES) + target_compile_options(MLIRSparseTensorTransforms PUBLIC + "-Wweak-vtables") +endif() diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/DimLvlMapping.h b/mlir/lib/Dialect/SparseTensor/Transforms/DimLvlMapping.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Transforms/DimLvlMapping.h @@ -0,0 +1,311 @@ +//===- DimLvlMapping.h - Utilities for dimension<->level --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines classes for managing all aspects of conversion +// between dimensions and levels. The classes are designed to be used +// by both SparseTensorConversionPass and SparseTensorCodegenPass. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_DIMLVLMAPPING_H_ +#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_DIMLVLMAPPING_H_ + +#include "CodegenUtils.h" + +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" + +namespace mlir { +namespace sparse_tensor { + +//===----------------------------------------------------------------------===// +// +// Type aliases to help code be more self-documenting. Unfortunately +// these are not type-checked, so they only provide documentation rather +// than doing anything to prevent mixups. +// +//===----------------------------------------------------------------------===// + +/// The type of dimension identifiers, and dimension-ranks. We use the +/// same type for both identifiers and ranks because the latter are used +/// mainly for ordering-comparisons against the former (just like how the +/// one-past-the-end iterators are used). +using Dimension = uint64_t; + +/// The type of level identifiers, and level-ranks. We use the same +/// type for both identifiers and ranks because the latter are used +/// mainly for ordering-comparisons against the former (just like how +/// the one-past-the-end iterators are used). +using Level = uint64_t; + +/// The type for individual components of a compile-time shape. We avoid +/// calling this "size" because we use the term "sizes" to indicate the +/// actual run-time sizes, whereas this type also allows the value +/// `ShapedType::kDynamic`. +using Ship = int64_t; + +//===----------------------------------------------------------------------===// +// +// Helper functions. +// +//===----------------------------------------------------------------------===// + +/// Constructs a `constantIndex` when the potentially-dynamic size is +/// not in fact dynamic. When it is dynamic, returns the nullptr instead. +inline Value constantSize(OpBuilder &builder, Location loc, Ship s) { + if (ShapedType::isDynamic(s)) + return nullptr; + assert(s > 0); + return constantIndex(builder, loc, s); +} + +//===----------------------------------------------------------------------===// +/// A class for encapsulating and managing non-codegen aspects of +/// converting between dimensions and levels. +/// +/// This class memoizes the level-sizes to avoid redundant computation. +/// We mark the methods `getLvlShape`, `getLvlShip`, etc as "const" +/// to indicate intent and ensure a uniform API (with respect to the +/// dimension variants of those methods). However, since these methods +/// may update the memo table, they are only logically-const not +/// immutably-const. Therefore, this class is not threadsafe (for now). +class DimLvlMapping { +public: + DimLvlMapping(RankedTensorType rtp) + : DimLvlMapping(rtp, getSparseTensorEncoding(rtp)) {} + + // We memoize `lvlRank` and `dim2lvl` to avoid redundant checks later on. + DimLvlMapping(ShapedType stp, SparseTensorEncodingAttr enc) + : stp(stp), enc(enc), + lvlRank(isSparse() ? enc.getDimLevelType().size() : getDimRank()), + dim2lvl(enc.hasIdDimOrdering() ? AffineMap() : enc.getDimOrdering()), + lvlShape(isIdentity() ? 0 : lvlRank, kEmptyMemo) { + assert((!isIdentity() || getDimRank() == lvlRank) && "Rank mismatch"); + } + + /// Constructs a new `DimLvlMapping` for `getEncoding().withoutOrdering()`. + DimLvlMapping withoutOrdering() const { + return DimLvlMapping(stp, enc.withoutOrdering()); + } + + MLIRContext *getContext() const { return stp.getContext(); } + + ShapedType getShapedType() const { return stp; } + + Type getElementType() const { return stp.getElementType(); } + + /// Returns the encoding (or nullptr for dense tensors). + SparseTensorEncodingAttr getEncoding() const { return enc; } + + /// Returns true for tensors which have an encoding, and false for + /// those which do not. Therefore tensors with an all-dense encoding + /// return true. + bool isSparse() const { return static_cast(enc); } + + /// Returns true for tensors which do not have an encoding, and false + /// for tensors which do. Therefore tensors with an all-dense encoding + /// return false. + bool isDense() const { return !enc; } + + /// Returns true if the dimToLvl mapping is the identity. + bool isIdentity() const { return !dim2lvl; } + + /// Returns the dimToLvl mapping (or nullptr for the identity). + AffineMap getDimToLvlMap() const { return dim2lvl; } + + /// Returns the dimToLvl mapping, where the identity map is expanded out + /// into a full `AffineMap`. This method is provided as a convenience, + /// but for most purposes other methods (`isIdentity`, `getDimToLvlMap`, + /// `getLvlShape`, etc) will be more helpful. + AffineMap getExpandedDimToLvlMap() const { + return dim2lvl + ? dim2lvl + : AffineMap::getMultiDimIdentityMap(getDimRank(), getContext()); + } + + /// Returns the dimension-rank. + Dimension getDimRank() const { return stp.getRank(); } + + /// Returns the level-rank. + Level getLvlRank() const { return lvlRank; } + + /// Returns the dimension-shape. + ArrayRef getDimShape() const { return stp.getShape(); } + + /// Returns the level-shape. + ArrayRef getLvlShape() const; + + /// Returns the requested dimension-ship. + Ship getDimShip(Dimension d) const { + assert(d < getDimRank()); + return getDimShape()[d]; + } + + /// Returns the requested level-ship. + Ship getLvlShip(Level l) const { + if (isIdentity()) + return getDimShip(l); + assert(l < lvlRank); // Ensure precondition of `setLvlShip`. + setLvlShip(l); + return lvlShape[l]; + } + + /// Returns true if the dimension-shape is static. + bool hasStaticDimShape() const { return stp.hasStaticShape(); } + + /// Returns true if the level-shape is static. + bool hasStaticLvlShape() const { + return !ShapedType::isDynamicShape(getLvlShape()); + } + + /// Returns true if the dimension-ship is dynamic. + bool isDynamicDim(Dimension d) const { + return ShapedType::isDynamic(getDimShip(d)); + } + + /// Returns true if the level-ship is dynamic. + bool isDynamicLvl(Level l) const { + return ShapedType::isDynamic(getLvlShip(l)); + } + +private: + /// Computes the requested level-ship, without performing any assertions + /// or checks that `getLvlShape` would want to hoist out of the loop. + /// + /// Preconditions: + /// * `dim2lvl` is non-null + /// * `l` is valid for `lvlRank`. + void setLvlShip(Level l) const; + + // The value indicating level-ships which have not yet been memoized. + // Since we disallow zero as a level-size, that means we can safely + // use zero to indicate the lack of a memo. + static constexpr Ship kEmptyMemo = 0; + + const ShapedType stp; + const SparseTensorEncodingAttr enc; + // Memoized to avoid frequent redundant conditionals. + const Level lvlRank; + const AffineMap dim2lvl; + // Memoized to avoid redundant computation. + mutable SmallVector lvlShape; +}; + +//===----------------------------------------------------------------------===// +/// Abstract base class for performing codegen associated with `DimLvlMapping`. +class DimLvlBuilder : public DimLvlMapping { + /// Out-of-line virtual method to ensure we avoid weak-vtables: + /// + virtual void anchor(); + +public: + DimLvlBuilder(OpBuilder &builder, Location loc, ShapedType stp, + SparseTensorEncodingAttr enc) + : DimLvlMapping(stp, enc), builder(builder), loc(loc) {} + + DimLvlBuilder(OpBuilder &builder, Location loc, RankedTensorType rtp) + : DimLvlMapping(rtp), builder(builder), loc(loc) {} + + DimLvlBuilder(OpBuilder &builder, Location loc, DimLvlMapping dlm) + : DimLvlMapping(dlm), builder(builder), loc(loc) {} + +protected: + // Since this class is virtual, we must disallow public copying in + // order to avoid "slicing". Since this class has data members, + // that means making copying protected. + // + DimLvlBuilder(const DimLvlBuilder &) = default; + // Copy-assignment would be implicitly deleted (because the base class + // has const fields), so we explicitly delete it for clarity. + DimLvlBuilder &operator=(const DimLvlBuilder &) = delete; + +public: + virtual ~DimLvlBuilder() = default; + + /// Generates code to lookup a dimension-size. This should only + /// generate the lookup itself, and not perform any dim<->lvl + /// conversion or other logic. + virtual Value lookupDimSizeImpl(Value tensor, Dimension d) const = 0; + + /// Generates code to lookup a level-size. This should only + /// generate the lookup itself, and not perform any dim<->lvl + /// conversion or other logic. + virtual Value lookupLvlSizeImpl(Value tensor, Level l) const = 0; + + /// Looks up a dimension-size by returning a constant from the shape + /// (for static sizes), or by calling `lookupDimSizeImpl` (for dynamic + /// sizes of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic + /// sizes of dense tensors). + Value lookupDimSize(Value tensor, Dimension d) const { + const auto sz = constantSize(builder, loc, getDimShip(d)); + return sz ? sz + : isSparse() ? lookupDimSizeImpl(tensor, d) + : linalg::createOrFoldDimOp(builder, loc, tensor, d); + } + + /// Looks up a level-size by returning a statically-computed constant + /// (when possible), or by calling `lookupLvlSizeImpl` (when dynamic). + Value lookupLvlSize(Value tensor, Level l) const { + const auto sz = constantSize(builder, loc, getLvlShip(l)); + return sz ? sz : lookupLvlSizeImpl(tensor, l); + } + + /// Populates the array with the dimension-shape of the `ShapedType`, + /// where dynamic sizes are represented by zero. + void reflectDimShape(SmallVectorImpl &out) const; + + /// Returns an array with the dimension-shape of the `ShapedType`, + /// where dynamic sizes are represented by zero. + SmallVector reflectDimShape() const { + SmallVector out; + reflectDimShape(out); + return out; + } + + /// Populates the array with the dimension-sizes of the given tensor. + void lookupDimSizes(Value tensor, SmallVectorImpl &out) const; + + /// Returns an array with the dimension-sizes of the given tensor. + SmallVector lookupDimSizes(Value tensor) const { + SmallVector out; + lookupDimSizes(tensor, out); + return out; + } + + /// Populates the array with the level-sizes of the given tensor. + void lookupLvlSizes(Value tensor, SmallVectorImpl &out) const; + + /// Returns an array with the level-sizes of the given tensor. + SmallVector lookupLvlSizes(Value tensor) const { + SmallVector out; + lookupLvlSizes(tensor, out); + return out; + } + + /// Generates code for computing level-sizes from the given + /// dimension-sizes, and populates the array with the results. + void computeLvlSizes(ValueRange dimSizes, SmallVectorImpl &out) const; + + /// Returns an array with the level-sizes for the given dimension-sizes. + SmallVector computeLvlSizes(ValueRange dimSizes) const { + // TODO: when `isIdentity`, we'd like to just return `dimSizes` + // directly rather than copying. + SmallVector out; + computeLvlSizes(dimSizes, out); + return out; + } + +protected: + OpBuilder &builder; + Location loc; +}; + +} // namespace sparse_tensor +} // namespace mlir + +#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_DIMLVLMAPPING_H_ diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/DimLvlMapping.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/DimLvlMapping.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Transforms/DimLvlMapping.cpp @@ -0,0 +1,104 @@ +//===- DimLvlMapping.cpp - Utilities for dimension<->level ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines classes for managing all aspects of conversion +// between dimensions and levels. The classes are designed to be used +// by both SparseTensorConversionPass and SparseTensorCodegenPass. +// +//===----------------------------------------------------------------------===// + +#include "DimLvlMapping.h" +#include "CodegenUtils.h" + +using namespace mlir; +using namespace mlir::sparse_tensor; + +//===----------------------------------------------------------------------===// + +ArrayRef DimLvlMapping::getLvlShape() const { + if (isIdentity()) + return getDimShape(); + for (Level l = 0; l < lvlRank; ++l) + setLvlShip(l); + return lvlShape; +} + +// TODO: Figure out how to make `setLvlShip` and `computeLvlSizes` share logic. +void DimLvlMapping::setLvlShip(Level l) const { + if (lvlShape[l] != kEmptyMemo) + return; // Already been set. + // TODO: The following implementation only handles permutations; + // we need to generalize this to handle arbitrary AffineExpr. + // + // There's no need to assert `isPermutation` here: because + // `getDimPosition` checks that the expr isa `AffineDimExpr`, + // which is all we care about (for supporting permutations). + lvlShape[l] = getDimShip(dim2lvl.getDimPosition(l)); + // Ensure that the computation doesn't result in a value that + // coincides with the empty memo. + assert(lvlShape[l] != kEmptyMemo); +} + +//===----------------------------------------------------------------------===// +/// Out-of-line virtual method to ensure we avoid weak-vtables: +/// +void DimLvlBuilder::anchor() {} + +void DimLvlBuilder::reflectDimShape(SmallVectorImpl &out) const { + out.clear(); + out.reserve(getDimRank()); + const auto zero = constantIndex(builder, loc, 0); + for (auto s : getDimShape()) { + const auto sz = constantSize(builder, loc, s); + out.push_back(sz ? sz : zero); + } +} + +void DimLvlBuilder::lookupDimSizes(Value tensor, + SmallVectorImpl &out) const { + const Dimension dimRank = getDimRank(); + out.clear(); + out.reserve(dimRank); + for (Dimension d = 0; d < dimRank; ++d) + out.push_back(lookupDimSize(tensor, d)); +} + +void DimLvlBuilder::lookupLvlSizes(Value tensor, + SmallVectorImpl &out) const { + out.clear(); + out.reserve(getLvlRank()); + // We inline `lookupLvlSize` to avoid the redundant checks of its + // underlying call to `getLvlShip`. + for (const auto &l : llvm::enumerate(getLvlShape())) { + const auto sz = constantSize(builder, loc, l.value()); + out.push_back(sz ? sz : lookupLvlSizeImpl(tensor, l.index())); + } +} + +// TODO: Figure out how to make `setLvlShip` and `computeLvlSizes` share logic. +void DimLvlBuilder::computeLvlSizes(ValueRange dimSizes, + SmallVectorImpl &out) const { + out.clear(); + out.reserve(getLvlRank()); + if (isIdentity()) { + for (auto sz : dimSizes) + out.push_back(sz); + return; + } + const auto dim2lvl = getDimToLvlMap(); + // TODO: The following implementation only handles permutations; + // we need to generalize this to handle arbitrary AffineExpr. + // + // There's no need to assert `isPermutation` before calling `getDimPosition`, + // because the latter already checks that the expr isa `AffineDimExpr` + // which is all we care about (for supporting permutations). + for (const auto &l : llvm::enumerate(getLvlShape())) { + const auto sz = constantSize(builder, loc, l.value()); + out.push_back(sz ? sz : dimSizes[dim2lvl.getDimPosition(l.index())]); + } +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -17,6 +17,7 @@ //===----------------------------------------------------------------------===// #include "CodegenUtils.h" +#include "DimLvlMapping.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -38,6 +39,11 @@ // Helper methods. //===----------------------------------------------------------------------===// +template +static RankedTensorType getRankedTensorType(T t) { + return t.getType().template cast(); +} + /// Maps each sparse tensor type to an opaque pointer. static std::optional convertSparseTensorTypes(Type type) { if (getSparseTensorEncoding(type) != nullptr) @@ -57,144 +63,54 @@ operands); } -/// Generates call to lookup a level-size. N.B., this only generates -/// the raw function call, and therefore (intentionally) does not perform -/// any dim<->lvl conversion or other logic. -static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor, - uint64_t lvl) { - StringRef name = "sparseLvlSize"; - SmallVector params{tensor, constantIndex(builder, loc, lvl)}; - Type iTp = builder.getIndexType(); - return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off) - .getResult(0); -} +/// Specialize `DimLvlBuilder` for generating calls into the runtime library. +class RuntimeDimLvlBuilder final : public DimLvlBuilder { +public: + using DimLvlBuilder::DimLvlBuilder; -/// Generates call to lookup a dimension-size. N.B., this only generates -/// the raw function call, and therefore (intentionally) does not perform -/// any dim<->lvl conversion or other logic. -static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor, - uint64_t dim) { - StringRef name = "sparseDimSize"; - SmallVector params{tensor, constantIndex(builder, loc, dim)}; - Type iTp = builder.getIndexType(); - return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off) - .getResult(0); -} + ~RuntimeDimLvlBuilder() final = default; -/// Looks up a level-size by returning a statically-computed constant -/// (when possible), or by calling `genLvlSizeCall` (when dynamic). -static Value createOrFoldLvlCall(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, ShapedType stp, - Value tensor, unsigned lvl) { - // Only sparse tensors have "levels" to query. - assert(enc); - auto dimOrder = enc.getDimOrdering(); - // TODO: The following implementation only handles permutations; - // we'll need to generalize this to handle arbitrary AffineExpr. - // - // There's no need to assert `isPermutation` here: because - // `getDimPosition` checks that the expr isa `AffineDimExpr`, - // which is all we care about (for supporting permutations). - unsigned dim = dimOrder ? dimOrder.getDimPosition(lvl) : lvl; - auto s = stp.getShape()[dim]; - if (s != ShapedType::kDynamic) - return constantIndex(builder, loc, s); - // If we cannot statically compute the size from the shape, then we - // must dynamically query it. (In principle we could also dynamically - // compute it, but since we already did so to construct the `tensor` - // in the first place, we might as well query rather than recompute.) - return genLvlSizeCall(builder, loc, tensor, lvl); -} - -/// Looks up a dimension-size by returning a constant from the shape -/// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes -/// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes -/// of dense tensors). -static Value createOrFoldDimCall(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, ShapedType stp, - Value tensor, unsigned dim) { - auto s = stp.getShape()[dim]; - if (s != ShapedType::kDynamic) - return constantIndex(builder, loc, s); - if (enc) - return genDimSizeCall(builder, loc, tensor, dim); - return linalg::createOrFoldDimOp(builder, loc, tensor, dim); -} - -/// Populates the array with the dimension-sizes of the given tensor. -static void fillDimSizes(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, ShapedType stp, - Value tensor, SmallVectorImpl &out) { - unsigned dimRank = stp.getRank(); - out.reserve(dimRank); - for (unsigned d = 0; d < dimRank; d++) - out.push_back(createOrFoldDimCall(builder, loc, enc, stp, tensor, d)); -} - -/// Returns an array with the dimension-sizes of the given tensor. -static SmallVector getDimSizes(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, - ShapedType stp, Value tensor) { - SmallVector out; - fillDimSizes(builder, loc, enc, stp, tensor, out); - return out; -} - -/// Populates the array with the dimension-shape of the given `ShapedType`, -/// where dynamic sizes are represented by zero. -static void fillDimShape(OpBuilder &builder, Location loc, ShapedType stp, - SmallVectorImpl &out) { - auto shape = stp.getShape(); - unsigned dimRank = stp.getRank(); - out.reserve(dimRank); - for (unsigned d = 0; d < dimRank; d++) { - auto s = shape[d] == ShapedType::kDynamic ? 0 : shape[d]; - out.push_back(constantIndex(builder, loc, s)); + Value lookupLvlSizeImpl(Value tensor, Level lvl) const final { + StringRef name = "sparseLvlSize"; + SmallVector params{tensor, constantIndex(builder, loc, lvl)}; + Type iTp = builder.getIndexType(); + return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off) + .getResult(0); } -} -/// Returns an array with the dimension-shape of the given `ShapedType`, -/// where dynamic sizes are represented by zero. -static SmallVector getDimShape(OpBuilder &builder, Location loc, - ShapedType stp) { - SmallVector out; - fillDimShape(builder, loc, stp, out); - return out; -} + Value lookupDimSizeImpl(Value tensor, Dimension dim) const final { + StringRef name = "sparseDimSize"; + SmallVector params{tensor, constantIndex(builder, loc, dim)}; + Type iTp = builder.getIndexType(); + return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off) + .getResult(0); + } +}; /// Populates the given sizes array for concatenation from type (for static /// sizes) and from an already-converted opaque pointer source (for dynamic /// sizes). -static void concatSizesFromInputs(OpBuilder &builder, - SmallVectorImpl &sizes, Location loc, - ShapedType dstTp, ValueRange srcs, - unsigned dim) { - auto dstShape = dstTp.getShape(); +static void concatDimSizesFromInputs(OpBuilder &builder, Location loc, + ShapedType dstTp, ValueRange srcs, + Dimension dim, + SmallVectorImpl &sizes) { + assert(dim < dstTp.getRank()); + // First, fill the sizes from an arbitrary source tensor. + RuntimeDimLvlBuilder(builder, loc, getRankedTensorType(srcs[0])) + .lookupDimSizes(srcs[0], sizes); - auto srcTp = srcs[0].getType().cast(); - auto srcEnc = getSparseTensorEncoding(srcTp); - // We first fills the sizes from an input tensor, and then - // compute the size of the concatenation dimension if necessary. - if (srcEnc) - // Reuses sizes from an arbitrary input tensor is fine. - fillDimSizes(builder, loc, srcEnc, srcTp, srcs[0], sizes); - else - sizesFromSrc(builder, sizes, loc, srcs[0]); - - // Sum up on the `dim` if the dimension is dynamic. - if (dstShape[dim] != ShapedType::kDynamic) { - // Faithfully take the static size. - sizes[dim] = constantIndex(builder, loc, dstShape[dim]); - } else { - // Else, compute the shape dynamically. - for (size_t i = 1, sz = srcs.size(); i < sz; i++) { - auto srcTp = srcs[i].getType().cast(); - auto encSrc = getSparseTensorEncoding(srcTp); - Value srcSz = - createOrFoldDimCall(builder, loc, encSrc, srcTp, srcs[i], dim); - // Sum up all the sizes. - sizes[dim] = builder.create(loc, sizes[dim], srcSz); - } + // Now, compute the size of the concatenation dimension. + // If the `dstTp` has a static size, then faithfully take it. + if (const auto sz = constantSize(builder, loc, dstTp.getShape()[dim])) { + sizes[dim] = sz; + return; + } + // Otherwise, dynamically sum up the actual sizes of the sources. + for (const auto &src : llvm::drop_begin(srcs)) { + sizes[dim] = builder.create( + loc, sizes[dim], + RuntimeDimLvlBuilder(builder, loc, getRankedTensorType(src)) + .lookupDimSize(src, dim)); } } @@ -235,8 +151,7 @@ /// MLIR buffers as needed, and returning `this` for method chaining. /// This method does not set the action and pointer arguments, since /// those are handled by `genNewCall` instead. - NewCallParams &genBuffers(SparseTensorEncodingAttr enc, ValueRange sizes, - ShapedType stp); + NewCallParams &genBuffers(DimLvlMapping dlm, ValueRange dimSizes); /// (Re)sets the C++ template type parameters, and returns `this` /// for method chaining. This is already done as part of `genBuffers`, @@ -308,50 +223,46 @@ // TODO: see the note at `_mlir_ciface_newSparseTensor` about how // the meaning of the various arguments (e.g., "sizes" vs "shapes") // is inconsistent between the different actions. -NewCallParams &NewCallParams::genBuffers(SparseTensorEncodingAttr enc, - ValueRange dimSizes, ShapedType stp) { - const unsigned lvlRank = enc.getDimLevelType().size(); - const unsigned dimRank = stp.getRank(); +NewCallParams &NewCallParams::genBuffers(DimLvlMapping dlm, + ValueRange dimSizes) { + const unsigned lvlRank = dlm.getLvlRank(); + const unsigned dimRank = dlm.getDimRank(); // Sparsity annotations. - params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, enc); + params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, dlm.getEncoding()); // Dimension-sizes array of the enveloping tensor. Useful for either // verification of external data, or for construction of internal data. assert(dimSizes.size() == dimRank && "Dimension-rank mismatch"); params[kParamDimSizes] = allocaBuffer(builder, loc, dimSizes); // The level-sizes array must be passed as well, since for arbitrary // dim2lvl mappings it cannot be trivially reconstructed at runtime. - // For now however, since we're still assuming permutations, we will - // initialize this parameter alongside the `dim2lvl` and `lvl2dim` - // parameters below. We preinitialize `lvlSizes` for code symmetry. - SmallVector lvlSizes(lvlRank); + SmallVector lvlSizes = + RuntimeDimLvlBuilder(builder, loc, dlm).computeLvlSizes(dimSizes); + params[kParamLvlSizes] = allocaBuffer(builder, loc, lvlSizes); // The dimension-to-level mapping and its inverse. We must preinitialize // `dim2lvl` so that the true branch below can perform random-access // `operator[]` assignment. We preinitialize `lvl2dim` for code symmetry. SmallVector dim2lvl(dimRank); SmallVector lvl2dim(lvlRank); - auto dimOrder = enc.getDimOrdering(); - if (dimOrder) { + if (!dlm.isIdentity()) { + const auto dimOrder = dlm.getEncoding().getDimOrdering(); assert(dimOrder.isPermutation()); for (unsigned l = 0; l < lvlRank; l++) { // The `d`th source variable occurs in the `l`th result position. uint64_t d = dimOrder.getDimPosition(l); dim2lvl[d] = constantIndex(builder, loc, l); lvl2dim[l] = constantIndex(builder, loc, d); - lvlSizes[l] = dimSizes[d]; } } else { assert(dimRank == lvlRank && "Rank mismatch"); - for (unsigned i = 0; i < lvlRank; i++) { + for (unsigned i = 0; i < lvlRank; i++) dim2lvl[i] = lvl2dim[i] = constantIndex(builder, loc, i); - lvlSizes[i] = dimSizes[i]; - } } - params[kParamLvlSizes] = allocaBuffer(builder, loc, lvlSizes); params[kParamLvl2Dim] = allocaBuffer(builder, loc, lvl2dim); - params[kParamDim2Lvl] = - dimOrder ? allocaBuffer(builder, loc, dim2lvl) : params[kParamLvl2Dim]; + params[kParamDim2Lvl] = dlm.isIdentity() + ? params[kParamLvl2Dim] + : allocaBuffer(builder, loc, dim2lvl); // Secondary and primary types encoding. - setTemplateTypes(enc, stp); + setTemplateTypes(dlm.getEncoding(), dlm.getShapedType()); // Finally, make note that initialization is complete. assert(isInitialized() && "Initialization failed"); // And return `this` for method chaining. @@ -505,37 +416,36 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) { Location loc = op.getLoc(); - auto srcTp = op.getSrc().getType().template cast(); - auto dstTp = op.getResult().getType().template cast(); - auto encSrc = getSparseTensorEncoding(srcTp); - auto encDst = getSparseTensorEncoding(dstTp); - if (!encDst || !encSrc) + const auto srcTp = getRankedTensorType(op.getSrc()); + const auto dstTp = getRankedTensorType(op.getResult()); + RuntimeDimLvlBuilder srcDLM(rewriter, loc, srcTp); + RuntimeDimLvlBuilder dstDLM(rewriter, loc, dstTp); + if (srcDLM.isDense() || dstDLM.isDense()) return failure(); - Type elemTp = srcTp.getElementType(); - assert(elemTp == dstTp.getElementType() && + const Type elemTp = srcDLM.getElementType(); + assert(elemTp == dstDLM.getElementType() && "reshape should not change element type"); // Start an iterator over the source tensor (in original index order). - const auto noPerm = encSrc.withoutOrdering(); - SmallVector srcDimSizes = - getDimSizes(rewriter, loc, encSrc, srcTp, adaptor.getSrc()); + const SmallVector srcDimSizes = + srcDLM.lookupDimSizes(adaptor.getSrc()); NewCallParams params(rewriter, loc); - Value iter = params.genBuffers(noPerm, srcDimSizes, srcTp) - .genNewCall(Action::kToIterator, adaptor.getSrc()); + const Value iter = params.genBuffers(srcDLM.withoutOrdering(), srcDimSizes) + .genNewCall(Action::kToIterator, adaptor.getSrc()); // Start a new COO for the destination tensor. SmallVector dstDimSizes; - if (dstTp.hasStaticShape()) + if (dstDLM.hasStaticDimShape()) // Static "shapes" are in fact "sizes". - fillDimShape(rewriter, loc, dstTp, dstDimSizes); + dstDLM.reflectDimShape(dstDimSizes); else genReshapeDstShape(loc, rewriter, dstDimSizes, srcDimSizes, - dstTp.getShape(), op.getReassociationIndices()); - Value coo = params.genBuffers(encDst, dstDimSizes, dstTp) - .genNewCall(Action::kEmptyCOO); - Value dstPerm = params.getDim2LvlMap(); + dstDLM.getDimShape(), op.getReassociationIndices()); + const Value coo = + params.genBuffers(dstDLM, dstDimSizes).genNewCall(Action::kEmptyCOO); + const Value dstPerm = params.getDim2LvlMap(); // Construct a while loop over the iterator. - Type iTp = rewriter.getIndexType(); - Value srcIdx = genAlloca(rewriter, loc, srcTp.getRank(), iTp); - Value dstIdx = genAlloca(rewriter, loc, dstTp.getRank(), iTp); + const Type iTp = rewriter.getIndexType(); + Value srcIdx = genAlloca(rewriter, loc, srcDLM.getDimRank(), iTp); + Value dstIdx = genAlloca(rewriter, loc, dstDLM.getDimRank(), iTp); Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); SmallVector noArgs; SmallVector noTypes; @@ -573,22 +483,19 @@ ConversionPatternRewriter &rewriter, Location loc, Value t, RankedTensorType tensorTp, function_ref bodyBuilder) { - auto enc = getSparseTensorEncoding(tensorTp); - assert(enc && "Generating Sparse Tensor COO Loop on a Dense Tensor!"); - - unsigned rank = tensorTp.getRank(); - Type elemTp = tensorTp.getElementType(); + RuntimeDimLvlBuilder dlm(rewriter, loc, tensorTp); + assert(dlm.isSparse() && + "Generating Sparse Tensor COO Loop on a Dense Tensor!"); // Start an iterator over the tensor (in original index order). - const auto noPerm = enc.withoutOrdering(); - SmallVector dimSizes = getDimSizes(rewriter, loc, noPerm, tensorTp, t); Value iter = NewCallParams(rewriter, loc) - .genBuffers(noPerm, dimSizes, tensorTp) + .genBuffers(dlm.withoutOrdering(), dlm.lookupDimSizes(t)) .genNewCall(Action::kToIterator, t); // Construct a while loop over the iterator. - Value srcIdx = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); - Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); + Value srcIdx = + genAlloca(rewriter, loc, dlm.getDimRank(), rewriter.getIndexType()); + Value elemPtr = genAllocaScalar(rewriter, loc, dlm.getElementType()); SmallVector noArgs; SmallVector noTypes; auto whileOp = rewriter.create(loc, noTypes, noArgs); @@ -599,8 +506,8 @@ Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); rewriter.setInsertionPointToStart(after); - bool hasDenseDim = llvm::any_of( - enc.getDimLevelType(), [](DimLevelType dlt) { return isDenseDLT(dlt); }); + bool hasDenseDim = + llvm::any_of(dlm.getEncoding().getDimLevelType(), isDenseDLT); if (hasDenseDim) { Value elemV = rewriter.create(loc, elemPtr); Value isZero = genIsNonzero(rewriter, loc, elemV); @@ -619,7 +526,7 @@ rewriter.setInsertionPointAfter(whileOp); // Free memory for iterator. - genDelIteratorCall(rewriter, loc, elemTp, iter); + genDelIteratorCall(rewriter, loc, dlm.getElementType(), iter); } // Generate loop that iterates over a dense tensor. @@ -686,10 +593,10 @@ LogicalResult matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto stp = op.getSource().getType().cast(); + const auto rtp = getRankedTensorType(op.getSource()); + RuntimeDimLvlBuilder dlm(rewriter, op->getLoc(), rtp); // Only rewrite sparse DimOp. - auto enc = getSparseTensorEncoding(stp); - if (!enc) + if (dlm.isDense()) return failure(); // Only rewrite DimOp with constant index. std::optional dim = op.getConstantIndex(); @@ -697,8 +604,7 @@ return failure(); // Generate the call. Value src = adaptor.getOperands()[0]; - rewriter.replaceOp( - op, createOrFoldDimCall(rewriter, op->getLoc(), enc, stp, src, *dim)); + rewriter.replaceOp(op, dlm.lookupDimSize(src, *dim)); return success(); } }; @@ -741,21 +647,21 @@ matchAndRewrite(NewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto stp = op.getType().cast(); - auto enc = getSparseTensorEncoding(stp); - if (!enc) + const auto rtp = getRankedTensorType(op); + const RuntimeDimLvlBuilder dlm(rewriter, loc, rtp); + const auto enc = dlm.getEncoding(); + if (dlm.isDense()) return failure(); - const unsigned dimRank = stp.getRank(); - const unsigned lvlRank = enc.getDimLevelType().size(); + const unsigned dimRank = dlm.getDimRank(); + const unsigned lvlRank = dlm.getLvlRank(); // Construct the dimShape. - const auto dimShape = stp.getShape(); - SmallVector dimShapeValues = getDimShape(rewriter, loc, stp); + SmallVector dimShapeValues = dlm.reflectDimShape(); Value dimShapeBuffer = allocaBuffer(rewriter, loc, dimShapeValues); // Allocate `SparseTensorReader` and perform all initial setup that // does not depend on lvlSizes (nor dim2lvl, lvl2dim, etc). Type opaqueTp = getOpaquePointerType(rewriter); Value valTp = - constantPrimaryTypeEncoding(rewriter, loc, stp.getElementType()); + constantPrimaryTypeEncoding(rewriter, loc, rtp.getElementType()); Value reader = createFuncCall(rewriter, loc, "createCheckedSparseTensorReader", opaqueTp, @@ -773,7 +679,7 @@ // // FIXME: reduce redundancy vs `NewCallParams::genBuffers`. Value dimSizesBuffer; - if (!stp.hasStaticShape()) { + if (!dlm.hasStaticDimShape()) { Type indexTp = rewriter.getIndexType(); auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp); dimSizesBuffer = @@ -786,6 +692,7 @@ Value dim2lvlBuffer; if (auto dimOrder = enc.getDimOrdering()) { assert(dimOrder.isPermutation() && "Got non-permutation"); + const auto dimShape = dlm.getDimShape(); // We preinitialize `dim2lvlValues` since we need random-access writing. // And we preinitialize the others for stylistic consistency. SmallVector lvlSizeValues(lvlRank); @@ -799,7 +706,7 @@ dim2lvlValues[d] = lvl; lvl2dimValues[l] = dim; lvlSizeValues[l] = - (dimShape[d] == ShapedType::kDynamic) + ShapedType::isDynamic(dimShape[d]) ? rewriter.create(loc, dimSizesBuffer, dim) : dimShapeValues[d]; } @@ -848,27 +755,25 @@ return rewriter.notifyMatchFailure(op, "sparse tensor copy not implemented"); Location loc = op.getLoc(); - RankedTensorType resType = op.getType(); - auto enc = getSparseTensorEncoding(resType); - if (!enc) + const DimLvlMapping dlm(op.getType()); + if (dlm.isDense()) return failure(); // Gather all dimension sizes as SSA values. - SmallVector sizes; + const Dimension dimRank = dlm.getDimRank(); + SmallVector dimSizes; + dimSizes.reserve(dimRank); unsigned int operandCtr = 0; - for (int64_t i = 0; i < resType.getRank(); ++i) { - if (resType.isDynamicDim(i)) { - sizes.push_back(adaptor.getOperands()[operandCtr++]); - } else { - sizes.push_back( - rewriter.create(loc, op.getStaticSize(i))); - } + for (Dimension d = 0; d < dimRank; ++d) { + dimSizes.push_back( + dlm.isDynamicDim(d) + ? adaptor.getOperands()[operandCtr++] + : constantIndex(rewriter, loc, op.getStaticSize(d))); } // Generate the call to construct empty tensor. The sizes are // explicitly defined by the arguments to the alloc operator. - rewriter.replaceOp(op, - NewCallParams(rewriter, loc) - .genBuffers(enc, sizes, resType.cast()) - .genNewCall(Action::kEmpty)); + rewriter.replaceOp(op, NewCallParams(rewriter, loc) + .genBuffers(dlm, dimSizes) + .genNewCall(Action::kEmpty)); return success(); } }; @@ -887,27 +792,31 @@ LogicalResult matchAndRewrite(ConvertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - Type resType = op.getType(); - Type srcType = op.getSource().getType(); - auto encDst = getSparseTensorEncoding(resType); - auto encSrc = getSparseTensorEncoding(srcType); - Value src = adaptor.getOperands()[0]; - if (encDst && encSrc) { + const Location loc = op->getLoc(); + const auto srcTp = getRankedTensorType(op.getSource()); + const auto dstTp = getRankedTensorType(op); + RuntimeDimLvlBuilder srcDLM(rewriter, loc, srcTp); + DimLvlMapping dstDLM(dstTp); + if (srcDLM.isDense() && dstDLM.isDense()) + return failure(); + + const Type elemTp = srcDLM.getElementType(); + const Value src = adaptor.getOperands()[0]; + if (srcDLM.isSparse() && dstDLM.isSparse()) { + const auto dstEnc = dstDLM.getEncoding(); + const auto srcEnc = srcDLM.getEncoding(); // This is a sparse => sparse conversion, which is handled as follows: // t = src->toCOO(); ; src to COO in dst order // dst = newSparseTensor(t) // Using the coordinate scheme as an intermediate does not always // yield the fastest conversion but avoids the need for a full // O(N^2) conversion matrix. - if (encDst == encSrc) { + if (dstEnc == srcEnc) { rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast return success(); } NewCallParams params(rewriter, loc); - ShapedType stp = srcType.cast(); - SmallVector dimSizes = - getDimSizes(rewriter, loc, encSrc, stp, src); + const SmallVector dimSizes = srcDLM.lookupDimSizes(src); bool useDirectConversion; switch (options.sparseToSparseStrategy) { case SparseToSparseConversionStrategy::kViaCOO: @@ -915,37 +824,39 @@ break; case SparseToSparseConversionStrategy::kDirect: useDirectConversion = true; - assert(canUseDirectConversion(encDst.getDimLevelType()) && + assert(canUseDirectConversion(dstEnc.getDimLevelType()) && "Unsupported target for direct sparse-to-sparse conversion"); break; case SparseToSparseConversionStrategy::kAuto: - useDirectConversion = canUseDirectConversion(encDst.getDimLevelType()); + useDirectConversion = canUseDirectConversion(dstEnc.getDimLevelType()); break; } if (useDirectConversion) { - rewriter.replaceOp(op, params.genBuffers(encDst, dimSizes, stp) - .genNewCall(Action::kSparseToSparse, src)); + rewriter.replaceOp( + op, params.genBuffers(DimLvlMapping(srcTp, dstEnc), dimSizes) + .genNewCall(Action::kSparseToSparse, src)); } else { // use via-COO conversion. // Set up encoding with right mix of src and dst so that the two // method calls can share most parameters, while still providing // the correct sparsity information to either of them. - auto enc = SparseTensorEncodingAttr::get( - op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(), - encDst.getHigherOrdering(), encSrc.getPointerBitWidth(), - encSrc.getIndexBitWidth()); + const auto mixedEnc = SparseTensorEncodingAttr::get( + op->getContext(), dstEnc.getDimLevelType(), dstEnc.getDimOrdering(), + dstEnc.getHigherOrdering(), srcEnc.getPointerBitWidth(), + srcEnc.getIndexBitWidth()); // TODO: This is the only place where `kToCOO` (or `kToIterator`) // is called with a non-identity permutation. Is there any clean // way to push the permutation over to the `kFromCOO` side instead? - Value coo = params.genBuffers(enc, dimSizes, stp) + Value coo = params.genBuffers(DimLvlMapping(srcTp, mixedEnc), dimSizes) .genNewCall(Action::kToCOO, src); - Value dst = params.setTemplateTypes(encDst, stp) + Value dst = params.setTemplateTypes(dstEnc, srcTp) .genNewCall(Action::kFromCOO, coo); - genDelCOOCall(rewriter, loc, stp.getElementType(), coo); + genDelCOOCall(rewriter, loc, elemTp, coo); rewriter.replaceOp(op, dst); } return success(); } - if (!encDst && encSrc) { + if (srcDLM.isSparse() && dstDLM.isDense()) { + const auto srcEnc = srcDLM.getEncoding(); // This is sparse => dense conversion, which is handled as follows: // dst = new Tensor(0); // iter = new SparseTensorIterator(src); @@ -953,28 +864,25 @@ // dst[elem.indices] = elem.value; // } // delete iter; - RankedTensorType dstTensorTp = resType.cast(); - RankedTensorType srcTensorTp = srcType.cast(); - unsigned rank = dstTensorTp.getRank(); - Type elemTp = dstTensorTp.getElementType(); + const unsigned dstDimRank = dstDLM.getDimRank(); // Fabricate a no-permutation encoding for NewCallParams // The pointer/index types must be those of `src`. // The dimLevelTypes aren't actually used by Action::kToIterator. - encDst = SparseTensorEncodingAttr::get( + const auto dstEnc = SparseTensorEncodingAttr::get( op->getContext(), - SmallVector(rank, DimLevelType::Dense), AffineMap(), - AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); - SmallVector dimSizes = - getDimSizes(rewriter, loc, encSrc, srcTensorTp, src); + SmallVector(dstDimRank, DimLevelType::Dense), + AffineMap(), AffineMap(), srcEnc.getPointerBitWidth(), + srcEnc.getIndexBitWidth()); + SmallVector dimSizes = srcDLM.lookupDimSizes(src); Value iter = NewCallParams(rewriter, loc) - .genBuffers(encDst, dimSizes, dstTensorTp) + .genBuffers(DimLvlMapping(dstTp, dstEnc), dimSizes) .genNewCall(Action::kToIterator, src); - Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); + Value ind = genAlloca(rewriter, loc, dstDimRank, rewriter.getIndexType()); Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); Block *insertionBlock = rewriter.getInsertionBlock(); // TODO: Dense buffers should be allocated/deallocated via the callback // in BufferizationOptions. - Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, dimSizes); + Value dst = allocDenseTensor(rewriter, loc, dstTp, dimSizes); SmallVector noArgs; SmallVector noTypes; auto whileOp = rewriter.create(loc, noTypes, noArgs); @@ -984,12 +892,12 @@ rewriter.create(loc, cond, before->getArguments()); Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); rewriter.setInsertionPointToStart(after); - SmallVector ivs = loadIndices(rewriter, loc, rank, ind); + SmallVector ivs = loadIndices(rewriter, loc, dstDimRank, ind); insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, ivs); rewriter.create(loc); rewriter.setInsertionPointAfter(whileOp); genDelIteratorCall(rewriter, loc, elemTp, iter); - rewriter.replaceOpWithNewOp(op, resType, dst); + rewriter.replaceOpWithNewOp(op, dstTp, dst); // Deallocate the buffer. if (bufferization::allocationDoesNotEscape(op->getOpResult(0))) { rewriter.setInsertionPoint(insertionBlock->getTerminator()); @@ -997,10 +905,7 @@ } return success(); } - if (!encDst && !encSrc) { - // dense => dense - return failure(); - } + assert(srcDLM.isDense() && dstDLM.isSparse()); // This is a dense => sparse conversion or a sparse constant in COO => // sparse conversion, which is handled as follows: // t = newSparseCOO() @@ -1027,30 +932,28 @@ // Also note that the code below only generates the "new" ops and // the loop-nest per se; whereas the entire body of the innermost // loop is generated by genAddElt(). - ShapedType stp = resType.cast(); - unsigned rank = stp.getRank(); + const unsigned dstDimRank = dstDLM.getDimRank(); SmallVector sizes; sizesFromSrc(rewriter, sizes, loc, src); NewCallParams params(rewriter, loc); - Value coo = - params.genBuffers(encDst, sizes, stp).genNewCall(Action::kEmptyCOO); - Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); - Value perm = params.getDim2LvlMap(); - Type eltType = stp.getElementType(); - Value elemPtr = genAllocaScalar(rewriter, loc, eltType); + Value coo = params.genBuffers(dstDLM, sizes).genNewCall(Action::kEmptyCOO); + Value ind = genAlloca(rewriter, loc, dstDimRank, rewriter.getIndexType()); + const Value perm = params.getDim2LvlMap(); + Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); genDenseTensorOrSparseConstantIterLoop( - rewriter, loc, src, rank, - [&](OpBuilder &builder, Location loc, Value val, ValueRange indices) { - for (unsigned i = 0; i < rank; i++) { - Value idx = constantIndex(builder, loc, i); - builder.create(loc, indices[i], ind, idx); + rewriter, loc, src, dstDimRank, + [&](OpBuilder &builder, Location loc, Value val, ValueRange dimInd) { + // TODO: rewrite this to use `storeIndices` + for (unsigned d = 0; d < dstDimRank; d++) { + Value dim = constantIndex(builder, loc, d); + builder.create(loc, dimInd[d], ind, dim); } builder.create(loc, val, elemPtr); - genAddEltCall(builder, loc, eltType, coo, elemPtr, ind, perm); + genAddEltCall(builder, loc, elemTp, coo, elemPtr, ind, perm); }); // Final call to construct sparse tensor storage. Value dst = params.genNewCall(Action::kFromCOO, coo); - genDelCOOCall(rewriter, loc, eltType, coo); + genDelCOOCall(rewriter, loc, elemTp, coo); rewriter.replaceOp(op, dst); return success(); } @@ -1192,7 +1095,7 @@ // index order. All values are passed by reference through stack // allocated memrefs. Location loc = op->getLoc(); - auto tp = op.getTensor().getType().cast(); + const auto tp = getRankedTensorType(op.getTensor()); auto elemTp = tp.getElementType(); unsigned rank = tp.getRank(); auto mref = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); @@ -1217,19 +1120,16 @@ matchAndRewrite(ExpandOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - RankedTensorType srcType = - op.getTensor().getType().cast(); + const auto srcType = getRankedTensorType(op.getTensor()); Type eltType = srcType.getElementType(); Type boolType = rewriter.getIntegerType(1); Type idxType = rewriter.getIndexType(); // All initialization should be done on entry of the loop nest. rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp()); // Get the cardinality of valid coordinates for the innermost level. - auto srcEnc = getSparseTensorEncoding(srcType); - unsigned lvlRank = - srcEnc ? srcEnc.getDimLevelType().size() : srcType.getRank(); - Value sz = createOrFoldLvlCall(rewriter, loc, srcEnc, srcType, - adaptor.getTensor(), lvlRank - 1); + RuntimeDimLvlBuilder srcDLM(rewriter, loc, srcType); + Value sz = + srcDLM.lookupLvlSize(adaptor.getTensor(), srcDLM.getLvlRank() - 1); // Allocate temporary buffers for values, filled-switch, and indices. // We do not use stack buffers for this, since the expanded size may // be rather large (as it envelops a single expanded dense dimension). @@ -1272,7 +1172,7 @@ Value added = adaptor.getAdded(); Value count = adaptor.getCount(); Value tensor = adaptor.getTensor(); - auto tp = op.getTensor().getType().cast(); + const auto tp = getRankedTensorType(op.getTensor()); Type elemTp = tp.getElementType(); unsigned rank = tp.getRank(); auto mref = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); @@ -1326,7 +1226,7 @@ // a[ adjustForOffset(elem.indices) ] = elem.value // return a Location loc = op.getLoc(); - auto dstTp = op.getType().cast(); + const auto dstTp = getRankedTensorType(op); auto encDst = getSparseTensorEncoding(dstTp); Type elemTp = dstTp.getElementType(); uint64_t concatDim = op.getDimension().getZExtValue(); @@ -1343,15 +1243,16 @@ SmallVector sizes; NewCallParams params(rewriter, loc); - concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), - concatDim); + concatDimSizesFromInputs(rewriter, loc, dstTp, op.getInputs(), concatDim, + sizes); bool allDense = false; Value dstTensor; if (encDst) { allDense = encDst.isAllDense(); // Start a new COO or an initialized annotated all dense sparse tensor. - dst = params.genBuffers(encDst, sizes, dstTp) + RuntimeDimLvlBuilder dstDLM(rewriter, loc, dstTp, encDst); + dst = params.genBuffers(dstDLM, sizes) .genNewCall(allDense ? Action::kEmpty : Action::kEmptyCOO); dstIdx = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); if (allDense) { @@ -1381,9 +1282,9 @@ for (auto it : llvm::zip(op.getInputs(), adaptor.getInputs())) { Value orignalOp = std::get<0>(it); // Input (with encoding) from Op Value adaptedOp = std::get<1>(it); // Input (type converted) from adaptor - RankedTensorType srcTp = orignalOp.getType().cast(); - auto encSrc = getSparseTensorEncoding(srcTp); - if (encSrc) { + const auto srcTp = getRankedTensorType(orignalOp); + RuntimeDimLvlBuilder srcDLM(rewriter, loc, srcTp); + if (srcDLM.isSparse()) { genSparseCOOIterationLoop( rewriter, loc, adaptedOp, srcTp, [&](OpBuilder &builder, Location loc, Value idx, @@ -1432,8 +1333,7 @@ } // Accumulate offset. // TODO: avoid calling sparseDimSize multiple times by caching the result! - Value curDim = createOrFoldDimCall(rewriter, loc, encSrc, srcTp, - adaptedOp, concatDim); + Value curDim = srcDLM.lookupDimSize(adaptedOp, concatDim); offset = rewriter.create(loc, offset, curDim); } @@ -1462,30 +1362,29 @@ LogicalResult matchAndRewrite(OutOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - ShapedType srcType = op.getTensor().getType().cast(); + const Location loc = op->getLoc(); + const auto srcTp = getRankedTensorType(op.getTensor()); + RuntimeDimLvlBuilder srcDLM(rewriter, loc, srcTp); // Convert to default permuted COO. - Value src = adaptor.getOperands()[0]; - auto encSrc = getSparseTensorEncoding(srcType); - SmallVector dimSizes = - getDimSizes(rewriter, loc, encSrc, srcType, src); - const auto enc = encSrc.withoutOrdering(); - Value coo = NewCallParams(rewriter, loc) - .genBuffers(enc, dimSizes, srcType) - .genNewCall(Action::kToCOO, src); + Value coo; + { + const Value src = adaptor.getOperands()[0]; + const SmallVector dimSizes = srcDLM.lookupDimSizes(src); + coo = NewCallParams(rewriter, loc) + .genBuffers(srcDLM.withoutOrdering(), dimSizes) + .genNewCall(Action::kToCOO, src); + } // Then output the tensor to external file with indices in the externally // visible lexicographic index order. A sort is required if the source was // not in that order yet (note that the sort can be dropped altogether if // external format does not care about the order at all, but here we assume // it does). - Value sort = constantI1(rewriter, loc, - encSrc.getDimOrdering() && - !encSrc.getDimOrdering().isIdentity()); + const Value sort = constantI1(rewriter, loc, !srcDLM.isIdentity()); SmallVector outParams{coo, adaptor.getOperands()[1], sort}; - Type eltType = srcType.getElementType(); - SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(eltType)}; + const Type elemTp = srcDLM.getElementType(); + SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(elemTp)}; createFuncCall(rewriter, loc, name, {}, outParams, EmitCInterface::Off); - genDelCOOCall(rewriter, loc, eltType, coo); + genDelCOOCall(rewriter, loc, elemTp, coo); rewriter.eraseOp(op); return success(); }