diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h @@ -95,7 +95,9 @@ kUndefined = 5 }; - explicit SparseTensorReader(const char *filename) : filename(filename) { + explicit SparseTensorReader(const char *filename) + : filename(filename), hasPendingSymmetricValue_(false), + pendingSymmetricValue_(nullptr), nnzCount_(0) { assert(filename && "Received nullptr for filename"); } @@ -106,7 +108,11 @@ // This dtor tries to avoid leaking the `file`. (Though it's better // to call `closeFile` explicitly when possible, since there are // circumstances where dtors are not called reliably.) - ~SparseTensorReader() { closeFile(); } + ~SparseTensorReader() { + closeFile(); + if (pendingSymmetricValue_) + delete pendingSymmetricValue_; + } /// Opens the file for reading. void openFile(); @@ -168,6 +174,35 @@ /// valid after parsing the header. void assertMatchesShape(uint64_t rank, const uint64_t *shape) const; + /// A wrapper of readCOOElement with the ability to buffer the value in the + /// symmetric structure for the next call of the routine. Returns false when + /// there is no more value to return. + template + bool readCOOElementHandleSymmetric(uint64_t rank, uint64_t *indices, + V *value) { + if (hasPendingSymmetricValue_) { + hasPendingSymmetricValue_ = false; + indices[0] = pendingSymmetricIndices_[0]; + indices[1] = pendingSymmetricIndices_[1]; + *value = *static_cast(pendingSymmetricValue_); + return true; + } + + if (nnzCount_++ == getNNZ()) + return false; + + *value = readCOOElement(rank, indices); + if (isSymmetric() && indices[0] != indices[1]) { + hasPendingSymmetricValue_ = true; + if (!pendingSymmetricValue_) + pendingSymmetricValue_ = static_cast(new V); + *static_cast(pendingSymmetricValue_) = *value; + pendingSymmetricIndices_[0] = indices[1]; + pendingSymmetricIndices_[1] = indices[0]; + } + return true; + } + /// Reads a sparse tensor element from the next line in the input file and /// returns the value of the element. Stores the coordinates of the element /// to the `indices` array. @@ -205,6 +240,10 @@ bool isSymmetric_ = false; uint64_t idata[512]; char line[kColWidth]; + bool hasPendingSymmetricValue_; + void *pendingSymmetricValue_; + uint64_t pendingSymmetricIndices_[2]; + uint64_t nnzCount_; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h @@ -245,7 +245,7 @@ /// Returns the next element for the sparse tensor being read. #define DECL_GETNEXT(VNAME, V) \ - MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_getSparseTensorReaderNext##VNAME( \ + MLIR_CRUNNERUTILS_EXPORT bool _mlir_ciface_getSparseTensorReaderNext##VNAME( \ void *p, StridedMemRefType *iref, \ StridedMemRefType *vref); MLIR_SPARSETENSOR_FOREVERY_V(DECL_GETNEXT) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -898,67 +898,49 @@ // Implement the NewOp as follows: // %tmp = bufferization.alloc_tensor : an unordered COO with identity // storage ordering - // for i = 0 to nnz - // get the next element from the input file + // while getSparseTensorReaderNext // insert the element to %tmp // %t = sparse_tensor.ConvertOp %tmp RankedTensorType cooTp = getUnorderedCOOFromType(dstTp); auto cooBuffer = rewriter.create(loc, cooTp, dynSizesArray).getResult(); - Value c0 = constantIndex(rewriter, loc, 0); - Value c1 = constantIndex(rewriter, loc, 1); - Value nnz = createFuncCall(rewriter, loc, "getSparseTensorReaderNNZ", - {indexTp}, {reader}, EmitCInterface::Off) - .getResult(0); - Value symmetric; - // We assume only rank 2 tensors may have the isSymmetric flag set. - if (rank == 2) { - symmetric = - createFuncCall(rewriter, loc, "getSparseTensorReaderIsSymmetric", - {rewriter.getI1Type()}, {reader}, EmitCInterface::Off) - .getResult(0); - } else { - symmetric = Value(); - } Type eltTp = dstTp.getElementType(); Value value = genAllocaScalar(rewriter, loc, eltTp); - scf::ForOp forOp = rewriter.create(loc, c0, nnz, c1, - ArrayRef(cooBuffer)); - rewriter.setInsertionPointToStart(forOp.getBody()); - + SmallVector types{cooBuffer.getType()}; + scf::WhileOp whileOp = + rewriter.create(loc, types, ValueRange{cooBuffer}); + + // The before-region of the WhileOp. + Block *before = + rewriter.createBlock(&whileOp.getBefore(), {}, types, {loc}); + rewriter.setInsertionPointToEnd(before); SmallString<18> getNextFuncName{"getSparseTensorReaderNext", primaryTypeFunctionSuffix(eltTp)}; Value indices = dimSizes; // Reuse the indices memref to store indices. - createFuncCall(rewriter, loc, getNextFuncName, {}, {reader, indices, value}, - EmitCInterface::On); + Value result = + createFuncCall(rewriter, loc, getNextFuncName, {rewriter.getI1Type()}, + {reader, indices, value}, EmitCInterface::On) + .getResult(0); + rewriter.create(loc, result, before->getArguments()); + + // The after-region of the WhileOp. + Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, types, {loc}); + rewriter.setInsertionPointToEnd(after); + SmallVector indicesArray; for (uint64_t i = 0; i < rank; i++) { indicesArray.push_back(rewriter.create( loc, indices, constantIndex(rewriter, loc, i))); } Value v = rewriter.create(loc, value); - Value t = rewriter.create(loc, v, forOp.getRegionIterArg(0), - indicesArray); - if (symmetric) { - Value eq = rewriter.create( - loc, arith::CmpIPredicate::ne, indicesArray[0], indicesArray[1]); - Value cond = rewriter.create(loc, symmetric, eq); - scf::IfOp ifOp = - rewriter.create(loc, t.getType(), cond, /*else*/ true); - rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); - rewriter.create( - loc, Value(rewriter.create( - loc, v, t, ValueRange{indicesArray[1], indicesArray[0]}))); - rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); - rewriter.create(loc, t); - t = ifOp.getResult(0); - rewriter.setInsertionPointAfter(ifOp); - } + Value t = + rewriter.create(loc, v, after->getArgument(0), indicesArray); + rewriter.create(loc, ArrayRef(t)); - rewriter.setInsertionPointAfter(forOp); + rewriter.setInsertionPointAfter(whileOp); // Link SSA chain. - cooBuffer = forOp.getResult(0); + cooBuffer = whileOp.getResult(0); // Release the sparse tensor reader. createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader}, diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp @@ -587,7 +587,7 @@ } #define IMPL_GETNEXT(VNAME, V) \ - void _mlir_ciface_getSparseTensorReaderNext##VNAME( \ + bool _mlir_ciface_getSparseTensorReaderNext##VNAME( \ void *p, StridedMemRefType *iref, \ StridedMemRefType *vref) { \ assert(p &&vref); \ @@ -596,7 +596,7 @@ SparseTensorReader *stfile = static_cast(p); \ index_type rank = stfile->getRank(); \ V *value = MEMREF_GET_PAYLOAD(vref); \ - *value = stfile->readCOOElement(rank, indices); \ + return stfile->readCOOElementHandleSymmetric(rank, indices, value); \ } MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETNEXT) #undef IMPL_GETNEXT diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_file_io.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_file_io.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_file_io.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_file_io.mlir @@ -29,7 +29,7 @@ func.func private @getSparseTensorReaderDimSizes(!TensorReader, memref) -> () attributes { llvm.emit_c_interface } func.func private @getSparseTensorReaderNextF32(!TensorReader, - memref, memref) -> () attributes { llvm.emit_c_interface } + memref, memref) -> (i1) attributes { llvm.emit_c_interface } func.func private @createSparseTensorWriter(!Filename) -> (!TensorWriter) func.func private @delSparseTensorWriter(!TensorWriter) @@ -68,9 +68,12 @@ %vs = memref.alloc(%nnz) : memref %indices = memref.alloc(%rank) : memref %value = memref.alloca() : memref - scf.for %i = %c0 to %nnz step %c1 { - func.call @getSparseTensorReaderNextF32(%tensor, %indices, %value) - : (!TensorReader, memref, memref) -> () + %count = scf.while (%i = %c0) : (index) -> (index) { + %r = func.call @getSparseTensorReaderNextF32(%tensor, %indices, %value) + : (!TensorReader, memref, memref) -> (i1) + scf.condition(%r) %i : index + } do { + ^bb0(%i : index): // TODO: can we use memref.subview to avoid the need for the %value // buffer? %v = memref.load %value[] : memref @@ -79,6 +82,8 @@ memref.store %i0, %x0s[%i] : memref %i1 = memref.load %indices[%c1] : memref memref.store %i1, %x1s[%i] : memref + %n = arith.addi %i, %c1 : index + scf.yield %n : index } // Release the resource for the indices. @@ -137,15 +142,17 @@ call @outSparseTensorWriterMetaData(%tensor1, %rank, %nnz, %dimSizes) : (!TensorWriter, index, index, memref) -> () - //TODO: handle isSymmetric. // Assume rank == 2. %indices = memref.alloc(%rank) : memref %value = memref.alloca() : memref - scf.for %i = %c0 to %nnz step %c1 { - func.call @getSparseTensorReaderNextF32(%tensor0, %indices, %value) - : (!TensorReader, memref, memref) -> () + scf.while : () -> () { + %r = func.call @getSparseTensorReaderNextF32(%tensor0, %indices, %value) + : (!TensorReader, memref, memref) -> (i1) + scf.condition(%r) + } do { func.call @outSparseTensorWriterNextF32(%tensor1, %rank, %indices, %value) : (!TensorWriter, index, memref, memref) -> () + scf.yield } // Release the resources.