diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp @@ -138,7 +138,7 @@ op.getDim().value().getZExtValue()); } else { auto enc = op.getSpecifier().getType().getEncoding(); - StorageLayout layout(enc); + StorageLayout layout(enc); Optional dim = std::nullopt; if (op.getDim()) dim = op.getDim().value().getZExtValue(); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -295,11 +295,10 @@ unsigned idxIndex; unsigned idxStride; std::tie(idxIndex, idxStride) = desc.getIdxMemRefIndexAndStride(d); - unsigned ptrIndex = desc.getPtrMemRefIndex(d); Value one = constantIndex(builder, loc, 1); Value pp1 = builder.create(loc, pos, one); - Value plo = genLoad(builder, loc, desc.getMemRefField(ptrIndex), pos); - Value phi = genLoad(builder, loc, desc.getMemRefField(ptrIndex), pp1); + Value plo = genLoad(builder, loc, desc.getPtrMemRef(d), pos); + Value phi = genLoad(builder, loc, desc.getPtrMemRef(d), pp1); Value msz = desc.getIdxMemSize(builder, loc, d); Value idxStrideC; if (idxStride > 1) { @@ -325,7 +324,7 @@ builder.create(loc, eq); builder.setInsertionPointToStart(&ifOp1.getElseRegion().front()); if (d > 0) - genStore(builder, loc, msz, desc.getMemRefField(ptrIndex), pos); + genStore(builder, loc, msz, desc.getPtrMemRef(d), pos); builder.create(loc, constantI1(builder, loc, false)); builder.setInsertionPointAfter(ifOp1); Value p = ifOp1.getResult(0); @@ -352,7 +351,7 @@ // If !present (changes fields, update next). builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); Value mszp1 = builder.create(loc, msz, one); - genStore(builder, loc, mszp1, desc.getMemRefField(ptrIndex), pp1); + genStore(builder, loc, mszp1, desc.getPtrMemRef(d), pp1); createPushback(builder, loc, desc, SparseTensorFieldKind::IdxMemRef, d, indices[d]); // Prepare the next dimension "as needed". @@ -638,10 +637,8 @@ if (!index || !getSparseTensorEncoding(adaptor.getSource().getType())) return failure(); - Location loc = op.getLoc(); - auto desc = - getDescriptorFromTensorTuple(rewriter, loc, adaptor.getSource()); - auto sz = sizeFromTensorAtDim(rewriter, loc, desc, *index); + auto desc = getDescriptorFromTensorTuple(adaptor.getSource()); + auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *index); if (!sz) return failure(); @@ -756,8 +753,7 @@ if (!getSparseTensorEncoding(op.getTensor().getType())) return failure(); Location loc = op->getLoc(); - auto desc = - getDescriptorFromTensorTuple(rewriter, loc, adaptor.getTensor()); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); RankedTensorType srcType = op.getTensor().getType().cast(); Type eltType = srcType.getElementType(); @@ -900,8 +896,7 @@ // Replace the requested pointer access with corresponding field. // The cast_op is inserted by type converter to intermix 1:N type // conversion. - auto desc = getDescriptorFromTensorTuple(rewriter, op.getLoc(), - adaptor.getTensor()); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); uint64_t dim = op.getDimension().getZExtValue(); rewriter.replaceOp(op, desc.getPtrMemRef(dim)); return success(); @@ -919,17 +914,17 @@ // Replace the requested pointer access with corresponding field. // The cast_op is inserted by type converter to intermix 1:N type // conversion. - auto desc = getDescriptorFromTensorTuple(rewriter, op.getLoc(), - adaptor.getTensor()); + Location loc = op.getLoc(); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); uint64_t dim = op.getDimension().getZExtValue(); - Value field = desc.getIdxMemRef(dim); + Value field = desc.getIdxMemRefOrView(rewriter, loc, dim); // Insert a cast to bridge the actual type to the user expected type. If the // actual type and the user expected type aren't compatible, the compiler or // the runtime will issue an error. Type resType = op.getResult().getType(); if (resType != field.getType()) - field = rewriter.create(op.getLoc(), resType, field); + field = rewriter.create(loc, resType, field); rewriter.replaceOp(op, field); return success(); @@ -967,8 +962,7 @@ // Replace the requested pointer access with corresponding field. // The cast_op is inserted by type converter to intermix 1:N type // conversion. - auto desc = getDescriptorFromTensorTuple(rewriter, op.getLoc(), - adaptor.getTensor()); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); rewriter.replaceOp(op, desc.getValMemRef()); return success(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h @@ -77,8 +77,7 @@ llvm::function_ref, - bool isBuffer = false); + DimLevelType /*DLT (if applicable)*/)>); /// Same as above, except that it also builds the Type for the corresponding /// field. @@ -90,7 +89,7 @@ DimLevelType /*DLT (if applicable)*/)>); /// Gets the total number of fields for the given sparse tensor encoding. -unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc, bool isBuffer); +unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc); /// Gets the total number of data fields (index arrays, pointer arrays, and a /// value array) for the given sparse tensor encoding. @@ -107,12 +106,7 @@ } /// Provides methods to access fields of a sparse tensor with the given -/// encoding. When isBuffer is true, the fields are the actual buffers of the -/// sparse tensor storage. In particular, when a linear buffer is used to -/// store the COO data as an array-of-structures, the fields include the -/// linear buffer (isBuffer=true) or includes the subviews of the buffer for the -/// indices (isBuffer=false). -template +/// encoding. class StorageLayout { public: explicit StorageLayout(SparseTensorEncodingAttr enc) : enc(enc) {} @@ -132,7 +126,7 @@ } static unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) { - return sparse_tensor::getNumFieldsFromEncoding(enc, isBuffer); + return sparse_tensor::getNumFieldsFromEncoding(enc); } static void foreachFieldInSparseTensor( @@ -140,7 +134,7 @@ llvm::function_ref callback) { - return sparse_tensor::foreachFieldInSparseTensor(enc, callback, isBuffer); + return sparse_tensor::foreachFieldInSparseTensor(enc, callback); } std::pair @@ -148,7 +142,7 @@ std::optional dim) const { unsigned fieldIdx = -1u; unsigned stride = 1; - if (isBuffer && kind == SparseTensorFieldKind::IdxMemRef) { + if (kind == SparseTensorFieldKind::IdxMemRef) { assert(dim.has_value()); unsigned cooStart = getCOOStart(enc); unsigned rank = enc.getDimLevelType().size(); @@ -222,18 +216,11 @@ using ValueArrayRef = typename std::conditional &, ValueRange>::type; - SparseTensorDescriptorImpl(Type tp) - : rType(tp.cast()), fields() {} - SparseTensorDescriptorImpl(Type tp, ValueArrayRef fields) : rType(tp.cast()), fields(fields) { - sanityCheck(); - } - - void sanityCheck() { - assert(getSparseTensorEncoding(rType) && - StorageLayout::getNumFieldsFromEncoding( - getSparseTensorEncoding(rType)) == fields.size()); + assert(getSparseTensorEncoding(tp) && + getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) == + fields.size()); // We should make sure the class is trivially copyable (and should be small // enough) such that we can pass it by value. static_assert( @@ -244,22 +231,10 @@ unsigned getMemRefFieldIndex(SparseTensorFieldKind kind, std::optional dim) const { // Delegates to storage layout. - StorageLayout layout(getSparseTensorEncoding(rType)); + StorageLayout layout(getSparseTensorEncoding(rType)); return layout.getMemRefFieldIndex(kind, dim); } - unsigned getPtrMemRefIndex(unsigned ptrDim) const { - return getMemRefFieldIndex(SparseTensorFieldKind::PtrMemRef, ptrDim); - } - - unsigned getIdxMemRefIndex(unsigned idxDim) const { - return getMemRefFieldIndex(SparseTensorFieldKind::IdxMemRef, idxDim); - } - - unsigned getValMemRefIndex() const { - return getMemRefFieldIndex(SparseTensorFieldKind::ValMemRef, std::nullopt); - } - unsigned getNumFields() const { return fields.size(); } /// @@ -281,10 +256,6 @@ return getMemRefField(SparseTensorFieldKind::PtrMemRef, ptrDim); } - Value getIdxMemRef(unsigned idxDim) const { - return getMemRefField(SparseTensorFieldKind::IdxMemRef, idxDim); - } - Value getValMemRef() const { return getMemRefField(SparseTensorFieldKind::ValMemRef, std::nullopt); } @@ -299,15 +270,19 @@ return fields[fidx]; } - Value getField(unsigned fidx) const { - assert(fidx < fields.size()); - return fields[fidx]; + Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const { + return getSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize, + dim); } - ValueRange getMemRefFields() const { - ValueRange ret = fields; - // Drop the last metadata fields. - return ret.slice(0, fields.size() - 1); + Value getIdxMemSize(OpBuilder &builder, Location loc, unsigned dim) const { + return getSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize, + dim); + } + + Value getValMemSize(OpBuilder &builder, Location loc) const { + return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize, + std::nullopt); } Type getMemRefElementType(SparseTensorFieldKind kind, @@ -331,23 +306,15 @@ MutSparseTensorDescriptor(Type tp, ValueArrayRef buffers) : SparseTensorDescriptorImpl(tp, buffers) {} - /// - /// Getters: get the value for required field. - /// - - Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const { - return getSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize, - dim); - } - - Value getIdxMemSize(OpBuilder &builder, Location loc, unsigned dim) const { - return getSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize, - dim); + Value getField(unsigned fidx) const { + assert(fidx < fields.size()); + return fields[fidx]; } - Value getValMemSize(OpBuilder &builder, Location loc) const { - return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize, - std::nullopt); + ValueRange getMemRefFields() const { + ValueRange ret = fields; + // Drop the last metadata fields. + return ret.slice(0, fields.size() - 1); } /// @@ -384,7 +351,7 @@ std::pair getIdxMemRefIndexAndStride(unsigned idxDim) const { - StorageLayout layout(getSparseTensorEncoding(rType)); + StorageLayout layout(getSparseTensorEncoding(rType)); return layout.getFieldIndexAndStride(SparseTensorFieldKind::IdxMemRef, idxDim); } @@ -393,19 +360,17 @@ auto enc = getSparseTensorEncoding(rType); unsigned cooStart = getCOOStart(enc); assert(cooStart < enc.getDimLevelType().size()); - return getIdxMemRef(cooStart); + return getMemRefField(SparseTensorFieldKind::IdxMemRef, cooStart); } }; class SparseTensorDescriptor : public SparseTensorDescriptorImpl { public: - SparseTensorDescriptor(OpBuilder &builder, Location loc, Type tp, - ValueArrayRef buffers); + SparseTensorDescriptor(Type tp, ValueArrayRef buffers) + : SparseTensorDescriptorImpl(tp, buffers) {} -private: - // Store the fields passed to SparseTensorDescriptorImpl when the tensor has - // a COO region. - SmallVector expandedFields; + Value getIdxMemRefOrView(OpBuilder &builder, Location loc, + unsigned idxDim) const; }; /// Returns the "tuple" value of the adapted tensor. @@ -425,11 +390,9 @@ return genTuple(builder, loc, desc.getTensorType(), desc.getFields()); } -inline SparseTensorDescriptor -getDescriptorFromTensorTuple(OpBuilder &builder, Location loc, Value tensor) { +inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) { auto tuple = getTuple(tensor); - return SparseTensorDescriptor(builder, loc, tuple.getResultTypes()[0], - tuple.getInputs()); + return SparseTensorDescriptor(tuple.getResultTypes()[0], tuple.getInputs()); } inline MutSparseTensorDescriptor diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp @@ -109,41 +109,24 @@ // SparseTensorDescriptor methods. //===----------------------------------------------------------------------===// -sparse_tensor::SparseTensorDescriptor::SparseTensorDescriptor( - OpBuilder &builder, Location loc, Type tp, ValueArrayRef buffers) - : SparseTensorDescriptorImpl(tp), expandedFields() { - SparseTensorEncodingAttr enc = getSparseTensorEncoding(tp); - unsigned rank = enc.getDimLevelType().size(); +Value sparse_tensor::SparseTensorDescriptor::getIdxMemRefOrView( + OpBuilder &builder, Location loc, unsigned idxDim) const { + auto enc = getSparseTensorEncoding(rType); unsigned cooStart = getCOOStart(enc); - if (cooStart < rank) { - ValueRange beforeFields = buffers.drop_back(3); - expandedFields.append(beforeFields.begin(), beforeFields.end()); - Value buffer = buffers[buffers.size() - 3]; - + unsigned idx = idxDim >= cooStart ? cooStart : idxDim; + Value buffer = getMemRefField(SparseTensorFieldKind::IdxMemRef, idx); + if (idxDim >= cooStart) { + unsigned rank = enc.getDimLevelType().size(); Value stride = constantIndex(builder, loc, rank - cooStart); - SmallVector buffersArray(buffers.begin(), buffers.end()); - MutSparseTensorDescriptor mutDesc(tp, buffersArray); - // Calculate subbuffer size as memSizes[idx] / (stride). - Value subBufferSize = mutDesc.getIdxMemSize(builder, loc, cooStart); - subBufferSize = builder.create(loc, subBufferSize, stride); - - // Create views of the linear idx buffer for the COO indices. - for (unsigned i = cooStart; i < rank; i++) { - Value subBuffer = builder.create( - loc, buffer, - /*offset=*/ValueRange{constantIndex(builder, loc, i - cooStart)}, - /*size=*/ValueRange{subBufferSize}, - /*step=*/ValueRange{stride}); - expandedFields.push_back(subBuffer); - } - expandedFields.push_back(buffers[buffers.size() - 2]); // The Values memref. - expandedFields.push_back(buffers.back()); // The specifier. - fields = expandedFields; - } else { - fields = buffers; + Value size = getIdxMemSize(builder, loc, cooStart); + size = builder.create(loc, size, stride); + buffer = builder.create( + loc, buffer, + /*offset=*/ValueRange{constantIndex(builder, loc, idxDim - cooStart)}, + /*size=*/ValueRange{size}, + /*step=*/ValueRange{stride}); } - - sanityCheck(); + return buffer; } //===----------------------------------------------------------------------===// @@ -156,8 +139,7 @@ const SparseTensorEncodingAttr enc, llvm::function_ref - callback, - bool isBuffer) { + callback) { assert(enc); #define RETURN_ON_FALSE(idx, kind, dim, dlt) \ @@ -165,11 +147,13 @@ return; unsigned rank = enc.getDimLevelType().size(); - unsigned cooStart = isBuffer ? getCOOStart(enc) : rank; + unsigned end = getCOOStart(enc); + if (end != rank) + end += 1; static_assert(kDataFieldStartingIdx == 0); unsigned fieldIdx = kDataFieldStartingIdx; // Per-dimension storage. - for (unsigned r = 0; r < rank; r++) { + for (unsigned r = 0; r < end; r++) { // Dimension level types apply in order to the reordered dimension. // As a result, the compound type can be constructed directly in the given // order. @@ -178,8 +162,7 @@ RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PtrMemRef, r, dlt); RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); } else if (isSingletonDLT(dlt)) { - if (r < cooStart) - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); } else { assert(isDenseDLT(dlt)); // no fields } @@ -231,38 +214,32 @@ return callback(valMemType, fieldIdx, fieldKind, dim, dlt); }; llvm_unreachable("unrecognized field kind"); - }, - /*isBuffer=*/true); + }); } -unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc, - bool isBuffer) { +unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) { unsigned numFields = 0; - foreachFieldInSparseTensor( - enc, - [&numFields](unsigned, SparseTensorFieldKind, unsigned, - DimLevelType) -> bool { - numFields++; - return true; - }, - isBuffer); + foreachFieldInSparseTensor(enc, + [&numFields](unsigned, SparseTensorFieldKind, + unsigned, DimLevelType) -> bool { + numFields++; + return true; + }); return numFields; } unsigned sparse_tensor::getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) { unsigned numFields = 0; // one value memref - foreachFieldInSparseTensor( - enc, - [&numFields](unsigned fidx, SparseTensorFieldKind, unsigned, - DimLevelType) -> bool { - if (fidx >= kDataFieldStartingIdx) - numFields++; - return true; - }, - /*isBuffer=*/true); + foreachFieldInSparseTensor(enc, + [&numFields](unsigned fidx, SparseTensorFieldKind, + unsigned, DimLevelType) -> bool { + if (fidx >= kDataFieldStartingIdx) + numFields++; + return true; + }); numFields -= 1; // the last field is MetaData field - assert(numFields == getNumFieldsFromEncoding(enc, /*isBuffer=*/true) - - kDataFieldStartingIdx - 1); + assert(numFields == + getNumFieldsFromEncoding(enc) - kDataFieldStartingIdx - 1); return numFields; }