diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h @@ -122,7 +122,8 @@ /// The actions performed by @newSparseTensor. enum class Action : uint32_t { kEmpty = 0, - kFromFile = 1, + // newSparseTensor no longer handles `kFromFile=1`, so we leave this + // number reserved to help catch any code that still needs updating. kFromCOO = 2, kSparseToSparse = 3, kEmptyCOO = 4, diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h @@ -103,6 +103,23 @@ SparseTensorReader(const SparseTensorReader &) = delete; SparseTensorReader &operator=(const SparseTensorReader &) = delete; + /// Factory method to allocate a new reader, open the file, read the + /// header, and validate that the actual contents of the file match + /// the expected `dimShape` and `valTp`. + static SparseTensorReader *create(const char *filename, uint64_t dimRank, + const uint64_t *dimShape, + PrimaryType valTp) { + SparseTensorReader *reader = new SparseTensorReader(filename); + reader->openFile(); + reader->readHeader(); + if (!reader->canReadAs(valTp)) + MLIR_SPARSETENSOR_FATAL( + "Tensor element type %d not compatible with values in file %s\n", + static_cast(valTp), filename); + reader->assertMatchesShape(dimRank, dimShape); + return reader; + } + // This dtor tries to avoid leaking the `file`. (Though it's better // to call `closeFile` explicitly when possible, since there are // circumstances where dtors are not called reliably.) @@ -173,10 +190,51 @@ /// to the `indices` array. template V readCOOElement(uint64_t rank, uint64_t *indices) { - char *linePtr = readCOOIndices(rank, indices); + assert(rank == getRank() && "rank mismatch"); + char *linePtr = readCOOIndices(indices); return detail::readCOOValue(&linePtr, isPattern()); } + /// Allocates a new COO object for `lvlSizes`, initializes it by reading + /// all the elements from the file and applying `dim2lvl` to their indices, + /// and then closes the file. + /// + /// Preconditions: + /// * `lvlSizes` must be valid for `lvlRank`. + /// * `dim2lvl` must be valid for `getRank()`. + /// * `dim2lvl` maps indices valid for `getDimSizes()` to indices + /// valid for `lvlSizes`. + /// * the file's actual value type can be read as `V`. + /// + /// Asserts: + /// * `isValid()` + /// * `dim2lvl` is a permutation, and therefore also `lvlRank == getRank()`. + /// (This requirement will be lifted once we functionalize `dim2lvl`.) + // + // NOTE: This method is factored out of `readSparseTensor` primarily to + // reduce code bloat (since the bulk of the code doesn't care about the + // `` type template parameters). But we leave it public since it's + // perfectly reasonable for clients to use. + template + SparseTensorCOO *readCOO(uint64_t lvlRank, const uint64_t *lvlSizes, + const uint64_t *dim2lvl); + + /// Allocates a new sparse-tensor storage object with the given encoding, + /// initializes it by reading all the elements from the file, and then + /// closes the file. Preconditions/assertions are as per `readCOO` + /// and `SparseTensorStorage::newFromCOO`. + template + SparseTensorStorage * + readSparseTensor(uint64_t lvlRank, const uint64_t *lvlSizes, + const DimLevelType *lvlTypes, const uint64_t *lvl2dim, + const uint64_t *dim2lvl) { + auto *lvlCOO = readCOO(lvlRank, lvlSizes, dim2lvl); + auto *tensor = SparseTensorStorage::newFromCOO( + getRank(), getDimSizes(), lvlRank, lvlTypes, lvl2dim, *lvlCOO); + delete lvlCOO; + return tensor; + } + private: /// Attempts to read a line from the file. Is private because there's /// no reason for client code to call it. @@ -187,7 +245,9 @@ /// buffer where the element's value should be parsed from. This method /// has been factored out from `readCOOElement` to minimize code bloat /// for the generated library. - char *readCOOIndices(uint64_t rank, uint64_t *indices); + /// + /// Precondition: `indices` is valid for `getRank()`. + char *readCOOIndices(uint64_t *indices); /// Reads the MME header of a general sparse matrix of type real. void readMMEHeader(); @@ -209,72 +269,49 @@ //===----------------------------------------------------------------------===// -/// Reads a sparse tensor with the given filename into a memory-resident -/// sparse tensor. -/// -/// Preconditions: -/// * `dimShape` and `dim2lvl` must be valid for `dimRank`. -/// * `lvlTypes` and `lvl2dim` must be valid for `lvlRank`. -/// * `dim2lvl` is the inverse of `lvl2dim`. -/// -/// Asserts: -/// * the file's actual value type can be read as `valTp`. -/// * the file's actual dimension-sizes match the expected `dimShape`. -/// * `dim2lvl` is a permutation, and therefore also `dimRank == lvlRank`. -// -// TODO: As currently written, this function uses `dim2lvl` in two -// places: first, to construct the level-sizes from the file's actual -// dimension-sizes; and second, to map the file's dimension-indices into -// level-indices. The latter can easily generalize to arbitrary mappings, -// however the former cannot. Thus, once we functionalize the mappings, -// this function will need both the sizes-to-sizes and indices-to-indices -// variants of the `dim2lvl` mapping. For the `lvl2dim` direction we only -// need the indices-to-indices variant, for handing off to `newFromCOO`. -template -inline SparseTensorStorage * -openSparseTensor(uint64_t dimRank, const uint64_t *dimShape, uint64_t lvlRank, - const DimLevelType *lvlTypes, const uint64_t *lvl2dim, - const uint64_t *dim2lvl, const char *filename, - PrimaryType valTp) { - // Read the file's header and check the file's actual element type and - // dimension-sizes against the expected element type and dimension-shape. - SparseTensorReader stfile(filename); - stfile.openFile(); - stfile.readHeader(); - if (!stfile.canReadAs(valTp)) - MLIR_SPARSETENSOR_FATAL( - "Tensor element type %d not compatible with values in file %s\n", - static_cast(valTp), filename); - stfile.assertMatchesShape(dimRank, dimShape); - const uint64_t *dimSizes = stfile.getDimSizes(); - // Construct the level-sizes from the file's dimension-sizes - // TODO: This doesn't generalize to arbitrary mappings. (See above.) - assert(dimRank == lvlRank && "Rank mismatch"); +template +SparseTensorCOO *SparseTensorReader::readCOO(uint64_t lvlRank, + const uint64_t *lvlSizes, + const uint64_t *dim2lvl) { + assert(isValid() && "Attempt to readCOO() before readHeader()"); + // Construct a `PermutationRef` for the `pushforward` below. + // TODO: This specific implementation does not generalize to arbitrary + // mappings, but once we functionalize the `dim2lvl` argument we can + // simply use that function instead. + const uint64_t dimRank = getRank(); + assert(lvlRank == dimRank && "Rank mismatch"); detail::PermutationRef d2l(dimRank, dim2lvl); - std::vector lvlSizes = d2l.pushforward(dimRank, dimSizes); // Prepare a COO object with the number of nonzeros as initial capacity. - uint64_t nnz = stfile.getNNZ(); - auto *lvlCOO = new SparseTensorCOO(lvlSizes, nnz); + const uint64_t nnz = getNNZ(); + auto *lvlCOO = new SparseTensorCOO(lvlRank, lvlSizes, nnz); // Read all nonzero elements. std::vector dimInd(dimRank); std::vector lvlInd(lvlRank); + // Do some manual LICM, to avoid assertions in the for-loop. + const bool addSymmetric = (isSymmetric() && dimRank == 2); + const bool isPattern_ = isPattern(); for (uint64_t k = 0; k < nnz; ++k) { - const V value = stfile.readCOOElement(dimRank, dimInd.data()); + // We inline `readCOOElement` here in order to avoid redundant + // assertions, since they're guaranteed by the call to `isValid()` + // and the construction of `dimInd` above. + char *linePtr = readCOOIndices(dimInd.data()); + const V value = detail::readCOOValue(&linePtr, isPattern_); d2l.pushforward(dimRank, dimInd.data(), lvlInd.data()); // TODO: lvlCOO->add(lvlInd, value); // We currently chose to deal with symmetric matrices by fully // constructing them. In the future, we may want to make symmetry // implicit for storage reasons. - if (stfile.isSymmetric() && lvlInd[0] != lvlInd[1]) - lvlCOO->add({lvlInd[1], lvlInd[0]}, value); + if (addSymmetric && dimInd[0] != dimInd[1]) { + // Must recompute `lvlInd`, since arbitrary mappings don't preserve swap. + std::swap(dimInd[0], dimInd[1]); + d2l.pushforward(dimRank, dimInd.data(), lvlInd.data()); + lvlCOO->add(lvlInd, value); + } } - // Close the file, convert the COO to SparseTensorStorage, and return. - stfile.closeFile(); - auto *tensor = SparseTensorStorage::newFromCOO( - dimRank, dimSizes, lvlRank, lvlTypes, lvl2dim, *lvlCOO); - delete lvlCOO; - return tensor; + // Close the file and return the COO. + closeFile(); + return lvlCOO; } /// Writes the sparse tensor to `filename` in extended FROSTT format. diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h @@ -220,8 +220,28 @@ /// Creates a SparseTensorReader for reading a sparse tensor from a file with /// the given file name. This opens the file and read the header meta data based /// of the sparse tensor format derived from the suffix of the file name. +// +// FIXME: update `SparseTensorCodegenPass` to use +// `_mlir_ciface_createCheckedSparseTensorReader` instead. MLIR_CRUNNERUTILS_EXPORT void *createSparseTensorReader(char *filename); +/// Constructs a new SparseTensorReader object, opens the file, reads the +/// header, and validates that the actual contents of the file match +/// the expected `dimShapeRef` and `valTp`. +MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_createCheckedSparseTensorReader( + char *filename, StridedMemRefType *dimShapeRef, + PrimaryType valTp); + +/// Constructs a new sparse-tensor storage object with the given encoding, +/// initializes it by reading all the elements from the file, and then +/// closes the file. +MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_newSparseTensorFromReader( + void *p, StridedMemRefType *lvlSizesRef, + StridedMemRefType *lvlTypesRef, + StridedMemRefType *lvl2dimRef, + StridedMemRefType *dim2lvlRef, OverheadType ptrTp, + OverheadType indTp, PrimaryType valTp); + /// Returns the rank of the sparse tensor being read. MLIR_CRUNNERUTILS_EXPORT index_type getSparseTensorReaderRank(void *p); @@ -235,10 +255,19 @@ MLIR_CRUNNERUTILS_EXPORT index_type getSparseTensorReaderDimSize(void *p, index_type d); -/// Returns all dimension sizes for the sparse tensor being read. -MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_getSparseTensorReaderDimSizes( +/// SparseTensorReader method to copy the dimension-sizes into the +/// provided memref. +// +// FIXME: update `SparseTensorCodegenPass` to use +// `_mlir_ciface_getSparseTensorReaderDimSizes` instead. +MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_copySparseTensorReaderDimSizes( void *p, StridedMemRefType *dref); +/// SparseTensorReader method to obtain direct access to the +/// dimension-sizes array. +MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_getSparseTensorReaderDimSizes( + StridedMemRefType *out, void *p); + /// Releases the SparseTensorReader. This also closes the file associated with /// the reader. MLIR_CRUNNERUTILS_EXPORT void delSparseTensorReader(void *p); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -95,7 +95,9 @@ static void sizesFromPtr(OpBuilder &builder, SmallVectorImpl &sizes, Location loc, SparseTensorEncodingAttr &enc, ShapedType stp, Value src) { - for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) + unsigned rank = stp.getRank(); + sizes.reserve(rank); + for (unsigned i = 0; i < rank; i++) sizes.push_back(sizeFromPtrAtDim(builder, loc, enc, stp, src, i)); } @@ -103,7 +105,9 @@ static void sizesFromType(OpBuilder &builder, SmallVectorImpl &sizes, Location loc, ShapedType stp) { auto shape = stp.getShape(); - for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) { + unsigned rank = stp.getRank(); + sizes.reserve(rank); + for (unsigned i = 0; i < rank; i++) { uint64_t s = shape[i] == ShapedType::kDynamic ? 0 : shape[i]; sizes.push_back(constantIndex(builder, loc, s)); } @@ -167,6 +171,17 @@ return buffer; } +/// Generates a temporary buffer for the level-types of the given encoding. +static Value genLvlTypesBuffer(OpBuilder &builder, Location loc, + SparseTensorEncodingAttr enc) { + SmallVector lvlTypes; + auto dlts = enc.getDimLevelType(); + lvlTypes.reserve(dlts.size()); + for (auto dlt : dlts) + lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt)); + return genBuffer(builder, loc, lvlTypes); +} + /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`: /// the "swiss army knife" method of the sparse runtime support library /// for materializing sparse tensors into the computation. This abstraction @@ -262,11 +277,7 @@ const unsigned lvlRank = enc.getDimLevelType().size(); const unsigned dimRank = stp.getRank(); // Sparsity annotations. - SmallVector lvlTypes; - for (auto dlt : enc.getDimLevelType()) - lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt)); - assert(lvlTypes.size() == lvlRank && "Level-rank mismatch"); - params[kParamLvlTypes] = genBuffer(builder, loc, lvlTypes); + params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, enc); // Dimension-sizes array of the enveloping tensor. Useful for either // verification of external data, or for construction of internal data. assert(dimSizes.size() == dimRank && "Dimension-rank mismatch"); @@ -715,19 +726,98 @@ matchAndRewrite(NewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Type resType = op.getType(); - auto enc = getSparseTensorEncoding(resType); + auto stp = op.getType().cast(); + auto enc = getSparseTensorEncoding(stp); if (!enc) return failure(); - // Generate the call to construct tensor from ptr. The sizes are - // inferred from the result type of the new operator. - SmallVector sizes; - ShapedType stp = resType.cast(); - sizesFromType(rewriter, sizes, loc, stp); - Value ptr = adaptor.getOperands()[0]; - rewriter.replaceOp(op, NewCallParams(rewriter, loc) - .genBuffers(enc, sizes, stp) - .genNewCall(Action::kFromFile, ptr)); + const unsigned dimRank = stp.getRank(); + const unsigned lvlRank = enc.getDimLevelType().size(); + // Construct the dimShape. + const auto dimShape = stp.getShape(); + SmallVector dimShapeValues; + sizesFromType(rewriter, dimShapeValues, loc, stp); + Value dimShapeBuffer = genBuffer(rewriter, loc, dimShapeValues); + // Allocate `SparseTensorReader` and perform all initial setup that + // does not depend on lvlSizes (nor dim2lvl, lvl2dim, etc). + Type opaqueTp = getOpaquePointerType(rewriter); + Value valTp = + constantPrimaryTypeEncoding(rewriter, loc, stp.getElementType()); + Value reader = + createFuncCall(rewriter, loc, "createCheckedSparseTensorReader", + opaqueTp, + {adaptor.getOperands()[0], dimShapeBuffer, valTp}, + EmitCInterface::On) + .getResult(0); + // Construct the lvlSizes. If the dimShape is static, then it's + // identical to dimSizes: so we can compute lvlSizes entirely at + // compile-time. If dimShape is dynamic, then we'll need to generate + // code for computing lvlSizes from the `reader`'s actual dimSizes. + // + // TODO: For now we're still assuming `dim2lvl` is a permutation. + // But since we're computing lvlSizes here (rather than in the runtime), + // we can easily generalize that simply by adjusting this code. + // + // FIXME: reduce redundancy vs `NewCallParams::genBuffers`. + Value dimSizesBuffer; + if (!stp.hasStaticShape()) { + Type indexTp = rewriter.getIndexType(); + auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp); + dimSizesBuffer = + createFuncCall(rewriter, loc, "getSparseTensorReaderDimSizes", memTp, + reader, EmitCInterface::On) + .getResult(0); + } + Value lvlSizesBuffer; + Value lvl2dimBuffer; + Value dim2lvlBuffer; + if (auto dimOrder = enc.getDimOrdering()) { + assert(dimOrder.isPermutation() && "Got non-permutation"); + // We preinitialize `dim2lvlValues` since we need random-access writing. + // And we preinitialize the others for stylistic consistency. + SmallVector lvlSizeValues(lvlRank); + SmallVector lvl2dimValues(lvlRank); + SmallVector dim2lvlValues(dimRank); + for (unsigned l = 0; l < lvlRank; l++) { + // The `d`th source variable occurs in the `l`th result position. + uint64_t d = dimOrder.getDimPosition(l); + Value lvl = constantIndex(rewriter, loc, l); + Value dim = constantIndex(rewriter, loc, d); + dim2lvlValues[d] = lvl; + lvl2dimValues[l] = dim; + lvlSizeValues[l] = + (dimShape[d] == ShapedType::kDynamic) + ? rewriter.create(loc, dimSizesBuffer, dim) + : dimShapeValues[d]; + } + lvlSizesBuffer = genBuffer(rewriter, loc, lvlSizeValues); + lvl2dimBuffer = genBuffer(rewriter, loc, lvl2dimValues); + dim2lvlBuffer = genBuffer(rewriter, loc, dim2lvlValues); + } else { + assert(dimRank == lvlRank && "Rank mismatch"); + SmallVector iotaValues; + iotaValues.reserve(lvlRank); + for (unsigned i = 0; i < lvlRank; i++) + iotaValues.push_back(constantIndex(rewriter, loc, i)); + lvlSizesBuffer = dimSizesBuffer ? dimSizesBuffer : dimShapeBuffer; + dim2lvlBuffer = lvl2dimBuffer = genBuffer(rewriter, loc, iotaValues); + } + // Use the `reader` to parse the file. + SmallVector params{ + reader, + lvlSizesBuffer, + genLvlTypesBuffer(rewriter, loc, enc), + lvl2dimBuffer, + dim2lvlBuffer, + constantPointerTypeEncoding(rewriter, loc, enc), + constantIndexTypeEncoding(rewriter, loc, enc), + valTp}; + Value tensor = createFuncCall(rewriter, loc, "newSparseTensorFromReader", + opaqueTp, params, EmitCInterface::On) + .getResult(0); + // Free the memory for `reader`. + createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader}, + EmitCInterface::Off); + rewriter.replaceOp(op, tensor); return success(); } }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -866,9 +866,8 @@ Location loc = op.getLoc(); auto dstTp = op.getResult().getType().template cast(); SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp); - if (!encDst) { + if (!encDst) return failure(); - } // Create a sparse tensor reader. Value fileName = op.getSource(); @@ -886,7 +885,7 @@ // the sparse tensor reader. SmallVector dynSizesArray; if (!dstTp.hasStaticShape()) { - createFuncCall(rewriter, loc, "getSparseTensorReaderDimSizes", {}, + createFuncCall(rewriter, loc, "copySparseTensorReaderDimSizes", {}, {reader, dimSizes}, EmitCInterface::On) .getResult(0); ArrayRef dstShape = dstTp.getShape(); @@ -930,7 +929,7 @@ ArrayRef(cooBuffer)); rewriter.setInsertionPointToStart(forOp.getBody()); - SmallString<18> getNextFuncName{"getSparseTensorReaderNext", + SmallString<29> getNextFuncName{"getSparseTensorReaderNext", primaryTypeFunctionSuffix(eltTp)}; Value indices = dimSizes; // Reuse the indices memref to store indices. createFuncCall(rewriter, loc, getNextFuncName, {}, {reader, indices, value}, @@ -1013,7 +1012,7 @@ Value indices = dimSizes; // Reuse the dimSizes buffer for indices. Type eltTp = srcTp.getElementType(); - SmallString<18> outNextFuncName{"outSparseTensorWriterNext", + SmallString<29> outNextFuncName{"outSparseTensorWriterNext", primaryTypeFunctionSuffix(eltTp)}; Value value = genAllocaScalar(rewriter, loc, eltTp); ModuleOp module = op->getParentOfType(); diff --git a/mlir/lib/ExecutionEngine/SparseTensor/File.cpp b/mlir/lib/ExecutionEngine/SparseTensor/File.cpp --- a/mlir/lib/ExecutionEngine/SparseTensor/File.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensor/File.cpp @@ -53,12 +53,11 @@ MLIR_SPARSETENSOR_FATAL("Cannot read next line of %s\n", filename); } -char *SparseTensorReader::readCOOIndices(uint64_t rank, uint64_t *indices) { - assert(rank == getRank() && "Rank mismatch"); +char *SparseTensorReader::readCOOIndices(uint64_t *indices) { readLine(); // Local variable for tracking the parser's position in the `line` buffer. char *linePtr = line; - for (uint64_t r = 0; r < rank; ++r) { + for (uint64_t rank = getRank(), r = 0; r < rank; ++r) { // Parse the 1-based index. uint64_t idx = strtoul(linePtr, &linePtr, 10); // Store the 0-based index. diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp @@ -243,15 +243,29 @@ #define MEMREF_GET_PAYLOAD(MEMREF) ((MEMREF)->data + (MEMREF)->offset) -// We make this a function rather than a macro mainly for type safety -// reasons. This function does not modify the vector, but it cannot -// be marked `const` because it is stored into the non-`const` memref. -template -static void vectorToMemref(std::vector &v, StridedMemRefType &ref) { - ref.basePtr = ref.data = v.data(); +/// Initializes the memref with the provided size and data pointer. This +/// is designed for functions which want to "return" a memref that aliases +/// into memory owned by some other object (e.g., `SparseTensorStorage`), +/// without doing any actual copying. (The "return" is in scarequotes +/// because the `_mlir_ciface_` calling convention migrates any returned +/// memrefs into an out-parameter passed before all the other function +/// parameters.) +/// +/// We make this a function rather than a macro mainly for type safety +/// reasons. This function does not modify the data pointer, but it +/// cannot be marked `const` because it is stored into the (necessarily) +/// non-`const` memref. This function is templated over the `DataSizeT` +/// to work around signedness warnings due to many data types having +/// varying signedness across different platforms. The templating allows +/// this function to ensure that it does the right thing and never +/// introduces errors due to implicit conversions. +template +static inline void aliasIntoMemref(DataSizeT size, T *data, + StridedMemRefType &ref) { + ref.basePtr = ref.data = data; ref.offset = 0; - using SizeT = typename std::remove_reference_t; - ref.sizes[0] = detail::checkOverflowCast(v.size()); + using MemrefSizeT = typename std::remove_reference_t; + ref.sizes[0] = detail::checkOverflowCast(size); ref.strides[0] = 1; } @@ -272,11 +286,6 @@ case Action::kEmpty: \ return SparseTensorStorage::newEmpty( \ dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, lvl2dim); \ - case Action::kFromFile: { \ - char *filename = static_cast(ptr); \ - return openSparseTensor(dimRank, dimSizes, lvlRank, lvlTypes, \ - lvl2dim, dim2lvl, filename, v); \ - } \ case Action::kFromCOO: { \ assert(ptr && "Received nullptr for SparseTensorCOO object"); \ auto &coo = *static_cast *>(ptr); \ @@ -468,7 +477,7 @@ std::vector *v; \ static_cast(tensor)->getValues(&v); \ assert(v); \ - vectorToMemref(*v, *ref); \ + aliasIntoMemref(v->size(), v->data(), *ref); \ } MLIR_SPARSETENSOR_FOREVERY_V(IMPL_SPARSEVALUES) #undef IMPL_SPARSEVALUES @@ -480,7 +489,7 @@ std::vector *v; \ static_cast(tensor)->LIB(&v, d); \ assert(v); \ - vectorToMemref(*v, *ref); \ + aliasIntoMemref(v->size(), v->data(), *ref); \ } #define IMPL_SPARSEPOINTERS(PNAME, P) \ IMPL_GETOVERHEAD(sparsePointers##PNAME, P, getPointers) @@ -574,16 +583,37 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT) #undef IMPL_EXPINSERT -void _mlir_ciface_getSparseTensorReaderDimSizes( +void *_mlir_ciface_createCheckedSparseTensorReader( + char *filename, StridedMemRefType *dimShapeRef, + PrimaryType valTp) { + ASSERT_NO_STRIDE(dimShapeRef); + const uint64_t dimRank = MEMREF_GET_USIZE(dimShapeRef); + const index_type *dimShape = MEMREF_GET_PAYLOAD(dimShapeRef); + auto *reader = SparseTensorReader::create(filename, dimRank, dimShape, valTp); + return static_cast(reader); +} + +// FIXME: update `SparseTensorCodegenPass` to use +// `_mlir_ciface_getSparseTensorReaderDimSizes` instead. +void _mlir_ciface_copySparseTensorReaderDimSizes( void *p, StridedMemRefType *dref) { assert(p); + SparseTensorReader &reader = *static_cast(p); ASSERT_NO_STRIDE(dref); + const uint64_t dimRank = MEMREF_GET_USIZE(dref); + ASSERT_USIZE_EQ(dref, reader.getRank()); index_type *dimSizes = MEMREF_GET_PAYLOAD(dref); - SparseTensorReader &file = *static_cast(p); - const index_type *sizes = file.getDimSizes(); - index_type rank = file.getRank(); - for (index_type r = 0; r < rank; ++r) - dimSizes[r] = sizes[r]; + const index_type *fileSizes = reader.getDimSizes(); + for (uint64_t d = 0; d < dimRank; ++d) + dimSizes[d] = fileSizes[d]; +} + +void _mlir_ciface_getSparseTensorReaderDimSizes( + StridedMemRefType *out, void *p) { + assert(out && p); + SparseTensorReader &reader = *static_cast(p); + auto *dimSizes = const_cast(reader.getDimSizes()); + aliasIntoMemref(reader.getRank(), dimSizes, *out); } #define IMPL_GETNEXT(VNAME, V) \ @@ -591,16 +621,126 @@ void *p, StridedMemRefType *iref, \ StridedMemRefType *vref) { \ assert(p &&vref); \ + auto &reader = *static_cast(p); \ ASSERT_NO_STRIDE(iref); \ + const uint64_t rank = MEMREF_GET_USIZE(iref); \ index_type *indices = MEMREF_GET_PAYLOAD(iref); \ - SparseTensorReader *stfile = static_cast(p); \ - index_type rank = stfile->getRank(); \ V *value = MEMREF_GET_PAYLOAD(vref); \ - *value = stfile->readCOOElement(rank, indices); \ + *value = reader.readCOOElement(rank, indices); \ } MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETNEXT) #undef IMPL_GETNEXT +void *_mlir_ciface_newSparseTensorFromReader( + void *p, StridedMemRefType *lvlSizesRef, + StridedMemRefType *lvlTypesRef, + StridedMemRefType *lvl2dimRef, + StridedMemRefType *dim2lvlRef, OverheadType ptrTp, + OverheadType indTp, PrimaryType valTp) { + assert(p); + SparseTensorReader &reader = *static_cast(p); + ASSERT_NO_STRIDE(lvlSizesRef); + ASSERT_NO_STRIDE(lvlTypesRef); + ASSERT_NO_STRIDE(lvl2dimRef); + ASSERT_NO_STRIDE(dim2lvlRef); + const uint64_t dimRank = reader.getRank(); + const uint64_t lvlRank = MEMREF_GET_USIZE(lvlSizesRef); + ASSERT_USIZE_EQ(lvlTypesRef, lvlRank); + ASSERT_USIZE_EQ(lvl2dimRef, lvlRank); + ASSERT_USIZE_EQ(dim2lvlRef, dimRank); + const index_type *lvlSizes = MEMREF_GET_PAYLOAD(lvlSizesRef); + const DimLevelType *lvlTypes = MEMREF_GET_PAYLOAD(lvlTypesRef); + const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef); + const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef); + // + // FIXME(wrengr): Really need to define a separate x-macro for handling + // all this. (Or ideally some better, entirely-different approach) +#define CASE(p, i, v, P, I, V) \ + if (ptrTp == OverheadType::p && indTp == OverheadType::i && \ + valTp == PrimaryType::v) \ + return static_cast(reader.readSparseTensor( \ + lvlRank, lvlSizes, lvlTypes, lvl2dim, dim2lvl)); +#define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V) + // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases. + // This is safe because of the static_assert above. + if (ptrTp == OverheadType::kIndex) + ptrTp = OverheadType::kU64; + if (indTp == OverheadType::kIndex) + indTp = OverheadType::kU64; + // Double matrices with all combinations of overhead storage. + CASE(kU64, kU64, kF64, uint64_t, uint64_t, double); + CASE(kU64, kU32, kF64, uint64_t, uint32_t, double); + CASE(kU64, kU16, kF64, uint64_t, uint16_t, double); + CASE(kU64, kU8, kF64, uint64_t, uint8_t, double); + CASE(kU32, kU64, kF64, uint32_t, uint64_t, double); + CASE(kU32, kU32, kF64, uint32_t, uint32_t, double); + CASE(kU32, kU16, kF64, uint32_t, uint16_t, double); + CASE(kU32, kU8, kF64, uint32_t, uint8_t, double); + CASE(kU16, kU64, kF64, uint16_t, uint64_t, double); + CASE(kU16, kU32, kF64, uint16_t, uint32_t, double); + CASE(kU16, kU16, kF64, uint16_t, uint16_t, double); + CASE(kU16, kU8, kF64, uint16_t, uint8_t, double); + CASE(kU8, kU64, kF64, uint8_t, uint64_t, double); + CASE(kU8, kU32, kF64, uint8_t, uint32_t, double); + CASE(kU8, kU16, kF64, uint8_t, uint16_t, double); + CASE(kU8, kU8, kF64, uint8_t, uint8_t, double); + // Float matrices with all combinations of overhead storage. + CASE(kU64, kU64, kF32, uint64_t, uint64_t, float); + CASE(kU64, kU32, kF32, uint64_t, uint32_t, float); + CASE(kU64, kU16, kF32, uint64_t, uint16_t, float); + CASE(kU64, kU8, kF32, uint64_t, uint8_t, float); + CASE(kU32, kU64, kF32, uint32_t, uint64_t, float); + CASE(kU32, kU32, kF32, uint32_t, uint32_t, float); + CASE(kU32, kU16, kF32, uint32_t, uint16_t, float); + CASE(kU32, kU8, kF32, uint32_t, uint8_t, float); + CASE(kU16, kU64, kF32, uint16_t, uint64_t, float); + CASE(kU16, kU32, kF32, uint16_t, uint32_t, float); + CASE(kU16, kU16, kF32, uint16_t, uint16_t, float); + CASE(kU16, kU8, kF32, uint16_t, uint8_t, float); + CASE(kU8, kU64, kF32, uint8_t, uint64_t, float); + CASE(kU8, kU32, kF32, uint8_t, uint32_t, float); + CASE(kU8, kU16, kF32, uint8_t, uint16_t, float); + CASE(kU8, kU8, kF32, uint8_t, uint8_t, float); + // Two-byte floats with both overheads of the same type. + CASE_SECSAME(kU64, kF16, uint64_t, f16); + CASE_SECSAME(kU64, kBF16, uint64_t, bf16); + CASE_SECSAME(kU32, kF16, uint32_t, f16); + CASE_SECSAME(kU32, kBF16, uint32_t, bf16); + CASE_SECSAME(kU16, kF16, uint16_t, f16); + CASE_SECSAME(kU16, kBF16, uint16_t, bf16); + CASE_SECSAME(kU8, kF16, uint8_t, f16); + CASE_SECSAME(kU8, kBF16, uint8_t, bf16); + // Integral matrices with both overheads of the same type. + CASE_SECSAME(kU64, kI64, uint64_t, int64_t); + CASE_SECSAME(kU64, kI32, uint64_t, int32_t); + CASE_SECSAME(kU64, kI16, uint64_t, int16_t); + CASE_SECSAME(kU64, kI8, uint64_t, int8_t); + CASE_SECSAME(kU32, kI64, uint32_t, int64_t); + CASE_SECSAME(kU32, kI32, uint32_t, int32_t); + CASE_SECSAME(kU32, kI16, uint32_t, int16_t); + CASE_SECSAME(kU32, kI8, uint32_t, int8_t); + CASE_SECSAME(kU16, kI64, uint16_t, int64_t); + CASE_SECSAME(kU16, kI32, uint16_t, int32_t); + CASE_SECSAME(kU16, kI16, uint16_t, int16_t); + CASE_SECSAME(kU16, kI8, uint16_t, int8_t); + CASE_SECSAME(kU8, kI64, uint8_t, int64_t); + CASE_SECSAME(kU8, kI32, uint8_t, int32_t); + CASE_SECSAME(kU8, kI16, uint8_t, int16_t); + CASE_SECSAME(kU8, kI8, uint8_t, int8_t); + // Complex matrices with wide overhead. + CASE_SECSAME(kU64, kC64, uint64_t, complex64); + CASE_SECSAME(kU64, kC32, uint64_t, complex32); + + // Unsupported case (add above if needed). + // TODO: better pretty-printing of enum values! + MLIR_SPARSETENSOR_FATAL( + "unsupported combination of types: \n", + static_cast(ptrTp), static_cast(indTp), + static_cast(valTp)); +#undef CASE_SECSAME +#undef CASE +} + void _mlir_ciface_outSparseTensorWriterMetaData( void *p, index_type rank, index_type nnz, StridedMemRefType *dref) { @@ -686,14 +826,14 @@ void readSparseTensorShape(char *filename, std::vector *out) { assert(out && "Received nullptr for out-parameter"); - SparseTensorReader stfile(filename); - stfile.openFile(); - stfile.readHeader(); - stfile.closeFile(); - const uint64_t rank = stfile.getRank(); - const uint64_t *dimSizes = stfile.getDimSizes(); - out->reserve(rank); - out->assign(dimSizes, dimSizes + rank); + SparseTensorReader reader(filename); + reader.openFile(); + reader.readHeader(); + reader.closeFile(); + const uint64_t dimRank = reader.getRank(); + const uint64_t *dimSizes = reader.getDimSizes(); + out->reserve(dimRank); + out->assign(dimSizes, dimSizes + dimRank); } // We can't use `static_cast` here because `DimLevelType` is an enum-class. @@ -718,11 +858,13 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_CONVERTFROMMLIRSPARSETENSOR) #undef IMPL_CONVERTFROMMLIRSPARSETENSOR +// FIXME: update `SparseTensorCodegenPass` to use +// `_mlir_ciface_createCheckedSparseTensorReader` instead. void *createSparseTensorReader(char *filename) { - SparseTensorReader *stfile = new SparseTensorReader(filename); - stfile->openFile(); - stfile->readHeader(); - return static_cast(stfile); + SparseTensorReader *reader = new SparseTensorReader(filename); + reader->openFile(); + reader->readHeader(); + return static_cast(reader); } index_type getSparseTensorReaderRank(void *p) { diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir --- a/mlir/test/Dialect/SparseTensor/conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion.mlir @@ -77,16 +77,15 @@ // CHECK-LABEL: func @sparse_new1d( // CHECK-SAME: %[[A:.*]]: !llvm.ptr) -> !llvm.ptr -// CHECK-DAG: %[[FromFile:.*]] = arith.constant 1 : i32 -// CHECK-DAG: %[[DimSizes0:.*]] = memref.alloca() : memref<1xindex> -// CHECK-DAG: %[[LvlSizes0:.*]] = memref.alloca() : memref<1xindex> -// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi8> +// CHECK-DAG: %[[DimShape0:.*]] = memref.alloca() : memref<1xindex> +// CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<1xindex> to memref +// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}}) // CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<1xindex> -// CHECK-DAG: %[[DimSizes:.*]] = memref.cast %[[DimSizes0]] : memref<1xindex> to memref -// CHECK-DAG: %[[LvlSizes:.*]] = memref.cast %[[LvlSizes0]] : memref<1xindex> to memref -// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi8> to memref // CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<1xindex> to memref -// CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[LvlSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromFile]], %[[A]]) +// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi8> +// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi8> to memref +// CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[DimShape]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}) +// CHECK: call @delSparseTensorReader(%[[Reader]]) // CHECK: return %[[T]] : !llvm.ptr func.func @sparse_new1d(%arg0: !llvm.ptr) -> tensor<128xf64, #SparseVector> { %0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor<128xf64, #SparseVector> @@ -95,16 +94,16 @@ // CHECK-LABEL: func @sparse_new2d( // CHECK-SAME: %[[A:.*]]: !llvm.ptr) -> !llvm.ptr -// CHECK-DAG: %[[FromFile:.*]] = arith.constant 1 : i32 -// CHECK-DAG: %[[DimSizes0:.*]] = memref.alloca() : memref<2xindex> -// CHECK-DAG: %[[LvlSizes0:.*]] = memref.alloca() : memref<2xindex> -// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8> +// CHECK-DAG: %[[DimShape0:.*]] = memref.alloca() : memref<2xindex> +// CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<2xindex> to memref +// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}}) +// CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]]) // CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex> -// CHECK-DAG: %[[DimSizes:.*]] = memref.cast %[[DimSizes0]] : memref<2xindex> to memref -// CHECK-DAG: %[[LvlSizes:.*]] = memref.cast %[[LvlSizes0]] : memref<2xindex> to memref -// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref // CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref -// CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[LvlSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromFile]], %[[A]]) +// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8> +// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref +// CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[DimSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}) +// CHECK: call @delSparseTensorReader(%[[Reader]]) // CHECK: return %[[T]] : !llvm.ptr func.func @sparse_new2d(%arg0: !llvm.ptr) -> tensor { %0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor @@ -113,18 +112,20 @@ // CHECK-LABEL: func @sparse_new3d( // CHECK-SAME: %[[A:.*]]: !llvm.ptr) -> !llvm.ptr -// CHECK-DAG: %[[FromFile:.*]] = arith.constant 1 : i32 -// CHECK-DAG: %[[DimSizes0:.*]] = memref.alloca() : memref<3xindex> +// CHECK-DAG: %[[DimShape0:.*]] = memref.alloca() : memref<3xindex> +// CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<3xindex> to memref +// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}}) +// CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]]) // CHECK-DAG: %[[LvlSizes0:.*]] = memref.alloca() : memref<3xindex> -// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8> -// CHECK-DAG: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex> -// CHECK-DAG: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex> -// CHECK-DAG: %[[DimSizes:.*]] = memref.cast %[[DimSizes0]] : memref<3xindex> to memref // CHECK-DAG: %[[LvlSizes:.*]] = memref.cast %[[LvlSizes0]] : memref<3xindex> to memref -// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref +// CHECK-DAG: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex> // CHECK-DAG: %[[Lvl2Dim:.*]] = memref.cast %[[Lvl2Dim0]] : memref<3xindex> to memref +// CHECK-DAG: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex> // CHECK-DAG: %[[Dim2Lvl:.*]] = memref.cast %[[Dim2Lvl0]] : memref<3xindex> to memref -// CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[LvlSizes]], %[[LvlTypes]], %[[Lvl2Dim]], %[[Dim2Lvl]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromFile]], %[[A]]) +// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8> +// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref +// CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[LvlSizes]], %[[LvlTypes]], %[[Lvl2Dim]], %[[Dim2Lvl]], %{{.*}}, %{{.*}}, %{{.*}}) +// CHECK: call @delSparseTensorReader(%[[Reader]]) // CHECK: return %[[T]] : !llvm.ptr func.func @sparse_new3d(%arg0: !llvm.ptr) -> tensor { %0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor diff --git a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir --- a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir @@ -12,7 +12,7 @@ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[R:.*]] = call @createSparseTensorReader(%[[A]]) // CHECK: %[[DS:.*]] = memref.alloca(%[[C2]]) : memref -// CHECK: call @getSparseTensorReaderDimSizes(%[[R]], %[[DS]]) +// CHECK: call @copySparseTensorReaderDimSizes(%[[R]], %[[DS]]) // CHECK: %[[D0:.*]] = memref.load %[[DS]]{{\[}}%[[C0]]] // CHECK: %[[D1:.*]] = memref.load %[[DS]]{{\[}}%[[C1]]] // CHECK: %[[T:.*]] = bufferization.alloc_tensor(%[[D0]], %[[D1]]) diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_file_io.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_file_io.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_file_io.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_file_io.mlir @@ -26,7 +26,7 @@ func.func private @getSparseTensorReaderRank(!TensorReader) -> (index) func.func private @getSparseTensorReaderNNZ(!TensorReader) -> (index) func.func private @getSparseTensorReaderIsSymmetric(!TensorReader) -> (i1) - func.func private @getSparseTensorReaderDimSizes(!TensorReader, + func.func private @copySparseTensorReaderDimSizes(!TensorReader, memref) -> () attributes { llvm.emit_c_interface } func.func private @getSparseTensorReaderNextF32(!TensorReader, memref, memref) -> () attributes { llvm.emit_c_interface } @@ -98,7 +98,7 @@ : (!TensorReader) -> i1 vector.print %symmetric : i1 %dimSizes = memref.alloc(%rank) : memref - func.call @getSparseTensorReaderDimSizes(%tensor, %dimSizes) + func.call @copySparseTensorReaderDimSizes(%tensor, %dimSizes) : (!TensorReader, memref) -> () call @dumpi(%dimSizes) : (memref) -> () %x0s, %x1s, %vs = call @readTensorFile(%tensor) @@ -132,7 +132,7 @@ %rank = call @getSparseTensorReaderRank(%tensor0) : (!TensorReader) -> index %nnz = call @getSparseTensorReaderNNZ(%tensor0) : (!TensorReader) -> index %dimSizes = memref.alloc(%rank) : memref - func.call @getSparseTensorReaderDimSizes(%tensor0,%dimSizes) + func.call @copySparseTensorReaderDimSizes(%tensor0, %dimSizes) : (!TensorReader, memref) -> () call @outSparseTensorWriterMetaData(%tensor1, %rank, %nnz, %dimSizes) : (!TensorWriter, index, index, memref) -> ()