diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp @@ -562,6 +562,63 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT) #undef IMPL_EXPINSERT +void _mlir_ciface_getSparseTensorReaderDimSizes( + void *p, StridedMemRefType *dref) { + assert(p); + ASSERT_NO_STRIDE(dref); + index_type *dimSizes = MEMREF_GET_PAYLOAD(dref); + SparseTensorReader &file = *static_cast(p); + const index_type *sizes = file.getDimSizes(); + index_type rank = file.getRank(); + for (uint64_t r = 0; r < rank; ++r) + dimSizes[r] = sizes[r]; +} + +#define IMPL_GETNEXT(VNAME, V) \ + void _mlir_ciface_getSparseTensorReaderNext##VNAME( \ + void *p, StridedMemRefType *iref, \ + StridedMemRefType *vref) { \ + assert(p &&vref); \ + ASSERT_NO_STRIDE(iref); \ + index_type *indices = MEMREF_GET_PAYLOAD(iref); \ + SparseTensorReader *stfile = static_cast(p); \ + index_type rank = stfile->getRank(); \ + V *value = MEMREF_GET_PAYLOAD(vref); \ + *value = stfile->readCOOElement(rank, indices); \ + } +MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETNEXT) +#undef IMPL_GETNEXT + +void _mlir_ciface_outSparseTensorWriterMetaData( + void *p, index_type rank, index_type nnz, + StridedMemRefType *dref) { + assert(p); + ASSERT_NO_STRIDE(dref); + assert(rank != 0); + index_type *dimSizes = MEMREF_GET_PAYLOAD(dref); + SparseTensorWriter &file = *static_cast(p); + file << rank << " " << nnz << std::endl; + for (index_type r = 0; r < rank - 1; ++r) + file << dimSizes[r] << " "; + file << dimSizes[rank - 1] << std::endl; +} + +#define IMPL_OUTNEXT(VNAME, V) \ + void _mlir_ciface_outSparseTensorWriterNext##VNAME( \ + void *p, index_type rank, StridedMemRefType *iref, \ + StridedMemRefType *vref) { \ + assert(p &&vref); \ + ASSERT_NO_STRIDE(iref); \ + index_type *indices = MEMREF_GET_PAYLOAD(iref); \ + SparseTensorWriter &file = *static_cast(p); \ + for (uint64_t r = 0; r < rank; ++r) \ + file << (indices[r] + 1) << " "; \ + V *value = MEMREF_GET_PAYLOAD(vref); \ + file << *value << std::endl; \ + } +MLIR_SPARSETENSOR_FOREVERY_V(IMPL_OUTNEXT) +#undef IMPL_OUTNEXT + //===----------------------------------------------------------------------===// // // Public functions which accept only C-style data structures to interact @@ -672,37 +729,10 @@ return static_cast(p)->getDimSize(d); } -void _mlir_ciface_getSparseTensorReaderDimSizes( - void *p, StridedMemRefType *dref) { - assert(p); - ASSERT_NO_STRIDE(dref); - index_type *dimSizes = MEMREF_GET_PAYLOAD(dref); - SparseTensorReader &file = *static_cast(p); - const index_type *sizes = file.getDimSizes(); - index_type rank = file.getRank(); - for (uint64_t r = 0; r < rank; ++r) - dimSizes[r] = sizes[r]; -} - void delSparseTensorReader(void *p) { delete static_cast(p); } -#define IMPL_GETNEXT(VNAME, V) \ - void _mlir_ciface_getSparseTensorReaderNext##VNAME( \ - void *p, StridedMemRefType *iref, \ - StridedMemRefType *vref) { \ - assert(p &&vref); \ - ASSERT_NO_STRIDE(iref); \ - index_type *indices = MEMREF_GET_PAYLOAD(iref); \ - SparseTensorReader *stfile = static_cast(p); \ - index_type rank = stfile->getRank(); \ - V *value = MEMREF_GET_PAYLOAD(vref); \ - *value = stfile->readCOOElement(rank, indices); \ - } -MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETNEXT) -#undef IMPL_GETNEXT - void *createSparseTensorWriter(char *filename) { SparseTensorWriter *file = (filename[0] == 0) ? &std::cout : new std::ofstream(filename); @@ -718,36 +748,6 @@ delete file; } -void _mlir_ciface_outSparseTensorWriterMetaData( - void *p, index_type rank, index_type nnz, - StridedMemRefType *dref) { - assert(p); - ASSERT_NO_STRIDE(dref); - assert(rank != 0); - index_type *dimSizes = MEMREF_GET_PAYLOAD(dref); - SparseTensorWriter &file = *static_cast(p); - file << rank << " " << nnz << std::endl; - for (index_type r = 0; r < rank - 1; ++r) - file << dimSizes[r] << " "; - file << dimSizes[rank - 1] << std::endl; -} - -#define IMPL_OUTNEXT(VNAME, V) \ - void _mlir_ciface_outSparseTensorWriterNext##VNAME( \ - void *p, index_type rank, StridedMemRefType *iref, \ - StridedMemRefType *vref) { \ - assert(p &&vref); \ - ASSERT_NO_STRIDE(iref); \ - index_type *indices = MEMREF_GET_PAYLOAD(iref); \ - SparseTensorWriter &file = *static_cast(p); \ - for (uint64_t r = 0; r < rank; ++r) \ - file << (indices[r] + 1) << " "; \ - V *value = MEMREF_GET_PAYLOAD(vref); \ - file << *value << std::endl; \ - } -MLIR_SPARSETENSOR_FOREVERY_V(IMPL_OUTNEXT) -#undef IMPL_OUTNEXT - } // extern "C" #undef MEMREF_GET_PAYLOAD