diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -75,6 +75,21 @@ return isSingletonDLT(getDimLevelType(type, d)); } +/// Convenience function to test for dense dimension (0 <= d < rank). +inline bool isDenseDim(SparseTensorEncodingAttr enc, uint64_t d) { + return isDenseDLT(getDimLevelType(enc, d)); +} + +/// Convenience function to test for compressed dimension (0 <= d < rank). +inline bool isCompressedDim(SparseTensorEncodingAttr enc, uint64_t d) { + return isCompressedDLT(getDimLevelType(enc, d)); +} + +/// Convenience function to test for singleton dimension (0 <= d < rank). +inline bool isSingletonDim(SparseTensorEncodingAttr enc, uint64_t d) { + return isSingletonDLT(getDimLevelType(enc, d)); +} + // // Dimension level properties. // diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -311,8 +311,282 @@ } //===----------------------------------------------------------------------===// -// SparseTensorLoopEmiter class, manages sparse tensors and helps to generate -// loop structure to (co)-iterate sparse tensors. +// SparseTensorDescriptor and helpers, manage the sparse tensor memory layout +// scheme. +// +// Sparse tensor storage scheme for rank-dimensional tensor is organized +// as a single compound type with the following fields. Note that every +// memref with ? size actually behaves as a "vector", i.e. the stored +// size is the capacity and the used size resides in the memSizes array. +// +// struct { +// memref dimSizes ; size in each dimension +// memref memSizes ; sizes of ptrs/inds/values +// ; per-dimension d: +// ; if dense: +// +// ; if compresed: +// memref pointers-d ; pointers for sparse dim d +// memref indices-d ; indices for sparse dim d +// ; if singleton: +// memref indices-d ; indices for singleton dim d +// memref values ; values +// }; +// +// The dimSizes array and memSizes array. +// +//===----------------------------------------------------------------------===// +enum class SparseTensorFieldKind { + DimSizes, + MemSizes, + PtrMemRef, + IdxMemRef, + ValMemRef +}; + +constexpr uint64_t dimSizesIdx = 0; +constexpr uint64_t memSizesIdx = dimSizesIdx + 1; +constexpr uint64_t dataFieldIdx = memSizesIdx + 1; + +/// For each field that will be allocated for the given sparse tensor encoding, +/// calls the callback with the corresponding field index, field kind, dimension +/// (for sparse tensor level memrefs) and dimlevelType. +/// The field index always starts with zero and increments by one between two +/// callback invocations. +/// Ideally, all other methods should rely on this function to query a sparse +/// tensor fields instead of relying ad-hoc index computation. +void foreachFieldInSparseTensor( + SparseTensorEncodingAttr, + llvm::function_ref); + +/// Same as above, except that it also builds the Type for the corresponding +/// field. +void foreachFieldAndTypeInSparseTensor( + RankedTensorType, + llvm::function_ref); + +/// Gets the total number of fields for the given sparse tensor encoding. +unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc); + +/// Gets the total number of data fields (index arrays, pointer arrays and a +/// value array) for the given sparse tensor encoding. +unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc); + +/// Get the index of the field in memSizes (only valid for data fields). +inline unsigned getFieldMemSizesIndex(unsigned fid) { + assert(fid >= dataFieldIdx); + return fid - dataFieldIdx; +} + +/// A helper class around a array of values that corresponding to a sparse +/// tensor, provides a set of meaningful APIs to query and update a particular +/// field in a consistent way. +/// Users should not make assumption on how a sparse tensor is layouted but +/// instead relies on this class to access the right value for the right field. +template +class SparseTensorDescriptorImpl { +private: + template + struct ArrayStorage; + + template <> + struct ArrayStorage { + using ValueArray = ValueRange; + }; + + template <> + struct ArrayStorage { + using ValueArray = SmallVectorImpl &; + }; + + // Uses ValueRange for immuatable descriptors; uses SmallVectorImpl & + // for mutable descriptors. + // Using SmallVector for mutable descriptor allows users to reuse it as a tmp + // buffers to append value for some special cases, though users should be + // responsible to restored the buffer to legal states after their use. It is + // probably not a clean way, but it is the most efficient way to avoid copying + // the fields into another SmallVector. If a more clear way is wanted, we + // should change it to MutableArrayRef instead. + using Storage = typename ArrayStorage::ValueArray; + +public: + SparseTensorDescriptorImpl(Type tp, Storage fields) + : rType(tp.cast()), fields(fields) { + assert(getSparseTensorEncoding(tp) && + getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) == + fields.size()); + // We should make sure the class is trivially copyable (and should be small + // enough) such that we can pass it by value. + static_assert( + std::is_trivially_copyable_v>); + } + + // Implicit (and cheap) type conversion from MutSparseTensorDescriptor to + // SparseTensorDescriptor. + template > + /*implicit*/ SparseTensorDescriptorImpl(std::enable_if_t &mDesc) + : rType(mDesc.getTensorType()), fields(mDesc.getFields()) {} + + /// + /// Getters: get the field index for required field. + /// + + unsigned getPtrMemRefIndex(unsigned ptrDim) const { + return getFieldIndex(ptrDim, SparseTensorFieldKind::PtrMemRef); + } + + unsigned getIdxMemRefIndex(unsigned idxDim) const { + return getFieldIndex(idxDim, SparseTensorFieldKind::IdxMemRef); + } + + unsigned getDataFieldIndex(unsigned dim, SparseTensorFieldKind kind) const { + if (kind == SparseTensorFieldKind::ValMemRef) + return getValMemRefIndex(); + return getFieldIndex(dim, kind); + } + + unsigned getValMemRefIndex() const { return fields.size() - 1; } + + unsigned getPtrMemSizesIndex(unsigned dim) const { + return getPtrMemRefIndex(dim) - dataFieldIdx; + } + + unsigned getIdxMemSizesIndex(unsigned dim) const { + return getIdxMemRefIndex(dim) - dataFieldIdx; + } + + unsigned getValMemSizesIndex() const { + return getValMemRefIndex() - dataFieldIdx; + } + + unsigned getFieldMemSizesIndex(unsigned dim, + SparseTensorFieldKind kind) const { + return getFieldMemSizesIndex(getDataFieldIndex(dim, kind)); + } + + unsigned getNumFields() const { return fields.size(); } + + /// + /// Getters: get the value for required field. + /// + + Value getDimSizesMemRef() const { return fields[dimSizesIdx]; } + Value getMemSizesMemRef() const { return fields[memSizesIdx]; } + + Value getPtrMemRef(unsigned ptrDim) const { + return fields[getPtrMemRefIndex(ptrDim)]; + } + + Value getIdxMemRef(unsigned idxDim) const { + return fields[getIdxMemRefIndex(idxDim)]; + } + + Value getValMemRef() const { return fields[getValMemRefIndex()]; } + + Value getField(unsigned fid) const { + assert(fid < fields.size()); + return fields[fid]; + } + + /// + /// Setters: update the value for required field (only enabled for + /// MutSparseTensorDescriptor). + /// + + template + void setDimSizesMemRef(std::enable_if_t v) { + fields[dimSizesIdx] = v; + } + + template + void setMemSizesMemRef(std::enable_if_t v) { + fields[memSizesIdx] = v; + } + + template + void setPtrMemRef(unsigned ptrDim, std::enable_if_t v) { + fields[getPtrMemRefIndex(ptrDim)] = v; + } + + template + void setIdxMemRef(unsigned idxDim, std::enable_if_t v) { + fields[getIdxMemRefIndex(idxDim)] = v; + } + + template + void setLvlMemRef(unsigned dim, SparseTensorFieldKind kind, + std::enable_if_t v) { + fields[getDataFieldIndex(dim, kind)] = v; + } + + template + void setValMemRef(std::enable_if_t v) { + fields[getValMemRefIndex()] = v; + } + + template + void setField(unsigned fid, std::enable_if_t v) { + assert(fid < fields.size()); + fields[fid] = v; + } + + RankedTensorType getTensorType() const { return rType; } + Storage getFields() const { return fields; } + + Type getElementType(unsigned fidx) const { + return fields[fidx].getType().template cast().getElementType(); + } + + // TODO: a better places for these functions should be in + // SparseTensorEncodingAttr. + Type getPtrElementType() const { + auto *ctx = rType.getContext(); + unsigned ptrWidth = getSparseTensorEncoding(rType).getPointerBitWidth(); + Type indexType = IndexType::get(ctx); + return ptrWidth ? IntegerType::get(ctx, ptrWidth) : indexType; + } + + Type getIdxElementType() const { + auto *ctx = rType.getContext(); + unsigned idxWidth = getSparseTensorEncoding(rType).getIndexBitWidth(); + Type indexType = IndexType::get(ctx); + return idxWidth ? IntegerType::get(ctx, idxWidth) : indexType; + } + +private: + unsigned getFieldIndex(unsigned dim, SparseTensorFieldKind kind) const { + unsigned fieldIdx = -1u; + foreachFieldInSparseTensor( + getSparseTensorEncoding(rType), + [dim, kind, &fieldIdx](unsigned fIdx, SparseTensorFieldKind fKind, + unsigned fDim, DimLevelType dlt) -> bool { + if (fDim == dim && kind == fKind) { + fieldIdx = fIdx; + // Returns false to break the iteration. + return false; + } + return true; + }); + assert(fieldIdx != -1u); + return fieldIdx; + } + + RankedTensorType rType; + Storage fields; +}; + +using SparseTensorDescriptor = SparseTensorDescriptorImpl; +using MutSparseTensorDescriptor = SparseTensorDescriptorImpl; + +//===----------------------------------------------------------------------===// +// SparseTensorLoopEmiter class, manages sparse tensors and helps to +// generate loop structure to (co)-iterate sparse tensors. // // An example usage: // To generate the following loops over T1 and T2 @@ -345,15 +619,15 @@ using OutputUpdater = function_ref; - /// Constructor: take an array of tensors inputs, on which the generated loops - /// will iterate on. The index of the tensor in the array is also the + /// Constructor: take an array of tensors inputs, on which the generated + /// loops will iterate on. The index of the tensor in the array is also the /// tensor id (tid) used in related functions. /// If isSparseOut is set, loop emitter assume that the sparse output tensor /// is empty, and will always generate loops on it based on the dim sizes. /// An optional array could be provided (by sparsification) to indicate the /// loop id sequence that will be generated. It is used to establish the - /// mapping between affineDimExpr to the corresponding loop index in the loop - /// stack that are maintained by the loop emitter. + /// mapping between affineDimExpr to the corresponding loop index in the + /// loop stack that are maintained by the loop emitter. explicit SparseTensorLoopEmitter(ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false, @@ -368,8 +642,8 @@ /// Generates a list of operations to compute the affine expression. Value genAffine(OpBuilder &builder, AffineExpr a, Location loc); - /// Enters a new loop sequence, the loops within the same sequence starts from - /// the break points of previous loop instead of starting over from 0. + /// Enters a new loop sequence, the loops within the same sequence starts + /// from the break points of previous loop instead of starting over from 0. /// e.g., /// { /// // loop sequence start. @@ -524,10 +798,10 @@ /// scf.reduce.return %val /// } /// } - /// NOTE: only one instruction will be moved into reduce block, transformation - /// will fail if multiple instructions are used to compute the reduction - /// value. - /// Return %ret to user, while %val is provided by users (`reduc`). + /// NOTE: only one instruction will be moved into reduce block, + /// transformation will fail if multiple instructions are used to compute + /// the reduction value. Return %ret to user, while %val is provided by + /// users (`reduc`). void exitForLoop(RewriterBase &rewriter, Location loc, MutableArrayRef reduc); @@ -535,9 +809,9 @@ void exitCoIterationLoop(OpBuilder &builder, Location loc, MutableArrayRef reduc); - /// A optional string attribute that should be attached to the loop generated - /// by loop emitter, it might help following passes to identify loops that - /// operates on sparse tensors more easily. + /// A optional string attribute that should be attached to the loop + /// generated by loop emitter, it might help following passes to identify + /// loops that operates on sparse tensors more easily. StringAttr loopTag; /// Whether the loop emitter needs to treat the last tensor as the output /// tensor. @@ -556,7 +830,8 @@ std::vector> idxBuffer; // to_indices std::vector valBuffer; // to_value - // Loop Stack, stores the information of all the nested loops that are alive. + // Loop Stack, stores the information of all the nested loops that are + // alive. std::vector loopStack; // Loop Sequence Stack, stores the unversial index for the current loop diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -90,6 +90,118 @@ return val; } +void sparse_tensor::foreachFieldInSparseTensor( + const SparseTensorEncodingAttr enc, + llvm::function_ref + callback) { + assert(enc); + +#define RETURN_ON_FALSE(idx, kind, dim, dlt) \ + if (!(callback(idx, kind, dim, dlt))) \ + return; + + RETURN_ON_FALSE(dimSizesIdx, SparseTensorFieldKind::DimSizes, -1u, + DimLevelType::Undef); + RETURN_ON_FALSE(memSizesIdx, SparseTensorFieldKind::MemSizes, -1u, + DimLevelType::Undef); + + static_assert(dataFieldIdx == memSizesIdx + 1); + unsigned fieldIdx = dataFieldIdx; + // Per-dimension storage. + for (unsigned r = 0, rank = enc.getDimLevelType().size(); r < rank; r++) { + // Dimension level types apply in order to the reordered dimension. + // As a result, the compound type can be constructed directly in the given + // order. + auto dlt = getDimLevelType(enc, r); + if (isCompressedDLT(dlt)) { + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PtrMemRef, r, dlt); + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); + } else if (isSingletonDLT(dlt)) { + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); + } else { + assert(isDenseDLT(dlt)); // no fields + } + } + // The values array. + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, -1u, + DimLevelType::Undef); + +#undef RETURN_ON_FALSE +} + +void sparse_tensor::foreachFieldAndTypeInSparseTensor( + RankedTensorType rType, + llvm::function_ref + callback) { + auto enc = getSparseTensorEncoding(rType); + assert(enc); + // Construct the basic types. + auto *context = rType.getContext(); + unsigned idxWidth = enc.getIndexBitWidth(); + unsigned ptrWidth = enc.getPointerBitWidth(); + Type indexType = IndexType::get(context); + Type idxType = idxWidth ? IntegerType::get(context, idxWidth) : indexType; + Type ptrType = ptrWidth ? IntegerType::get(context, ptrWidth) : indexType; + Type eltType = rType.getElementType(); + unsigned rank = rType.getShape().size(); + // memref dimSizes + Type dimSizeType = MemRefType::get({rank}, indexType); + // memref memSizes + Type memSizeType = + MemRefType::get({getNumDataFieldsFromEncoding(enc)}, indexType); + // memref pointers + Type ptrMemType = MemRefType::get({ShapedType::kDynamic}, ptrType); + // memref indices + Type idxMemType = MemRefType::get({ShapedType::kDynamic}, idxType); + // memref values + Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType); + + foreachFieldInSparseTensor( + enc, + [dimSizeType, memSizeType, ptrMemType, idxMemType, valMemType, + callback](unsigned fieldIdx, SparseTensorFieldKind fieldKind, + unsigned dim, DimLevelType dlt) -> bool { + switch (fieldKind) { + case SparseTensorFieldKind::DimSizes: + return callback(dimSizeType, fieldIdx, fieldKind, dim, dlt); + case SparseTensorFieldKind::MemSizes: + return callback(memSizeType, fieldIdx, fieldKind, dim, dlt); + case SparseTensorFieldKind::PtrMemRef: + return callback(ptrMemType, fieldIdx, fieldKind, dim, dlt); + case SparseTensorFieldKind::IdxMemRef: + return callback(idxMemType, fieldIdx, fieldKind, dim, dlt); + case SparseTensorFieldKind::ValMemRef: + return callback(valMemType, fieldIdx, fieldKind, dim, dlt); + }; + }); +} + +unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) { + unsigned numFields = 0; + foreachFieldInSparseTensor(enc, + [&numFields](unsigned, SparseTensorFieldKind, + unsigned, DimLevelType) -> bool { + numFields++; + return true; + }); + return numFields; +} + +unsigned +sparse_tensor::getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) { + unsigned numFields = 0; // one value memref + foreachFieldInSparseTensor(enc, + [&numFields](unsigned fidx, SparseTensorFieldKind, + unsigned, DimLevelType) -> bool { + if (fidx >= dataFieldIdx) + numFields++; + return true; + }); + assert(numFields == getNumFieldsFromEncoding(enc) - dataFieldIdx); + return numFields; +} //===----------------------------------------------------------------------===// // Sparse tensor loop emitter class implementations //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -30,11 +30,6 @@ using namespace mlir::sparse_tensor; namespace { - -static constexpr uint64_t dimSizesIdx = 0; -static constexpr uint64_t memSizesIdx = 1; -static constexpr uint64_t fieldsIdx = 2; - //===----------------------------------------------------------------------===// // Helper methods. //===----------------------------------------------------------------------===// @@ -44,6 +39,11 @@ return llvm::cast(tensor.getDefiningOp()); } +static SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) { + auto tuple = getTuple(tensor); + return SparseTensorDescriptor(tuple.getResultTypes()[0], tuple.getInputs()); +} + /// Packs the given values as a "tuple" value. static Value genTuple(OpBuilder &builder, Location loc, Type tp, ValueRange values) { @@ -51,6 +51,14 @@ .getResult(0); } +static Value genTuple(OpBuilder &builder, Location loc, + SparseTensorDescriptor desc) { + return builder + .create(loc, desc.getTensorType(), + desc.getFields()) + .getResult(0); +} + /// Flatten a list of operands that may contain sparse tensors. static void flattenOperands(ValueRange operands, SmallVectorImpl &flattened) { @@ -96,7 +104,7 @@ /// Creates a straightforward counting for-loop. static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, - SmallVectorImpl &fields, + MutableArrayRef fields, Value lower = Value()) { Type indexType = builder.getIndexType(); if (!lower) @@ -113,81 +121,58 @@ /// original dimension 'dim'. Returns None if no sparse encoding is /// attached to the given tensor type. static Optional sizeFromTensorAtDim(OpBuilder &builder, Location loc, - RankedTensorType tensorTp, - Value adaptedValue, unsigned dim) { - auto enc = getSparseTensorEncoding(tensorTp); - if (!enc) - return llvm::None; - + SparseTensorDescriptor desc, + unsigned dim) { + RankedTensorType rtp = desc.getTensorType(); // Access into static dimension can query original type directly. // Note that this is typically already done by DimOp's folding. - auto shape = tensorTp.getShape(); + auto shape = rtp.getShape(); if (!ShapedType::isDynamic(shape[dim])) return constantIndex(builder, loc, shape[dim]); // Any other query can consult the dimSizes array at field DimSizesIdx, // accounting for the reordering applied to the sparse storage. - auto tuple = getTuple(adaptedValue); - Value idx = constantIndex(builder, loc, toStoredDim(tensorTp, dim)); - return builder - .create(loc, tuple.getInputs()[dimSizesIdx], idx) + Value idx = constantIndex(builder, loc, toStoredDim(rtp, dim)); + return builder.create(loc, desc.getDimSizesMemRef(), idx) .getResult(); } +//// Gets the dimension size at the given stored dimension 'd', either as a +//// constant for a static size, or otherwise dynamically descizes. +// Value sizeAtStoredDim(OpBuilder &builder, Location loc, Rankdesc rtp, +// SmallVectorImpl &fieldesc d) { +// unsigned dim = tdesc, d); +// auto shape descpe(); +// if (!ShapedType::isDyndescim])) +// return constantIndex(builder, descim]); +// return genLoad(builder, loc, fiedescIdx], +// constandescdesc; +//} + // Gets the dimension size at the given stored dimension 'd', either as a // constant for a static size, or otherwise dynamically through memSizes. -Value sizeAtStoredDim(OpBuilder &builder, Location loc, RankedTensorType rtp, - SmallVectorImpl &fields, unsigned d) { +Value sizeAtStoredDim(OpBuilder &builder, Location loc, + SparseTensorDescriptor desc, unsigned d) { + RankedTensorType rtp = desc.getTensorType(); unsigned dim = toOrigDim(rtp, d); auto shape = rtp.getShape(); if (!ShapedType::isDynamic(shape[dim])) return constantIndex(builder, loc, shape[dim]); - return genLoad(builder, loc, fields[dimSizesIdx], - constantIndex(builder, loc, d)); -} -/// Translates field index to memSizes index. -static unsigned getMemSizesIndex(unsigned field) { - assert(fieldsIdx <= field); - return field - fieldsIdx; + return genLoad(builder, loc, desc.getDimSizesMemRef(), + constantIndex(builder, loc, d)); } -/// Creates a pushback op for given field and updates the fields array -/// accordingly. This operation also updates the memSizes contents. static void createPushback(OpBuilder &builder, Location loc, - SmallVectorImpl &fields, unsigned field, + MutSparseTensorDescriptor desc, unsigned fidx, Value value, Value repeat = Value()) { - assert(fieldsIdx <= field && field < fields.size()); - Type etp = fields[field].getType().cast().getElementType(); - fields[field] = builder.create( - loc, fields[field].getType(), fields[memSizesIdx], fields[field], - toType(builder, loc, value, etp), APInt(64, getMemSizesIndex(field)), + Type etp = desc.getElementType(fidx); + Value field = desc.getField(fidx); + Value newField = builder.create( + loc, field.getType(), desc.getMemSizesMemRef(), field, + toType(builder, loc, value, etp), APInt(64, getFieldMemSizesIndex(fidx)), repeat); -} - -/// Returns field index of sparse tensor type for pointers/indices, when set. -static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) { - assert(getSparseTensorEncoding(type)); - RankedTensorType rType = type.cast(); - unsigned field = fieldsIdx; // start past header - for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) { - if (isCompressedDim(rType, r)) { - if (r == ptrDim) - return field; - field++; - if (r == idxDim) - return field; - field++; - } else if (isSingletonDim(rType, r)) { - if (r == idxDim) - return field; - field++; - } else { - assert(isDenseDim(rType, r)); // no fields - } - } - assert(ptrDim == -1u && idxDim == -1u); - return field + 1; // return values field index + desc.setField(fidx, newField); } /// Maps a sparse tensor type to the appropriate compounded buffers. @@ -196,66 +181,24 @@ auto enc = getSparseTensorEncoding(type); if (!enc) return llvm::None; - // Construct the basic types. - auto *context = type.getContext(); - unsigned idxWidth = enc.getIndexBitWidth(); - unsigned ptrWidth = enc.getPointerBitWidth(); + RankedTensorType rType = type.cast(); - Type indexType = IndexType::get(context); - Type idxType = idxWidth ? IntegerType::get(context, idxWidth) : indexType; - Type ptrType = ptrWidth ? IntegerType::get(context, ptrWidth) : indexType; - Type eltType = rType.getElementType(); - // - // Sparse tensor storage scheme for rank-dimensional tensor is organized - // as a single compound type with the following fields. Note that every - // memref with ? size actually behaves as a "vector", i.e. the stored - // size is the capacity and the used size resides in the memSizes array. - // - // struct { - // memref dimSizes ; size in each dimension - // memref memSizes ; sizes of ptrs/inds/values - // ; per-dimension d: - // ; if dense: - // - // ; if compresed: - // memref pointers-d ; pointers for sparse dim d - // memref indices-d ; indices for sparse dim d - // ; if singleton: - // memref indices-d ; indices for singleton dim d - // memref values ; values - // }; - // - unsigned rank = rType.getShape().size(); - unsigned lastField = getFieldIndex(type, -1u, -1u); - // The dimSizes array and memSizes array. - fields.push_back(MemRefType::get({rank}, indexType)); - fields.push_back(MemRefType::get({getMemSizesIndex(lastField)}, indexType)); - // Per-dimension storage. - for (unsigned r = 0; r < rank; r++) { - // Dimension level types apply in order to the reordered dimension. - // As a result, the compound type can be constructed directly in the given - // order. Clients of this type know what field is what from the sparse - // tensor type. - if (isCompressedDim(rType, r)) { - fields.push_back(MemRefType::get({ShapedType::kDynamic}, ptrType)); - fields.push_back(MemRefType::get({ShapedType::kDynamic}, idxType)); - } else if (isSingletonDim(rType, r)) { - fields.push_back(MemRefType::get({ShapedType::kDynamic}, idxType)); - } else { - assert(isDenseDim(rType, r)); // no fields - } - } - // The values array. - fields.push_back(MemRefType::get({ShapedType::kDynamic}, eltType)); - assert(fields.size() == lastField); + foreachFieldAndTypeInSparseTensor( + rType, + [&fields](Type fieldType, unsigned fieldIdx, + SparseTensorFieldKind /*fieldKind*/, unsigned /*dim*/, + DimLevelType /*dlt*/) -> bool { + assert(fieldIdx == fields.size()); + fields.push_back(fieldType); + return true; + }); return success(); } /// Generates code that allocates a sparse storage scheme for given rank. static void allocSchemeForRank(OpBuilder &builder, Location loc, - RankedTensorType rtp, - SmallVectorImpl &fields, unsigned field, - unsigned r0) { + MutSparseTensorDescriptor desc, unsigned r0) { + RankedTensorType rtp = desc.getTensorType(); unsigned rank = rtp.getShape().size(); Value linear = constantIndex(builder, loc, 1); for (unsigned r = r0; r < rank; r++) { @@ -263,36 +206,35 @@ // Append linear x pointers, initialized to zero. Since each compressed // dimension initially already has a single zero entry, this maintains // the desired "linear + 1" length property at all times. - unsigned ptrWidth = getSparseTensorEncoding(rtp).getPointerBitWidth(); - Type indexType = builder.getIndexType(); - Type ptrType = ptrWidth ? builder.getIntegerType(ptrWidth) : indexType; + Type ptrType = desc.getPtrElementType(); Value ptrZero = constantZero(builder, loc, ptrType); - createPushback(builder, loc, fields, field, ptrZero, linear); + unsigned fidx = desc.getPtrMemRefIndex(r); + createPushback(builder, loc, desc, fidx, ptrZero, linear); return; } if (isSingletonDim(rtp, r)) { return; // nothing to do } // Keep compounding the size, but nothing needs to be initialized - // at this level. We will eventually reach a compressed level or - // otherwise the values array for the from-here "all-dense" case. - assert(isDenseDim(rtp, r)); - Value size = sizeAtStoredDim(builder, loc, rtp, fields, r); - linear = builder.create(loc, linear, size); + // at this level. We will eventually reach a compressed level or + // otherwise the values array for the from-here "all-dense" case. + assert(isDenseDim(rtp, r)); + Value size = sizeAtStoredDim(builder, loc, desc, r); + linear = builder.create(loc, linear, size); } // Reached values array so prepare for an insertion. Value valZero = constantZero(builder, loc, rtp.getElementType()); - createPushback(builder, loc, fields, field, valZero, linear); - assert(fields.size() == ++field); + createPushback(builder, loc, desc, desc.getValMemRefIndex(), valZero, linear); } /// Creates allocation operation. -static Value createAllocation(OpBuilder &builder, Location loc, Type type, - Value sz, bool enableInit) { - auto memType = MemRefType::get({ShapedType::kDynamic}, type); - Value buffer = builder.create(loc, memType, sz); +static Value createAllocation(OpBuilder &builder, Location loc, + MemRefType memRefType, Value sz, + bool enableInit) { + Value buffer = builder.create(loc, memRefType, sz); + Type elemType = memRefType.getElementType(); if (enableInit) { - Value fillValue = - builder.create(loc, type, builder.getZeroAttr(type)); + Value fillValue = builder.create( + loc, elemType, builder.getZeroAttr(elemType)); builder.create(loc, fillValue, buffer); } return buffer; @@ -308,72 +250,69 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type, ValueRange dynSizes, bool enableInit, SmallVectorImpl &fields) { - auto enc = getSparseTensorEncoding(type); - assert(enc); - // Construct the basic types. - unsigned idxWidth = enc.getIndexBitWidth(); - unsigned ptrWidth = enc.getPointerBitWidth(); RankedTensorType rtp = type.cast(); - Type indexType = builder.getIndexType(); - Type idxType = idxWidth ? builder.getIntegerType(idxWidth) : indexType; - Type ptrType = ptrWidth ? builder.getIntegerType(ptrWidth) : indexType; - Type eltType = rtp.getElementType(); - auto shape = rtp.getShape(); - unsigned rank = shape.size(); Value heuristic = constantIndex(builder, loc, 16); + + foreachFieldAndTypeInSparseTensor( + rtp, + [&](Type fType, unsigned fIdx, SparseTensorFieldKind fKind, + unsigned /*dim*/, DimLevelType /*dlt*/) -> bool { + assert(fields.size() == fIdx); + auto memRefTp = fType.cast(); + Value field; + switch (fKind) { + case SparseTensorFieldKind::DimSizes: + case SparseTensorFieldKind::MemSizes: + field = builder.create(loc, memRefTp); + break; + case SparseTensorFieldKind::PtrMemRef: + case SparseTensorFieldKind::IdxMemRef: + case SparseTensorFieldKind::ValMemRef: + field = + createAllocation(builder, loc, memRefTp, heuristic, enableInit); + break; + } + assert(field); + fields.push_back(field); + // Returns true to ontinue the iteration. + return true; + }); + + MutSparseTensorDescriptor desc(rtp, fields); + // Build original sizes. SmallVector sizes; + auto shape = rtp.getShape(); + unsigned rank = shape.size(); for (unsigned r = 0, o = 0; r < rank; r++) { if (ShapedType::isDynamic(shape[r])) sizes.push_back(dynSizes[o++]); else sizes.push_back(constantIndex(builder, loc, shape[r])); } - // The dimSizes array and memSizes array. - unsigned lastField = getFieldIndex(type, -1u, -1u); - Value dimSizes = - builder.create(loc, MemRefType::get({rank}, indexType)); - Value memSizes = builder.create( - loc, MemRefType::get({getMemSizesIndex(lastField)}, indexType)); - fields.push_back(dimSizes); - fields.push_back(memSizes); - // Per-dimension storage. - for (unsigned r = 0; r < rank; r++) { - if (isCompressedDim(rtp, r)) { - fields.push_back( - createAllocation(builder, loc, ptrType, heuristic, enableInit)); - fields.push_back( - createAllocation(builder, loc, idxType, heuristic, enableInit)); - } else if (isSingletonDim(rtp, r)) { - fields.push_back( - createAllocation(builder, loc, idxType, heuristic, enableInit)); - } else { - assert(isDenseDim(rtp, r)); // no fields - } - } - // The values array. - fields.push_back( - createAllocation(builder, loc, eltType, heuristic, enableInit)); - assert(fields.size() == lastField); // Initialize the storage scheme to an empty tensor. Initialized memSizes // to all zeros, sets the dimSizes to known values and gives all pointer // fields an initial zero entry, so that it is easier to maintain the // "linear + 1" length property. builder.create( - loc, ValueRange{constantZero(builder, loc, indexType)}, - ValueRange{memSizes}); // zero memSizes - Value ptrZero = constantZero(builder, loc, ptrType); - for (unsigned r = 0, field = fieldsIdx; r < rank; r++) { + loc, constantZero(builder, loc, builder.getIndexType()), + desc.getMemSizesMemRef()); // zero memSizes + + Value ptrZero = constantZero(builder, loc, desc.getPtrElementType()); + for (unsigned r = 0; r < rank; r++) { unsigned ro = toOrigDim(rtp, r); - genStore(builder, loc, sizes[ro], dimSizes, constantIndex(builder, loc, r)); + // Fills dim sizes array. + genStore(builder, loc, sizes[ro], desc.getDimSizesMemRef(), + constantIndex(builder, loc, r)); + + // Pushes a leading zero to pointers memref. if (isCompressedDim(rtp, r)) { - createPushback(builder, loc, fields, field, ptrZero); - field += 2; - } else if (isSingletonDim(rtp, r)) { - field += 1; + unsigned fidx = + desc.getDataFieldIndex(r, SparseTensorFieldKind::PtrMemRef); + createPushback(builder, loc, desc, fidx, ptrZero); } } - allocSchemeForRank(builder, loc, rtp, fields, fieldsIdx, /*rank=*/0); + allocSchemeForRank(builder, loc, desc, /*rank=*/0); } /// Helper method that generates block specific to compressed case: @@ -397,19 +336,22 @@ /// } /// pos[d] = next static Value genCompressed(OpBuilder &builder, Location loc, - RankedTensorType rtp, SmallVectorImpl &fields, + MutSparseTensorDescriptor desc, SmallVectorImpl &indices, Value value, - Value pos, unsigned field, unsigned d) { + Value pos, unsigned d) { + RankedTensorType rtp = desc.getTensorType(); unsigned rank = rtp.getShape().size(); SmallVector types; Type indexType = builder.getIndexType(); Type boolType = builder.getIntegerType(1); + unsigned idxIndex = desc.getIdxMemRefIndex(d); + unsigned ptrIndex = desc.getPtrMemRefIndex(d); Value one = constantIndex(builder, loc, 1); Value pp1 = builder.create(loc, pos, one); - Value plo = genLoad(builder, loc, fields[field], pos); - Value phi = genLoad(builder, loc, fields[field], pp1); - Value psz = constantIndex(builder, loc, getMemSizesIndex(field + 1)); - Value msz = genLoad(builder, loc, fields[memSizesIdx], psz); + Value plo = genLoad(builder, loc, desc.getField(ptrIndex), pos); + Value phi = genLoad(builder, loc, desc.getField(ptrIndex), pp1); + Value psz = constantIndex(builder, loc, getFieldMemSizesIndex(idxIndex)); + Value msz = genLoad(builder, loc, desc.getMemSizesMemRef(), psz); Value phim1 = builder.create( loc, toType(builder, loc, phi, indexType), one); // Conditional expression. @@ -419,49 +361,55 @@ scf::IfOp ifOp1 = builder.create(loc, types, lt, /*else*/ true); types.pop_back(); builder.setInsertionPointToStart(&ifOp1.getThenRegion().front()); - Value crd = genLoad(builder, loc, fields[field + 1], phim1); + Value crd = genLoad(builder, loc, desc.getField(idxIndex), phim1); Value eq = builder.create(loc, arith::CmpIPredicate::eq, toType(builder, loc, crd, indexType), indices[d]); builder.create(loc, eq); builder.setInsertionPointToStart(&ifOp1.getElseRegion().front()); if (d > 0) - genStore(builder, loc, msz, fields[field], pos); + genStore(builder, loc, msz, desc.getField(ptrIndex), pos); builder.create(loc, constantI1(builder, loc, false)); builder.setInsertionPointAfter(ifOp1); Value p = ifOp1.getResult(0); - // If present construct. Note that for a non-unique dimension level, we simply - // set the condition to false and rely on CSE/DCE to clean up the IR. + // If present construct. Note that for a non-unique dimension level, we + // simply set the condition to false and rely on CSE/DCE to clean up the IR. // // TODO: generate less temporary IR? // - for (unsigned i = 0, e = fields.size(); i < e; i++) - types.push_back(fields[i].getType()); + for (unsigned i = 0, e = desc.getNumFields(); i < e; i++) + types.push_back(desc.getField(i).getType()); types.push_back(indexType); if (!isUniqueDim(rtp, d)) p = constantI1(builder, loc, false); scf::IfOp ifOp2 = builder.create(loc, types, p, /*else*/ true); // If present (fields unaffected, update next to phim1). builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); - fields.push_back(phim1); - builder.create(loc, fields); - fields.pop_back(); + + // FIXME: This does not looks like a clean way, but probably the most + // efficient way. + desc.getFields().push_back(phim1); + builder.create(loc, desc.getFields()); + desc.getFields().pop_back(); + // If !present (changes fields, update next). builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); Value mszp1 = builder.create(loc, msz, one); - genStore(builder, loc, mszp1, fields[field], pp1); - createPushback(builder, loc, fields, field + 1, indices[d]); + genStore(builder, loc, mszp1, desc.getField(ptrIndex), pp1); + createPushback(builder, loc, desc, idxIndex, indices[d]); // Prepare the next dimension "as needed". if ((d + 1) < rank) - allocSchemeForRank(builder, loc, rtp, fields, field + 2, d + 1); - fields.push_back(msz); - builder.create(loc, fields); - fields.pop_back(); + allocSchemeForRank(builder, loc, desc, d + 1); + + desc.getFields().push_back(msz); + builder.create(loc, desc.getFields()); + desc.getFields().pop_back(); + // Update fields and return next pos. builder.setInsertionPointAfter(ifOp2); unsigned o = 0; - for (unsigned i = 0, e = fields.size(); i < e; i++) - fields[i] = ifOp2.getResult(o++); + for (unsigned i = 0, e = desc.getNumFields(); i < e; i++) + desc.setField(i, ifOp2.getResult(o++)); return ifOp2.getResult(o); } @@ -476,12 +424,12 @@ /// /// TODO: better unord/not-unique; also generalize, optimize, specialize! /// -static void genInsert(OpBuilder &builder, Location loc, RankedTensorType rtp, - SmallVectorImpl &fields, +static void genInsert(OpBuilder &builder, Location loc, + MutSparseTensorDescriptor desc, SmallVectorImpl &indices, Value value) { + RankedTensorType rtp = desc.getTensorType(); unsigned rank = rtp.getShape().size(); assert(rank == indices.size()); - unsigned field = fieldsIdx; // start past header Value pos = constantZero(builder, loc, builder.getIndexType()); // Generate code for every dimension. for (unsigned d = 0; d < rank; d++) { @@ -493,67 +441,67 @@ // } // pos[d] = indices.size() - 1 // - pos = genCompressed(builder, loc, rtp, fields, indices, value, pos, field, - d); - field += 2; + pos = genCompressed(builder, loc, desc, indices, value, pos, d); } else if (isSingletonDim(rtp, d)) { // Create: // indices[d].push_back(i[d]) // pos[d] = pos[d-1] // - createPushback(builder, loc, fields, field, indices[d]); - field += 1; + unsigned fidx = desc.getIdxMemRefIndex(d); + createPushback(builder, loc, desc, fidx, indices[d]); } else { assert(isDenseDim(rtp, d)); // Construct the new position as: // pos[d] = size * pos[d-1] + i[d] // - Value size = sizeAtStoredDim(builder, loc, rtp, fields, d); + Value size = sizeAtStoredDim(builder, loc, desc, d); Value mult = builder.create(loc, size, pos); pos = builder.create(loc, mult, indices[d]); } } // Reached the actual value append/insert. - if (!isDenseDim(rtp, rank - 1)) - createPushback(builder, loc, fields, field++, value); - else - genStore(builder, loc, value, fields[field++], pos); - assert(fields.size() == field); + unsigned valIdx = desc.getValMemRefIndex(); + if (!isDenseDim(rtp, rank - 1)) { + createPushback(builder, loc, desc, valIdx, value); + } else + genStore(builder, loc, value, desc.getValMemRef(), pos); } /// Generations insertion finalization code. -static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp, - SmallVectorImpl &fields) { +static void genEndInsert(OpBuilder &builder, Location loc, + MutSparseTensorDescriptor desc) { + RankedTensorType rtp = desc.getTensorType(); unsigned rank = rtp.getShape().size(); - unsigned field = fieldsIdx; // start past header for (unsigned d = 0; d < rank; d++) { if (isCompressedDim(rtp, d)) { // Compressed dimensions need a pointer cleanup for all entries // that were not visited during the insertion pass. // - // TODO: avoid cleanup and keep compressed scheme consistent at all times? + // TODO: avoid cleanup and keep compressed scheme consistent at all + // times? // if (d > 0) { unsigned ptrWidth = getSparseTensorEncoding(rtp).getPointerBitWidth(); Type indexType = builder.getIndexType(); Type ptrType = ptrWidth ? builder.getIntegerType(ptrWidth) : indexType; - Value mz = constantIndex(builder, loc, getMemSizesIndex(field)); - Value hi = genLoad(builder, loc, fields[memSizesIdx], mz); + Value ptrMemRef = desc.getPtrMemRef(d); + Value mz = constantIndex(builder, loc, desc.getPtrMemSizesIndex(d)); + Value hi = genLoad(builder, loc, desc.getMemSizesMemRef(), mz); Value zero = constantIndex(builder, loc, 0); Value one = constantIndex(builder, loc, 1); // Vector of only one, but needed by createFor's prototype. - SmallVector inits{genLoad(builder, loc, fields[field], zero)}; + SmallVector inits{genLoad(builder, loc, ptrMemRef, zero)}; scf::ForOp loop = createFor(builder, loc, hi, inits, one); Value i = loop.getInductionVar(); Value oldv = loop.getRegionIterArg(0); - Value newv = genLoad(builder, loc, fields[field], i); + Value newv = genLoad(builder, loc, ptrMemRef, i); Value ptrZero = constantZero(builder, loc, ptrType); Value cond = builder.create( loc, arith::CmpIPredicate::eq, newv, ptrZero); scf::IfOp ifOp = builder.create(loc, TypeRange(ptrType), cond, /*else*/ true); builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - genStore(builder, loc, oldv, fields[field], i); + genStore(builder, loc, oldv, ptrMemRef, i); builder.create(loc, oldv); builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); builder.create(loc, newv); @@ -561,14 +509,10 @@ builder.create(loc, ifOp.getResult(0)); builder.setInsertionPointAfter(loop); } - field += 2; - } else if (isSingletonDim(rtp, d)) { - field++; } else { - assert(isDenseDim(rtp, d)); + assert(isDenseDim(rtp, d) || isSingletonDim(rtp, d)); } } - assert(fields.size() == ++field); } //===----------------------------------------------------------------------===// @@ -659,12 +603,12 @@ matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Optional index = op.getConstantIndex(); - if (!index) + if (!index || !getSparseTensorEncoding(adaptor.getSource().getType())) return failure(); - auto sz = - sizeFromTensorAtDim(rewriter, op.getLoc(), - op.getSource().getType().cast(), - adaptor.getSource(), *index); + + auto desc = getDescriptorFromTensorTuple(adaptor.getSource()); + auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *index); + if (!sz) return failure(); @@ -760,10 +704,11 @@ // Prepare fields. SmallVector fields(tuple.getInputs()); // Generate optional insertion finalization code. + MutSparseTensorDescriptor desc(srcType, fields); if (op.getHasInserts()) - genEndInsert(rewriter, op.getLoc(), srcType, fields); + genEndInsert(rewriter, op.getLoc(), desc); // Replace operation with resulting memrefs. - rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), srcType, fields)); + rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc)); return success(); } }; @@ -775,7 +720,10 @@ LogicalResult matchAndRewrite(ExpandOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + if (!getSparseTensorEncoding(op.getTensor().getType())) + return failure(); Location loc = op->getLoc(); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); RankedTensorType srcType = op.getTensor().getType().cast(); Type eltType = srcType.getElementType(); @@ -787,8 +735,7 @@ // dimension size, translated back to original dimension). Note that we // recursively rewrite the new DimOp on the **original** tensor. unsigned innerDim = toOrigDim(srcType, srcType.getRank() - 1); - auto sz = sizeFromTensorAtDim(rewriter, loc, srcType, adaptor.getTensor(), - innerDim); + auto sz = sizeFromTensorAtDim(rewriter, loc, desc, innerDim); assert(sz); // This for sure is a sparse tensor // Generate a memref for `sz` elements of type `t`. auto genAlloc = [&](Type t) { @@ -839,6 +786,7 @@ // Prepare fields and indices. SmallVector fields(tuple.getInputs()); SmallVector indices(adaptor.getIndices()); + MutSparseTensorDescriptor desc(dstType, fields); // If the innermost dimension is ordered, we need to sort the indices // in the "added" array prior to applying the compression. unsigned rank = dstType.getShape().size(); @@ -859,17 +807,17 @@ // filled[index] = false; // yield new_memrefs // } - scf::ForOp loop = createFor(rewriter, loc, count, fields); + scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields()); Value i = loop.getInductionVar(); Value index = genLoad(rewriter, loc, added, i); Value value = genLoad(rewriter, loc, values, index); indices.push_back(index); // TODO: faster for subsequent insertions? - genInsert(rewriter, loc, dstType, fields, indices, value); + genInsert(rewriter, loc, desc, indices, value); genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values, index); genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, index); - rewriter.create(loc, fields); + rewriter.create(loc, desc.getFields()); rewriter.setInsertionPointAfter(loop); Value result = genTuple(rewriter, loc, dstType, loop->getResults()); // Deallocate the buffers on exit of the full loop nest. @@ -897,11 +845,12 @@ // Prepare fields and indices. SmallVector fields(tuple.getInputs()); SmallVector indices(adaptor.getIndices()); + MutSparseTensorDescriptor desc(dstType, fields); // Generate insertion. Value value = adaptor.getValue(); - genInsert(rewriter, op->getLoc(), dstType, fields, indices, value); + genInsert(rewriter, op->getLoc(), desc, indices, value); // Replace operation with resulting memrefs. - rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), dstType, fields)); + rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc)); return success(); } }; @@ -918,11 +867,9 @@ // Replace the requested pointer access with corresponding field. // The cast_op is inserted by type converter to intermix 1:N type // conversion. - auto tuple = getTuple(adaptor.getTensor()); - unsigned idx = Base::getIndexForOp(tuple, op); - auto fields = tuple.getInputs(); - assert(idx < fields.size()); - rewriter.replaceOp(op, fields[idx]); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + Value field = Base::getFieldForOp(desc, op); + rewriter.replaceOp(op, field); return success(); } }; @@ -933,10 +880,10 @@ public: using SparseGetterOpConverter::SparseGetterOpConverter; // Callback for SparseGetterOpConverter. - static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/, - ToPointersOp op) { + static Value getFieldForOp(const SparseTensorDescriptor &desc, + ToPointersOp op) { uint64_t dim = op.getDimension().getZExtValue(); - return getFieldIndex(op.getTensor().getType(), /*ptrDim=*/dim, -1u); + return desc.getPtrMemRef(dim); } }; @@ -946,10 +893,10 @@ public: using SparseGetterOpConverter::SparseGetterOpConverter; // Callback for SparseGetterOpConverter. - static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/, - ToIndicesOp op) { + static Value getFieldForOp(const SparseTensorDescriptor &desc, + ToIndicesOp op) { uint64_t dim = op.getDimension().getZExtValue(); - return getFieldIndex(op.getTensor().getType(), -1u, /*idxDim=*/dim); + return desc.getIdxMemRef(dim); } }; @@ -959,10 +906,9 @@ public: using SparseGetterOpConverter::SparseGetterOpConverter; // Callback for SparseGetterOpConverter. - static unsigned getIndexForOp(UnrealizedConversionCastOp tuple, - ToValuesOp /*op*/) { - // The last field holds the value buffer. - return tuple.getInputs().size() - 1; + static Value getFieldForOp(const SparseTensorDescriptor &desc, + ToValuesOp /*op*/) { + return desc.getValMemRef(); } }; @@ -994,12 +940,11 @@ matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Query memSizes for the actually stored values size. - auto tuple = getTuple(adaptor.getTensor()); - auto fields = tuple.getInputs(); - unsigned lastField = fields.size() - 1; + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); Value field = - constantIndex(rewriter, op.getLoc(), getMemSizesIndex(lastField)); - rewriter.replaceOpWithNewOp(op, fields[memSizesIdx], field); + constantIndex(rewriter, op.getLoc(), desc.getValMemSizesIndex()); + rewriter.replaceOpWithNewOp(op, desc.getMemSizesMemRef(), + field); return success(); } };