diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h @@ -240,8 +240,9 @@ /// Returns the next element for the sparse tensor being read. #define IMPL_GETNEXT(VNAME, V) \ - MLIR_CRUNNERUTILS_EXPORT V _mlir_ciface_getSparseTensorReaderNext##VNAME( \ - void *p, StridedMemRefType *iref); + MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_getSparseTensorReaderNext##VNAME( \ + void *p, StridedMemRefType *iref, \ + StridedMemRefType *vref); MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETNEXT) #undef IMPL_GETNEXT @@ -266,7 +267,7 @@ #define IMPL_OUTNEXT(VNAME, V) \ MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_outSparseTensorWriterNext##VNAME( \ void *p, index_type rank, StridedMemRefType *iref, \ - V value); + StridedMemRefType *vref); MLIR_SPARSETENSOR_FOREVERY_V(IMPL_OUTNEXT) #undef IMPL_OUTNEXT diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp @@ -618,9 +618,10 @@ } #define IMPL_GETNEXT(VNAME, V) \ - V _mlir_ciface_getSparseTensorReaderNext##VNAME( \ - void *p, StridedMemRefType *iref) { \ - assert(p &&iref); \ + void _mlir_ciface_getSparseTensorReaderNext##VNAME( \ + void *p, StridedMemRefType *iref, \ + StridedMemRefType *vref) { \ + assert(p &&iref &&vref); \ assert(iref->strides[0] == 1); \ index_type *indices = iref->data + iref->offset; \ SparseTensorReader *stfile = static_cast(p); \ @@ -630,7 +631,8 @@ uint64_t idx = strtoul(linePtr, &linePtr, 10); \ indices[r] = idx - 1; \ } \ - return detail::readCOOValue(&linePtr, stfile->isPattern()); \ + V *value = vref->data + vref->offset; \ + *value = detail::readCOOValue(&linePtr, stfile->isPattern()); \ } MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETNEXT) #undef IMPL_GETNEXT @@ -667,14 +669,15 @@ #define IMPL_OUTNEXT(VNAME, V) \ void _mlir_ciface_outSparseTensorWriterNext##VNAME( \ void *p, index_type rank, StridedMemRefType *iref, \ - V value) { \ - assert(p &&iref); \ + StridedMemRefType *vref) { \ + assert(p &&iref &&vref); \ assert(iref->strides[0] == 1); \ index_type *indices = iref->data + iref->offset; \ SparseTensorWriter &file = *static_cast(p); \ for (uint64_t r = 0; r < rank; ++r) \ file << (indices[r] + 1) << " "; \ - file << value << std::endl; \ + V *value = vref->data + vref->offset; \ + file << *value << std::endl; \ } MLIR_SPARSETENSOR_FOREVERY_V(IMPL_OUTNEXT) #undef IMPL_OUTNEXT diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_file_io.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_file_io.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_file_io.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_file_io.mlir @@ -19,17 +19,17 @@ func.func private @getSparseTensorReaderRank(!TensorReader) -> (index) func.func private @getSparseTensorReaderNNZ(!TensorReader) -> (index) func.func private @getSparseTensorReaderIsSymmetric(!TensorReader) -> (i1) - func.func private @getSparseTensorReaderDimSizes(!TensorReader, memref) - -> () attributes { llvm.emit_c_interface } - func.func private @getSparseTensorReaderNextF32(!TensorReader, memref) - -> (f32) attributes { llvm.emit_c_interface } + func.func private @getSparseTensorReaderDimSizes(!TensorReader, + memref) -> () attributes { llvm.emit_c_interface } + func.func private @getSparseTensorReaderNextF32(!TensorReader, + memref, memref) -> () attributes { llvm.emit_c_interface } func.func private @createSparseTensorWriter(!Filename) -> (!TensorWriter) func.func private @delSparseTensorWriter(!TensorWriter) func.func private @outSparseTensorWriterMetaData(!TensorWriter, index, index, memref) -> () attributes { llvm.emit_c_interface } func.func private @outSparseTensorWriterNextF32(!TensorWriter, index, - memref, f32) -> () attributes { llvm.emit_c_interface } + memref, memref) -> () attributes { llvm.emit_c_interface } func.func @dumpi(%arg0: memref) { %c0 = arith.constant 0 : index @@ -60,9 +60,13 @@ %x1s = memref.alloc(%nnz) : memref %vs = memref.alloc(%nnz) : memref %indices = memref.alloc(%rank) : memref + %value = memref.alloca() : memref scf.for %i = %c0 to %nnz step %c1 { - %v = func.call @getSparseTensorReaderNextF32(%tensor, %indices) - : (!TensorReader, memref) -> f32 + func.call @getSparseTensorReaderNextF32(%tensor, %indices, %value) + : (!TensorReader, memref, memref) -> () + // TODO: can we use memref.subview to avoid the need for the %value + // buffer? + %v = memref.load %value[] : memref memref.store %v, %vs[%i] : memref %i0 = memref.load %indices[%c0] : memref memref.store %i0, %x0s[%i] : memref @@ -129,11 +133,12 @@ //TODO: handle isSymmetric. // Assume rank == 2. %indices = memref.alloc(%rank) : memref + %value = memref.alloca() : memref scf.for %i = %c0 to %nnz step %c1 { - %v = func.call @getSparseTensorReaderNextF32(%tensor0, %indices) - : (!TensorReader, memref) -> f32 - func.call @outSparseTensorWriterNextF32(%tensor1, %rank, %indices, %v) - : (!TensorWriter, index, memref, f32) -> () + func.call @getSparseTensorReaderNextF32(%tensor0, %indices, %value) + : (!TensorReader, memref, memref) -> () + func.call @outSparseTensorWriterNextF32(%tensor1, %rank, %indices, %value) + : (!TensorWriter, index, memref, memref) -> () } // Release the resources.