diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -75,6 +75,21 @@ return isSingletonDLT(getDimLevelType(type, d)); } +/// Convenience function to test for dense dimension (0 <= d < rank). +inline bool isDenseDim(SparseTensorEncodingAttr enc, uint64_t d) { + return isDenseDLT(getDimLevelType(enc, d)); +} + +/// Convenience function to test for compressed dimension (0 <= d < rank). +inline bool isCompressedDim(SparseTensorEncodingAttr enc, uint64_t d) { + return isCompressedDLT(getDimLevelType(enc, d)); +} + +/// Convenience function to test for singleton dimension (0 <= d < rank). +inline bool isSingletonDim(SparseTensorEncodingAttr enc, uint64_t d) { + return isSingletonDLT(getDimLevelType(enc, d)); +} + // // Dimension level properties. // diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -311,8 +311,222 @@ } //===----------------------------------------------------------------------===// -// SparseTensorLoopEmiter class, manages sparse tensors and helps to generate -// loop structure to (co)-iterate sparse tensors. +// SparseTensorDescriptor and helpers, manage the sparse tensor memory layout +// scheme. +// +// Sparse tensor storage scheme for rank-dimensional tensor is organized +// as a single compound type with the following fields. Note that every +// memref with ? size actually behaves as a "vector", i.e. the stored +// size is the capacity and the used size resides in the memSizes array. +// +// struct { +// memref dimSizes ; size in each dimension +// memref memSizes ; sizes of ptrs/inds/values +// ; per-dimension d: +// ; if dense: +// +// ; if compresed: +// memref pointers-d ; pointers for sparse dim d +// memref indices-d ; indices for sparse dim d +// ; if singleton: +// memref indices-d ; indices for singleton dim d +// memref values ; values +// }; +// +//===----------------------------------------------------------------------===// +enum class SparseTensorFieldKind { + DimSizes, + MemSizes, + PtrMemRef, + IdxMemRef, + ValMemRef +}; + +constexpr uint64_t dimSizesIdx = 0; +constexpr uint64_t memSizesIdx = dimSizesIdx + 1; +constexpr uint64_t dataFieldIdx = memSizesIdx + 1; + +/// For each field that will be allocated for the given sparse tensor encoding, +/// calls the callback with the corresponding field index, field kind, dimension +/// (for sparse tensor level memrefs) and dimlevelType. +/// The field index always starts with zero and increments by one between two +/// callback invocations. +/// Ideally, all other methods should rely on this function to query a sparse +/// tensor fields instead of relying on ad-hoc index computation. +void foreachFieldInSparseTensor( + SparseTensorEncodingAttr, + llvm::function_ref); + +/// Same as above, except that it also builds the Type for the corresponding +/// field. +void foreachFieldAndTypeInSparseTensor( + RankedTensorType, + llvm::function_ref); + +/// Gets the total number of fields for the given sparse tensor encoding. +unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc); + +/// Gets the total number of data fields (index arrays, pointer arrays, and a +/// value array) for the given sparse tensor encoding. +unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc); + +/// Get the index of the field in memSizes (only valid for data fields). +inline unsigned getFieldMemSizesIndex(unsigned fid) { + assert(fid >= dataFieldIdx); + return fid - dataFieldIdx; +} + +/// A helper class around an array of values that corresponding to a sparse +/// tensor, provides a set of meaningful APIs to query and update a particular +/// field in a consistent way. +/// Users should not make assumption on how a sparse tensor is laid out but +/// instead relies on this class to access the right value for the right field. +template +class SparseTensorDescriptorImpl { +private: + template + struct ArrayStorage; + + template <> + struct ArrayStorage { + using ValueArray = ValueRange; + }; + + template <> + struct ArrayStorage { + using ValueArray = SmallVectorImpl &; + }; + + // Uses ValueRange for immuatable descriptors; uses SmallVectorImpl & + // for mutable descriptors. + // Using SmallVector for mutable descriptor allows users to reuse it as a tmp + // buffers to append value for some special cases, though users should be + // responsible to restore the buffer to legal states after their use. It is + // probably not a clean way, but it is the most efficient way to avoid copying + // the fields into another SmallVector. If a more clear way is wanted, we + // should change it to MutableArrayRef instead. + using Storage = typename ArrayStorage::ValueArray; + +public: + SparseTensorDescriptorImpl(Type tp, Storage fields) + : rType(tp.cast()), fields(fields) { + assert(getSparseTensorEncoding(tp) && + getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) == + fields.size()); + // We should make sure the class is trivially copyable (and should be small + // enough) such that we can pass it by value. + static_assert( + std::is_trivially_copyable_v>); + } + + // Implicit (and cheap) type conversion from MutSparseTensorDescriptor to + // SparseTensorDescriptor. + template > + /*implicit*/ SparseTensorDescriptorImpl(std::enable_if_t &mDesc) + : rType(mDesc.getTensorType()), fields(mDesc.getFields()) {} + + /// + /// Getters: get the field index for required field. + /// + + unsigned getPtrMemRefIndex(unsigned ptrDim) const { + return getFieldIndex(ptrDim, SparseTensorFieldKind::PtrMemRef); + } + + unsigned getIdxMemRefIndex(unsigned idxDim) const { + return getFieldIndex(idxDim, SparseTensorFieldKind::IdxMemRef); + } + + unsigned getValMemRefIndex() const { return fields.size() - 1; } + + unsigned getPtrMemSizesIndex(unsigned dim) const { + return getPtrMemRefIndex(dim) - dataFieldIdx; + } + + unsigned getIdxMemSizesIndex(unsigned dim) const { + return getIdxMemRefIndex(dim) - dataFieldIdx; + } + + unsigned getValMemSizesIndex() const { + return getValMemRefIndex() - dataFieldIdx; + } + + unsigned getNumFields() const { return fields.size(); } + + /// + /// Getters: get the value for required field. + /// + + Value getDimSizesMemRef() const { return fields[dimSizesIdx]; } + Value getMemSizesMemRef() const { return fields[memSizesIdx]; } + + Value getPtrMemRef(unsigned ptrDim) const { + return fields[getPtrMemRefIndex(ptrDim)]; + } + + Value getIdxMemRef(unsigned idxDim) const { + return fields[getIdxMemRefIndex(idxDim)]; + } + + Value getValMemRef() const { return fields[getValMemRefIndex()]; } + + Value getField(unsigned fid) const { + assert(fid < fields.size()); + return fields[fid]; + } + + /// + /// Setters: update the value for required field (only enabled for + /// MutSparseTensorDescriptor). + /// + + template + void setField(unsigned fid, std::enable_if_t v) { + assert(fid < fields.size()); + fields[fid] = v; + } + + RankedTensorType getTensorType() const { return rType; } + Storage getFields() const { return fields; } + + Type getElementType(unsigned fidx) const { + return fields[fidx].getType().template cast().getElementType(); + } + +private: + unsigned getFieldIndex(unsigned dim, SparseTensorFieldKind kind) const { + unsigned fieldIdx = -1u; + foreachFieldInSparseTensor( + getSparseTensorEncoding(rType), + [dim, kind, &fieldIdx](unsigned fIdx, SparseTensorFieldKind fKind, + unsigned fDim, DimLevelType dlt) -> bool { + if (fDim == dim && kind == fKind) { + fieldIdx = fIdx; + // Returns false to break the iteration. + return false; + } + return true; + }); + assert(fieldIdx != -1u); + return fieldIdx; + } + + RankedTensorType rType; + Storage fields; +}; + +using SparseTensorDescriptor = SparseTensorDescriptorImpl; +using MutSparseTensorDescriptor = SparseTensorDescriptorImpl; + +//===----------------------------------------------------------------------===// +// SparseTensorLoopEmiter class, manages sparse tensors and helps to +// generate loop structure to (co)-iterate sparse tensors. // // An example usage: // To generate the following loops over T1 and T2 @@ -345,15 +559,15 @@ using OutputUpdater = function_ref; - /// Constructor: take an array of tensors inputs, on which the generated loops - /// will iterate on. The index of the tensor in the array is also the + /// Constructor: take an array of tensors inputs, on which the generated + /// loops will iterate on. The index of the tensor in the array is also the /// tensor id (tid) used in related functions. /// If isSparseOut is set, loop emitter assume that the sparse output tensor /// is empty, and will always generate loops on it based on the dim sizes. /// An optional array could be provided (by sparsification) to indicate the /// loop id sequence that will be generated. It is used to establish the - /// mapping between affineDimExpr to the corresponding loop index in the loop - /// stack that are maintained by the loop emitter. + /// mapping between affineDimExpr to the corresponding loop index in the + /// loop stack that are maintained by the loop emitter. explicit SparseTensorLoopEmitter(ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false, @@ -368,8 +582,8 @@ /// Generates a list of operations to compute the affine expression. Value genAffine(OpBuilder &builder, AffineExpr a, Location loc); - /// Enters a new loop sequence, the loops within the same sequence starts from - /// the break points of previous loop instead of starting over from 0. + /// Enters a new loop sequence, the loops within the same sequence starts + /// from the break points of previous loop instead of starting over from 0. /// e.g., /// { /// // loop sequence start. @@ -524,10 +738,10 @@ /// scf.reduce.return %val /// } /// } - /// NOTE: only one instruction will be moved into reduce block, transformation - /// will fail if multiple instructions are used to compute the reduction - /// value. - /// Return %ret to user, while %val is provided by users (`reduc`). + /// NOTE: only one instruction will be moved into reduce block, + /// transformation will fail if multiple instructions are used to compute + /// the reduction value. Return %ret to user, while %val is provided by + /// users (`reduc`). void exitForLoop(RewriterBase &rewriter, Location loc, MutableArrayRef reduc); @@ -535,9 +749,9 @@ void exitCoIterationLoop(OpBuilder &builder, Location loc, MutableArrayRef reduc); - /// A optional string attribute that should be attached to the loop generated - /// by loop emitter, it might help following passes to identify loops that - /// operates on sparse tensors more easily. + /// A optional string attribute that should be attached to the loop + /// generated by loop emitter, it might help following passes to identify + /// loops that operates on sparse tensors more easily. StringAttr loopTag; /// Whether the loop emitter needs to treat the last tensor as the output /// tensor. @@ -556,7 +770,8 @@ std::vector> idxBuffer; // to_indices std::vector valBuffer; // to_value - // Loop Stack, stores the information of all the nested loops that are alive. + // Loop Stack, stores the information of all the nested loops that are + // alive. std::vector loopStack; // Loop Sequence Stack, stores the unversial index for the current loop diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -90,6 +90,115 @@ return val; } +void sparse_tensor::foreachFieldInSparseTensor( + const SparseTensorEncodingAttr enc, + llvm::function_ref + callback) { + assert(enc); + +#define RETURN_ON_FALSE(idx, kind, dim, dlt) \ + if (!(callback(idx, kind, dim, dlt))) \ + return; + + RETURN_ON_FALSE(dimSizesIdx, SparseTensorFieldKind::DimSizes, -1u, + DimLevelType::Undef); + RETURN_ON_FALSE(memSizesIdx, SparseTensorFieldKind::MemSizes, -1u, + DimLevelType::Undef); + + static_assert(dataFieldIdx == memSizesIdx + 1); + unsigned fieldIdx = dataFieldIdx; + // Per-dimension storage. + for (unsigned r = 0, rank = enc.getDimLevelType().size(); r < rank; r++) { + // Dimension level types apply in order to the reordered dimension. + // As a result, the compound type can be constructed directly in the given + // order. + auto dlt = getDimLevelType(enc, r); + if (isCompressedDLT(dlt)) { + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PtrMemRef, r, dlt); + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); + } else if (isSingletonDLT(dlt)) { + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); + } else { + assert(isDenseDLT(dlt)); // no fields + } + } + // The values array. + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, -1u, + DimLevelType::Undef); + +#undef RETURN_ON_FALSE +} + +void sparse_tensor::foreachFieldAndTypeInSparseTensor( + RankedTensorType rType, + llvm::function_ref + callback) { + auto enc = getSparseTensorEncoding(rType); + assert(enc); + // Construct the basic types. + Type indexType = IndexType::get(enc.getContext()); + Type idxType = enc.getIndexType(); + Type ptrType = enc.getPointerType(); + Type eltType = rType.getElementType(); + unsigned rank = rType.getShape().size(); + // memref dimSizes + Type dimSizeType = MemRefType::get({rank}, indexType); + // memref memSizes + Type memSizeType = + MemRefType::get({getNumDataFieldsFromEncoding(enc)}, indexType); + // memref pointers + Type ptrMemType = MemRefType::get({ShapedType::kDynamic}, ptrType); + // memref indices + Type idxMemType = MemRefType::get({ShapedType::kDynamic}, idxType); + // memref values + Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType); + + foreachFieldInSparseTensor( + enc, + [dimSizeType, memSizeType, ptrMemType, idxMemType, valMemType, + callback](unsigned fieldIdx, SparseTensorFieldKind fieldKind, + unsigned dim, DimLevelType dlt) -> bool { + switch (fieldKind) { + case SparseTensorFieldKind::DimSizes: + return callback(dimSizeType, fieldIdx, fieldKind, dim, dlt); + case SparseTensorFieldKind::MemSizes: + return callback(memSizeType, fieldIdx, fieldKind, dim, dlt); + case SparseTensorFieldKind::PtrMemRef: + return callback(ptrMemType, fieldIdx, fieldKind, dim, dlt); + case SparseTensorFieldKind::IdxMemRef: + return callback(idxMemType, fieldIdx, fieldKind, dim, dlt); + case SparseTensorFieldKind::ValMemRef: + return callback(valMemType, fieldIdx, fieldKind, dim, dlt); + }; + }); +} + +unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) { + unsigned numFields = 0; + foreachFieldInSparseTensor(enc, + [&numFields](unsigned, SparseTensorFieldKind, + unsigned, DimLevelType) -> bool { + numFields++; + return true; + }); + return numFields; +} + +unsigned +sparse_tensor::getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) { + unsigned numFields = 0; // one value memref + foreachFieldInSparseTensor(enc, + [&numFields](unsigned fidx, SparseTensorFieldKind, + unsigned, DimLevelType) -> bool { + if (fidx >= dataFieldIdx) + numFields++; + return true; + }); + assert(numFields == getNumFieldsFromEncoding(enc) - dataFieldIdx); + return numFields; +} //===----------------------------------------------------------------------===// // Sparse tensor loop emitter class implementations //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -36,10 +36,6 @@ static constexpr const char kInsertFuncNamePrefix[] = "_insert_"; -static constexpr uint64_t dimSizesIdx = 0; -static constexpr uint64_t memSizesIdx = 1; -static constexpr uint64_t fieldsIdx = 2; - //===----------------------------------------------------------------------===// // Helper methods. //===----------------------------------------------------------------------===// @@ -49,6 +45,18 @@ return llvm::cast(tensor.getDefiningOp()); } +static SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) { + auto tuple = getTuple(tensor); + return SparseTensorDescriptor(tuple.getResultTypes()[0], tuple.getInputs()); +} + +static MutSparseTensorDescriptor +getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl &fields) { + auto tuple = getTuple(tensor); + fields.assign(tuple.getInputs().begin(), tuple.getInputs().end()); + return MutSparseTensorDescriptor(tuple.getResultTypes()[0], fields); +} + /// Packs the given values as a "tuple" value. static Value genTuple(OpBuilder &builder, Location loc, Type tp, ValueRange values) { @@ -56,6 +64,14 @@ .getResult(0); } +static Value genTuple(OpBuilder &builder, Location loc, + SparseTensorDescriptor desc) { + return builder + .create(loc, desc.getTensorType(), + desc.getFields()) + .getResult(0); +} + /// Flatten a list of operands that may contain sparse tensors. static void flattenOperands(ValueRange operands, SmallVectorImpl &flattened) { @@ -101,7 +117,7 @@ /// Creates a straightforward counting for-loop. static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, - SmallVectorImpl &fields, + MutableArrayRef fields, Value lower = Value()) { Type indexType = builder.getIndexType(); if (!lower) @@ -118,81 +134,46 @@ /// original dimension 'dim'. Returns None if no sparse encoding is /// attached to the given tensor type. static Optional sizeFromTensorAtDim(OpBuilder &builder, Location loc, - RankedTensorType tensorTp, - Value adaptedValue, unsigned dim) { - auto enc = getSparseTensorEncoding(tensorTp); - if (!enc) - return llvm::None; - + SparseTensorDescriptor desc, + unsigned dim) { + RankedTensorType rtp = desc.getTensorType(); // Access into static dimension can query original type directly. // Note that this is typically already done by DimOp's folding. - auto shape = tensorTp.getShape(); + auto shape = rtp.getShape(); if (!ShapedType::isDynamic(shape[dim])) return constantIndex(builder, loc, shape[dim]); // Any other query can consult the dimSizes array at field DimSizesIdx, // accounting for the reordering applied to the sparse storage. - auto tuple = getTuple(adaptedValue); - Value idx = constantIndex(builder, loc, toStoredDim(tensorTp, dim)); - return builder - .create(loc, tuple.getInputs()[dimSizesIdx], idx) + Value idx = constantIndex(builder, loc, toStoredDim(rtp, dim)); + return builder.create(loc, desc.getDimSizesMemRef(), idx) .getResult(); } // Gets the dimension size at the given stored dimension 'd', either as a // constant for a static size, or otherwise dynamically through memSizes. -Value sizeAtStoredDim(OpBuilder &builder, Location loc, RankedTensorType rtp, - SmallVectorImpl &fields, unsigned d) { +Value sizeAtStoredDim(OpBuilder &builder, Location loc, + SparseTensorDescriptor desc, unsigned d) { + RankedTensorType rtp = desc.getTensorType(); unsigned dim = toOrigDim(rtp, d); auto shape = rtp.getShape(); if (!ShapedType::isDynamic(shape[dim])) return constantIndex(builder, loc, shape[dim]); - return genLoad(builder, loc, fields[dimSizesIdx], - constantIndex(builder, loc, d)); -} -/// Translates field index to memSizes index. -static unsigned getMemSizesIndex(unsigned field) { - assert(fieldsIdx <= field); - return field - fieldsIdx; + return genLoad(builder, loc, desc.getDimSizesMemRef(), + constantIndex(builder, loc, d)); } -/// Creates a pushback op for given field and updates the fields array -/// accordingly. This operation also updates the memSizes contents. static void createPushback(OpBuilder &builder, Location loc, - SmallVectorImpl &fields, unsigned field, + MutSparseTensorDescriptor desc, unsigned fidx, Value value, Value repeat = Value()) { - assert(fieldsIdx <= field && field < fields.size()); - Type etp = fields[field].getType().cast().getElementType(); - fields[field] = builder.create( - loc, fields[field].getType(), fields[memSizesIdx], fields[field], - toType(builder, loc, value, etp), APInt(64, getMemSizesIndex(field)), + Type etp = desc.getElementType(fidx); + Value field = desc.getField(fidx); + Value newField = builder.create( + loc, field.getType(), desc.getMemSizesMemRef(), field, + toType(builder, loc, value, etp), APInt(64, getFieldMemSizesIndex(fidx)), repeat); -} - -/// Returns field index of sparse tensor type for pointers/indices, when set. -static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) { - assert(getSparseTensorEncoding(type)); - RankedTensorType rType = type.cast(); - unsigned field = fieldsIdx; // start past header - for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) { - if (isCompressedDim(rType, r)) { - if (r == ptrDim) - return field; - field++; - if (r == idxDim) - return field; - field++; - } else if (isSingletonDim(rType, r)) { - if (r == idxDim) - return field; - field++; - } else { - assert(isDenseDim(rType, r)); // no fields - } - } - assert(ptrDim == -1u && idxDim == -1u); - return field + 1; // return values field index + desc.setField(fidx, newField); } /// Maps a sparse tensor type to the appropriate compounded buffers. @@ -201,64 +182,24 @@ auto enc = getSparseTensorEncoding(type); if (!enc) return llvm::None; - // Construct the basic types. - auto *context = type.getContext(); + RankedTensorType rType = type.cast(); - Type indexType = IndexType::get(context); - Type idxType = enc.getIndexType(); - Type ptrType = enc.getPointerType(); - Type eltType = rType.getElementType(); - // - // Sparse tensor storage scheme for rank-dimensional tensor is organized - // as a single compound type with the following fields. Note that every - // memref with ? size actually behaves as a "vector", i.e. the stored - // size is the capacity and the used size resides in the memSizes array. - // - // struct { - // memref dimSizes ; size in each dimension - // memref memSizes ; sizes of ptrs/inds/values - // ; per-dimension d: - // ; if dense: - // - // ; if compresed: - // memref pointers-d ; pointers for sparse dim d - // memref indices-d ; indices for sparse dim d - // ; if singleton: - // memref indices-d ; indices for singleton dim d - // memref values ; values - // }; - // - unsigned rank = rType.getShape().size(); - unsigned lastField = getFieldIndex(type, -1u, -1u); - // The dimSizes array and memSizes array. - fields.push_back(MemRefType::get({rank}, indexType)); - fields.push_back(MemRefType::get({getMemSizesIndex(lastField)}, indexType)); - // Per-dimension storage. - for (unsigned r = 0; r < rank; r++) { - // Dimension level types apply in order to the reordered dimension. - // As a result, the compound type can be constructed directly in the given - // order. Clients of this type know what field is what from the sparse - // tensor type. - if (isCompressedDim(rType, r)) { - fields.push_back(MemRefType::get({ShapedType::kDynamic}, ptrType)); - fields.push_back(MemRefType::get({ShapedType::kDynamic}, idxType)); - } else if (isSingletonDim(rType, r)) { - fields.push_back(MemRefType::get({ShapedType::kDynamic}, idxType)); - } else { - assert(isDenseDim(rType, r)); // no fields - } - } - // The values array. - fields.push_back(MemRefType::get({ShapedType::kDynamic}, eltType)); - assert(fields.size() == lastField); + foreachFieldAndTypeInSparseTensor( + rType, + [&fields](Type fieldType, unsigned fieldIdx, + SparseTensorFieldKind /*fieldKind*/, unsigned /*dim*/, + DimLevelType /*dlt*/) -> bool { + assert(fieldIdx == fields.size()); + fields.push_back(fieldType); + return true; + }); return success(); } /// Generates code that allocates a sparse storage scheme for given rank. static void allocSchemeForRank(OpBuilder &builder, Location loc, - RankedTensorType rtp, - SmallVectorImpl &fields, unsigned field, - unsigned r0) { + MutSparseTensorDescriptor desc, unsigned r0) { + RankedTensorType rtp = desc.getTensorType(); unsigned rank = rtp.getShape().size(); Value linear = constantIndex(builder, loc, 1); for (unsigned r = r0; r < rank; r++) { @@ -268,7 +209,8 @@ // the desired "linear + 1" length property at all times. Type ptrType = getSparseTensorEncoding(rtp).getPointerType(); Value ptrZero = constantZero(builder, loc, ptrType); - createPushback(builder, loc, fields, field, ptrZero, linear); + createPushback(builder, loc, desc, desc.getPtrMemRefIndex(r), ptrZero, + linear); return; } if (isSingletonDim(rtp, r)) { @@ -278,23 +220,23 @@ // at this level. We will eventually reach a compressed level or // otherwise the values array for the from-here "all-dense" case. assert(isDenseDim(rtp, r)); - Value size = sizeAtStoredDim(builder, loc, rtp, fields, r); + Value size = sizeAtStoredDim(builder, loc, desc, r); linear = builder.create(loc, linear, size); } // Reached values array so prepare for an insertion. Value valZero = constantZero(builder, loc, rtp.getElementType()); - createPushback(builder, loc, fields, field, valZero, linear); - assert(fields.size() == ++field); + createPushback(builder, loc, desc, desc.getValMemRefIndex(), valZero, linear); } /// Creates allocation operation. -static Value createAllocation(OpBuilder &builder, Location loc, Type type, - Value sz, bool enableInit) { - auto memType = MemRefType::get({ShapedType::kDynamic}, type); - Value buffer = builder.create(loc, memType, sz); +static Value createAllocation(OpBuilder &builder, Location loc, + MemRefType memRefType, Value sz, + bool enableInit) { + Value buffer = builder.create(loc, memRefType, sz); + Type elemType = memRefType.getElementType(); if (enableInit) { - Value fillValue = - builder.create(loc, type, builder.getZeroAttr(type)); + Value fillValue = builder.create( + loc, elemType, builder.getZeroAttr(elemType)); builder.create(loc, fillValue, buffer); } return buffer; @@ -310,69 +252,68 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type, ValueRange dynSizes, bool enableInit, SmallVectorImpl &fields) { - auto enc = getSparseTensorEncoding(type); - assert(enc); RankedTensorType rtp = type.cast(); - Type indexType = builder.getIndexType(); - Type idxType = enc.getIndexType(); - Type ptrType = enc.getPointerType(); - Type eltType = rtp.getElementType(); - auto shape = rtp.getShape(); - unsigned rank = shape.size(); Value heuristic = constantIndex(builder, loc, 16); + + foreachFieldAndTypeInSparseTensor( + rtp, + [&builder, &fields, loc, heuristic, + enableInit](Type fType, unsigned fIdx, SparseTensorFieldKind fKind, + unsigned /*dim*/, DimLevelType /*dlt*/) -> bool { + assert(fields.size() == fIdx); + auto memRefTp = fType.cast(); + Value field; + switch (fKind) { + case SparseTensorFieldKind::DimSizes: + case SparseTensorFieldKind::MemSizes: + field = builder.create(loc, memRefTp); + break; + case SparseTensorFieldKind::PtrMemRef: + case SparseTensorFieldKind::IdxMemRef: + case SparseTensorFieldKind::ValMemRef: + field = + createAllocation(builder, loc, memRefTp, heuristic, enableInit); + break; + } + assert(field); + fields.push_back(field); + // Returns true to continue the iteration. + return true; + }); + + MutSparseTensorDescriptor desc(rtp, fields); + // Build original sizes. SmallVector sizes; + auto shape = rtp.getShape(); + unsigned rank = shape.size(); for (unsigned r = 0, o = 0; r < rank; r++) { if (ShapedType::isDynamic(shape[r])) sizes.push_back(dynSizes[o++]); else sizes.push_back(constantIndex(builder, loc, shape[r])); } - // The dimSizes array and memSizes array. - unsigned lastField = getFieldIndex(type, -1u, -1u); - Value dimSizes = - builder.create(loc, MemRefType::get({rank}, indexType)); - Value memSizes = builder.create( - loc, MemRefType::get({getMemSizesIndex(lastField)}, indexType)); - fields.push_back(dimSizes); - fields.push_back(memSizes); - // Per-dimension storage. - for (unsigned r = 0; r < rank; r++) { - if (isCompressedDim(rtp, r)) { - fields.push_back( - createAllocation(builder, loc, ptrType, heuristic, enableInit)); - fields.push_back( - createAllocation(builder, loc, idxType, heuristic, enableInit)); - } else if (isSingletonDim(rtp, r)) { - fields.push_back( - createAllocation(builder, loc, idxType, heuristic, enableInit)); - } else { - assert(isDenseDim(rtp, r)); // no fields - } - } - // The values array. - fields.push_back( - createAllocation(builder, loc, eltType, heuristic, enableInit)); - assert(fields.size() == lastField); // Initialize the storage scheme to an empty tensor. Initialized memSizes // to all zeros, sets the dimSizes to known values and gives all pointer // fields an initial zero entry, so that it is easier to maintain the // "linear + 1" length property. builder.create( - loc, ValueRange{constantZero(builder, loc, indexType)}, - ValueRange{memSizes}); // zero memSizes - Value ptrZero = constantZero(builder, loc, ptrType); - for (unsigned r = 0, field = fieldsIdx; r < rank; r++) { + loc, constantZero(builder, loc, builder.getIndexType()), + desc.getMemSizesMemRef()); // zero memSizes + + Value ptrZero = + constantZero(builder, loc, getSparseTensorEncoding(rtp).getPointerType()); + for (unsigned r = 0; r < rank; r++) { unsigned ro = toOrigDim(rtp, r); - genStore(builder, loc, sizes[ro], dimSizes, constantIndex(builder, loc, r)); - if (isCompressedDim(rtp, r)) { - createPushback(builder, loc, fields, field, ptrZero); - field += 2; - } else if (isSingletonDim(rtp, r)) { - field += 1; - } + // Fills dim sizes array. + genStore(builder, loc, sizes[ro], desc.getDimSizesMemRef(), + constantIndex(builder, loc, r)); + + // Pushes a leading zero to pointers memref. + if (isCompressedDim(rtp, r)) + createPushback(builder, loc, desc, desc.getPtrMemRefIndex(r), ptrZero); } - allocSchemeForRank(builder, loc, rtp, fields, fieldsIdx, /*rank=*/0); + allocSchemeForRank(builder, loc, desc, /*rank=*/0); } /// Helper method that generates block specific to compressed case: @@ -396,19 +337,22 @@ /// } /// pos[d] = next static Value genCompressed(OpBuilder &builder, Location loc, - RankedTensorType rtp, SmallVectorImpl &fields, + MutSparseTensorDescriptor desc, SmallVectorImpl &indices, Value value, - Value pos, unsigned field, unsigned d) { + Value pos, unsigned d) { + RankedTensorType rtp = desc.getTensorType(); unsigned rank = rtp.getShape().size(); SmallVector types; Type indexType = builder.getIndexType(); Type boolType = builder.getIntegerType(1); + unsigned idxIndex = desc.getIdxMemRefIndex(d); + unsigned ptrIndex = desc.getPtrMemRefIndex(d); Value one = constantIndex(builder, loc, 1); Value pp1 = builder.create(loc, pos, one); - Value plo = genLoad(builder, loc, fields[field], pos); - Value phi = genLoad(builder, loc, fields[field], pp1); - Value psz = constantIndex(builder, loc, getMemSizesIndex(field + 1)); - Value msz = genLoad(builder, loc, fields[memSizesIdx], psz); + Value plo = genLoad(builder, loc, desc.getField(ptrIndex), pos); + Value phi = genLoad(builder, loc, desc.getField(ptrIndex), pp1); + Value psz = constantIndex(builder, loc, getFieldMemSizesIndex(idxIndex)); + Value msz = genLoad(builder, loc, desc.getMemSizesMemRef(), psz); Value phim1 = builder.create( loc, toType(builder, loc, phi, indexType), one); // Conditional expression. @@ -418,49 +362,55 @@ scf::IfOp ifOp1 = builder.create(loc, types, lt, /*else*/ true); types.pop_back(); builder.setInsertionPointToStart(&ifOp1.getThenRegion().front()); - Value crd = genLoad(builder, loc, fields[field + 1], phim1); + Value crd = genLoad(builder, loc, desc.getField(idxIndex), phim1); Value eq = builder.create(loc, arith::CmpIPredicate::eq, toType(builder, loc, crd, indexType), indices[d]); builder.create(loc, eq); builder.setInsertionPointToStart(&ifOp1.getElseRegion().front()); if (d > 0) - genStore(builder, loc, msz, fields[field], pos); + genStore(builder, loc, msz, desc.getField(ptrIndex), pos); builder.create(loc, constantI1(builder, loc, false)); builder.setInsertionPointAfter(ifOp1); Value p = ifOp1.getResult(0); - // If present construct. Note that for a non-unique dimension level, we simply - // set the condition to false and rely on CSE/DCE to clean up the IR. + // If present construct. Note that for a non-unique dimension level, we + // simply set the condition to false and rely on CSE/DCE to clean up the IR. // // TODO: generate less temporary IR? // - for (unsigned i = 0, e = fields.size(); i < e; i++) - types.push_back(fields[i].getType()); + for (unsigned i = 0, e = desc.getNumFields(); i < e; i++) + types.push_back(desc.getField(i).getType()); types.push_back(indexType); if (!isUniqueDim(rtp, d)) p = constantI1(builder, loc, false); scf::IfOp ifOp2 = builder.create(loc, types, p, /*else*/ true); // If present (fields unaffected, update next to phim1). builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); - fields.push_back(phim1); - builder.create(loc, fields); - fields.pop_back(); + + // FIXME: This does not looks like a clean way, but probably the most + // efficient way. + desc.getFields().push_back(phim1); + builder.create(loc, desc.getFields()); + desc.getFields().pop_back(); + // If !present (changes fields, update next). builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); Value mszp1 = builder.create(loc, msz, one); - genStore(builder, loc, mszp1, fields[field], pp1); - createPushback(builder, loc, fields, field + 1, indices[d]); + genStore(builder, loc, mszp1, desc.getField(ptrIndex), pp1); + createPushback(builder, loc, desc, idxIndex, indices[d]); // Prepare the next dimension "as needed". if ((d + 1) < rank) - allocSchemeForRank(builder, loc, rtp, fields, field + 2, d + 1); - fields.push_back(msz); - builder.create(loc, fields); - fields.pop_back(); + allocSchemeForRank(builder, loc, desc, d + 1); + + desc.getFields().push_back(msz); + builder.create(loc, desc.getFields()); + desc.getFields().pop_back(); + // Update fields and return next pos. builder.setInsertionPointAfter(ifOp2); unsigned o = 0; - for (unsigned i = 0, e = fields.size(); i < e; i++) - fields[i] = ifOp2.getResult(o++); + for (unsigned i = 0, e = desc.getNumFields(); i < e; i++) + desc.setField(i, ifOp2.getResult(o++)); return ifOp2.getResult(o); } @@ -488,11 +438,10 @@ // Construct fields and indices arrays from parameters. ValueRange tmp = args.drop_back(rank + 1); SmallVector fields(tmp.begin(), tmp.end()); + MutSparseTensorDescriptor desc(rtp, fields); tmp = args.take_back(rank + 1).drop_back(); SmallVector indices(tmp.begin(), tmp.end()); Value value = args.back(); - - unsigned field = fieldsIdx; // Start past header. Value pos = constantZero(builder, loc, builder.getIndexType()); // Generate code for every dimension. for (unsigned d = 0; d < rank; d++) { @@ -504,39 +453,35 @@ // } // pos[d] = indices.size() - 1 // - pos = genCompressed(builder, loc, rtp, fields, indices, value, pos, field, - d); - field += 2; + pos = genCompressed(builder, loc, desc, indices, value, pos, d); } else if (isSingletonDim(rtp, d)) { // Create: // indices[d].push_back(i[d]) // pos[d] = pos[d-1] // - createPushback(builder, loc, fields, field, indices[d]); - field += 1; + createPushback(builder, loc, desc, desc.getIdxMemRefIndex(d), indices[d]); } else { assert(isDenseDim(rtp, d)); // Construct the new position as: // pos[d] = size * pos[d-1] + i[d] // - Value size = sizeAtStoredDim(builder, loc, rtp, fields, d); + Value size = sizeAtStoredDim(builder, loc, desc, d); Value mult = builder.create(loc, size, pos); pos = builder.create(loc, mult, indices[d]); } } // Reached the actual value append/insert. if (!isDenseDim(rtp, rank - 1)) - createPushback(builder, loc, fields, field++, value); + createPushback(builder, loc, desc, desc.getValMemRefIndex(), value); else - genStore(builder, loc, value, fields[field++], pos); - assert(fields.size() == field); + genStore(builder, loc, value, desc.getValMemRef(), pos); builder.create(loc, fields); } /// Generates a call to a function to perform an insertion operation. If the /// function doesn't exist yet, call `createFunc` to generate the function. -static void genInsertionCallHelper(OpBuilder &builder, RankedTensorType rtp, - SmallVectorImpl &fields, +static void genInsertionCallHelper(OpBuilder &builder, + MutSparseTensorDescriptor desc, SmallVectorImpl &indices, Value value, func::FuncOp insertPoint, StringRef namePrefix, @@ -544,6 +489,7 @@ // The mangled name of the function has this format: // _[C|S|D]___ // __ + RankedTensorType rtp = desc.getTensorType(); SmallString<32> nameBuffer; llvm::raw_svector_ostream nameOstream(nameBuffer); nameOstream << namePrefix; @@ -577,7 +523,7 @@ auto func = module.lookupSymbol(result.getAttr()); // Construct parameters for fields and indices. - SmallVector operands(fields.begin(), fields.end()); + SmallVector operands(desc.getFields().begin(), desc.getFields().end()); operands.append(indices.begin(), indices.end()); operands.push_back(value); Location loc = insertPoint.getLoc(); @@ -590,7 +536,7 @@ func = builder.create( loc, nameOstream.str(), FunctionType::get(context, ValueRange(operands).getTypes(), - ValueRange(fields).getTypes())); + ValueRange(desc.getFields()).getTypes())); func.setPrivate(); createFunc(builder, module, func, rtp); } @@ -598,42 +544,44 @@ // Generate a call to perform the insertion and update `fields` with values // returned from the call. func::CallOp call = builder.create(loc, func, operands); - for (size_t i = 0; i < fields.size(); i++) { - fields[i] = call.getResult(i); + for (size_t i = 0, e = desc.getNumFields(); i < e; i++) { + desc.getFields()[i] = call.getResult(i); } } /// Generations insertion finalization code. -static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp, - SmallVectorImpl &fields) { +static void genEndInsert(OpBuilder &builder, Location loc, + MutSparseTensorDescriptor desc) { + RankedTensorType rtp = desc.getTensorType(); unsigned rank = rtp.getShape().size(); - unsigned field = fieldsIdx; // start past header for (unsigned d = 0; d < rank; d++) { if (isCompressedDim(rtp, d)) { // Compressed dimensions need a pointer cleanup for all entries // that were not visited during the insertion pass. // - // TODO: avoid cleanup and keep compressed scheme consistent at all times? + // TODO: avoid cleanup and keep compressed scheme consistent at all + // times? // if (d > 0) { Type ptrType = getSparseTensorEncoding(rtp).getPointerType(); - Value mz = constantIndex(builder, loc, getMemSizesIndex(field)); - Value hi = genLoad(builder, loc, fields[memSizesIdx], mz); + Value ptrMemRef = desc.getPtrMemRef(d); + Value mz = constantIndex(builder, loc, desc.getPtrMemSizesIndex(d)); + Value hi = genLoad(builder, loc, desc.getMemSizesMemRef(), mz); Value zero = constantIndex(builder, loc, 0); Value one = constantIndex(builder, loc, 1); // Vector of only one, but needed by createFor's prototype. - SmallVector inits{genLoad(builder, loc, fields[field], zero)}; + SmallVector inits{genLoad(builder, loc, ptrMemRef, zero)}; scf::ForOp loop = createFor(builder, loc, hi, inits, one); Value i = loop.getInductionVar(); Value oldv = loop.getRegionIterArg(0); - Value newv = genLoad(builder, loc, fields[field], i); + Value newv = genLoad(builder, loc, ptrMemRef, i); Value ptrZero = constantZero(builder, loc, ptrType); Value cond = builder.create( loc, arith::CmpIPredicate::eq, newv, ptrZero); scf::IfOp ifOp = builder.create(loc, TypeRange(ptrType), cond, /*else*/ true); builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - genStore(builder, loc, oldv, fields[field], i); + genStore(builder, loc, oldv, ptrMemRef, i); builder.create(loc, oldv); builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); builder.create(loc, newv); @@ -641,14 +589,10 @@ builder.create(loc, ifOp.getResult(0)); builder.setInsertionPointAfter(loop); } - field += 2; - } else if (isSingletonDim(rtp, d)) { - field++; } else { - assert(isDenseDim(rtp, d)); + assert(isDenseDim(rtp, d) || isSingletonDim(rtp, d)); } } - assert(fields.size() == ++field); } //===----------------------------------------------------------------------===// @@ -739,12 +683,12 @@ matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Optional index = op.getConstantIndex(); - if (!index) + if (!index || !getSparseTensorEncoding(adaptor.getSource().getType())) return failure(); - auto sz = - sizeFromTensorAtDim(rewriter, op.getLoc(), - op.getSource().getType().cast(), - adaptor.getSource(), *index); + + auto desc = getDescriptorFromTensorTuple(adaptor.getSource()); + auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *index); + if (!sz) return failure(); @@ -834,16 +778,14 @@ LogicalResult matchAndRewrite(LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - RankedTensorType srcType = - op.getTensor().getType().cast(); - auto tuple = getTuple(adaptor.getTensor()); - // Prepare fields. - SmallVector fields(tuple.getInputs()); + // Prepare descriptor. + SmallVector fields; + auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields); // Generate optional insertion finalization code. if (op.getHasInserts()) - genEndInsert(rewriter, op.getLoc(), srcType, fields); + genEndInsert(rewriter, op.getLoc(), desc); // Replace operation with resulting memrefs. - rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), srcType, fields)); + rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc)); return success(); } }; @@ -855,7 +797,10 @@ LogicalResult matchAndRewrite(ExpandOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + if (!getSparseTensorEncoding(op.getTensor().getType())) + return failure(); Location loc = op->getLoc(); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); RankedTensorType srcType = op.getTensor().getType().cast(); Type eltType = srcType.getElementType(); @@ -867,8 +812,7 @@ // dimension size, translated back to original dimension). Note that we // recursively rewrite the new DimOp on the **original** tensor. unsigned innerDim = toOrigDim(srcType, srcType.getRank() - 1); - auto sz = sizeFromTensorAtDim(rewriter, loc, srcType, adaptor.getTensor(), - innerDim); + auto sz = sizeFromTensorAtDim(rewriter, loc, desc, innerDim); assert(sz); // This for sure is a sparse tensor // Generate a memref for `sz` elements of type `t`. auto genAlloc = [&](Type t) { @@ -908,16 +852,15 @@ matchAndRewrite(CompressOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - RankedTensorType dstType = - op.getTensor().getType().cast(); - Type eltType = dstType.getElementType(); - auto tuple = getTuple(adaptor.getTensor()); + SmallVector fields; + auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields); Value values = adaptor.getValues(); Value filled = adaptor.getFilled(); Value added = adaptor.getAdded(); Value count = adaptor.getCount(); - // Prepare fields and indices. - SmallVector fields(tuple.getInputs()); + RankedTensorType dstType = desc.getTensorType(); + Type eltType = dstType.getElementType(); + // Prepare indices. SmallVector indices(adaptor.getIndices()); // If the innermost dimension is ordered, we need to sort the indices // in the "added" array prior to applying the compression. @@ -939,19 +882,19 @@ // filled[index] = false; // yield new_memrefs // } - scf::ForOp loop = createFor(rewriter, loc, count, fields); + scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields()); Value i = loop.getInductionVar(); Value index = genLoad(rewriter, loc, added, i); Value value = genLoad(rewriter, loc, values, index); indices.push_back(index); // TODO: faster for subsequent insertions? auto insertPoint = op->template getParentOfType(); - genInsertionCallHelper(rewriter, dstType, fields, indices, value, - insertPoint, kInsertFuncNamePrefix, genInsertBody); + genInsertionCallHelper(rewriter, desc, indices, value, insertPoint, + kInsertFuncNamePrefix, genInsertBody); genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values, index); genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, index); - rewriter.create(loc, fields); + rewriter.create(loc, desc.getFields()); rewriter.setInsertionPointAfter(loop); Value result = genTuple(rewriter, loc, dstType, loop->getResults()); // Deallocate the buffers on exit of the full loop nest. @@ -973,20 +916,18 @@ LogicalResult matchAndRewrite(InsertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - RankedTensorType dstType = - op.getTensor().getType().cast(); - auto tuple = getTuple(adaptor.getTensor()); - // Prepare fields and indices. - SmallVector fields(tuple.getInputs()); + SmallVector fields; + auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields); + // Prepare and indices. SmallVector indices(adaptor.getIndices()); // Generate insertion. Value value = adaptor.getValue(); auto insertPoint = op->template getParentOfType(); - genInsertionCallHelper(rewriter, dstType, fields, indices, value, - insertPoint, kInsertFuncNamePrefix, genInsertBody); + genInsertionCallHelper(rewriter, desc, indices, value, insertPoint, + kInsertFuncNamePrefix, genInsertBody); // Replace operation with resulting memrefs. - rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), dstType, fields)); + rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc)); return success(); } }; @@ -1003,11 +944,9 @@ // Replace the requested pointer access with corresponding field. // The cast_op is inserted by type converter to intermix 1:N type // conversion. - auto tuple = getTuple(adaptor.getTensor()); - unsigned idx = Base::getIndexForOp(tuple, op); - auto fields = tuple.getInputs(); - assert(idx < fields.size()); - rewriter.replaceOp(op, fields[idx]); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + Value field = Base::getFieldForOp(desc, op); + rewriter.replaceOp(op, field); return success(); } }; @@ -1018,10 +957,10 @@ public: using SparseGetterOpConverter::SparseGetterOpConverter; // Callback for SparseGetterOpConverter. - static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/, - ToPointersOp op) { + static Value getFieldForOp(const SparseTensorDescriptor &desc, + ToPointersOp op) { uint64_t dim = op.getDimension().getZExtValue(); - return getFieldIndex(op.getTensor().getType(), /*ptrDim=*/dim, -1u); + return desc.getPtrMemRef(dim); } }; @@ -1031,10 +970,10 @@ public: using SparseGetterOpConverter::SparseGetterOpConverter; // Callback for SparseGetterOpConverter. - static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/, - ToIndicesOp op) { + static Value getFieldForOp(const SparseTensorDescriptor &desc, + ToIndicesOp op) { uint64_t dim = op.getDimension().getZExtValue(); - return getFieldIndex(op.getTensor().getType(), -1u, /*idxDim=*/dim); + return desc.getIdxMemRef(dim); } }; @@ -1044,10 +983,9 @@ public: using SparseGetterOpConverter::SparseGetterOpConverter; // Callback for SparseGetterOpConverter. - static unsigned getIndexForOp(UnrealizedConversionCastOp tuple, - ToValuesOp /*op*/) { - // The last field holds the value buffer. - return tuple.getInputs().size() - 1; + static Value getFieldForOp(const SparseTensorDescriptor &desc, + ToValuesOp /*op*/) { + return desc.getValMemRef(); } }; @@ -1079,12 +1017,11 @@ matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Query memSizes for the actually stored values size. - auto tuple = getTuple(adaptor.getTensor()); - auto fields = tuple.getInputs(); - unsigned lastField = fields.size() - 1; + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); Value field = - constantIndex(rewriter, op.getLoc(), getMemSizesIndex(lastField)); - rewriter.replaceOpWithNewOp(op, fields[memSizesIdx], field); + constantIndex(rewriter, op.getLoc(), desc.getValMemSizesIndex()); + rewriter.replaceOpWithNewOp(op, desc.getMemSizesMemRef(), + field); return success(); } };