diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -75,6 +75,21 @@ return isSingletonDLT(getDimLevelType(type, d)); } +/// Convenience function to test for dense dimension (0 <= d < rank). +inline bool isDenseDim(SparseTensorEncodingAttr enc, uint64_t d) { + return isDenseDLT(getDimLevelType(enc, d)); +} + +/// Convenience function to test for compressed dimension (0 <= d < rank). +inline bool isCompressedDim(SparseTensorEncodingAttr enc, uint64_t d) { + return isCompressedDLT(getDimLevelType(enc, d)); +} + +/// Convenience function to test for singleton dimension (0 <= d < rank). +inline bool isSingletonDim(SparseTensorEncodingAttr enc, uint64_t d) { + return isSingletonDLT(getDimLevelType(enc, d)); +} + // // Dimension level properties. // diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -310,9 +310,135 @@ return !rtp || rtp.getRank() == 0; } +enum class SparseTensorFieldKind { + DimSizes, + MemSizes, + PtrMemRef, + IdxMemRef, + ValMemRef +}; + +//===----------------------------------------------------------------------===// +// SparseTensorDescriptor and helpers, manage the sparse tensor memory layout +// scheme. +//===----------------------------------------------------------------------===// + +constexpr uint64_t dimSizesIdx = 0; +constexpr uint64_t memSizesIdx = dimSizesIdx + 1; +constexpr uint64_t dataFieldIdx = memSizesIdx + 1; + +void foreachFieldInSparseTensor( + SparseTensorEncodingAttr, + llvm::function_ref); + +void foreachFieldAndTypeInSparseTensor( + RankedTensorType, + llvm::function_ref); + +unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc); +unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc); + +inline unsigned getFieldMemSizesIndex(unsigned fid) { + assert(fid >= dataFieldIdx); + return fid - dataFieldIdx; +} + +class SparseTensorDescriptor { +public: + SparseTensorDescriptor(Type tp, ValueRange fields) + : rType(tp.cast()), fields(fields) { + assert(getSparseTensorEncoding(tp) && + getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) == + fields.size()); + } + + unsigned getFieldIndex(unsigned dim, SparseTensorFieldKind kind) const { + unsigned fieldIdx = -1u; + foreachFieldInSparseTensor( + getSparseTensorEncoding(rType), + [dim, kind, &fieldIdx](unsigned fIdx, SparseTensorFieldKind fKind, + unsigned fDim, DimLevelType dlt) -> bool { + if (fDim == dim && kind == fKind) { + fieldIdx = fIdx; + // Returns false to break the iteration. + return false; + } + return true; + }); + assert(fieldIdx != -1u); + return fieldIdx; + } + + unsigned getPtrMemRefIndex(unsigned ptrDim) const { + return getFieldIndex(ptrDim, SparseTensorFieldKind::PtrMemRef); + } + + unsigned getIdxMemRefIndex(unsigned idxDim) const { + return getFieldIndex(idxDim, SparseTensorFieldKind::IdxMemRef); + } + + unsigned getValMemRefIndex() const { return fields.size() - 1; } + + unsigned getPtrMemSizesIndex(unsigned dim) const { + return getPtrMemRefIndex(dim) - dataFieldIdx; + } + + unsigned getIdxMemSizesIndex(unsigned dim) const { + return getIdxMemRefIndex(dim) - dataFieldIdx; + } + + unsigned getValMemSizesIndex() const { + return getValMemRefIndex() - dataFieldIdx; + } + + unsigned getNumFields() const { return fields.size(); } + + Value getDimSizesMemRef() const { return fields[dimSizesIdx]; } + Value getMemSizesMemRef() const { return fields[memSizesIdx]; } + + Value getPtrMemRef(unsigned ptrDim) const { + return fields[getPtrMemRefIndex(ptrDim)]; + } + + Value getIdxMemRef(unsigned idxDim) const { + return fields[getIdxMemRefIndex(idxDim)]; + } + + Value getValMemRef() const { return fields[getValMemRefIndex()]; } + + Value getField(unsigned fid) const { + assert(fid < fields.size()); + return fields[fid]; + } + + Type getPtrElementType() const { + auto *ctx = rType.getContext(); + unsigned ptrWidth = getSparseTensorEncoding(rType).getPointerBitWidth(); + Type indexType = IndexType::get(ctx); + return ptrWidth ? IntegerType::get(ctx, ptrWidth) : indexType; + } + + Type getIdxElementType() const { + auto *ctx = rType.getContext(); + unsigned idxWidth = getSparseTensorEncoding(rType).getIndexBitWidth(); + Type indexType = IndexType::get(ctx); + return idxWidth ? IntegerType::get(ctx, idxWidth) : indexType; + } + +private: + RankedTensorType rType; + ValueRange fields; +}; + //===----------------------------------------------------------------------===// -// SparseTensorLoopEmiter class, manages sparse tensors and helps to generate -// loop structure to (co)-iterate sparse tensors. +// SparseTensorLoopEmiter class, manages sparse tensors and helps to +// generate loop structure to (co)-iterate sparse tensors. // // An example usage: // To generate the following loops over T1 and T2 @@ -345,15 +471,15 @@ using OutputUpdater = function_ref; - /// Constructor: take an array of tensors inputs, on which the generated loops - /// will iterate on. The index of the tensor in the array is also the + /// Constructor: take an array of tensors inputs, on which the generated + /// loops will iterate on. The index of the tensor in the array is also the /// tensor id (tid) used in related functions. /// If isSparseOut is set, loop emitter assume that the sparse output tensor /// is empty, and will always generate loops on it based on the dim sizes. /// An optional array could be provided (by sparsification) to indicate the /// loop id sequence that will be generated. It is used to establish the - /// mapping between affineDimExpr to the corresponding loop index in the loop - /// stack that are maintained by the loop emitter. + /// mapping between affineDimExpr to the corresponding loop index in the + /// loop stack that are maintained by the loop emitter. explicit SparseTensorLoopEmitter(ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false, @@ -368,8 +494,8 @@ /// Generates a list of operations to compute the affine expression. Value genAffine(OpBuilder &builder, AffineExpr a, Location loc); - /// Enters a new loop sequence, the loops within the same sequence starts from - /// the break points of previous loop instead of starting over from 0. + /// Enters a new loop sequence, the loops within the same sequence starts + /// from the break points of previous loop instead of starting over from 0. /// e.g., /// { /// // loop sequence start. @@ -524,10 +650,10 @@ /// scf.reduce.return %val /// } /// } - /// NOTE: only one instruction will be moved into reduce block, transformation - /// will fail if multiple instructions are used to compute the reduction - /// value. - /// Return %ret to user, while %val is provided by users (`reduc`). + /// NOTE: only one instruction will be moved into reduce block, + /// transformation will fail if multiple instructions are used to compute + /// the reduction value. Return %ret to user, while %val is provided by + /// users (`reduc`). void exitForLoop(RewriterBase &rewriter, Location loc, MutableArrayRef reduc); @@ -535,9 +661,9 @@ void exitCoIterationLoop(OpBuilder &builder, Location loc, MutableArrayRef reduc); - /// A optional string attribute that should be attached to the loop generated - /// by loop emitter, it might help following passes to identify loops that - /// operates on sparse tensors more easily. + /// A optional string attribute that should be attached to the loop + /// generated by loop emitter, it might help following passes to identify + /// loops that operates on sparse tensors more easily. StringAttr loopTag; /// Whether the loop emitter needs to treat the last tensor as the output /// tensor. @@ -556,7 +682,8 @@ std::vector> idxBuffer; // to_indices std::vector valBuffer; // to_value - // Loop Stack, stores the information of all the nested loops that are alive. + // Loop Stack, stores the information of all the nested loops that are + // alive. std::vector loopStack; // Loop Sequence Stack, stores the unversial index for the current loop diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -90,6 +90,139 @@ return val; } +void sparse_tensor::foreachFieldInSparseTensor( + const SparseTensorEncodingAttr enc, + llvm::function_ref + callback) { + // + // Sparse tensor storage scheme for rank-dimensional tensor is organized + // as a single compound type with the following fields. Note that every + // memref with ? size actually behaves as a "vector", i.e. the stored + // size is the capacity and the used size resides in the memSizes array. + // + // struct { + // memref dimSizes ; size in each dimension + // memref memSizes ; sizes of ptrs/inds/values + // ; per-dimension d: + // ; if dense: + // + // ; if compresed: + // memref pointers-d ; pointers for sparse dim d + // memref indices-d ; indices for sparse dim d + // ; if singleton: + // memref indices-d ; indices for singleton dim d + // memref values ; values + // }; + // + // The dimSizes array and memSizes array. + assert(enc); + +#define RETURN_ON_FALSE(idx, kind, dim, dlt) \ + if (!(callback(idx, kind, dim, dlt))) \ + return; + + RETURN_ON_FALSE(dimSizesIdx, SparseTensorFieldKind::DimSizes, -1u, + DimLevelType::Undef); + RETURN_ON_FALSE(memSizesIdx, SparseTensorFieldKind::MemSizes, -1u, + DimLevelType::Undef); + + static_assert(dataFieldIdx == memSizesIdx + 1); + unsigned fieldIdx = dataFieldIdx; + // Per-dimension storage. + for (unsigned r = 0, rank = enc.getDimLevelType().size(); r < rank; r++) { + // Dimension level types apply in order to the reordered dimension. + // As a result, the compound type can be constructed directly in the given + // order. + auto dlt = getDimLevelType(enc, r); + if (isCompressedDLT(dlt)) { + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PtrMemRef, r, dlt); + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); + } else if (isSingletonDLT(dlt)) { + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); + } else { + assert(isDenseDLT(dlt)); // no fields + } + } + // The values array. + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, -1u, + DimLevelType::Undef); + +#undef RETURN_ON_FALSE +} + +void sparse_tensor::foreachFieldAndTypeInSparseTensor( + RankedTensorType rType, + llvm::function_ref + callback) { + auto enc = getSparseTensorEncoding(rType); + assert(enc); + // Construct the basic types. + auto *context = rType.getContext(); + unsigned idxWidth = enc.getIndexBitWidth(); + unsigned ptrWidth = enc.getPointerBitWidth(); + Type indexType = IndexType::get(context); + Type idxType = idxWidth ? IntegerType::get(context, idxWidth) : indexType; + Type ptrType = ptrWidth ? IntegerType::get(context, ptrWidth) : indexType; + Type eltType = rType.getElementType(); + unsigned rank = rType.getShape().size(); + // memref dimSizes + Type dimSizeType = MemRefType::get({rank}, indexType); + // memref memSizes + Type memSizeType = + MemRefType::get({getNumDataFieldsFromEncoding(enc)}, indexType); + // memref pointers + Type ptrMemType = MemRefType::get({ShapedType::kDynamic}, ptrType); + // memref indices + Type idxMemType = MemRefType::get({ShapedType::kDynamic}, idxType); + // memref values + Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType); + + foreachFieldInSparseTensor( + enc, + [dimSizeType, memSizeType, ptrMemType, idxMemType, valMemType, + callback](unsigned fieldIdx, SparseTensorFieldKind fieldKind, + unsigned dim, DimLevelType dlt) -> bool { + switch (fieldKind) { + case SparseTensorFieldKind::DimSizes: + return callback(dimSizeType, fieldIdx, fieldKind, dim, dlt); + case SparseTensorFieldKind::MemSizes: + return callback(memSizeType, fieldIdx, fieldKind, dim, dlt); + case SparseTensorFieldKind::PtrMemRef: + return callback(ptrMemType, fieldIdx, fieldKind, dim, dlt); + case SparseTensorFieldKind::IdxMemRef: + return callback(idxMemType, fieldIdx, fieldKind, dim, dlt); + case SparseTensorFieldKind::ValMemRef: + return callback(valMemType, fieldIdx, fieldKind, dim, dlt); + }; + }); +} + +unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) { + unsigned numFields = 0; + foreachFieldInSparseTensor(enc, + [&numFields](unsigned, SparseTensorFieldKind, + unsigned, DimLevelType) -> bool { + numFields++; + return true; + }); + return numFields; +} + +unsigned +sparse_tensor::getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) { + unsigned numFields = 1; // one value memref + foreachFieldInSparseTensor(enc, + [&numFields](unsigned, SparseTensorFieldKind, + unsigned, DimLevelType dlt) -> bool { + if (!isUndefDLT(dlt)) + numFields++; + return true; + }); + assert(numFields == getNumFieldsFromEncoding(enc) - dataFieldIdx); + return numFields; +} //===----------------------------------------------------------------------===// // Sparse tensor loop emitter class implementations //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -30,11 +30,6 @@ using namespace mlir::sparse_tensor; namespace { - -static constexpr uint64_t dimSizesIdx = 0; -static constexpr uint64_t memSizesIdx = 1; -static constexpr uint64_t fieldsIdx = 2; - //===----------------------------------------------------------------------===// // Helper methods. //===----------------------------------------------------------------------===// @@ -44,6 +39,11 @@ return llvm::cast(tensor.getDefiningOp()); } +static SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) { + auto tuple = getTuple(tensor); + return SparseTensorDescriptor(tuple.getResultTypes()[0], tuple.getInputs()); +} + /// Packs the given values as a "tuple" value. static Value genTuple(OpBuilder &builder, Location loc, Type tp, ValueRange values) { @@ -96,7 +96,7 @@ /// Creates a straightforward counting for-loop. static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, - SmallVectorImpl &fields, + MutableArrayRef fields, Value lower = Value()) { Type indexType = builder.getIndexType(); if (!lower) @@ -127,10 +127,9 @@ // Any other query can consult the dimSizes array at field DimSizesIdx, // accounting for the reordering applied to the sparse storage. - auto tuple = getTuple(adaptedValue); + auto desc = getDescriptorFromTensorTuple(adaptedValue); Value idx = constantIndex(builder, loc, toStoredDim(tensorTp, dim)); - return builder - .create(loc, tuple.getInputs()[dimSizesIdx], idx) + return builder.create(loc, desc.getDimSizesMemRef(), idx) .getResult(); } @@ -146,108 +145,36 @@ constantIndex(builder, loc, d)); } -/// Translates field index to memSizes index. -static unsigned getMemSizesIndex(unsigned field) { - assert(fieldsIdx <= field); - return field - fieldsIdx; -} - /// Creates a pushback op for given field and updates the fields array /// accordingly. This operation also updates the memSizes contents. static void createPushback(OpBuilder &builder, Location loc, SmallVectorImpl &fields, unsigned field, Value value, Value repeat = Value()) { - assert(fieldsIdx <= field && field < fields.size()); + assert(field < fields.size()); Type etp = fields[field].getType().cast().getElementType(); fields[field] = builder.create( loc, fields[field].getType(), fields[memSizesIdx], fields[field], - toType(builder, loc, value, etp), APInt(64, getMemSizesIndex(field)), + toType(builder, loc, value, etp), APInt(64, getFieldMemSizesIndex(field)), repeat); } -/// Returns field index of sparse tensor type for pointers/indices, when set. -static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) { - assert(getSparseTensorEncoding(type)); - RankedTensorType rType = type.cast(); - unsigned field = fieldsIdx; // start past header - for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) { - if (isCompressedDim(rType, r)) { - if (r == ptrDim) - return field; - field++; - if (r == idxDim) - return field; - field++; - } else if (isSingletonDim(rType, r)) { - if (r == idxDim) - return field; - field++; - } else { - assert(isDenseDim(rType, r)); // no fields - } - } - assert(ptrDim == -1u && idxDim == -1u); - return field + 1; // return values field index -} - /// Maps a sparse tensor type to the appropriate compounded buffers. static Optional convertSparseTensorType(Type type, SmallVectorImpl &fields) { auto enc = getSparseTensorEncoding(type); if (!enc) return llvm::None; - // Construct the basic types. - auto *context = type.getContext(); - unsigned idxWidth = enc.getIndexBitWidth(); - unsigned ptrWidth = enc.getPointerBitWidth(); + RankedTensorType rType = type.cast(); - Type indexType = IndexType::get(context); - Type idxType = idxWidth ? IntegerType::get(context, idxWidth) : indexType; - Type ptrType = ptrWidth ? IntegerType::get(context, ptrWidth) : indexType; - Type eltType = rType.getElementType(); - // - // Sparse tensor storage scheme for rank-dimensional tensor is organized - // as a single compound type with the following fields. Note that every - // memref with ? size actually behaves as a "vector", i.e. the stored - // size is the capacity and the used size resides in the memSizes array. - // - // struct { - // memref dimSizes ; size in each dimension - // memref memSizes ; sizes of ptrs/inds/values - // ; per-dimension d: - // ; if dense: - // - // ; if compresed: - // memref pointers-d ; pointers for sparse dim d - // memref indices-d ; indices for sparse dim d - // ; if singleton: - // memref indices-d ; indices for singleton dim d - // memref values ; values - // }; - // - unsigned rank = rType.getShape().size(); - unsigned lastField = getFieldIndex(type, -1u, -1u); - // The dimSizes array and memSizes array. - fields.push_back(MemRefType::get({rank}, indexType)); - fields.push_back(MemRefType::get({getMemSizesIndex(lastField)}, indexType)); - // Per-dimension storage. - for (unsigned r = 0; r < rank; r++) { - // Dimension level types apply in order to the reordered dimension. - // As a result, the compound type can be constructed directly in the given - // order. Clients of this type know what field is what from the sparse - // tensor type. - if (isCompressedDim(rType, r)) { - fields.push_back(MemRefType::get({ShapedType::kDynamic}, ptrType)); - fields.push_back(MemRefType::get({ShapedType::kDynamic}, idxType)); - } else if (isSingletonDim(rType, r)) { - fields.push_back(MemRefType::get({ShapedType::kDynamic}, idxType)); - } else { - assert(isDenseDim(rType, r)); // no fields - } - } - // The values array. - fields.push_back(MemRefType::get({ShapedType::kDynamic}, eltType)); - assert(fields.size() == lastField); + foreachFieldAndTypeInSparseTensor( + rType, + [&fields](Type fieldType, unsigned fieldIdx, + SparseTensorFieldKind /*fieldKind*/, unsigned /*dim*/, + DimLevelType /*dlt*/) -> bool { + assert(fieldIdx == fields.size()); + fields.push_back(fieldType); + return true; + }); return success(); } @@ -273,26 +200,27 @@ if (isSingletonDim(rtp, r)) { return; // nothing to do } // Keep compounding the size, but nothing needs to be initialized - // at this level. We will eventually reach a compressed level or - // otherwise the values array for the from-here "all-dense" case. - assert(isDenseDim(rtp, r)); - Value size = sizeAtStoredDim(builder, loc, rtp, fields, r); - linear = builder.create(loc, linear, size); + // at this level. We will eventually reach a compressed level or + // otherwise the values array for the from-here "all-dense" case. + assert(isDenseDim(rtp, r)); + Value size = sizeAtStoredDim(builder, loc, rtp, fields, r); + linear = builder.create(loc, linear, size); } // Reached values array so prepare for an insertion. Value valZero = constantZero(builder, loc, rtp.getElementType()); createPushback(builder, loc, fields, field, valZero, linear); - assert(fields.size() == ++field); + assert(fields.size() == field + 1); } /// Creates allocation operation. -static Value createAllocation(OpBuilder &builder, Location loc, Type type, - Value sz, bool enableInit) { - auto memType = MemRefType::get({ShapedType::kDynamic}, type); - Value buffer = builder.create(loc, memType, sz); +static Value createAllocation(OpBuilder &builder, Location loc, + MemRefType memRefType, Value sz, + bool enableInit) { + Value buffer = builder.create(loc, memRefType, sz); + Type elemType = memRefType.getElementType(); if (enableInit) { - Value fillValue = - builder.create(loc, type, builder.getZeroAttr(type)); + Value fillValue = builder.create( + loc, elemType, builder.getZeroAttr(elemType)); builder.create(loc, fillValue, buffer); } return buffer; @@ -308,72 +236,68 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type, ValueRange dynSizes, bool enableInit, SmallVectorImpl &fields) { - auto enc = getSparseTensorEncoding(type); - assert(enc); - // Construct the basic types. - unsigned idxWidth = enc.getIndexBitWidth(); - unsigned ptrWidth = enc.getPointerBitWidth(); RankedTensorType rtp = type.cast(); - Type indexType = builder.getIndexType(); - Type idxType = idxWidth ? builder.getIntegerType(idxWidth) : indexType; - Type ptrType = ptrWidth ? builder.getIntegerType(ptrWidth) : indexType; - Type eltType = rtp.getElementType(); - auto shape = rtp.getShape(); - unsigned rank = shape.size(); Value heuristic = constantIndex(builder, loc, 16); + + foreachFieldAndTypeInSparseTensor( + rtp, + [&](Type fType, unsigned fIdx, SparseTensorFieldKind fKind, + unsigned /*dim*/, DimLevelType /*dlt*/) -> bool { + assert(fields.size() == fIdx); + auto memRefTp = fType.cast(); + Value field; + switch (fKind) { + case SparseTensorFieldKind::DimSizes: + case SparseTensorFieldKind::MemSizes: + field = builder.create(loc, memRefTp); + break; + case SparseTensorFieldKind::PtrMemRef: + case SparseTensorFieldKind::IdxMemRef: + case SparseTensorFieldKind::ValMemRef: + field = + createAllocation(builder, loc, memRefTp, heuristic, enableInit); + break; + } + assert(field); + fields.push_back(field); + // Returns true to ontinue the iteration. + return true; + }); + + SparseTensorDescriptor desc(rtp, fields); + // Build original sizes. SmallVector sizes; + auto shape = rtp.getShape(); + unsigned rank = shape.size(); for (unsigned r = 0, o = 0; r < rank; r++) { if (ShapedType::isDynamic(shape[r])) sizes.push_back(dynSizes[o++]); else sizes.push_back(constantIndex(builder, loc, shape[r])); } - // The dimSizes array and memSizes array. - unsigned lastField = getFieldIndex(type, -1u, -1u); - Value dimSizes = - builder.create(loc, MemRefType::get({rank}, indexType)); - Value memSizes = builder.create( - loc, MemRefType::get({getMemSizesIndex(lastField)}, indexType)); - fields.push_back(dimSizes); - fields.push_back(memSizes); - // Per-dimension storage. - for (unsigned r = 0; r < rank; r++) { - if (isCompressedDim(rtp, r)) { - fields.push_back( - createAllocation(builder, loc, ptrType, heuristic, enableInit)); - fields.push_back( - createAllocation(builder, loc, idxType, heuristic, enableInit)); - } else if (isSingletonDim(rtp, r)) { - fields.push_back( - createAllocation(builder, loc, idxType, heuristic, enableInit)); - } else { - assert(isDenseDim(rtp, r)); // no fields - } - } - // The values array. - fields.push_back( - createAllocation(builder, loc, eltType, heuristic, enableInit)); - assert(fields.size() == lastField); // Initialize the storage scheme to an empty tensor. Initialized memSizes // to all zeros, sets the dimSizes to known values and gives all pointer // fields an initial zero entry, so that it is easier to maintain the // "linear + 1" length property. builder.create( - loc, ValueRange{constantZero(builder, loc, indexType)}, - ValueRange{memSizes}); // zero memSizes - Value ptrZero = constantZero(builder, loc, ptrType); - for (unsigned r = 0, field = fieldsIdx; r < rank; r++) { + loc, constantZero(builder, loc, builder.getIndexType()), + desc.getMemSizesMemRef()); // zero memSizes + + Value ptrZero = constantZero(builder, loc, desc.getPtrElementType()); + for (unsigned r = 0; r < rank; r++) { unsigned ro = toOrigDim(rtp, r); - genStore(builder, loc, sizes[ro], dimSizes, constantIndex(builder, loc, r)); + // Fills dim sizes array. + genStore(builder, loc, sizes[ro], desc.getDimSizesMemRef(), + constantIndex(builder, loc, r)); + + // Pushes a leading zero to pointers memref. if (isCompressedDim(rtp, r)) { - createPushback(builder, loc, fields, field, ptrZero); - field += 2; - } else if (isSingletonDim(rtp, r)) { - field += 1; + unsigned fid = desc.getPtrMemRefIndex(r); + createPushback(builder, loc, fields, fid, ptrZero); } } - allocSchemeForRank(builder, loc, rtp, fields, fieldsIdx, /*rank=*/0); + allocSchemeForRank(builder, loc, rtp, fields, dataFieldIdx, /*rank=*/0); } /// Helper method that generates block specific to compressed case: @@ -408,7 +332,7 @@ Value pp1 = builder.create(loc, pos, one); Value plo = genLoad(builder, loc, fields[field], pos); Value phi = genLoad(builder, loc, fields[field], pp1); - Value psz = constantIndex(builder, loc, getMemSizesIndex(field + 1)); + Value psz = constantIndex(builder, loc, getFieldMemSizesIndex(field + 1)); Value msz = genLoad(builder, loc, fields[memSizesIdx], psz); Value phim1 = builder.create( loc, toType(builder, loc, phi, indexType), one); @@ -430,8 +354,8 @@ builder.create(loc, constantI1(builder, loc, false)); builder.setInsertionPointAfter(ifOp1); Value p = ifOp1.getResult(0); - // If present construct. Note that for a non-unique dimension level, we simply - // set the condition to false and rely on CSE/DCE to clean up the IR. + // If present construct. Note that for a non-unique dimension level, we + // simply set the condition to false and rely on CSE/DCE to clean up the IR. // // TODO: generate less temporary IR? // @@ -481,7 +405,7 @@ SmallVectorImpl &indices, Value value) { unsigned rank = rtp.getShape().size(); assert(rank == indices.size()); - unsigned field = fieldsIdx; // start past header + SparseTensorDescriptor desc(rtp, fields); Value pos = constantZero(builder, loc, builder.getIndexType()); // Generate code for every dimension. for (unsigned d = 0; d < rank; d++) { @@ -493,16 +417,15 @@ // } // pos[d] = indices.size() - 1 // - pos = genCompressed(builder, loc, rtp, fields, indices, value, pos, field, - d); - field += 2; + pos = genCompressed(builder, loc, rtp, fields, indices, value, pos, + desc.getPtrMemRefIndex(d), d); } else if (isSingletonDim(rtp, d)) { // Create: // indices[d].push_back(i[d]) // pos[d] = pos[d-1] // - createPushback(builder, loc, fields, field, indices[d]); - field += 1; + unsigned fidx = desc.getIdxMemRefIndex(d); + createPushback(builder, loc, fields, fidx, indices[d]); } else { assert(isDenseDim(rtp, d)); // Construct the new position as: @@ -514,30 +437,32 @@ } } // Reached the actual value append/insert. - if (!isDenseDim(rtp, rank - 1)) - createPushback(builder, loc, fields, field++, value); - else - genStore(builder, loc, value, fields[field++], pos); - assert(fields.size() == field); + unsigned valIdx = desc.getValMemRefIndex(); + if (!isDenseDim(rtp, rank - 1)) { + createPushback(builder, loc, fields, valIdx, value); + } else + genStore(builder, loc, value, desc.getValMemRef(), pos); + // assert(fields.size() == field); } /// Generations insertion finalization code. static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp, SmallVectorImpl &fields) { unsigned rank = rtp.getShape().size(); - unsigned field = fieldsIdx; // start past header + unsigned field = dataFieldIdx; // start past header for (unsigned d = 0; d < rank; d++) { if (isCompressedDim(rtp, d)) { // Compressed dimensions need a pointer cleanup for all entries // that were not visited during the insertion pass. // - // TODO: avoid cleanup and keep compressed scheme consistent at all times? + // TODO: avoid cleanup and keep compressed scheme consistent at all + // times? // if (d > 0) { unsigned ptrWidth = getSparseTensorEncoding(rtp).getPointerBitWidth(); Type indexType = builder.getIndexType(); Type ptrType = ptrWidth ? builder.getIntegerType(ptrWidth) : indexType; - Value mz = constantIndex(builder, loc, getMemSizesIndex(field)); + Value mz = constantIndex(builder, loc, getFieldMemSizesIndex(field)); Value hi = genLoad(builder, loc, fields[memSizesIdx], mz); Value zero = constantIndex(builder, loc, 0); Value one = constantIndex(builder, loc, 1); @@ -918,11 +843,9 @@ // Replace the requested pointer access with corresponding field. // The cast_op is inserted by type converter to intermix 1:N type // conversion. - auto tuple = getTuple(adaptor.getTensor()); - unsigned idx = Base::getIndexForOp(tuple, op); - auto fields = tuple.getInputs(); - assert(idx < fields.size()); - rewriter.replaceOp(op, fields[idx]); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + Value field = Base::getFieldForOp(desc, op); + rewriter.replaceOp(op, field); return success(); } }; @@ -933,10 +856,10 @@ public: using SparseGetterOpConverter::SparseGetterOpConverter; // Callback for SparseGetterOpConverter. - static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/, - ToPointersOp op) { + static Value getFieldForOp(const SparseTensorDescriptor &desc, + ToPointersOp op) { uint64_t dim = op.getDimension().getZExtValue(); - return getFieldIndex(op.getTensor().getType(), /*ptrDim=*/dim, -1u); + return desc.getPtrMemRef(dim); } }; @@ -946,10 +869,10 @@ public: using SparseGetterOpConverter::SparseGetterOpConverter; // Callback for SparseGetterOpConverter. - static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/, - ToIndicesOp op) { + static Value getFieldForOp(const SparseTensorDescriptor &desc, + ToIndicesOp op) { uint64_t dim = op.getDimension().getZExtValue(); - return getFieldIndex(op.getTensor().getType(), -1u, /*idxDim=*/dim); + return desc.getIdxMemRef(dim); } }; @@ -959,10 +882,9 @@ public: using SparseGetterOpConverter::SparseGetterOpConverter; // Callback for SparseGetterOpConverter. - static unsigned getIndexForOp(UnrealizedConversionCastOp tuple, - ToValuesOp /*op*/) { - // The last field holds the value buffer. - return tuple.getInputs().size() - 1; + static Value getFieldForOp(const SparseTensorDescriptor &desc, + ToValuesOp /*op*/) { + return desc.getValMemRef(); } }; @@ -994,12 +916,11 @@ matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Query memSizes for the actually stored values size. - auto tuple = getTuple(adaptor.getTensor()); - auto fields = tuple.getInputs(); - unsigned lastField = fields.size() - 1; + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); Value field = - constantIndex(rewriter, op.getLoc(), getMemSizesIndex(lastField)); - rewriter.replaceOpWithNewOp(op, fields[memSizesIdx], field); + constantIndex(rewriter, op.getLoc(), desc.getValMemSizesIndex()); + rewriter.replaceOpWithNewOp(op, desc.getMemSizesMemRef(), + field); return success(); } };